1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28fn serialize_user_content<S>(
31 content: &OneOrMany<UserContent>,
32 serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35 S: Serializer,
36{
37 if content.len() == 1
38 && let UserContent::Text { text } = content.first_ref()
39 {
40 return serializer.serialize_str(text);
41 }
42 content.serialize(serializer)
43}
44
45pub const GPT_5_2: &str = "gpt-5.2";
47
48pub const GPT_5_1: &str = "gpt-5.1";
50
51pub const GPT_5: &str = "gpt-5";
53pub const GPT_5_MINI: &str = "gpt-5-mini";
55pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64pub const GPT_4O: &str = "gpt-4o";
66pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84pub const GPT_4: &str = "gpt-4";
86pub const GPT_4_0613: &str = "gpt-4-0613";
88pub const GPT_4_32K: &str = "gpt-4-32k";
90pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95pub const O4_MINI: &str = "o4-mini";
97pub const O3: &str = "o3";
99pub const O3_MINI: &str = "o3-mini";
101pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103pub const O1_PRO: &str = "o1-pro";
105pub const O1: &str = "o1";
107pub const O1_2024_12_17: &str = "o1-2024-12-17";
109pub const O1_PREVIEW: &str = "o1-preview";
111pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113pub const O1_MINI: &str = "o1-mini";
115pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128 fn from(err: ApiErrorResponse) -> Self {
129 CompletionError::ProviderError(err.message)
130 }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136 #[serde(alias = "developer")]
137 System {
138 #[serde(deserialize_with = "string_or_one_or_many")]
139 content: OneOrMany<SystemContent>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 name: Option<String>,
142 },
143 User {
144 #[serde(
145 deserialize_with = "string_or_one_or_many",
146 serialize_with = "serialize_user_content"
147 )]
148 content: OneOrMany<UserContent>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 name: Option<String>,
151 },
152 Assistant {
153 #[serde(
154 default,
155 deserialize_with = "json_utils::string_or_vec",
156 skip_serializing_if = "Vec::is_empty",
157 serialize_with = "serialize_assistant_content_vec"
158 )]
159 content: Vec<AssistantContent>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 refusal: Option<String>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 audio: Option<AudioAssistant>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 name: Option<String>,
166 #[serde(
167 default,
168 deserialize_with = "json_utils::null_or_vec",
169 skip_serializing_if = "Vec::is_empty"
170 )]
171 tool_calls: Vec<ToolCall>,
172 },
173 #[serde(rename = "tool")]
174 ToolResult {
175 tool_call_id: String,
176 content: ToolResultContentValue,
177 },
178}
179
180impl Message {
181 pub fn system(content: &str) -> Self {
182 Message::System {
183 content: OneOrMany::one(content.to_owned().into()),
184 name: None,
185 }
186 }
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190pub struct AudioAssistant {
191 pub id: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct SystemContent {
196 #[serde(default)]
197 pub r#type: SystemContentType,
198 pub text: String,
199}
200
201#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
202#[serde(rename_all = "lowercase")]
203pub enum SystemContentType {
204 #[default]
205 Text,
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum AssistantContent {
211 Text { text: String },
212 Refusal { refusal: String },
213}
214
215impl From<AssistantContent> for completion::AssistantContent {
216 fn from(value: AssistantContent) -> Self {
217 match value {
218 AssistantContent::Text { text } => completion::AssistantContent::text(text),
219 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
220 }
221 }
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum UserContent {
227 Text {
228 text: String,
229 },
230 #[serde(rename = "image_url")]
231 Image {
232 image_url: ImageUrl,
233 },
234 Audio {
235 input_audio: InputAudio,
236 },
237}
238
239#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
240pub struct ImageUrl {
241 pub url: String,
242 #[serde(default)]
243 pub detail: ImageDetail,
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct InputAudio {
248 pub data: String,
249 pub format: AudioMediaType,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct ToolResultContent {
254 #[serde(default)]
255 r#type: ToolResultContentType,
256 pub text: String,
257}
258
259#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolResultContentType {
262 #[default]
263 Text,
264}
265
266impl FromStr for ToolResultContent {
267 type Err = Infallible;
268
269 fn from_str(s: &str) -> Result<Self, Self::Err> {
270 Ok(s.to_owned().into())
271 }
272}
273
274impl From<String> for ToolResultContent {
275 fn from(s: String) -> Self {
276 ToolResultContent {
277 r#type: ToolResultContentType::default(),
278 text: s,
279 }
280 }
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
284#[serde(untagged)]
285pub enum ToolResultContentValue {
286 Array(Vec<ToolResultContent>),
287 String(String),
288}
289
290impl ToolResultContentValue {
291 pub fn from_string(s: String, use_array_format: bool) -> Self {
292 if use_array_format {
293 ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
294 } else {
295 ToolResultContentValue::String(s)
296 }
297 }
298
299 pub fn as_text(&self) -> String {
300 match self {
301 ToolResultContentValue::Array(arr) => arr
302 .iter()
303 .map(|c| c.text.clone())
304 .collect::<Vec<_>>()
305 .join("\n"),
306 ToolResultContentValue::String(s) => s.clone(),
307 }
308 }
309
310 pub fn to_array(&self) -> Self {
311 match self {
312 ToolResultContentValue::Array(_) => self.clone(),
313 ToolResultContentValue::String(s) => {
314 ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
315 }
316 }
317 }
318}
319
320#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
321pub struct ToolCall {
322 pub id: String,
323 #[serde(default)]
324 pub r#type: ToolType,
325 pub function: Function,
326}
327
328#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
329#[serde(rename_all = "lowercase")]
330pub enum ToolType {
331 #[default]
332 Function,
333}
334
335#[derive(Debug, Deserialize, Serialize, Clone)]
337pub struct FunctionDefinition {
338 pub name: String,
339 pub description: String,
340 pub parameters: serde_json::Value,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub strict: Option<bool>,
343}
344
345#[derive(Debug, Deserialize, Serialize, Clone)]
346pub struct ToolDefinition {
347 pub r#type: String,
348 pub function: FunctionDefinition,
349}
350
351impl From<completion::ToolDefinition> for ToolDefinition {
352 fn from(tool: completion::ToolDefinition) -> Self {
353 Self {
354 r#type: "function".into(),
355 function: FunctionDefinition {
356 name: tool.name,
357 description: tool.description,
358 parameters: tool.parameters,
359 strict: None,
360 },
361 }
362 }
363}
364
365impl ToolDefinition {
366 pub fn with_strict(mut self) -> Self {
369 self.function.strict = Some(true);
370 super::sanitize_schema(&mut self.function.parameters);
371 self
372 }
373}
374
375#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
376#[serde(rename_all = "snake_case")]
377pub enum ToolChoice {
378 #[default]
379 Auto,
380 None,
381 Required,
382}
383
384impl TryFrom<crate::message::ToolChoice> for ToolChoice {
385 type Error = CompletionError;
386 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
387 let res = match value {
388 message::ToolChoice::Specific { .. } => {
389 return Err(CompletionError::ProviderError(
390 "Provider doesn't support only using specific tools".to_string(),
391 ));
392 }
393 message::ToolChoice::Auto => Self::Auto,
394 message::ToolChoice::None => Self::None,
395 message::ToolChoice::Required => Self::Required,
396 };
397
398 Ok(res)
399 }
400}
401
402#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
403pub struct Function {
404 pub name: String,
405 #[serde(with = "json_utils::stringified_json")]
406 pub arguments: serde_json::Value,
407}
408
409impl TryFrom<message::ToolResult> for Message {
410 type Error = message::MessageError;
411
412 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
413 let text = value
414 .content
415 .into_iter()
416 .map(|content| {
417 match content {
418 message::ToolResultContent::Text(message::Text { text }) => Ok(text),
419 message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
420 "OpenAI does not support images in tool results. Tool results must be text."
421 .into(),
422 )),
423 }
424 })
425 .collect::<Result<Vec<_>, _>>()?
426 .join("\n");
427
428 Ok(Message::ToolResult {
429 tool_call_id: value.id,
430 content: ToolResultContentValue::String(text),
431 })
432 }
433}
434
435impl TryFrom<message::UserContent> for UserContent {
436 type Error = message::MessageError;
437
438 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
439 match value {
440 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
441 message::UserContent::Image(message::Image {
442 data,
443 detail,
444 media_type,
445 ..
446 }) => match data {
447 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
448 image_url: ImageUrl {
449 url,
450 detail: detail.unwrap_or_default(),
451 },
452 }),
453 DocumentSourceKind::Base64(data) => {
454 let url = format!(
455 "data:{};base64,{}",
456 media_type.map(|i| i.to_mime_type()).ok_or(
457 message::MessageError::ConversionError(
458 "OpenAI Image URI must have media type".into()
459 )
460 )?,
461 data
462 );
463
464 let detail = detail.ok_or(message::MessageError::ConversionError(
465 "OpenAI image URI must have image detail".into(),
466 ))?;
467
468 Ok(UserContent::Image {
469 image_url: ImageUrl { url, detail },
470 })
471 }
472 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
473 "Raw files not supported, encode as base64 first".into(),
474 )),
475 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
476 "Document has no body".into(),
477 )),
478 doc => Err(message::MessageError::ConversionError(format!(
479 "Unsupported document type: {doc:?}"
480 ))),
481 },
482 message::UserContent::Document(message::Document { data, .. }) => {
483 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
484 Ok(UserContent::Text { text })
485 } else {
486 Err(message::MessageError::ConversionError(
487 "Documents must be base64 or a string".into(),
488 ))
489 }
490 }
491 message::UserContent::Audio(message::Audio {
492 data, media_type, ..
493 }) => match data {
494 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
495 input_audio: InputAudio {
496 data,
497 format: match media_type {
498 Some(media_type) => media_type,
499 None => AudioMediaType::MP3,
500 },
501 },
502 }),
503 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
504 "URLs are not supported for audio".into(),
505 )),
506 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
507 "Raw files are not supported for audio".into(),
508 )),
509 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
510 "Audio has no body".into(),
511 )),
512 audio => Err(message::MessageError::ConversionError(format!(
513 "Unsupported audio type: {audio:?}"
514 ))),
515 },
516 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
517 "Tool result is in unsupported format".into(),
518 )),
519 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
520 "Video is in unsupported format".into(),
521 )),
522 }
523 }
524}
525
526impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
527 type Error = message::MessageError;
528
529 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
530 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
531 .into_iter()
532 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
533
534 if !tool_results.is_empty() {
537 tool_results
538 .into_iter()
539 .map(|content| match content {
540 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
541 _ => unreachable!(),
542 })
543 .collect::<Result<Vec<_>, _>>()
544 } else {
545 let other_content: Vec<UserContent> = other_content
546 .into_iter()
547 .map(|content| content.try_into())
548 .collect::<Result<Vec<_>, _>>()?;
549
550 let other_content = OneOrMany::many(other_content)
551 .expect("There must be other content here if there were no tool result content");
552
553 Ok(vec![Message::User {
554 content: other_content,
555 name: None,
556 }])
557 }
558 }
559}
560
561impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
562 type Error = message::MessageError;
563
564 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
565 let mut text_content = Vec::new();
566 let mut tool_calls = Vec::new();
567
568 for content in value {
569 match content {
570 message::AssistantContent::Text(text) => text_content.push(text),
571 message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
572 message::AssistantContent::Reasoning(_) => {
573 }
576 message::AssistantContent::Image(_) => {
577 panic!(
578 "The OpenAI Completions API doesn't support image content in assistant messages!"
579 );
580 }
581 }
582 }
583
584 if text_content.is_empty() && tool_calls.is_empty() {
585 return Ok(vec![]);
586 }
587
588 Ok(vec![Message::Assistant {
589 content: text_content
590 .into_iter()
591 .map(|content| content.text.into())
592 .collect::<Vec<_>>(),
593 refusal: None,
594 audio: None,
595 name: None,
596 tool_calls: tool_calls
597 .into_iter()
598 .map(|tool_call| tool_call.into())
599 .collect::<Vec<_>>(),
600 }])
601 }
602}
603
604impl TryFrom<message::Message> for Vec<Message> {
605 type Error = message::MessageError;
606
607 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
608 match message {
609 message::Message::System { content } => Ok(vec![Message::system(&content)]),
610 message::Message::User { content } => content.try_into(),
611 message::Message::Assistant { content, .. } => content.try_into(),
612 }
613 }
614}
615
616impl From<message::ToolCall> for ToolCall {
617 fn from(tool_call: message::ToolCall) -> Self {
618 Self {
619 id: tool_call.id,
620 r#type: ToolType::default(),
621 function: Function {
622 name: tool_call.function.name,
623 arguments: tool_call.function.arguments,
624 },
625 }
626 }
627}
628
629impl From<ToolCall> for message::ToolCall {
630 fn from(tool_call: ToolCall) -> Self {
631 Self {
632 id: tool_call.id,
633 call_id: None,
634 function: message::ToolFunction {
635 name: tool_call.function.name,
636 arguments: tool_call.function.arguments,
637 },
638 signature: None,
639 additional_params: None,
640 }
641 }
642}
643
644impl TryFrom<Message> for message::Message {
645 type Error = message::MessageError;
646
647 fn try_from(message: Message) -> Result<Self, Self::Error> {
648 Ok(match message {
649 Message::User { content, .. } => message::Message::User {
650 content: content.map(|content| content.into()),
651 },
652 Message::Assistant {
653 content,
654 tool_calls,
655 ..
656 } => {
657 let mut content = content
658 .into_iter()
659 .map(|content| match content {
660 AssistantContent::Text { text } => message::AssistantContent::text(text),
661
662 AssistantContent::Refusal { refusal } => {
665 message::AssistantContent::text(refusal)
666 }
667 })
668 .collect::<Vec<_>>();
669
670 content.extend(
671 tool_calls
672 .into_iter()
673 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
674 .collect::<Result<Vec<_>, _>>()?,
675 );
676
677 message::Message::Assistant {
678 id: None,
679 content: OneOrMany::many(content).map_err(|_| {
680 message::MessageError::ConversionError(
681 "Neither `content` nor `tool_calls` was provided to the Message"
682 .to_owned(),
683 )
684 })?,
685 }
686 }
687
688 Message::ToolResult {
689 tool_call_id,
690 content,
691 } => message::Message::User {
692 content: OneOrMany::one(message::UserContent::tool_result(
693 tool_call_id,
694 OneOrMany::one(message::ToolResultContent::text(content.as_text())),
695 )),
696 },
697
698 Message::System { content, .. } => message::Message::User {
701 content: content.map(|content| message::UserContent::text(content.text)),
702 },
703 })
704 }
705}
706
707impl From<UserContent> for message::UserContent {
708 fn from(content: UserContent) -> Self {
709 match content {
710 UserContent::Text { text } => message::UserContent::text(text),
711 UserContent::Image { image_url } => {
712 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
713 }
714 UserContent::Audio { input_audio } => {
715 message::UserContent::audio(input_audio.data, Some(input_audio.format))
716 }
717 }
718 }
719}
720
721impl From<String> for UserContent {
722 fn from(s: String) -> Self {
723 UserContent::Text { text: s }
724 }
725}
726
727impl FromStr for UserContent {
728 type Err = Infallible;
729
730 fn from_str(s: &str) -> Result<Self, Self::Err> {
731 Ok(UserContent::Text {
732 text: s.to_string(),
733 })
734 }
735}
736
737impl From<String> for AssistantContent {
738 fn from(s: String) -> Self {
739 AssistantContent::Text { text: s }
740 }
741}
742
743impl FromStr for AssistantContent {
744 type Err = Infallible;
745
746 fn from_str(s: &str) -> Result<Self, Self::Err> {
747 Ok(AssistantContent::Text {
748 text: s.to_string(),
749 })
750 }
751}
752impl From<String> for SystemContent {
753 fn from(s: String) -> Self {
754 SystemContent {
755 r#type: SystemContentType::default(),
756 text: s,
757 }
758 }
759}
760
761impl FromStr for SystemContent {
762 type Err = Infallible;
763
764 fn from_str(s: &str) -> Result<Self, Self::Err> {
765 Ok(SystemContent {
766 r#type: SystemContentType::default(),
767 text: s.to_string(),
768 })
769 }
770}
771
772#[derive(Debug, Deserialize, Serialize)]
773pub struct CompletionResponse {
774 pub id: String,
775 pub object: String,
776 pub created: u64,
777 pub model: String,
778 pub system_fingerprint: Option<String>,
779 pub choices: Vec<Choice>,
780 pub usage: Option<Usage>,
781}
782
783impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
784 type Error = CompletionError;
785
786 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
787 let choice = response.choices.first().ok_or_else(|| {
788 CompletionError::ResponseError("Response contained no choices".to_owned())
789 })?;
790
791 let content = match &choice.message {
792 Message::Assistant {
793 content,
794 tool_calls,
795 ..
796 } => {
797 let mut content = content
798 .iter()
799 .filter_map(|c| {
800 let s = match c {
801 AssistantContent::Text { text } => text,
802 AssistantContent::Refusal { refusal } => refusal,
803 };
804 if s.is_empty() {
805 None
806 } else {
807 Some(completion::AssistantContent::text(s))
808 }
809 })
810 .collect::<Vec<_>>();
811
812 content.extend(
813 tool_calls
814 .iter()
815 .map(|call| {
816 completion::AssistantContent::tool_call(
817 &call.id,
818 &call.function.name,
819 call.function.arguments.clone(),
820 )
821 })
822 .collect::<Vec<_>>(),
823 );
824 Ok(content)
825 }
826 _ => Err(CompletionError::ResponseError(
827 "Response did not contain a valid message or tool call".into(),
828 )),
829 }?;
830
831 let choice = OneOrMany::many(content).map_err(|_| {
832 CompletionError::ResponseError(
833 "Response contained no message or tool call (empty)".to_owned(),
834 )
835 })?;
836
837 let usage = response
838 .usage
839 .as_ref()
840 .map(|usage| completion::Usage {
841 input_tokens: usage.prompt_tokens as u64,
842 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
843 total_tokens: usage.total_tokens as u64,
844 cached_input_tokens: usage
845 .prompt_tokens_details
846 .as_ref()
847 .map(|d| d.cached_tokens as u64)
848 .unwrap_or(0),
849 })
850 .unwrap_or_default();
851
852 Ok(completion::CompletionResponse {
853 choice,
854 usage,
855 raw_response: response,
856 message_id: None,
857 })
858 }
859}
860
861impl ProviderResponseExt for CompletionResponse {
862 type OutputMessage = Choice;
863 type Usage = Usage;
864
865 fn get_response_id(&self) -> Option<String> {
866 Some(self.id.to_owned())
867 }
868
869 fn get_response_model_name(&self) -> Option<String> {
870 Some(self.model.to_owned())
871 }
872
873 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
874 self.choices.clone()
875 }
876
877 fn get_text_response(&self) -> Option<String> {
878 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
879 return None;
880 };
881
882 let UserContent::Text { text } = content.first() else {
883 return None;
884 };
885
886 Some(text)
887 }
888
889 fn get_usage(&self) -> Option<Self::Usage> {
890 self.usage.clone()
891 }
892}
893
894#[derive(Clone, Debug, Serialize, Deserialize)]
895pub struct Choice {
896 pub index: usize,
897 pub message: Message,
898 pub logprobs: Option<serde_json::Value>,
899 pub finish_reason: String,
900}
901
902#[derive(Clone, Debug, Deserialize, Serialize, Default)]
903pub struct PromptTokensDetails {
904 #[serde(default)]
906 pub cached_tokens: usize,
907}
908
909#[derive(Clone, Debug, Deserialize, Serialize)]
910pub struct Usage {
911 pub prompt_tokens: usize,
912 pub total_tokens: usize,
913 #[serde(skip_serializing_if = "Option::is_none")]
914 pub prompt_tokens_details: Option<PromptTokensDetails>,
915}
916
917impl Usage {
918 pub fn new() -> Self {
919 Self {
920 prompt_tokens: 0,
921 total_tokens: 0,
922 prompt_tokens_details: None,
923 }
924 }
925}
926
927impl Default for Usage {
928 fn default() -> Self {
929 Self::new()
930 }
931}
932
933impl fmt::Display for Usage {
934 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
935 let Usage {
936 prompt_tokens,
937 total_tokens,
938 ..
939 } = self;
940 write!(
941 f,
942 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
943 )
944 }
945}
946
947impl GetTokenUsage for Usage {
948 fn token_usage(&self) -> Option<crate::completion::Usage> {
949 let mut usage = crate::completion::Usage::new();
950 usage.input_tokens = self.prompt_tokens as u64;
951 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
952 usage.total_tokens = self.total_tokens as u64;
953 usage.cached_input_tokens = self
954 .prompt_tokens_details
955 .as_ref()
956 .map(|d| d.cached_tokens as u64)
957 .unwrap_or(0);
958
959 Some(usage)
960 }
961}
962
963#[derive(Clone)]
964pub struct CompletionModel<T = reqwest::Client> {
965 pub(crate) client: Client<T>,
966 pub model: String,
967 pub strict_tools: bool,
968 pub tool_result_array_content: bool,
969}
970
971impl<T> CompletionModel<T>
972where
973 T: Default + std::fmt::Debug + Clone + 'static,
974{
975 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
976 Self {
977 client,
978 model: model.into(),
979 strict_tools: false,
980 tool_result_array_content: false,
981 }
982 }
983
984 pub fn with_model(client: Client<T>, model: &str) -> Self {
985 Self {
986 client,
987 model: model.into(),
988 strict_tools: false,
989 tool_result_array_content: false,
990 }
991 }
992
993 pub fn with_strict_tools(mut self) -> Self {
1002 self.strict_tools = true;
1003 self
1004 }
1005
1006 pub fn with_tool_result_array_content(mut self) -> Self {
1007 self.tool_result_array_content = true;
1008 self
1009 }
1010}
1011
1012#[derive(Debug, Serialize, Deserialize, Clone)]
1013pub struct CompletionRequest {
1014 model: String,
1015 messages: Vec<Message>,
1016 #[serde(skip_serializing_if = "Vec::is_empty")]
1017 tools: Vec<ToolDefinition>,
1018 #[serde(skip_serializing_if = "Option::is_none")]
1019 tool_choice: Option<ToolChoice>,
1020 #[serde(skip_serializing_if = "Option::is_none")]
1021 temperature: Option<f64>,
1022 #[serde(skip_serializing_if = "Option::is_none")]
1023 max_tokens: Option<u64>,
1024 #[serde(flatten)]
1025 additional_params: Option<serde_json::Value>,
1026}
1027
1028pub struct OpenAIRequestParams {
1029 pub model: String,
1030 pub request: CoreCompletionRequest,
1031 pub strict_tools: bool,
1032 pub tool_result_array_content: bool,
1033}
1034
1035impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1036 type Error = CompletionError;
1037
1038 fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1039 let OpenAIRequestParams {
1040 model,
1041 request: req,
1042 strict_tools,
1043 tool_result_array_content,
1044 } = params;
1045
1046 let mut partial_history = vec![];
1047 if let Some(docs) = req.normalized_documents() {
1048 partial_history.push(docs);
1049 }
1050 let CoreCompletionRequest {
1051 model: request_model,
1052 preamble,
1053 chat_history,
1054 tools,
1055 temperature,
1056 max_tokens,
1057 additional_params,
1058 tool_choice,
1059 output_schema,
1060 ..
1061 } = req;
1062
1063 partial_history.extend(chat_history);
1064
1065 let mut full_history: Vec<Message> =
1066 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1067
1068 full_history.extend(
1069 partial_history
1070 .into_iter()
1071 .map(message::Message::try_into)
1072 .collect::<Result<Vec<Vec<Message>>, _>>()?
1073 .into_iter()
1074 .flatten()
1075 .collect::<Vec<_>>(),
1076 );
1077
1078 if full_history.is_empty() {
1079 return Err(CompletionError::RequestError(
1080 std::io::Error::new(
1081 std::io::ErrorKind::InvalidInput,
1082 "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1083 )
1084 .into(),
1085 ));
1086 }
1087
1088 if tool_result_array_content {
1089 for msg in &mut full_history {
1090 if let Message::ToolResult { content, .. } = msg {
1091 *content = content.to_array();
1092 }
1093 }
1094 }
1095
1096 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1097
1098 let tools: Vec<ToolDefinition> = tools
1099 .into_iter()
1100 .map(|tool| {
1101 let def = ToolDefinition::from(tool);
1102 if strict_tools { def.with_strict() } else { def }
1103 })
1104 .collect();
1105
1106 let additional_params = if let Some(schema) = output_schema {
1108 let name = schema
1109 .as_object()
1110 .and_then(|o| o.get("title"))
1111 .and_then(|v| v.as_str())
1112 .unwrap_or("response_schema")
1113 .to_string();
1114 let mut schema_value = schema.to_value();
1115 super::sanitize_schema(&mut schema_value);
1116 let response_format = serde_json::json!({
1117 "response_format": {
1118 "type": "json_schema",
1119 "json_schema": {
1120 "name": name,
1121 "strict": true,
1122 "schema": schema_value
1123 }
1124 }
1125 });
1126 Some(match additional_params {
1127 Some(existing) => json_utils::merge(existing, response_format),
1128 None => response_format,
1129 })
1130 } else {
1131 additional_params
1132 };
1133
1134 let res = Self {
1135 model: request_model.unwrap_or(model),
1136 messages: full_history,
1137 tools,
1138 tool_choice,
1139 temperature,
1140 max_tokens,
1141 additional_params,
1142 };
1143
1144 Ok(res)
1145 }
1146}
1147
1148impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1149 type Error = CompletionError;
1150
1151 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1152 CompletionRequest::try_from(OpenAIRequestParams {
1153 model,
1154 request: req,
1155 strict_tools: false,
1156 tool_result_array_content: false,
1157 })
1158 }
1159}
1160
1161impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1162 type InputMessage = Message;
1163
1164 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1165 self.messages.clone()
1166 }
1167
1168 fn get_system_prompt(&self) -> Option<String> {
1169 let first_message = self.messages.first()?;
1170
1171 let Message::System { ref content, .. } = first_message.clone() else {
1172 return None;
1173 };
1174
1175 let SystemContent { text, .. } = content.first();
1176
1177 Some(text)
1178 }
1179
1180 fn get_prompt(&self) -> Option<String> {
1181 let last_message = self.messages.last()?;
1182
1183 let Message::User { ref content, .. } = last_message.clone() else {
1184 return None;
1185 };
1186
1187 let UserContent::Text { text } = content.first() else {
1188 return None;
1189 };
1190
1191 Some(text)
1192 }
1193
1194 fn get_model_name(&self) -> String {
1195 self.model.clone()
1196 }
1197}
1198
1199impl CompletionModel<reqwest::Client> {
1200 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1201 crate::agent::AgentBuilder::new(self)
1202 }
1203}
1204
1205impl<T> completion::CompletionModel for CompletionModel<T>
1206where
1207 T: HttpClientExt
1208 + Default
1209 + std::fmt::Debug
1210 + Clone
1211 + WasmCompatSend
1212 + WasmCompatSync
1213 + 'static,
1214{
1215 type Response = CompletionResponse;
1216 type StreamingResponse = StreamingCompletionResponse;
1217
1218 type Client = super::CompletionsClient<T>;
1219
1220 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1221 Self::new(client.clone(), model)
1222 }
1223
1224 async fn completion(
1225 &self,
1226 completion_request: CoreCompletionRequest,
1227 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1228 let span = if tracing::Span::current().is_disabled() {
1229 info_span!(
1230 target: "rig::completions",
1231 "chat",
1232 gen_ai.operation.name = "chat",
1233 gen_ai.provider.name = "openai",
1234 gen_ai.request.model = self.model,
1235 gen_ai.system_instructions = &completion_request.preamble,
1236 gen_ai.response.id = tracing::field::Empty,
1237 gen_ai.response.model = tracing::field::Empty,
1238 gen_ai.usage.output_tokens = tracing::field::Empty,
1239 gen_ai.usage.input_tokens = tracing::field::Empty,
1240 gen_ai.usage.cached_tokens = tracing::field::Empty,
1241 )
1242 } else {
1243 tracing::Span::current()
1244 };
1245
1246 let request = CompletionRequest::try_from(OpenAIRequestParams {
1247 model: self.model.to_owned(),
1248 request: completion_request,
1249 strict_tools: self.strict_tools,
1250 tool_result_array_content: self.tool_result_array_content,
1251 })?;
1252
1253 if enabled!(Level::TRACE) {
1254 tracing::trace!(
1255 target: "rig::completions",
1256 "OpenAI Chat Completions completion request: {}",
1257 serde_json::to_string_pretty(&request)?
1258 );
1259 }
1260
1261 let body = serde_json::to_vec(&request)?;
1262
1263 let req = self
1264 .client
1265 .post("/chat/completions")?
1266 .body(body)
1267 .map_err(|e| CompletionError::HttpError(e.into()))?;
1268
1269 async move {
1270 let response = self.client.send(req).await?;
1271
1272 if response.status().is_success() {
1273 let text = http_client::text(response).await?;
1274
1275 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1276 ApiResponse::Ok(response) => {
1277 let span = tracing::Span::current();
1278 span.record_response_metadata(&response);
1279 span.record_token_usage(&response.usage);
1280
1281 if enabled!(Level::TRACE) {
1282 tracing::trace!(
1283 target: "rig::completions",
1284 "OpenAI Chat Completions completion response: {}",
1285 serde_json::to_string_pretty(&response)?
1286 );
1287 }
1288
1289 response.try_into()
1290 }
1291 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1292 }
1293 } else {
1294 let text = http_client::text(response).await?;
1295 Err(CompletionError::ProviderError(text))
1296 }
1297 }
1298 .instrument(span)
1299 .await
1300 }
1301
1302 async fn stream(
1303 &self,
1304 request: CoreCompletionRequest,
1305 ) -> Result<
1306 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1307 CompletionError,
1308 > {
1309 Self::stream(self, request).await
1310 }
1311}
1312
1313fn serialize_assistant_content_vec<S>(
1314 value: &Vec<AssistantContent>,
1315 serializer: S,
1316) -> Result<S::Ok, S::Error>
1317where
1318 S: Serializer,
1319{
1320 if value.is_empty() {
1321 serializer.serialize_str("")
1322 } else {
1323 value.serialize(serializer)
1324 }
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329 use super::*;
1330
1331 #[test]
1332 fn test_openai_request_uses_request_model_override() {
1333 let request = crate::completion::CompletionRequest {
1334 model: Some("gpt-4.1".to_string()),
1335 preamble: None,
1336 chat_history: crate::OneOrMany::one("Hello".into()),
1337 documents: vec![],
1338 tools: vec![],
1339 temperature: None,
1340 max_tokens: None,
1341 tool_choice: None,
1342 additional_params: None,
1343 output_schema: None,
1344 };
1345
1346 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1347 model: "gpt-4o-mini".to_string(),
1348 request,
1349 strict_tools: false,
1350 tool_result_array_content: false,
1351 })
1352 .expect("request conversion should succeed");
1353 let serialized =
1354 serde_json::to_value(openai_request).expect("serialization should succeed");
1355
1356 assert_eq!(serialized["model"], "gpt-4.1");
1357 }
1358
1359 #[test]
1360 fn test_openai_request_uses_default_model_when_override_unset() {
1361 let request = crate::completion::CompletionRequest {
1362 model: None,
1363 preamble: None,
1364 chat_history: crate::OneOrMany::one("Hello".into()),
1365 documents: vec![],
1366 tools: vec![],
1367 temperature: None,
1368 max_tokens: None,
1369 tool_choice: None,
1370 additional_params: None,
1371 output_schema: None,
1372 };
1373
1374 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1375 model: "gpt-4o-mini".to_string(),
1376 request,
1377 strict_tools: false,
1378 tool_result_array_content: false,
1379 })
1380 .expect("request conversion should succeed");
1381 let serialized =
1382 serde_json::to_value(openai_request).expect("serialization should succeed");
1383
1384 assert_eq!(serialized["model"], "gpt-4o-mini");
1385 }
1386
1387 #[test]
1388 fn assistant_reasoning_is_silently_skipped() {
1389 let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1390
1391 let converted: Vec<Message> = assistant_content
1392 .try_into()
1393 .expect("conversion should work");
1394
1395 assert!(converted.is_empty());
1396 }
1397
1398 #[test]
1399 fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1400 let assistant_content = OneOrMany::many(vec![
1401 message::AssistantContent::reasoning("hidden"),
1402 message::AssistantContent::text("visible"),
1403 message::AssistantContent::tool_call(
1404 "call_1",
1405 "subtract",
1406 serde_json::json!({"x": 2, "y": 1}),
1407 ),
1408 ])
1409 .expect("non-empty assistant content");
1410
1411 let converted: Vec<Message> = assistant_content
1412 .try_into()
1413 .expect("conversion should work");
1414 assert_eq!(converted.len(), 1);
1415
1416 match &converted[0] {
1417 Message::Assistant {
1418 content,
1419 tool_calls,
1420 ..
1421 } => {
1422 assert_eq!(
1423 content,
1424 &vec![AssistantContent::Text {
1425 text: "visible".to_string()
1426 }]
1427 );
1428 assert_eq!(tool_calls.len(), 1);
1429 assert_eq!(tool_calls[0].id, "call_1");
1430 assert_eq!(tool_calls[0].function.name, "subtract");
1431 assert_eq!(
1432 tool_calls[0].function.arguments,
1433 serde_json::json!({"x": 2, "y": 1})
1434 );
1435 }
1436 _ => panic!("expected assistant message"),
1437 }
1438 }
1439
1440 #[test]
1441 fn test_max_tokens_is_forwarded_to_request() {
1442 let request = crate::completion::CompletionRequest {
1443 model: None,
1444 preamble: None,
1445 chat_history: crate::OneOrMany::one("Hello".into()),
1446 documents: vec![],
1447 tools: vec![],
1448 temperature: None,
1449 max_tokens: Some(4096),
1450 tool_choice: None,
1451 additional_params: None,
1452 output_schema: None,
1453 };
1454
1455 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1456 model: "gpt-4o-mini".to_string(),
1457 request,
1458 strict_tools: false,
1459 tool_result_array_content: false,
1460 })
1461 .expect("request conversion should succeed");
1462 let serialized =
1463 serde_json::to_value(openai_request).expect("serialization should succeed");
1464
1465 assert_eq!(serialized["max_tokens"], 4096);
1466 }
1467
1468 #[test]
1469 fn test_max_tokens_omitted_when_none() {
1470 let request = crate::completion::CompletionRequest {
1471 model: None,
1472 preamble: None,
1473 chat_history: crate::OneOrMany::one("Hello".into()),
1474 documents: vec![],
1475 tools: vec![],
1476 temperature: None,
1477 max_tokens: None,
1478 tool_choice: None,
1479 additional_params: None,
1480 output_schema: None,
1481 };
1482
1483 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1484 model: "gpt-4o-mini".to_string(),
1485 request,
1486 strict_tools: false,
1487 tool_result_array_content: false,
1488 })
1489 .expect("request conversion should succeed");
1490 let serialized =
1491 serde_json::to_value(openai_request).expect("serialization should succeed");
1492
1493 assert!(serialized.get("max_tokens").is_none());
1494 }
1495
1496 #[test]
1497 fn request_conversion_errors_when_all_messages_are_filtered() {
1498 let request = CoreCompletionRequest {
1499 model: None,
1500 preamble: None,
1501 chat_history: OneOrMany::one(message::Message::Assistant {
1502 id: None,
1503 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1504 }),
1505 documents: vec![],
1506 tools: vec![],
1507 temperature: None,
1508 max_tokens: None,
1509 tool_choice: None,
1510 additional_params: None,
1511 output_schema: None,
1512 };
1513
1514 let result = CompletionRequest::try_from(OpenAIRequestParams {
1515 model: "gpt-4o-mini".to_string(),
1516 request,
1517 strict_tools: false,
1518 tool_result_array_content: false,
1519 });
1520
1521 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1522 }
1523}