1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28fn serialize_user_content<S>(
31 content: &OneOrMany<UserContent>,
32 serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35 S: Serializer,
36{
37 if content.len() == 1
38 && let UserContent::Text { text } = content.first_ref()
39 {
40 return serializer.serialize_str(text);
41 }
42 content.serialize(serializer)
43}
44
45pub const GPT_5_2: &str = "gpt-5.2";
47
48pub const GPT_5_1: &str = "gpt-5.1";
50
51pub const GPT_5: &str = "gpt-5";
53pub const GPT_5_MINI: &str = "gpt-5-mini";
55pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64pub const GPT_4O: &str = "gpt-4o";
66pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84pub const GPT_4: &str = "gpt-4";
86pub const GPT_4_0613: &str = "gpt-4-0613";
88pub const GPT_4_32K: &str = "gpt-4-32k";
90pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95pub const O4_MINI: &str = "o4-mini";
97pub const O3: &str = "o3";
99pub const O3_MINI: &str = "o3-mini";
101pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103pub const O1_PRO: &str = "o1-pro";
105pub const O1: &str = "o1";
107pub const O1_2024_12_17: &str = "o1-2024-12-17";
109pub const O1_PREVIEW: &str = "o1-preview";
111pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113pub const O1_MINI: &str = "o1-mini";
115pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128 fn from(err: ApiErrorResponse) -> Self {
129 CompletionError::ProviderError(err.message)
130 }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136 #[serde(alias = "developer")]
137 System {
138 #[serde(deserialize_with = "string_or_one_or_many")]
139 content: OneOrMany<SystemContent>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 name: Option<String>,
142 },
143 User {
144 #[serde(
145 deserialize_with = "string_or_one_or_many",
146 serialize_with = "serialize_user_content"
147 )]
148 content: OneOrMany<UserContent>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 name: Option<String>,
151 },
152 Assistant {
153 #[serde(
154 default,
155 deserialize_with = "json_utils::string_or_vec",
156 skip_serializing_if = "Vec::is_empty",
157 serialize_with = "serialize_assistant_content_vec"
158 )]
159 content: Vec<AssistantContent>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 refusal: Option<String>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 audio: Option<AudioAssistant>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 name: Option<String>,
166 #[serde(
167 default,
168 deserialize_with = "json_utils::null_or_vec",
169 skip_serializing_if = "Vec::is_empty"
170 )]
171 tool_calls: Vec<ToolCall>,
172 },
173 #[serde(rename = "tool")]
174 ToolResult {
175 tool_call_id: String,
176 content: ToolResultContentValue,
177 },
178}
179
180impl Message {
181 pub fn system(content: &str) -> Self {
182 Message::System {
183 content: OneOrMany::one(content.to_owned().into()),
184 name: None,
185 }
186 }
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190pub struct AudioAssistant {
191 pub id: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct SystemContent {
196 #[serde(default)]
197 pub r#type: SystemContentType,
198 pub text: String,
199}
200
201#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
202#[serde(rename_all = "lowercase")]
203pub enum SystemContentType {
204 #[default]
205 Text,
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum AssistantContent {
211 Text { text: String },
212 Refusal { refusal: String },
213}
214
215impl From<AssistantContent> for completion::AssistantContent {
216 fn from(value: AssistantContent) -> Self {
217 match value {
218 AssistantContent::Text { text } => completion::AssistantContent::text(text),
219 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
220 }
221 }
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum UserContent {
227 Text {
228 text: String,
229 },
230 #[serde(rename = "image_url")]
231 Image {
232 image_url: ImageUrl,
233 },
234 Audio {
235 input_audio: InputAudio,
236 },
237}
238
239#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
240pub struct ImageUrl {
241 pub url: String,
242 #[serde(default)]
243 pub detail: ImageDetail,
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct InputAudio {
248 pub data: String,
249 pub format: AudioMediaType,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct ToolResultContent {
254 #[serde(default)]
255 r#type: ToolResultContentType,
256 pub text: String,
257}
258
259#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolResultContentType {
262 #[default]
263 Text,
264}
265
266impl FromStr for ToolResultContent {
267 type Err = Infallible;
268
269 fn from_str(s: &str) -> Result<Self, Self::Err> {
270 Ok(s.to_owned().into())
271 }
272}
273
274impl From<String> for ToolResultContent {
275 fn from(s: String) -> Self {
276 ToolResultContent {
277 r#type: ToolResultContentType::default(),
278 text: s,
279 }
280 }
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
284#[serde(untagged)]
285pub enum ToolResultContentValue {
286 Array(Vec<ToolResultContent>),
287 String(String),
288}
289
290impl ToolResultContentValue {
291 pub fn from_string(s: String, use_array_format: bool) -> Self {
292 if use_array_format {
293 ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
294 } else {
295 ToolResultContentValue::String(s)
296 }
297 }
298
299 pub fn as_text(&self) -> String {
300 match self {
301 ToolResultContentValue::Array(arr) => arr
302 .iter()
303 .map(|c| c.text.clone())
304 .collect::<Vec<_>>()
305 .join("\n"),
306 ToolResultContentValue::String(s) => s.clone(),
307 }
308 }
309
310 pub fn to_array(&self) -> Self {
311 match self {
312 ToolResultContentValue::Array(_) => self.clone(),
313 ToolResultContentValue::String(s) => {
314 ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
315 }
316 }
317 }
318}
319
320#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
321pub struct ToolCall {
322 pub id: String,
323 #[serde(default)]
324 pub r#type: ToolType,
325 pub function: Function,
326}
327
328#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
329#[serde(rename_all = "lowercase")]
330pub enum ToolType {
331 #[default]
332 Function,
333}
334
335#[derive(Debug, Deserialize, Serialize, Clone)]
337pub struct FunctionDefinition {
338 pub name: String,
339 pub description: String,
340 pub parameters: serde_json::Value,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub strict: Option<bool>,
343}
344
345#[derive(Debug, Deserialize, Serialize, Clone)]
346pub struct ToolDefinition {
347 pub r#type: String,
348 pub function: FunctionDefinition,
349}
350
351impl From<completion::ToolDefinition> for ToolDefinition {
352 fn from(tool: completion::ToolDefinition) -> Self {
353 Self {
354 r#type: "function".into(),
355 function: FunctionDefinition {
356 name: tool.name,
357 description: tool.description,
358 parameters: tool.parameters,
359 strict: None,
360 },
361 }
362 }
363}
364
365impl ToolDefinition {
366 pub fn with_strict(mut self) -> Self {
369 self.function.strict = Some(true);
370 super::sanitize_schema(&mut self.function.parameters);
371 self
372 }
373}
374
375#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
376#[serde(rename_all = "snake_case")]
377pub enum ToolChoice {
378 #[default]
379 Auto,
380 None,
381 Required,
382}
383
384impl TryFrom<crate::message::ToolChoice> for ToolChoice {
385 type Error = CompletionError;
386 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
387 let res = match value {
388 message::ToolChoice::Specific { .. } => {
389 return Err(CompletionError::ProviderError(
390 "Provider doesn't support only using specific tools".to_string(),
391 ));
392 }
393 message::ToolChoice::Auto => Self::Auto,
394 message::ToolChoice::None => Self::None,
395 message::ToolChoice::Required => Self::Required,
396 };
397
398 Ok(res)
399 }
400}
401
402#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
403pub struct Function {
404 pub name: String,
405 #[serde(with = "json_utils::stringified_json")]
406 pub arguments: serde_json::Value,
407}
408
409impl TryFrom<message::ToolResult> for Message {
410 type Error = message::MessageError;
411
412 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
413 let text = value
414 .content
415 .into_iter()
416 .map(|content| {
417 match content {
418 message::ToolResultContent::Text(message::Text { text }) => Ok(text),
419 message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
420 "OpenAI does not support images in tool results. Tool results must be text."
421 .into(),
422 )),
423 }
424 })
425 .collect::<Result<Vec<_>, _>>()?
426 .join("\n");
427
428 Ok(Message::ToolResult {
429 tool_call_id: value.id,
430 content: ToolResultContentValue::String(text),
431 })
432 }
433}
434
435impl TryFrom<message::UserContent> for UserContent {
436 type Error = message::MessageError;
437
438 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
439 match value {
440 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
441 message::UserContent::Image(message::Image {
442 data,
443 detail,
444 media_type,
445 ..
446 }) => match data {
447 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
448 image_url: ImageUrl {
449 url,
450 detail: detail.unwrap_or_default(),
451 },
452 }),
453 DocumentSourceKind::Base64(data) => {
454 let url = format!(
455 "data:{};base64,{}",
456 media_type.map(|i| i.to_mime_type()).ok_or(
457 message::MessageError::ConversionError(
458 "OpenAI Image URI must have media type".into()
459 )
460 )?,
461 data
462 );
463
464 let detail = detail.ok_or(message::MessageError::ConversionError(
465 "OpenAI image URI must have image detail".into(),
466 ))?;
467
468 Ok(UserContent::Image {
469 image_url: ImageUrl { url, detail },
470 })
471 }
472 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
473 "Raw files not supported, encode as base64 first".into(),
474 )),
475 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
476 "Document has no body".into(),
477 )),
478 doc => Err(message::MessageError::ConversionError(format!(
479 "Unsupported document type: {doc:?}"
480 ))),
481 },
482 message::UserContent::Document(message::Document { data, .. }) => {
483 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
484 Ok(UserContent::Text { text })
485 } else {
486 Err(message::MessageError::ConversionError(
487 "Documents must be base64 or a string".into(),
488 ))
489 }
490 }
491 message::UserContent::Audio(message::Audio {
492 data, media_type, ..
493 }) => match data {
494 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
495 input_audio: InputAudio {
496 data,
497 format: match media_type {
498 Some(media_type) => media_type,
499 None => AudioMediaType::MP3,
500 },
501 },
502 }),
503 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
504 "URLs are not supported for audio".into(),
505 )),
506 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
507 "Raw files are not supported for audio".into(),
508 )),
509 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
510 "Audio has no body".into(),
511 )),
512 audio => Err(message::MessageError::ConversionError(format!(
513 "Unsupported audio type: {audio:?}"
514 ))),
515 },
516 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
517 "Tool result is in unsupported format".into(),
518 )),
519 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
520 "Video is in unsupported format".into(),
521 )),
522 }
523 }
524}
525
526impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
527 type Error = message::MessageError;
528
529 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
530 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
531 .into_iter()
532 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
533
534 if !tool_results.is_empty() {
537 tool_results
538 .into_iter()
539 .map(|content| match content {
540 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
541 _ => unreachable!(),
542 })
543 .collect::<Result<Vec<_>, _>>()
544 } else {
545 let other_content: Vec<UserContent> = other_content
546 .into_iter()
547 .map(|content| content.try_into())
548 .collect::<Result<Vec<_>, _>>()?;
549
550 let other_content = OneOrMany::many(other_content)
551 .expect("There must be other content here if there were no tool result content");
552
553 Ok(vec![Message::User {
554 content: other_content,
555 name: None,
556 }])
557 }
558 }
559}
560
561impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
562 type Error = message::MessageError;
563
564 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
565 let mut text_content = Vec::new();
566 let mut tool_calls = Vec::new();
567
568 for content in value {
569 match content {
570 message::AssistantContent::Text(text) => text_content.push(text),
571 message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
572 message::AssistantContent::Reasoning(_) => {
573 }
576 message::AssistantContent::Image(_) => {
577 panic!(
578 "The OpenAI Completions API doesn't support image content in assistant messages!"
579 );
580 }
581 }
582 }
583
584 if text_content.is_empty() && tool_calls.is_empty() {
585 return Ok(vec![]);
586 }
587
588 Ok(vec![Message::Assistant {
589 content: text_content
590 .into_iter()
591 .map(|content| content.text.into())
592 .collect::<Vec<_>>(),
593 refusal: None,
594 audio: None,
595 name: None,
596 tool_calls: tool_calls
597 .into_iter()
598 .map(|tool_call| tool_call.into())
599 .collect::<Vec<_>>(),
600 }])
601 }
602}
603
604impl TryFrom<message::Message> for Vec<Message> {
605 type Error = message::MessageError;
606
607 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
608 match message {
609 message::Message::System { content } => Ok(vec![Message::system(&content)]),
610 message::Message::User { content } => content.try_into(),
611 message::Message::Assistant { content, .. } => content.try_into(),
612 }
613 }
614}
615
616impl From<message::ToolCall> for ToolCall {
617 fn from(tool_call: message::ToolCall) -> Self {
618 Self {
619 id: tool_call.id,
620 r#type: ToolType::default(),
621 function: Function {
622 name: tool_call.function.name,
623 arguments: tool_call.function.arguments,
624 },
625 }
626 }
627}
628
629impl From<ToolCall> for message::ToolCall {
630 fn from(tool_call: ToolCall) -> Self {
631 Self {
632 id: tool_call.id,
633 call_id: None,
634 function: message::ToolFunction {
635 name: tool_call.function.name,
636 arguments: tool_call.function.arguments,
637 },
638 signature: None,
639 additional_params: None,
640 }
641 }
642}
643
644impl TryFrom<Message> for message::Message {
645 type Error = message::MessageError;
646
647 fn try_from(message: Message) -> Result<Self, Self::Error> {
648 Ok(match message {
649 Message::User { content, .. } => message::Message::User {
650 content: content.map(|content| content.into()),
651 },
652 Message::Assistant {
653 content,
654 tool_calls,
655 ..
656 } => {
657 let mut content = content
658 .into_iter()
659 .map(|content| match content {
660 AssistantContent::Text { text } => message::AssistantContent::text(text),
661
662 AssistantContent::Refusal { refusal } => {
665 message::AssistantContent::text(refusal)
666 }
667 })
668 .collect::<Vec<_>>();
669
670 content.extend(
671 tool_calls
672 .into_iter()
673 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
674 .collect::<Result<Vec<_>, _>>()?,
675 );
676
677 message::Message::Assistant {
678 id: None,
679 content: OneOrMany::many(content).map_err(|_| {
680 message::MessageError::ConversionError(
681 "Neither `content` nor `tool_calls` was provided to the Message"
682 .to_owned(),
683 )
684 })?,
685 }
686 }
687
688 Message::ToolResult {
689 tool_call_id,
690 content,
691 } => message::Message::User {
692 content: OneOrMany::one(message::UserContent::tool_result(
693 tool_call_id,
694 OneOrMany::one(message::ToolResultContent::text(content.as_text())),
695 )),
696 },
697
698 Message::System { content, .. } => message::Message::User {
701 content: content.map(|content| message::UserContent::text(content.text)),
702 },
703 })
704 }
705}
706
707impl From<UserContent> for message::UserContent {
708 fn from(content: UserContent) -> Self {
709 match content {
710 UserContent::Text { text } => message::UserContent::text(text),
711 UserContent::Image { image_url } => {
712 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
713 }
714 UserContent::Audio { input_audio } => {
715 message::UserContent::audio(input_audio.data, Some(input_audio.format))
716 }
717 }
718 }
719}
720
721impl From<String> for UserContent {
722 fn from(s: String) -> Self {
723 UserContent::Text { text: s }
724 }
725}
726
727impl FromStr for UserContent {
728 type Err = Infallible;
729
730 fn from_str(s: &str) -> Result<Self, Self::Err> {
731 Ok(UserContent::Text {
732 text: s.to_string(),
733 })
734 }
735}
736
737impl From<String> for AssistantContent {
738 fn from(s: String) -> Self {
739 AssistantContent::Text { text: s }
740 }
741}
742
743impl FromStr for AssistantContent {
744 type Err = Infallible;
745
746 fn from_str(s: &str) -> Result<Self, Self::Err> {
747 Ok(AssistantContent::Text {
748 text: s.to_string(),
749 })
750 }
751}
752impl From<String> for SystemContent {
753 fn from(s: String) -> Self {
754 SystemContent {
755 r#type: SystemContentType::default(),
756 text: s,
757 }
758 }
759}
760
761impl FromStr for SystemContent {
762 type Err = Infallible;
763
764 fn from_str(s: &str) -> Result<Self, Self::Err> {
765 Ok(SystemContent {
766 r#type: SystemContentType::default(),
767 text: s.to_string(),
768 })
769 }
770}
771
772#[derive(Debug, Deserialize, Serialize)]
773pub struct CompletionResponse {
774 pub id: String,
775 pub object: String,
776 pub created: u64,
777 pub model: String,
778 pub system_fingerprint: Option<String>,
779 pub choices: Vec<Choice>,
780 pub usage: Option<Usage>,
781}
782
783impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
784 type Error = CompletionError;
785
786 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
787 let choice = response.choices.first().ok_or_else(|| {
788 CompletionError::ResponseError("Response contained no choices".to_owned())
789 })?;
790
791 let content = match &choice.message {
792 Message::Assistant {
793 content,
794 tool_calls,
795 ..
796 } => {
797 let mut content = content
798 .iter()
799 .filter_map(|c| {
800 let s = match c {
801 AssistantContent::Text { text } => text,
802 AssistantContent::Refusal { refusal } => refusal,
803 };
804 if s.is_empty() {
805 None
806 } else {
807 Some(completion::AssistantContent::text(s))
808 }
809 })
810 .collect::<Vec<_>>();
811
812 content.extend(
813 tool_calls
814 .iter()
815 .map(|call| {
816 completion::AssistantContent::tool_call(
817 &call.id,
818 &call.function.name,
819 call.function.arguments.clone(),
820 )
821 })
822 .collect::<Vec<_>>(),
823 );
824 Ok(content)
825 }
826 _ => Err(CompletionError::ResponseError(
827 "Response did not contain a valid message or tool call".into(),
828 )),
829 }?;
830
831 let choice = OneOrMany::many(content).map_err(|_| {
832 CompletionError::ResponseError(
833 "Response contained no message or tool call (empty)".to_owned(),
834 )
835 })?;
836
837 let usage = response
838 .usage
839 .as_ref()
840 .map(|usage| completion::Usage {
841 input_tokens: usage.prompt_tokens as u64,
842 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
843 total_tokens: usage.total_tokens as u64,
844 cached_input_tokens: usage
845 .prompt_tokens_details
846 .as_ref()
847 .map(|d| d.cached_tokens as u64)
848 .unwrap_or(0),
849 cache_creation_input_tokens: 0,
850 })
851 .unwrap_or_default();
852
853 Ok(completion::CompletionResponse {
854 choice,
855 usage,
856 raw_response: response,
857 message_id: None,
858 })
859 }
860}
861
862impl ProviderResponseExt for CompletionResponse {
863 type OutputMessage = Choice;
864 type Usage = Usage;
865
866 fn get_response_id(&self) -> Option<String> {
867 Some(self.id.to_owned())
868 }
869
870 fn get_response_model_name(&self) -> Option<String> {
871 Some(self.model.to_owned())
872 }
873
874 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
875 self.choices.clone()
876 }
877
878 fn get_text_response(&self) -> Option<String> {
879 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
880 return None;
881 };
882
883 let UserContent::Text { text } = content.first() else {
884 return None;
885 };
886
887 Some(text)
888 }
889
890 fn get_usage(&self) -> Option<Self::Usage> {
891 self.usage.clone()
892 }
893}
894
895#[derive(Clone, Debug, Serialize, Deserialize)]
896pub struct Choice {
897 pub index: usize,
898 pub message: Message,
899 pub logprobs: Option<serde_json::Value>,
900 pub finish_reason: String,
901}
902
903#[derive(Clone, Debug, Deserialize, Serialize, Default)]
904pub struct PromptTokensDetails {
905 #[serde(default)]
907 pub cached_tokens: usize,
908}
909
910#[derive(Clone, Debug, Deserialize, Serialize)]
911pub struct Usage {
912 pub prompt_tokens: usize,
913 pub total_tokens: usize,
914 #[serde(skip_serializing_if = "Option::is_none")]
915 pub prompt_tokens_details: Option<PromptTokensDetails>,
916}
917
918impl Usage {
919 pub fn new() -> Self {
920 Self {
921 prompt_tokens: 0,
922 total_tokens: 0,
923 prompt_tokens_details: None,
924 }
925 }
926}
927
928impl Default for Usage {
929 fn default() -> Self {
930 Self::new()
931 }
932}
933
934impl fmt::Display for Usage {
935 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
936 let Usage {
937 prompt_tokens,
938 total_tokens,
939 ..
940 } = self;
941 write!(
942 f,
943 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
944 )
945 }
946}
947
948impl GetTokenUsage for Usage {
949 fn token_usage(&self) -> Option<crate::completion::Usage> {
950 let mut usage = crate::completion::Usage::new();
951 usage.input_tokens = self.prompt_tokens as u64;
952 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
953 usage.total_tokens = self.total_tokens as u64;
954 usage.cached_input_tokens = self
955 .prompt_tokens_details
956 .as_ref()
957 .map(|d| d.cached_tokens as u64)
958 .unwrap_or(0);
959
960 Some(usage)
961 }
962}
963
964#[derive(Clone)]
965pub struct CompletionModel<T = reqwest::Client> {
966 pub(crate) client: Client<T>,
967 pub model: String,
968 pub strict_tools: bool,
969 pub tool_result_array_content: bool,
970}
971
972impl<T> CompletionModel<T>
973where
974 T: Default + std::fmt::Debug + Clone + 'static,
975{
976 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
977 Self {
978 client,
979 model: model.into(),
980 strict_tools: false,
981 tool_result_array_content: false,
982 }
983 }
984
985 pub fn with_model(client: Client<T>, model: &str) -> Self {
986 Self {
987 client,
988 model: model.into(),
989 strict_tools: false,
990 tool_result_array_content: false,
991 }
992 }
993
994 pub fn with_strict_tools(mut self) -> Self {
1003 self.strict_tools = true;
1004 self
1005 }
1006
1007 pub fn with_tool_result_array_content(mut self) -> Self {
1008 self.tool_result_array_content = true;
1009 self
1010 }
1011}
1012
1013#[derive(Debug, Serialize, Deserialize, Clone)]
1014pub struct CompletionRequest {
1015 model: String,
1016 messages: Vec<Message>,
1017 #[serde(skip_serializing_if = "Vec::is_empty")]
1018 tools: Vec<ToolDefinition>,
1019 #[serde(skip_serializing_if = "Option::is_none")]
1020 tool_choice: Option<ToolChoice>,
1021 #[serde(skip_serializing_if = "Option::is_none")]
1022 temperature: Option<f64>,
1023 #[serde(skip_serializing_if = "Option::is_none")]
1024 max_tokens: Option<u64>,
1025 #[serde(flatten)]
1026 additional_params: Option<serde_json::Value>,
1027}
1028
1029pub struct OpenAIRequestParams {
1030 pub model: String,
1031 pub request: CoreCompletionRequest,
1032 pub strict_tools: bool,
1033 pub tool_result_array_content: bool,
1034}
1035
1036impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1037 type Error = CompletionError;
1038
1039 fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1040 let OpenAIRequestParams {
1041 model,
1042 request: req,
1043 strict_tools,
1044 tool_result_array_content,
1045 } = params;
1046
1047 let mut partial_history = vec![];
1048 if let Some(docs) = req.normalized_documents() {
1049 partial_history.push(docs);
1050 }
1051 let CoreCompletionRequest {
1052 model: request_model,
1053 preamble,
1054 chat_history,
1055 tools,
1056 temperature,
1057 max_tokens,
1058 additional_params,
1059 tool_choice,
1060 output_schema,
1061 ..
1062 } = req;
1063
1064 partial_history.extend(chat_history);
1065
1066 let mut full_history: Vec<Message> =
1067 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1068
1069 full_history.extend(
1070 partial_history
1071 .into_iter()
1072 .map(message::Message::try_into)
1073 .collect::<Result<Vec<Vec<Message>>, _>>()?
1074 .into_iter()
1075 .flatten()
1076 .collect::<Vec<_>>(),
1077 );
1078
1079 if full_history.is_empty() {
1080 return Err(CompletionError::RequestError(
1081 std::io::Error::new(
1082 std::io::ErrorKind::InvalidInput,
1083 "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1084 )
1085 .into(),
1086 ));
1087 }
1088
1089 if tool_result_array_content {
1090 for msg in &mut full_history {
1091 if let Message::ToolResult { content, .. } = msg {
1092 *content = content.to_array();
1093 }
1094 }
1095 }
1096
1097 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1098
1099 let tools: Vec<ToolDefinition> = tools
1100 .into_iter()
1101 .map(|tool| {
1102 let def = ToolDefinition::from(tool);
1103 if strict_tools { def.with_strict() } else { def }
1104 })
1105 .collect();
1106
1107 let additional_params = if let Some(schema) = output_schema {
1109 let name = schema
1110 .as_object()
1111 .and_then(|o| o.get("title"))
1112 .and_then(|v| v.as_str())
1113 .unwrap_or("response_schema")
1114 .to_string();
1115 let mut schema_value = schema.to_value();
1116 super::sanitize_schema(&mut schema_value);
1117 let response_format = serde_json::json!({
1118 "response_format": {
1119 "type": "json_schema",
1120 "json_schema": {
1121 "name": name,
1122 "strict": true,
1123 "schema": schema_value
1124 }
1125 }
1126 });
1127 Some(match additional_params {
1128 Some(existing) => json_utils::merge(existing, response_format),
1129 None => response_format,
1130 })
1131 } else {
1132 additional_params
1133 };
1134
1135 let res = Self {
1136 model: request_model.unwrap_or(model),
1137 messages: full_history,
1138 tools,
1139 tool_choice,
1140 temperature,
1141 max_tokens,
1142 additional_params,
1143 };
1144
1145 Ok(res)
1146 }
1147}
1148
1149impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1150 type Error = CompletionError;
1151
1152 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1153 CompletionRequest::try_from(OpenAIRequestParams {
1154 model,
1155 request: req,
1156 strict_tools: false,
1157 tool_result_array_content: false,
1158 })
1159 }
1160}
1161
1162impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1163 type InputMessage = Message;
1164
1165 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1166 self.messages.clone()
1167 }
1168
1169 fn get_system_prompt(&self) -> Option<String> {
1170 let first_message = self.messages.first()?;
1171
1172 let Message::System { ref content, .. } = first_message.clone() else {
1173 return None;
1174 };
1175
1176 let SystemContent { text, .. } = content.first();
1177
1178 Some(text)
1179 }
1180
1181 fn get_prompt(&self) -> Option<String> {
1182 let last_message = self.messages.last()?;
1183
1184 let Message::User { ref content, .. } = last_message.clone() else {
1185 return None;
1186 };
1187
1188 let UserContent::Text { text } = content.first() else {
1189 return None;
1190 };
1191
1192 Some(text)
1193 }
1194
1195 fn get_model_name(&self) -> String {
1196 self.model.clone()
1197 }
1198}
1199
1200impl CompletionModel<reqwest::Client> {
1201 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1202 crate::agent::AgentBuilder::new(self)
1203 }
1204}
1205
1206impl<T> completion::CompletionModel for CompletionModel<T>
1207where
1208 T: HttpClientExt
1209 + Default
1210 + std::fmt::Debug
1211 + Clone
1212 + WasmCompatSend
1213 + WasmCompatSync
1214 + 'static,
1215{
1216 type Response = CompletionResponse;
1217 type StreamingResponse = StreamingCompletionResponse;
1218
1219 type Client = super::CompletionsClient<T>;
1220
1221 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1222 Self::new(client.clone(), model)
1223 }
1224
1225 async fn completion(
1226 &self,
1227 completion_request: CoreCompletionRequest,
1228 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1229 let span = if tracing::Span::current().is_disabled() {
1230 info_span!(
1231 target: "rig::completions",
1232 "chat",
1233 gen_ai.operation.name = "chat",
1234 gen_ai.provider.name = "openai",
1235 gen_ai.request.model = self.model,
1236 gen_ai.system_instructions = &completion_request.preamble,
1237 gen_ai.response.id = tracing::field::Empty,
1238 gen_ai.response.model = tracing::field::Empty,
1239 gen_ai.usage.output_tokens = tracing::field::Empty,
1240 gen_ai.usage.input_tokens = tracing::field::Empty,
1241 gen_ai.usage.cached_tokens = tracing::field::Empty,
1242 )
1243 } else {
1244 tracing::Span::current()
1245 };
1246
1247 let request = CompletionRequest::try_from(OpenAIRequestParams {
1248 model: self.model.to_owned(),
1249 request: completion_request,
1250 strict_tools: self.strict_tools,
1251 tool_result_array_content: self.tool_result_array_content,
1252 })?;
1253
1254 if enabled!(Level::TRACE) {
1255 tracing::trace!(
1256 target: "rig::completions",
1257 "OpenAI Chat Completions completion request: {}",
1258 serde_json::to_string_pretty(&request)?
1259 );
1260 }
1261
1262 let body = serde_json::to_vec(&request)?;
1263
1264 let req = self
1265 .client
1266 .post("/chat/completions")?
1267 .body(body)
1268 .map_err(|e| CompletionError::HttpError(e.into()))?;
1269
1270 async move {
1271 let response = self.client.send(req).await?;
1272
1273 if response.status().is_success() {
1274 let text = http_client::text(response).await?;
1275
1276 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1277 ApiResponse::Ok(response) => {
1278 let span = tracing::Span::current();
1279 span.record_response_metadata(&response);
1280 span.record_token_usage(&response.usage);
1281
1282 if enabled!(Level::TRACE) {
1283 tracing::trace!(
1284 target: "rig::completions",
1285 "OpenAI Chat Completions completion response: {}",
1286 serde_json::to_string_pretty(&response)?
1287 );
1288 }
1289
1290 response.try_into()
1291 }
1292 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1293 }
1294 } else {
1295 let text = http_client::text(response).await?;
1296 Err(CompletionError::ProviderError(text))
1297 }
1298 }
1299 .instrument(span)
1300 .await
1301 }
1302
1303 async fn stream(
1304 &self,
1305 request: CoreCompletionRequest,
1306 ) -> Result<
1307 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1308 CompletionError,
1309 > {
1310 Self::stream(self, request).await
1311 }
1312}
1313
1314fn serialize_assistant_content_vec<S>(
1315 value: &Vec<AssistantContent>,
1316 serializer: S,
1317) -> Result<S::Ok, S::Error>
1318where
1319 S: Serializer,
1320{
1321 if value.is_empty() {
1322 serializer.serialize_str("")
1323 } else {
1324 value.serialize(serializer)
1325 }
1326}
1327
1328#[cfg(test)]
1329mod tests {
1330 use super::*;
1331
1332 #[test]
1333 fn test_openai_request_uses_request_model_override() {
1334 let request = crate::completion::CompletionRequest {
1335 model: Some("gpt-4.1".to_string()),
1336 preamble: None,
1337 chat_history: crate::OneOrMany::one("Hello".into()),
1338 documents: vec![],
1339 tools: vec![],
1340 temperature: None,
1341 max_tokens: None,
1342 tool_choice: None,
1343 additional_params: None,
1344 output_schema: None,
1345 };
1346
1347 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1348 model: "gpt-4o-mini".to_string(),
1349 request,
1350 strict_tools: false,
1351 tool_result_array_content: false,
1352 })
1353 .expect("request conversion should succeed");
1354 let serialized =
1355 serde_json::to_value(openai_request).expect("serialization should succeed");
1356
1357 assert_eq!(serialized["model"], "gpt-4.1");
1358 }
1359
1360 #[test]
1361 fn test_openai_request_uses_default_model_when_override_unset() {
1362 let request = crate::completion::CompletionRequest {
1363 model: None,
1364 preamble: None,
1365 chat_history: crate::OneOrMany::one("Hello".into()),
1366 documents: vec![],
1367 tools: vec![],
1368 temperature: None,
1369 max_tokens: None,
1370 tool_choice: None,
1371 additional_params: None,
1372 output_schema: None,
1373 };
1374
1375 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1376 model: "gpt-4o-mini".to_string(),
1377 request,
1378 strict_tools: false,
1379 tool_result_array_content: false,
1380 })
1381 .expect("request conversion should succeed");
1382 let serialized =
1383 serde_json::to_value(openai_request).expect("serialization should succeed");
1384
1385 assert_eq!(serialized["model"], "gpt-4o-mini");
1386 }
1387
1388 #[test]
1389 fn assistant_reasoning_is_silently_skipped() {
1390 let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1391
1392 let converted: Vec<Message> = assistant_content
1393 .try_into()
1394 .expect("conversion should work");
1395
1396 assert!(converted.is_empty());
1397 }
1398
1399 #[test]
1400 fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1401 let assistant_content = OneOrMany::many(vec![
1402 message::AssistantContent::reasoning("hidden"),
1403 message::AssistantContent::text("visible"),
1404 message::AssistantContent::tool_call(
1405 "call_1",
1406 "subtract",
1407 serde_json::json!({"x": 2, "y": 1}),
1408 ),
1409 ])
1410 .expect("non-empty assistant content");
1411
1412 let converted: Vec<Message> = assistant_content
1413 .try_into()
1414 .expect("conversion should work");
1415 assert_eq!(converted.len(), 1);
1416
1417 match &converted[0] {
1418 Message::Assistant {
1419 content,
1420 tool_calls,
1421 ..
1422 } => {
1423 assert_eq!(
1424 content,
1425 &vec![AssistantContent::Text {
1426 text: "visible".to_string()
1427 }]
1428 );
1429 assert_eq!(tool_calls.len(), 1);
1430 assert_eq!(tool_calls[0].id, "call_1");
1431 assert_eq!(tool_calls[0].function.name, "subtract");
1432 assert_eq!(
1433 tool_calls[0].function.arguments,
1434 serde_json::json!({"x": 2, "y": 1})
1435 );
1436 }
1437 _ => panic!("expected assistant message"),
1438 }
1439 }
1440
1441 #[test]
1442 fn test_max_tokens_is_forwarded_to_request() {
1443 let request = crate::completion::CompletionRequest {
1444 model: None,
1445 preamble: None,
1446 chat_history: crate::OneOrMany::one("Hello".into()),
1447 documents: vec![],
1448 tools: vec![],
1449 temperature: None,
1450 max_tokens: Some(4096),
1451 tool_choice: None,
1452 additional_params: None,
1453 output_schema: None,
1454 };
1455
1456 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1457 model: "gpt-4o-mini".to_string(),
1458 request,
1459 strict_tools: false,
1460 tool_result_array_content: false,
1461 })
1462 .expect("request conversion should succeed");
1463 let serialized =
1464 serde_json::to_value(openai_request).expect("serialization should succeed");
1465
1466 assert_eq!(serialized["max_tokens"], 4096);
1467 }
1468
1469 #[test]
1470 fn test_max_tokens_omitted_when_none() {
1471 let request = crate::completion::CompletionRequest {
1472 model: None,
1473 preamble: None,
1474 chat_history: crate::OneOrMany::one("Hello".into()),
1475 documents: vec![],
1476 tools: vec![],
1477 temperature: None,
1478 max_tokens: None,
1479 tool_choice: None,
1480 additional_params: None,
1481 output_schema: None,
1482 };
1483
1484 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1485 model: "gpt-4o-mini".to_string(),
1486 request,
1487 strict_tools: false,
1488 tool_result_array_content: false,
1489 })
1490 .expect("request conversion should succeed");
1491 let serialized =
1492 serde_json::to_value(openai_request).expect("serialization should succeed");
1493
1494 assert!(serialized.get("max_tokens").is_none());
1495 }
1496
1497 #[test]
1498 fn request_conversion_errors_when_all_messages_are_filtered() {
1499 let request = CoreCompletionRequest {
1500 model: None,
1501 preamble: None,
1502 chat_history: OneOrMany::one(message::Message::Assistant {
1503 id: None,
1504 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1505 }),
1506 documents: vec![],
1507 tools: vec![],
1508 temperature: None,
1509 max_tokens: None,
1510 tool_choice: None,
1511 additional_params: None,
1512 output_schema: None,
1513 };
1514
1515 let result = CompletionRequest::try_from(OpenAIRequestParams {
1516 model: "gpt-4o-mini".to_string(),
1517 request,
1518 strict_tools: false,
1519 tool_result_array_content: false,
1520 });
1521
1522 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1523 }
1524}