1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28fn serialize_user_content<S>(
31 content: &OneOrMany<UserContent>,
32 serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35 S: Serializer,
36{
37 if content.len() == 1
38 && let UserContent::Text { text } = content.first_ref()
39 {
40 return serializer.serialize_str(text);
41 }
42 content.serialize(serializer)
43}
44
45pub const GPT_5_2: &str = "gpt-5.2";
47
48pub const GPT_5_1: &str = "gpt-5.1";
50
51pub const GPT_5: &str = "gpt-5";
53pub const GPT_5_MINI: &str = "gpt-5-mini";
55pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64pub const GPT_4O: &str = "gpt-4o";
66pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84pub const GPT_4: &str = "gpt-4";
86pub const GPT_4_0613: &str = "gpt-4-0613";
88pub const GPT_4_32K: &str = "gpt-4-32k";
90pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95pub const O4_MINI: &str = "o4-mini";
97pub const O3: &str = "o3";
99pub const O3_MINI: &str = "o3-mini";
101pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103pub const O1_PRO: &str = "o1-pro";
105pub const O1: &str = "o1";
107pub const O1_2024_12_17: &str = "o1-2024-12-17";
109pub const O1_PREVIEW: &str = "o1-preview";
111pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113pub const O1_MINI: &str = "o1-mini";
115pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128 fn from(err: ApiErrorResponse) -> Self {
129 CompletionError::ProviderError(err.message)
130 }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136 #[serde(alias = "developer")]
137 System {
138 #[serde(deserialize_with = "string_or_one_or_many")]
139 content: OneOrMany<SystemContent>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 name: Option<String>,
142 },
143 User {
144 #[serde(
145 deserialize_with = "string_or_one_or_many",
146 serialize_with = "serialize_user_content"
147 )]
148 content: OneOrMany<UserContent>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 name: Option<String>,
151 },
152 Assistant {
153 #[serde(
154 default,
155 deserialize_with = "json_utils::string_or_vec",
156 skip_serializing_if = "Vec::is_empty",
157 serialize_with = "serialize_assistant_content_vec"
158 )]
159 content: Vec<AssistantContent>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 refusal: Option<String>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 audio: Option<AudioAssistant>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 name: Option<String>,
166 #[serde(
167 default,
168 deserialize_with = "json_utils::null_or_vec",
169 skip_serializing_if = "Vec::is_empty"
170 )]
171 tool_calls: Vec<ToolCall>,
172 },
173 #[serde(rename = "tool")]
174 ToolResult {
175 tool_call_id: String,
176 content: ToolResultContentValue,
177 },
178}
179
180impl Message {
181 pub fn system(content: &str) -> Self {
182 Message::System {
183 content: OneOrMany::one(content.to_owned().into()),
184 name: None,
185 }
186 }
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190pub struct AudioAssistant {
191 pub id: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct SystemContent {
196 #[serde(default)]
197 pub r#type: SystemContentType,
198 pub text: String,
199}
200
201#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
202#[serde(rename_all = "lowercase")]
203pub enum SystemContentType {
204 #[default]
205 Text,
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum AssistantContent {
211 Text { text: String },
212 Refusal { refusal: String },
213}
214
215impl From<AssistantContent> for completion::AssistantContent {
216 fn from(value: AssistantContent) -> Self {
217 match value {
218 AssistantContent::Text { text } => completion::AssistantContent::text(text),
219 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
220 }
221 }
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum UserContent {
227 Text {
228 text: String,
229 },
230 #[serde(rename = "image_url")]
231 Image {
232 image_url: ImageUrl,
233 },
234 Audio {
235 input_audio: InputAudio,
236 },
237}
238
239#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
240pub struct ImageUrl {
241 pub url: String,
242 #[serde(default)]
243 pub detail: ImageDetail,
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct InputAudio {
248 pub data: String,
249 pub format: AudioMediaType,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct ToolResultContent {
254 #[serde(default)]
255 r#type: ToolResultContentType,
256 pub text: String,
257}
258
259#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolResultContentType {
262 #[default]
263 Text,
264}
265
266impl FromStr for ToolResultContent {
267 type Err = Infallible;
268
269 fn from_str(s: &str) -> Result<Self, Self::Err> {
270 Ok(s.to_owned().into())
271 }
272}
273
274impl From<String> for ToolResultContent {
275 fn from(s: String) -> Self {
276 ToolResultContent {
277 r#type: ToolResultContentType::default(),
278 text: s,
279 }
280 }
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
284#[serde(untagged)]
285pub enum ToolResultContentValue {
286 Array(Vec<ToolResultContent>),
287 String(String),
288}
289
290impl ToolResultContentValue {
291 pub fn from_string(s: String, use_array_format: bool) -> Self {
292 if use_array_format {
293 ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
294 } else {
295 ToolResultContentValue::String(s)
296 }
297 }
298
299 pub fn as_text(&self) -> String {
300 match self {
301 ToolResultContentValue::Array(arr) => arr
302 .iter()
303 .map(|c| c.text.clone())
304 .collect::<Vec<_>>()
305 .join("\n"),
306 ToolResultContentValue::String(s) => s.clone(),
307 }
308 }
309
310 pub fn to_array(&self) -> Self {
311 match self {
312 ToolResultContentValue::Array(_) => self.clone(),
313 ToolResultContentValue::String(s) => {
314 ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
315 }
316 }
317 }
318}
319
320#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
321pub struct ToolCall {
322 pub id: String,
323 #[serde(default)]
324 pub r#type: ToolType,
325 pub function: Function,
326}
327
328#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
329#[serde(rename_all = "lowercase")]
330pub enum ToolType {
331 #[default]
332 Function,
333}
334
335#[derive(Debug, Deserialize, Serialize, Clone)]
337pub struct FunctionDefinition {
338 pub name: String,
339 pub description: String,
340 pub parameters: serde_json::Value,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub strict: Option<bool>,
343}
344
345#[derive(Debug, Deserialize, Serialize, Clone)]
346pub struct ToolDefinition {
347 pub r#type: String,
348 pub function: FunctionDefinition,
349}
350
351impl From<completion::ToolDefinition> for ToolDefinition {
352 fn from(tool: completion::ToolDefinition) -> Self {
353 Self {
354 r#type: "function".into(),
355 function: FunctionDefinition {
356 name: tool.name,
357 description: tool.description,
358 parameters: tool.parameters,
359 strict: None,
360 },
361 }
362 }
363}
364
365impl ToolDefinition {
366 pub fn with_strict(mut self) -> Self {
369 self.function.strict = Some(true);
370 super::sanitize_schema(&mut self.function.parameters);
371 self
372 }
373}
374
375#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
376#[serde(rename_all = "snake_case")]
377pub enum ToolChoice {
378 #[default]
379 Auto,
380 None,
381 Required,
382}
383
384impl TryFrom<crate::message::ToolChoice> for ToolChoice {
385 type Error = CompletionError;
386 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
387 let res = match value {
388 message::ToolChoice::Specific { .. } => {
389 return Err(CompletionError::ProviderError(
390 "Provider doesn't support only using specific tools".to_string(),
391 ));
392 }
393 message::ToolChoice::Auto => Self::Auto,
394 message::ToolChoice::None => Self::None,
395 message::ToolChoice::Required => Self::Required,
396 };
397
398 Ok(res)
399 }
400}
401
402#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
403pub struct Function {
404 pub name: String,
405 #[serde(with = "json_utils::stringified_json")]
406 pub arguments: serde_json::Value,
407}
408
409impl TryFrom<message::ToolResult> for Message {
410 type Error = message::MessageError;
411
412 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
413 let text = value
414 .content
415 .into_iter()
416 .map(|content| {
417 match content {
418 message::ToolResultContent::Text(message::Text { text }) => Ok(text),
419 message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
420 "OpenAI does not support images in tool results. Tool results must be text."
421 .into(),
422 )),
423 }
424 })
425 .collect::<Result<Vec<_>, _>>()?
426 .join("\n");
427
428 Ok(Message::ToolResult {
429 tool_call_id: value.id,
430 content: ToolResultContentValue::String(text),
431 })
432 }
433}
434
435impl TryFrom<message::UserContent> for UserContent {
436 type Error = message::MessageError;
437
438 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
439 match value {
440 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
441 message::UserContent::Image(message::Image {
442 data,
443 detail,
444 media_type,
445 ..
446 }) => match data {
447 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
448 image_url: ImageUrl {
449 url,
450 detail: detail.unwrap_or_default(),
451 },
452 }),
453 DocumentSourceKind::Base64(data) => {
454 let url = format!(
455 "data:{};base64,{}",
456 media_type.map(|i| i.to_mime_type()).ok_or(
457 message::MessageError::ConversionError(
458 "OpenAI Image URI must have media type".into()
459 )
460 )?,
461 data
462 );
463
464 let detail = detail.ok_or(message::MessageError::ConversionError(
465 "OpenAI image URI must have image detail".into(),
466 ))?;
467
468 Ok(UserContent::Image {
469 image_url: ImageUrl { url, detail },
470 })
471 }
472 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
473 "Raw files not supported, encode as base64 first".into(),
474 )),
475 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
476 "Document has no body".into(),
477 )),
478 doc => Err(message::MessageError::ConversionError(format!(
479 "Unsupported document type: {doc:?}"
480 ))),
481 },
482 message::UserContent::Document(message::Document { data, .. }) => {
483 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
484 Ok(UserContent::Text { text })
485 } else {
486 Err(message::MessageError::ConversionError(
487 "Documents must be base64 or a string".into(),
488 ))
489 }
490 }
491 message::UserContent::Audio(message::Audio {
492 data, media_type, ..
493 }) => match data {
494 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
495 input_audio: InputAudio {
496 data,
497 format: match media_type {
498 Some(media_type) => media_type,
499 None => AudioMediaType::MP3,
500 },
501 },
502 }),
503 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
504 "URLs are not supported for audio".into(),
505 )),
506 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
507 "Raw files are not supported for audio".into(),
508 )),
509 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
510 "Audio has no body".into(),
511 )),
512 audio => Err(message::MessageError::ConversionError(format!(
513 "Unsupported audio type: {audio:?}"
514 ))),
515 },
516 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
517 "Tool result is in unsupported format".into(),
518 )),
519 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
520 "Video is in unsupported format".into(),
521 )),
522 }
523 }
524}
525
526impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
527 type Error = message::MessageError;
528
529 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
530 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
531 .into_iter()
532 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
533
534 if !tool_results.is_empty() {
537 tool_results
538 .into_iter()
539 .map(|content| match content {
540 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
541 _ => unreachable!(),
542 })
543 .collect::<Result<Vec<_>, _>>()
544 } else {
545 let other_content: Vec<UserContent> = other_content
546 .into_iter()
547 .map(|content| content.try_into())
548 .collect::<Result<Vec<_>, _>>()?;
549
550 let other_content = OneOrMany::many(other_content)
551 .expect("There must be other content here if there were no tool result content");
552
553 Ok(vec![Message::User {
554 content: other_content,
555 name: None,
556 }])
557 }
558 }
559}
560
561impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
562 type Error = message::MessageError;
563
564 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
565 let mut text_content = Vec::new();
566 let mut tool_calls = Vec::new();
567
568 for content in value {
569 match content {
570 message::AssistantContent::Text(text) => text_content.push(text),
571 message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
572 message::AssistantContent::Reasoning(_) => {
573 }
576 message::AssistantContent::Image(_) => {
577 panic!(
578 "The OpenAI Completions API doesn't support image content in assistant messages!"
579 );
580 }
581 }
582 }
583
584 if text_content.is_empty() && tool_calls.is_empty() {
585 return Ok(vec![]);
586 }
587
588 Ok(vec![Message::Assistant {
589 content: text_content
590 .into_iter()
591 .map(|content| content.text.into())
592 .collect::<Vec<_>>(),
593 refusal: None,
594 audio: None,
595 name: None,
596 tool_calls: tool_calls
597 .into_iter()
598 .map(|tool_call| tool_call.into())
599 .collect::<Vec<_>>(),
600 }])
601 }
602}
603
604impl TryFrom<message::Message> for Vec<Message> {
605 type Error = message::MessageError;
606
607 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
608 match message {
609 message::Message::User { content } => content.try_into(),
610 message::Message::Assistant { content, .. } => content.try_into(),
611 }
612 }
613}
614
615impl From<message::ToolCall> for ToolCall {
616 fn from(tool_call: message::ToolCall) -> Self {
617 Self {
618 id: tool_call.id,
619 r#type: ToolType::default(),
620 function: Function {
621 name: tool_call.function.name,
622 arguments: tool_call.function.arguments,
623 },
624 }
625 }
626}
627
628impl From<ToolCall> for message::ToolCall {
629 fn from(tool_call: ToolCall) -> Self {
630 Self {
631 id: tool_call.id,
632 call_id: None,
633 function: message::ToolFunction {
634 name: tool_call.function.name,
635 arguments: tool_call.function.arguments,
636 },
637 signature: None,
638 additional_params: None,
639 }
640 }
641}
642
643impl TryFrom<Message> for message::Message {
644 type Error = message::MessageError;
645
646 fn try_from(message: Message) -> Result<Self, Self::Error> {
647 Ok(match message {
648 Message::User { content, .. } => message::Message::User {
649 content: content.map(|content| content.into()),
650 },
651 Message::Assistant {
652 content,
653 tool_calls,
654 ..
655 } => {
656 let mut content = content
657 .into_iter()
658 .map(|content| match content {
659 AssistantContent::Text { text } => message::AssistantContent::text(text),
660
661 AssistantContent::Refusal { refusal } => {
664 message::AssistantContent::text(refusal)
665 }
666 })
667 .collect::<Vec<_>>();
668
669 content.extend(
670 tool_calls
671 .into_iter()
672 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
673 .collect::<Result<Vec<_>, _>>()?,
674 );
675
676 message::Message::Assistant {
677 id: None,
678 content: OneOrMany::many(content).map_err(|_| {
679 message::MessageError::ConversionError(
680 "Neither `content` nor `tool_calls` was provided to the Message"
681 .to_owned(),
682 )
683 })?,
684 }
685 }
686
687 Message::ToolResult {
688 tool_call_id,
689 content,
690 } => message::Message::User {
691 content: OneOrMany::one(message::UserContent::tool_result(
692 tool_call_id,
693 OneOrMany::one(message::ToolResultContent::text(content.as_text())),
694 )),
695 },
696
697 Message::System { content, .. } => message::Message::User {
700 content: content.map(|content| message::UserContent::text(content.text)),
701 },
702 })
703 }
704}
705
706impl From<UserContent> for message::UserContent {
707 fn from(content: UserContent) -> Self {
708 match content {
709 UserContent::Text { text } => message::UserContent::text(text),
710 UserContent::Image { image_url } => {
711 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
712 }
713 UserContent::Audio { input_audio } => {
714 message::UserContent::audio(input_audio.data, Some(input_audio.format))
715 }
716 }
717 }
718}
719
720impl From<String> for UserContent {
721 fn from(s: String) -> Self {
722 UserContent::Text { text: s }
723 }
724}
725
726impl FromStr for UserContent {
727 type Err = Infallible;
728
729 fn from_str(s: &str) -> Result<Self, Self::Err> {
730 Ok(UserContent::Text {
731 text: s.to_string(),
732 })
733 }
734}
735
736impl From<String> for AssistantContent {
737 fn from(s: String) -> Self {
738 AssistantContent::Text { text: s }
739 }
740}
741
742impl FromStr for AssistantContent {
743 type Err = Infallible;
744
745 fn from_str(s: &str) -> Result<Self, Self::Err> {
746 Ok(AssistantContent::Text {
747 text: s.to_string(),
748 })
749 }
750}
751impl From<String> for SystemContent {
752 fn from(s: String) -> Self {
753 SystemContent {
754 r#type: SystemContentType::default(),
755 text: s,
756 }
757 }
758}
759
760impl FromStr for SystemContent {
761 type Err = Infallible;
762
763 fn from_str(s: &str) -> Result<Self, Self::Err> {
764 Ok(SystemContent {
765 r#type: SystemContentType::default(),
766 text: s.to_string(),
767 })
768 }
769}
770
771#[derive(Debug, Deserialize, Serialize)]
772pub struct CompletionResponse {
773 pub id: String,
774 pub object: String,
775 pub created: u64,
776 pub model: String,
777 pub system_fingerprint: Option<String>,
778 pub choices: Vec<Choice>,
779 pub usage: Option<Usage>,
780}
781
782impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
783 type Error = CompletionError;
784
785 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
786 let choice = response.choices.first().ok_or_else(|| {
787 CompletionError::ResponseError("Response contained no choices".to_owned())
788 })?;
789
790 let content = match &choice.message {
791 Message::Assistant {
792 content,
793 tool_calls,
794 ..
795 } => {
796 let mut content = content
797 .iter()
798 .filter_map(|c| {
799 let s = match c {
800 AssistantContent::Text { text } => text,
801 AssistantContent::Refusal { refusal } => refusal,
802 };
803 if s.is_empty() {
804 None
805 } else {
806 Some(completion::AssistantContent::text(s))
807 }
808 })
809 .collect::<Vec<_>>();
810
811 content.extend(
812 tool_calls
813 .iter()
814 .map(|call| {
815 completion::AssistantContent::tool_call(
816 &call.id,
817 &call.function.name,
818 call.function.arguments.clone(),
819 )
820 })
821 .collect::<Vec<_>>(),
822 );
823 Ok(content)
824 }
825 _ => Err(CompletionError::ResponseError(
826 "Response did not contain a valid message or tool call".into(),
827 )),
828 }?;
829
830 let choice = OneOrMany::many(content).map_err(|_| {
831 CompletionError::ResponseError(
832 "Response contained no message or tool call (empty)".to_owned(),
833 )
834 })?;
835
836 let usage = response
837 .usage
838 .as_ref()
839 .map(|usage| completion::Usage {
840 input_tokens: usage.prompt_tokens as u64,
841 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
842 total_tokens: usage.total_tokens as u64,
843 cached_input_tokens: usage
844 .prompt_tokens_details
845 .as_ref()
846 .map(|d| d.cached_tokens as u64)
847 .unwrap_or(0),
848 })
849 .unwrap_or_default();
850
851 Ok(completion::CompletionResponse {
852 choice,
853 usage,
854 raw_response: response,
855 message_id: None,
856 })
857 }
858}
859
860impl ProviderResponseExt for CompletionResponse {
861 type OutputMessage = Choice;
862 type Usage = Usage;
863
864 fn get_response_id(&self) -> Option<String> {
865 Some(self.id.to_owned())
866 }
867
868 fn get_response_model_name(&self) -> Option<String> {
869 Some(self.model.to_owned())
870 }
871
872 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
873 self.choices.clone()
874 }
875
876 fn get_text_response(&self) -> Option<String> {
877 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
878 return None;
879 };
880
881 let UserContent::Text { text } = content.first() else {
882 return None;
883 };
884
885 Some(text)
886 }
887
888 fn get_usage(&self) -> Option<Self::Usage> {
889 self.usage.clone()
890 }
891}
892
893#[derive(Clone, Debug, Serialize, Deserialize)]
894pub struct Choice {
895 pub index: usize,
896 pub message: Message,
897 pub logprobs: Option<serde_json::Value>,
898 pub finish_reason: String,
899}
900
901#[derive(Clone, Debug, Deserialize, Serialize, Default)]
902pub struct PromptTokensDetails {
903 #[serde(default)]
905 pub cached_tokens: usize,
906}
907
908#[derive(Clone, Debug, Deserialize, Serialize)]
909pub struct Usage {
910 pub prompt_tokens: usize,
911 pub total_tokens: usize,
912 #[serde(skip_serializing_if = "Option::is_none")]
913 pub prompt_tokens_details: Option<PromptTokensDetails>,
914}
915
916impl Usage {
917 pub fn new() -> Self {
918 Self {
919 prompt_tokens: 0,
920 total_tokens: 0,
921 prompt_tokens_details: None,
922 }
923 }
924}
925
926impl Default for Usage {
927 fn default() -> Self {
928 Self::new()
929 }
930}
931
932impl fmt::Display for Usage {
933 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
934 let Usage {
935 prompt_tokens,
936 total_tokens,
937 ..
938 } = self;
939 write!(
940 f,
941 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
942 )
943 }
944}
945
946impl GetTokenUsage for Usage {
947 fn token_usage(&self) -> Option<crate::completion::Usage> {
948 let mut usage = crate::completion::Usage::new();
949 usage.input_tokens = self.prompt_tokens as u64;
950 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
951 usage.total_tokens = self.total_tokens as u64;
952 usage.cached_input_tokens = self
953 .prompt_tokens_details
954 .as_ref()
955 .map(|d| d.cached_tokens as u64)
956 .unwrap_or(0);
957
958 Some(usage)
959 }
960}
961
962#[derive(Clone)]
963pub struct CompletionModel<T = reqwest::Client> {
964 pub(crate) client: Client<T>,
965 pub model: String,
966 pub strict_tools: bool,
967 pub tool_result_array_content: bool,
968}
969
970impl<T> CompletionModel<T>
971where
972 T: Default + std::fmt::Debug + Clone + 'static,
973{
974 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
975 Self {
976 client,
977 model: model.into(),
978 strict_tools: false,
979 tool_result_array_content: false,
980 }
981 }
982
983 pub fn with_model(client: Client<T>, model: &str) -> Self {
984 Self {
985 client,
986 model: model.into(),
987 strict_tools: false,
988 tool_result_array_content: false,
989 }
990 }
991
992 pub fn with_strict_tools(mut self) -> Self {
1001 self.strict_tools = true;
1002 self
1003 }
1004
1005 pub fn with_tool_result_array_content(mut self) -> Self {
1006 self.tool_result_array_content = true;
1007 self
1008 }
1009}
1010
1011#[derive(Debug, Serialize, Deserialize, Clone)]
1012pub struct CompletionRequest {
1013 model: String,
1014 messages: Vec<Message>,
1015 #[serde(skip_serializing_if = "Vec::is_empty")]
1016 tools: Vec<ToolDefinition>,
1017 #[serde(skip_serializing_if = "Option::is_none")]
1018 tool_choice: Option<ToolChoice>,
1019 #[serde(skip_serializing_if = "Option::is_none")]
1020 temperature: Option<f64>,
1021 #[serde(flatten)]
1022 additional_params: Option<serde_json::Value>,
1023}
1024
1025pub struct OpenAIRequestParams {
1026 pub model: String,
1027 pub request: CoreCompletionRequest,
1028 pub strict_tools: bool,
1029 pub tool_result_array_content: bool,
1030}
1031
1032impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1033 type Error = CompletionError;
1034
1035 fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1036 let OpenAIRequestParams {
1037 model,
1038 request: req,
1039 strict_tools,
1040 tool_result_array_content,
1041 } = params;
1042
1043 let mut partial_history = vec![];
1044 if let Some(docs) = req.normalized_documents() {
1045 partial_history.push(docs);
1046 }
1047 let CoreCompletionRequest {
1048 model: request_model,
1049 preamble,
1050 chat_history,
1051 tools,
1052 temperature,
1053 additional_params,
1054 tool_choice,
1055 output_schema,
1056 ..
1057 } = req;
1058
1059 partial_history.extend(chat_history);
1060
1061 let mut full_history: Vec<Message> =
1062 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1063
1064 full_history.extend(
1065 partial_history
1066 .into_iter()
1067 .map(message::Message::try_into)
1068 .collect::<Result<Vec<Vec<Message>>, _>>()?
1069 .into_iter()
1070 .flatten()
1071 .collect::<Vec<_>>(),
1072 );
1073
1074 if full_history.is_empty() {
1075 return Err(CompletionError::RequestError(
1076 std::io::Error::new(
1077 std::io::ErrorKind::InvalidInput,
1078 "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1079 )
1080 .into(),
1081 ));
1082 }
1083
1084 if tool_result_array_content {
1085 for msg in &mut full_history {
1086 if let Message::ToolResult { content, .. } = msg {
1087 *content = content.to_array();
1088 }
1089 }
1090 }
1091
1092 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1093
1094 let tools: Vec<ToolDefinition> = tools
1095 .into_iter()
1096 .map(|tool| {
1097 let def = ToolDefinition::from(tool);
1098 if strict_tools { def.with_strict() } else { def }
1099 })
1100 .collect();
1101
1102 let additional_params = if let Some(schema) = output_schema {
1104 let name = schema
1105 .as_object()
1106 .and_then(|o| o.get("title"))
1107 .and_then(|v| v.as_str())
1108 .unwrap_or("response_schema")
1109 .to_string();
1110 let mut schema_value = schema.to_value();
1111 super::sanitize_schema(&mut schema_value);
1112 let response_format = serde_json::json!({
1113 "response_format": {
1114 "type": "json_schema",
1115 "json_schema": {
1116 "name": name,
1117 "strict": true,
1118 "schema": schema_value
1119 }
1120 }
1121 });
1122 Some(match additional_params {
1123 Some(existing) => json_utils::merge(existing, response_format),
1124 None => response_format,
1125 })
1126 } else {
1127 additional_params
1128 };
1129
1130 let res = Self {
1131 model: request_model.unwrap_or(model),
1132 messages: full_history,
1133 tools,
1134 tool_choice,
1135 temperature,
1136 additional_params,
1137 };
1138
1139 Ok(res)
1140 }
1141}
1142
1143impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1144 type Error = CompletionError;
1145
1146 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1147 CompletionRequest::try_from(OpenAIRequestParams {
1148 model,
1149 request: req,
1150 strict_tools: false,
1151 tool_result_array_content: false,
1152 })
1153 }
1154}
1155
1156impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1157 type InputMessage = Message;
1158
1159 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1160 self.messages.clone()
1161 }
1162
1163 fn get_system_prompt(&self) -> Option<String> {
1164 let first_message = self.messages.first()?;
1165
1166 let Message::System { ref content, .. } = first_message.clone() else {
1167 return None;
1168 };
1169
1170 let SystemContent { text, .. } = content.first();
1171
1172 Some(text)
1173 }
1174
1175 fn get_prompt(&self) -> Option<String> {
1176 let last_message = self.messages.last()?;
1177
1178 let Message::User { ref content, .. } = last_message.clone() else {
1179 return None;
1180 };
1181
1182 let UserContent::Text { text } = content.first() else {
1183 return None;
1184 };
1185
1186 Some(text)
1187 }
1188
1189 fn get_model_name(&self) -> String {
1190 self.model.clone()
1191 }
1192}
1193
1194impl CompletionModel<reqwest::Client> {
1195 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1196 crate::agent::AgentBuilder::new(self)
1197 }
1198}
1199
1200impl<T> completion::CompletionModel for CompletionModel<T>
1201where
1202 T: HttpClientExt
1203 + Default
1204 + std::fmt::Debug
1205 + Clone
1206 + WasmCompatSend
1207 + WasmCompatSync
1208 + 'static,
1209{
1210 type Response = CompletionResponse;
1211 type StreamingResponse = StreamingCompletionResponse;
1212
1213 type Client = super::CompletionsClient<T>;
1214
1215 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1216 Self::new(client.clone(), model)
1217 }
1218
1219 async fn completion(
1220 &self,
1221 completion_request: CoreCompletionRequest,
1222 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1223 let span = if tracing::Span::current().is_disabled() {
1224 info_span!(
1225 target: "rig::completions",
1226 "chat",
1227 gen_ai.operation.name = "chat",
1228 gen_ai.provider.name = "openai",
1229 gen_ai.request.model = self.model,
1230 gen_ai.system_instructions = &completion_request.preamble,
1231 gen_ai.response.id = tracing::field::Empty,
1232 gen_ai.response.model = tracing::field::Empty,
1233 gen_ai.usage.output_tokens = tracing::field::Empty,
1234 gen_ai.usage.input_tokens = tracing::field::Empty,
1235 )
1236 } else {
1237 tracing::Span::current()
1238 };
1239
1240 let request = CompletionRequest::try_from(OpenAIRequestParams {
1241 model: self.model.to_owned(),
1242 request: completion_request,
1243 strict_tools: self.strict_tools,
1244 tool_result_array_content: self.tool_result_array_content,
1245 })?;
1246
1247 if enabled!(Level::TRACE) {
1248 tracing::trace!(
1249 target: "rig::completions",
1250 "OpenAI Chat Completions completion request: {}",
1251 serde_json::to_string_pretty(&request)?
1252 );
1253 }
1254
1255 let body = serde_json::to_vec(&request)?;
1256
1257 let req = self
1258 .client
1259 .post("/chat/completions")?
1260 .body(body)
1261 .map_err(|e| CompletionError::HttpError(e.into()))?;
1262
1263 async move {
1264 let response = self.client.send(req).await?;
1265
1266 if response.status().is_success() {
1267 let text = http_client::text(response).await?;
1268
1269 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1270 ApiResponse::Ok(response) => {
1271 let span = tracing::Span::current();
1272 span.record_response_metadata(&response);
1273 span.record_token_usage(&response.usage);
1274
1275 if enabled!(Level::TRACE) {
1276 tracing::trace!(
1277 target: "rig::completions",
1278 "OpenAI Chat Completions completion response: {}",
1279 serde_json::to_string_pretty(&response)?
1280 );
1281 }
1282
1283 response.try_into()
1284 }
1285 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1286 }
1287 } else {
1288 let text = http_client::text(response).await?;
1289 Err(CompletionError::ProviderError(text))
1290 }
1291 }
1292 .instrument(span)
1293 .await
1294 }
1295
1296 async fn stream(
1297 &self,
1298 request: CoreCompletionRequest,
1299 ) -> Result<
1300 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1301 CompletionError,
1302 > {
1303 Self::stream(self, request).await
1304 }
1305}
1306
1307fn serialize_assistant_content_vec<S>(
1308 value: &Vec<AssistantContent>,
1309 serializer: S,
1310) -> Result<S::Ok, S::Error>
1311where
1312 S: Serializer,
1313{
1314 if value.is_empty() {
1315 serializer.serialize_str("")
1316 } else {
1317 value.serialize(serializer)
1318 }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323 use super::*;
1324
1325 #[test]
1326 fn test_openai_request_uses_request_model_override() {
1327 let request = crate::completion::CompletionRequest {
1328 model: Some("gpt-4.1".to_string()),
1329 preamble: None,
1330 chat_history: crate::OneOrMany::one("Hello".into()),
1331 documents: vec![],
1332 tools: vec![],
1333 temperature: None,
1334 max_tokens: None,
1335 tool_choice: None,
1336 additional_params: None,
1337 output_schema: None,
1338 };
1339
1340 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1341 model: "gpt-4o-mini".to_string(),
1342 request,
1343 strict_tools: false,
1344 tool_result_array_content: false,
1345 })
1346 .expect("request conversion should succeed");
1347 let serialized =
1348 serde_json::to_value(openai_request).expect("serialization should succeed");
1349
1350 assert_eq!(serialized["model"], "gpt-4.1");
1351 }
1352
1353 #[test]
1354 fn test_openai_request_uses_default_model_when_override_unset() {
1355 let request = crate::completion::CompletionRequest {
1356 model: None,
1357 preamble: None,
1358 chat_history: crate::OneOrMany::one("Hello".into()),
1359 documents: vec![],
1360 tools: vec![],
1361 temperature: None,
1362 max_tokens: None,
1363 tool_choice: None,
1364 additional_params: None,
1365 output_schema: None,
1366 };
1367
1368 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1369 model: "gpt-4o-mini".to_string(),
1370 request,
1371 strict_tools: false,
1372 tool_result_array_content: false,
1373 })
1374 .expect("request conversion should succeed");
1375 let serialized =
1376 serde_json::to_value(openai_request).expect("serialization should succeed");
1377
1378 assert_eq!(serialized["model"], "gpt-4o-mini");
1379 }
1380
1381 #[test]
1382 fn assistant_reasoning_is_silently_skipped() {
1383 let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1384
1385 let converted: Vec<Message> = assistant_content
1386 .try_into()
1387 .expect("conversion should work");
1388
1389 assert!(converted.is_empty());
1390 }
1391
1392 #[test]
1393 fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1394 let assistant_content = OneOrMany::many(vec![
1395 message::AssistantContent::reasoning("hidden"),
1396 message::AssistantContent::text("visible"),
1397 message::AssistantContent::tool_call(
1398 "call_1",
1399 "subtract",
1400 serde_json::json!({"x": 2, "y": 1}),
1401 ),
1402 ])
1403 .expect("non-empty assistant content");
1404
1405 let converted: Vec<Message> = assistant_content
1406 .try_into()
1407 .expect("conversion should work");
1408 assert_eq!(converted.len(), 1);
1409
1410 match &converted[0] {
1411 Message::Assistant {
1412 content,
1413 tool_calls,
1414 ..
1415 } => {
1416 assert_eq!(
1417 content,
1418 &vec![AssistantContent::Text {
1419 text: "visible".to_string()
1420 }]
1421 );
1422 assert_eq!(tool_calls.len(), 1);
1423 assert_eq!(tool_calls[0].id, "call_1");
1424 assert_eq!(tool_calls[0].function.name, "subtract");
1425 assert_eq!(
1426 tool_calls[0].function.arguments,
1427 serde_json::json!({"x": 2, "y": 1})
1428 );
1429 }
1430 _ => panic!("expected assistant message"),
1431 }
1432 }
1433
1434 #[test]
1435 fn request_conversion_errors_when_all_messages_are_filtered() {
1436 let request = CoreCompletionRequest {
1437 model: None,
1438 preamble: None,
1439 chat_history: OneOrMany::one(message::Message::Assistant {
1440 id: None,
1441 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1442 }),
1443 documents: vec![],
1444 tools: vec![],
1445 temperature: None,
1446 max_tokens: None,
1447 tool_choice: None,
1448 additional_params: None,
1449 output_schema: None,
1450 };
1451
1452 let result = CompletionRequest::try_from(OpenAIRequestParams {
1453 model: "gpt-4o-mini".to_string(),
1454 request,
1455 strict_tools: false,
1456 tool_result_array_content: false,
1457 });
1458
1459 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1460 }
1461}