1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28fn serialize_user_content<S>(
31 content: &OneOrMany<UserContent>,
32 serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35 S: Serializer,
36{
37 if content.len() == 1
38 && let UserContent::Text { text } = content.first_ref()
39 {
40 return serializer.serialize_str(text);
41 }
42 content.serialize(serializer)
43}
44
45pub const GPT_5_2: &str = "gpt-5.2";
47
48pub const GPT_5_1: &str = "gpt-5.1";
50
51pub const GPT_5: &str = "gpt-5";
53pub const GPT_5_MINI: &str = "gpt-5-mini";
55pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64pub const GPT_4O: &str = "gpt-4o";
66pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84pub const GPT_4: &str = "gpt-4";
86pub const GPT_4_0613: &str = "gpt-4-0613";
88pub const GPT_4_32K: &str = "gpt-4-32k";
90pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95pub const O4_MINI: &str = "o4-mini";
97pub const O3: &str = "o3";
99pub const O3_MINI: &str = "o3-mini";
101pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103pub const O1_PRO: &str = "o1-pro";
105pub const O1: &str = "o1";
107pub const O1_2024_12_17: &str = "o1-2024-12-17";
109pub const O1_PREVIEW: &str = "o1-preview";
111pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113pub const O1_MINI: &str = "o1-mini";
115pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128 fn from(err: ApiErrorResponse) -> Self {
129 CompletionError::ProviderError(err.message)
130 }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136 #[serde(alias = "developer")]
137 System {
138 #[serde(deserialize_with = "string_or_one_or_many")]
139 content: OneOrMany<SystemContent>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 name: Option<String>,
142 },
143 User {
144 #[serde(
145 deserialize_with = "string_or_one_or_many",
146 serialize_with = "serialize_user_content"
147 )]
148 content: OneOrMany<UserContent>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 name: Option<String>,
151 },
152 Assistant {
153 #[serde(
154 default,
155 deserialize_with = "json_utils::string_or_vec",
156 skip_serializing_if = "Vec::is_empty",
157 serialize_with = "serialize_assistant_content_vec"
158 )]
159 content: Vec<AssistantContent>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 refusal: Option<String>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 audio: Option<AudioAssistant>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 name: Option<String>,
166 #[serde(
167 default,
168 deserialize_with = "json_utils::null_or_vec",
169 skip_serializing_if = "Vec::is_empty"
170 )]
171 tool_calls: Vec<ToolCall>,
172 },
173 #[serde(rename = "tool")]
174 ToolResult {
175 tool_call_id: String,
176 content: ToolResultContentValue,
177 },
178}
179
180impl Message {
181 pub fn system(content: &str) -> Self {
182 Message::System {
183 content: OneOrMany::one(content.to_owned().into()),
184 name: None,
185 }
186 }
187}
188
189fn history_contains_tool_result(messages: &[Message]) -> bool {
190 messages
191 .iter()
192 .any(|message| matches!(message, Message::ToolResult { .. }))
193}
194
195#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
196pub struct AudioAssistant {
197 pub id: String,
198}
199
200#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
201pub struct SystemContent {
202 #[serde(default)]
203 pub r#type: SystemContentType,
204 pub text: String,
205}
206
207#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
208#[serde(rename_all = "lowercase")]
209pub enum SystemContentType {
210 #[default]
211 Text,
212}
213
214#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
215#[serde(tag = "type", rename_all = "lowercase")]
216pub enum AssistantContent {
217 Text { text: String },
218 Refusal { refusal: String },
219}
220
221impl From<AssistantContent> for completion::AssistantContent {
222 fn from(value: AssistantContent) -> Self {
223 match value {
224 AssistantContent::Text { text } => completion::AssistantContent::text(text),
225 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
226 }
227 }
228}
229
230#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
231#[serde(tag = "type", rename_all = "lowercase")]
232pub enum UserContent {
233 Text {
234 text: String,
235 },
236 #[serde(rename = "image_url")]
237 Image {
238 image_url: ImageUrl,
239 },
240 Audio {
241 input_audio: InputAudio,
242 },
243}
244
245#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
246pub struct ImageUrl {
247 pub url: String,
248 #[serde(default)]
249 pub detail: ImageDetail,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct InputAudio {
254 pub data: String,
255 pub format: AudioMediaType,
256}
257
258#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
259pub struct ToolResultContent {
260 #[serde(default)]
261 r#type: ToolResultContentType,
262 pub text: String,
263}
264
265#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
266#[serde(rename_all = "lowercase")]
267pub enum ToolResultContentType {
268 #[default]
269 Text,
270}
271
272impl FromStr for ToolResultContent {
273 type Err = Infallible;
274
275 fn from_str(s: &str) -> Result<Self, Self::Err> {
276 Ok(s.to_owned().into())
277 }
278}
279
280impl From<String> for ToolResultContent {
281 fn from(s: String) -> Self {
282 ToolResultContent {
283 r#type: ToolResultContentType::default(),
284 text: s,
285 }
286 }
287}
288
289#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
290#[serde(untagged)]
291pub enum ToolResultContentValue {
292 Array(Vec<ToolResultContent>),
293 String(String),
294}
295
296impl ToolResultContentValue {
297 pub fn from_string(s: String, use_array_format: bool) -> Self {
298 if use_array_format {
299 ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
300 } else {
301 ToolResultContentValue::String(s)
302 }
303 }
304
305 pub fn as_text(&self) -> String {
306 match self {
307 ToolResultContentValue::Array(arr) => arr
308 .iter()
309 .map(|c| c.text.clone())
310 .collect::<Vec<_>>()
311 .join("\n"),
312 ToolResultContentValue::String(s) => s.clone(),
313 }
314 }
315
316 pub fn to_array(&self) -> Self {
317 match self {
318 ToolResultContentValue::Array(_) => self.clone(),
319 ToolResultContentValue::String(s) => {
320 ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
321 }
322 }
323 }
324}
325
326#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
327pub struct ToolCall {
328 pub id: String,
329 #[serde(default)]
330 pub r#type: ToolType,
331 pub function: Function,
332}
333
334#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
335#[serde(rename_all = "lowercase")]
336pub enum ToolType {
337 #[default]
338 Function,
339}
340
341#[derive(Debug, Deserialize, Serialize, Clone)]
343pub struct FunctionDefinition {
344 pub name: String,
345 pub description: String,
346 pub parameters: serde_json::Value,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub strict: Option<bool>,
349}
350
351#[derive(Debug, Deserialize, Serialize, Clone)]
352pub struct ToolDefinition {
353 pub r#type: String,
354 pub function: FunctionDefinition,
355}
356
357impl From<completion::ToolDefinition> for ToolDefinition {
358 fn from(tool: completion::ToolDefinition) -> Self {
359 Self {
360 r#type: "function".into(),
361 function: FunctionDefinition {
362 name: tool.name,
363 description: tool.description,
364 parameters: tool.parameters,
365 strict: None,
366 },
367 }
368 }
369}
370
371impl ToolDefinition {
372 pub fn with_strict(mut self) -> Self {
375 self.function.strict = Some(true);
376 super::sanitize_schema(&mut self.function.parameters);
377 self
378 }
379}
380
381#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
382#[serde(rename_all = "snake_case")]
383pub enum ToolChoice {
384 #[default]
385 Auto,
386 None,
387 Required,
388}
389
390impl TryFrom<crate::message::ToolChoice> for ToolChoice {
391 type Error = CompletionError;
392 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
393 let res = match value {
394 message::ToolChoice::Specific { .. } => {
395 return Err(CompletionError::ProviderError(
396 "Provider doesn't support only using specific tools".to_string(),
397 ));
398 }
399 message::ToolChoice::Auto => Self::Auto,
400 message::ToolChoice::None => Self::None,
401 message::ToolChoice::Required => Self::Required,
402 };
403
404 Ok(res)
405 }
406}
407
408#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
409pub struct Function {
410 pub name: String,
411 #[serde(
412 serialize_with = "json_utils::stringified_json::serialize",
413 deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
414 )]
415 pub arguments: serde_json::Value,
416}
417
418impl TryFrom<message::ToolResult> for Message {
419 type Error = message::MessageError;
420
421 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
422 let text = value
423 .content
424 .into_iter()
425 .map(|content| {
426 match content {
427 message::ToolResultContent::Text(message::Text { text }) => Ok(text),
428 message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
429 "OpenAI does not support images in tool results. Tool results must be text."
430 .into(),
431 )),
432 }
433 })
434 .collect::<Result<Vec<_>, _>>()?
435 .join("\n");
436
437 Ok(Message::ToolResult {
438 tool_call_id: value.id,
439 content: ToolResultContentValue::String(text),
440 })
441 }
442}
443
444impl TryFrom<message::UserContent> for UserContent {
445 type Error = message::MessageError;
446
447 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
448 match value {
449 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
450 message::UserContent::Image(message::Image {
451 data,
452 detail,
453 media_type,
454 ..
455 }) => match data {
456 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
457 image_url: ImageUrl {
458 url,
459 detail: detail.unwrap_or_default(),
460 },
461 }),
462 DocumentSourceKind::Base64(data) => {
463 let url = format!(
464 "data:{};base64,{}",
465 media_type.map(|i| i.to_mime_type()).ok_or(
466 message::MessageError::ConversionError(
467 "OpenAI Image URI must have media type".into()
468 )
469 )?,
470 data
471 );
472
473 let detail = detail.ok_or(message::MessageError::ConversionError(
474 "OpenAI image URI must have image detail".into(),
475 ))?;
476
477 Ok(UserContent::Image {
478 image_url: ImageUrl { url, detail },
479 })
480 }
481 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
482 "Raw files not supported, encode as base64 first".into(),
483 )),
484 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
485 "Document has no body".into(),
486 )),
487 doc => Err(message::MessageError::ConversionError(format!(
488 "Unsupported document type: {doc:?}"
489 ))),
490 },
491 message::UserContent::Document(message::Document { data, .. }) => {
492 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
493 Ok(UserContent::Text { text })
494 } else {
495 Err(message::MessageError::ConversionError(
496 "Documents must be base64 or a string".into(),
497 ))
498 }
499 }
500 message::UserContent::Audio(message::Audio {
501 data, media_type, ..
502 }) => match data {
503 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
504 input_audio: InputAudio {
505 data,
506 format: match media_type {
507 Some(media_type) => media_type,
508 None => AudioMediaType::MP3,
509 },
510 },
511 }),
512 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
513 "URLs are not supported for audio".into(),
514 )),
515 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
516 "Raw files are not supported for audio".into(),
517 )),
518 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
519 "Audio has no body".into(),
520 )),
521 audio => Err(message::MessageError::ConversionError(format!(
522 "Unsupported audio type: {audio:?}"
523 ))),
524 },
525 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
526 "Tool result is in unsupported format".into(),
527 )),
528 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
529 "Video is in unsupported format".into(),
530 )),
531 }
532 }
533}
534
535impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
536 type Error = message::MessageError;
537
538 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
539 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
540 .into_iter()
541 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
542
543 if !tool_results.is_empty() {
546 tool_results
547 .into_iter()
548 .map(|content| match content {
549 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
550 _ => unreachable!(),
551 })
552 .collect::<Result<Vec<_>, _>>()
553 } else {
554 let other_content: Vec<UserContent> = other_content
555 .into_iter()
556 .map(|content| content.try_into())
557 .collect::<Result<Vec<_>, _>>()?;
558
559 let other_content = OneOrMany::many(other_content)
560 .expect("There must be other content here if there were no tool result content");
561
562 Ok(vec![Message::User {
563 content: other_content,
564 name: None,
565 }])
566 }
567 }
568}
569
570impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
571 type Error = message::MessageError;
572
573 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
574 let mut text_content = Vec::new();
575 let mut tool_calls = Vec::new();
576
577 for content in value {
578 match content {
579 message::AssistantContent::Text(text) => text_content.push(text),
580 message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
581 message::AssistantContent::Reasoning(_) => {
582 }
585 message::AssistantContent::Image(_) => {
586 panic!(
587 "The OpenAI Completions API doesn't support image content in assistant messages!"
588 );
589 }
590 }
591 }
592
593 if text_content.is_empty() && tool_calls.is_empty() {
594 return Ok(vec![]);
595 }
596
597 Ok(vec![Message::Assistant {
598 content: text_content
599 .into_iter()
600 .map(|content| content.text.into())
601 .collect::<Vec<_>>(),
602 refusal: None,
603 audio: None,
604 name: None,
605 tool_calls: tool_calls
606 .into_iter()
607 .map(|tool_call| tool_call.into())
608 .collect::<Vec<_>>(),
609 }])
610 }
611}
612
613impl TryFrom<message::Message> for Vec<Message> {
614 type Error = message::MessageError;
615
616 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
617 match message {
618 message::Message::System { content } => Ok(vec![Message::system(&content)]),
619 message::Message::User { content } => content.try_into(),
620 message::Message::Assistant { content, .. } => content.try_into(),
621 }
622 }
623}
624
625impl From<message::ToolCall> for ToolCall {
626 fn from(tool_call: message::ToolCall) -> Self {
627 Self {
628 id: tool_call.id,
629 r#type: ToolType::default(),
630 function: Function {
631 name: tool_call.function.name,
632 arguments: tool_call.function.arguments,
633 },
634 }
635 }
636}
637
638impl From<ToolCall> for message::ToolCall {
639 fn from(tool_call: ToolCall) -> Self {
640 Self {
641 id: tool_call.id,
642 call_id: None,
643 function: message::ToolFunction {
644 name: tool_call.function.name,
645 arguments: tool_call.function.arguments,
646 },
647 signature: None,
648 additional_params: None,
649 }
650 }
651}
652
653impl TryFrom<Message> for message::Message {
654 type Error = message::MessageError;
655
656 fn try_from(message: Message) -> Result<Self, Self::Error> {
657 Ok(match message {
658 Message::User { content, .. } => message::Message::User {
659 content: content.map(|content| content.into()),
660 },
661 Message::Assistant {
662 content,
663 tool_calls,
664 ..
665 } => {
666 let mut content = content
667 .into_iter()
668 .map(|content| match content {
669 AssistantContent::Text { text } => message::AssistantContent::text(text),
670
671 AssistantContent::Refusal { refusal } => {
674 message::AssistantContent::text(refusal)
675 }
676 })
677 .collect::<Vec<_>>();
678
679 content.extend(
680 tool_calls
681 .into_iter()
682 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
683 .collect::<Result<Vec<_>, _>>()?,
684 );
685
686 message::Message::Assistant {
687 id: None,
688 content: OneOrMany::many(content).map_err(|_| {
689 message::MessageError::ConversionError(
690 "Neither `content` nor `tool_calls` was provided to the Message"
691 .to_owned(),
692 )
693 })?,
694 }
695 }
696
697 Message::ToolResult {
698 tool_call_id,
699 content,
700 } => message::Message::User {
701 content: OneOrMany::one(message::UserContent::tool_result(
702 tool_call_id,
703 OneOrMany::one(message::ToolResultContent::text(content.as_text())),
704 )),
705 },
706
707 Message::System { content, .. } => message::Message::User {
710 content: content.map(|content| message::UserContent::text(content.text)),
711 },
712 })
713 }
714}
715
716impl From<UserContent> for message::UserContent {
717 fn from(content: UserContent) -> Self {
718 match content {
719 UserContent::Text { text } => message::UserContent::text(text),
720 UserContent::Image { image_url } => {
721 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
722 }
723 UserContent::Audio { input_audio } => {
724 message::UserContent::audio(input_audio.data, Some(input_audio.format))
725 }
726 }
727 }
728}
729
730impl From<String> for UserContent {
731 fn from(s: String) -> Self {
732 UserContent::Text { text: s }
733 }
734}
735
736impl FromStr for UserContent {
737 type Err = Infallible;
738
739 fn from_str(s: &str) -> Result<Self, Self::Err> {
740 Ok(UserContent::Text {
741 text: s.to_string(),
742 })
743 }
744}
745
746impl From<String> for AssistantContent {
747 fn from(s: String) -> Self {
748 AssistantContent::Text { text: s }
749 }
750}
751
752impl FromStr for AssistantContent {
753 type Err = Infallible;
754
755 fn from_str(s: &str) -> Result<Self, Self::Err> {
756 Ok(AssistantContent::Text {
757 text: s.to_string(),
758 })
759 }
760}
761impl From<String> for SystemContent {
762 fn from(s: String) -> Self {
763 SystemContent {
764 r#type: SystemContentType::default(),
765 text: s,
766 }
767 }
768}
769
770impl FromStr for SystemContent {
771 type Err = Infallible;
772
773 fn from_str(s: &str) -> Result<Self, Self::Err> {
774 Ok(SystemContent {
775 r#type: SystemContentType::default(),
776 text: s.to_string(),
777 })
778 }
779}
780
781#[derive(Debug, Deserialize, Serialize)]
782pub struct CompletionResponse {
783 pub id: String,
784 pub object: String,
785 pub created: u64,
786 pub model: String,
787 pub system_fingerprint: Option<String>,
788 pub choices: Vec<Choice>,
789 pub usage: Option<Usage>,
790}
791
792impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
793 type Error = CompletionError;
794
795 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
796 let choice = response.choices.first().ok_or_else(|| {
797 CompletionError::ResponseError("Response contained no choices".to_owned())
798 })?;
799
800 let content = match &choice.message {
801 Message::Assistant {
802 content,
803 tool_calls,
804 ..
805 } => {
806 let mut content = content
807 .iter()
808 .filter_map(|c| {
809 let s = match c {
810 AssistantContent::Text { text } => text,
811 AssistantContent::Refusal { refusal } => refusal,
812 };
813 if s.is_empty() {
814 None
815 } else {
816 Some(completion::AssistantContent::text(s))
817 }
818 })
819 .collect::<Vec<_>>();
820
821 content.extend(
822 tool_calls
823 .iter()
824 .map(|call| {
825 completion::AssistantContent::tool_call(
826 &call.id,
827 &call.function.name,
828 call.function.arguments.clone(),
829 )
830 })
831 .collect::<Vec<_>>(),
832 );
833 Ok(content)
834 }
835 _ => Err(CompletionError::ResponseError(
836 "Response did not contain a valid message or tool call".into(),
837 )),
838 }?;
839
840 let choice = OneOrMany::many(content).map_err(|_| {
841 CompletionError::ResponseError(
842 "Response contained no message or tool call (empty)".to_owned(),
843 )
844 })?;
845
846 let usage = response
847 .usage
848 .as_ref()
849 .map(|usage| completion::Usage {
850 input_tokens: usage.prompt_tokens as u64,
851 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
852 total_tokens: usage.total_tokens as u64,
853 cached_input_tokens: usage
854 .prompt_tokens_details
855 .as_ref()
856 .map(|d| d.cached_tokens as u64)
857 .unwrap_or(0),
858 cache_creation_input_tokens: 0,
859 })
860 .unwrap_or_default();
861
862 Ok(completion::CompletionResponse {
863 choice,
864 usage,
865 raw_response: response,
866 message_id: None,
867 })
868 }
869}
870
871impl ProviderResponseExt for CompletionResponse {
872 type OutputMessage = Choice;
873 type Usage = Usage;
874
875 fn get_response_id(&self) -> Option<String> {
876 Some(self.id.to_owned())
877 }
878
879 fn get_response_model_name(&self) -> Option<String> {
880 Some(self.model.to_owned())
881 }
882
883 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
884 self.choices.clone()
885 }
886
887 fn get_text_response(&self) -> Option<String> {
888 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
889 return None;
890 };
891
892 let UserContent::Text { text } = content.first() else {
893 return None;
894 };
895
896 Some(text)
897 }
898
899 fn get_usage(&self) -> Option<Self::Usage> {
900 self.usage.clone()
901 }
902}
903
904#[derive(Clone, Debug, Serialize, Deserialize)]
905pub struct Choice {
906 pub index: usize,
907 pub message: Message,
908 pub logprobs: Option<serde_json::Value>,
909 pub finish_reason: String,
910}
911
912#[derive(Clone, Debug, Deserialize, Serialize, Default)]
913pub struct PromptTokensDetails {
914 #[serde(default)]
916 pub cached_tokens: usize,
917}
918
919#[derive(Clone, Debug, Deserialize, Serialize)]
920pub struct Usage {
921 pub prompt_tokens: usize,
922 pub total_tokens: usize,
923 #[serde(skip_serializing_if = "Option::is_none")]
924 pub prompt_tokens_details: Option<PromptTokensDetails>,
925}
926
927impl Usage {
928 pub fn new() -> Self {
929 Self {
930 prompt_tokens: 0,
931 total_tokens: 0,
932 prompt_tokens_details: None,
933 }
934 }
935}
936
937impl Default for Usage {
938 fn default() -> Self {
939 Self::new()
940 }
941}
942
943impl fmt::Display for Usage {
944 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
945 let Usage {
946 prompt_tokens,
947 total_tokens,
948 ..
949 } = self;
950 write!(
951 f,
952 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
953 )
954 }
955}
956
957impl GetTokenUsage for Usage {
958 fn token_usage(&self) -> Option<crate::completion::Usage> {
959 let mut usage = crate::completion::Usage::new();
960 usage.input_tokens = self.prompt_tokens as u64;
961 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
962 usage.total_tokens = self.total_tokens as u64;
963 usage.cached_input_tokens = self
964 .prompt_tokens_details
965 .as_ref()
966 .map(|d| d.cached_tokens as u64)
967 .unwrap_or(0);
968
969 Some(usage)
970 }
971}
972
973#[derive(Clone)]
974pub struct CompletionModel<T = reqwest::Client> {
975 pub(crate) client: Client<T>,
976 pub model: String,
977 pub strict_tools: bool,
978 pub tool_result_array_content: bool,
979}
980
981impl<T> CompletionModel<T>
982where
983 T: Default + std::fmt::Debug + Clone + 'static,
984{
985 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
986 Self {
987 client,
988 model: model.into(),
989 strict_tools: false,
990 tool_result_array_content: false,
991 }
992 }
993
994 pub fn with_model(client: Client<T>, model: &str) -> Self {
995 Self {
996 client,
997 model: model.into(),
998 strict_tools: false,
999 tool_result_array_content: false,
1000 }
1001 }
1002
1003 pub fn with_strict_tools(mut self) -> Self {
1012 self.strict_tools = true;
1013 self
1014 }
1015
1016 pub fn with_tool_result_array_content(mut self) -> Self {
1017 self.tool_result_array_content = true;
1018 self
1019 }
1020}
1021
1022#[derive(Debug, Serialize, Deserialize, Clone)]
1023pub struct CompletionRequest {
1024 model: String,
1025 messages: Vec<Message>,
1026 #[serde(skip_serializing_if = "Vec::is_empty")]
1027 tools: Vec<ToolDefinition>,
1028 #[serde(skip_serializing_if = "Option::is_none")]
1029 tool_choice: Option<ToolChoice>,
1030 #[serde(skip_serializing_if = "Option::is_none")]
1031 temperature: Option<f64>,
1032 #[serde(skip_serializing_if = "Option::is_none")]
1033 max_tokens: Option<u64>,
1034 #[serde(flatten)]
1035 additional_params: Option<serde_json::Value>,
1036}
1037
1038pub struct OpenAIRequestParams {
1039 pub model: String,
1040 pub request: CoreCompletionRequest,
1041 pub strict_tools: bool,
1042 pub tool_result_array_content: bool,
1043}
1044
1045impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1046 type Error = CompletionError;
1047
1048 fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1049 let OpenAIRequestParams {
1050 model,
1051 request: req,
1052 strict_tools,
1053 tool_result_array_content,
1054 } = params;
1055
1056 let mut partial_history = vec![];
1057 if let Some(docs) = req.normalized_documents() {
1058 partial_history.push(docs);
1059 }
1060 let CoreCompletionRequest {
1061 model: request_model,
1062 preamble,
1063 chat_history,
1064 tools,
1065 temperature,
1066 max_tokens,
1067 additional_params,
1068 tool_choice,
1069 output_schema,
1070 ..
1071 } = req;
1072
1073 partial_history.extend(chat_history);
1074
1075 let mut full_history: Vec<Message> =
1076 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1077
1078 full_history.extend(
1079 partial_history
1080 .into_iter()
1081 .map(message::Message::try_into)
1082 .collect::<Result<Vec<Vec<Message>>, _>>()?
1083 .into_iter()
1084 .flatten()
1085 .collect::<Vec<_>>(),
1086 );
1087
1088 if full_history.is_empty() {
1089 return Err(CompletionError::RequestError(
1090 std::io::Error::new(
1091 std::io::ErrorKind::InvalidInput,
1092 "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1093 )
1094 .into(),
1095 ));
1096 }
1097
1098 if tool_result_array_content {
1099 for msg in &mut full_history {
1100 if let Message::ToolResult { content, .. } = msg {
1101 *content = content.to_array();
1102 }
1103 }
1104 }
1105
1106 let history_has_tool_result = history_contains_tool_result(&full_history);
1107
1108 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1109
1110 let tools: Vec<ToolDefinition> = tools
1111 .into_iter()
1112 .map(|tool| {
1113 let def = ToolDefinition::from(tool);
1114 if strict_tools { def.with_strict() } else { def }
1115 })
1116 .collect();
1117
1118 let should_apply_response_format =
1122 output_schema.is_some() && (tools.is_empty() || history_has_tool_result);
1123
1124 let additional_params = if let Some(schema) = output_schema
1126 && should_apply_response_format
1127 {
1128 let name = schema
1129 .as_object()
1130 .and_then(|o| o.get("title"))
1131 .and_then(|v| v.as_str())
1132 .unwrap_or("response_schema")
1133 .to_string();
1134 let mut schema_value = schema.to_value();
1135 super::sanitize_schema(&mut schema_value);
1136 let response_format = serde_json::json!({
1137 "response_format": {
1138 "type": "json_schema",
1139 "json_schema": {
1140 "name": name,
1141 "strict": true,
1142 "schema": schema_value
1143 }
1144 }
1145 });
1146 Some(match additional_params {
1147 Some(existing) => json_utils::merge(existing, response_format),
1148 None => response_format,
1149 })
1150 } else {
1151 additional_params
1152 };
1153
1154 let res = Self {
1155 model: request_model.unwrap_or(model),
1156 messages: full_history,
1157 tools,
1158 tool_choice,
1159 temperature,
1160 max_tokens,
1161 additional_params,
1162 };
1163
1164 Ok(res)
1165 }
1166}
1167
1168impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1169 type Error = CompletionError;
1170
1171 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1172 CompletionRequest::try_from(OpenAIRequestParams {
1173 model,
1174 request: req,
1175 strict_tools: false,
1176 tool_result_array_content: false,
1177 })
1178 }
1179}
1180
1181impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1182 type InputMessage = Message;
1183
1184 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1185 self.messages.clone()
1186 }
1187
1188 fn get_system_prompt(&self) -> Option<String> {
1189 let first_message = self.messages.first()?;
1190
1191 let Message::System { ref content, .. } = first_message.clone() else {
1192 return None;
1193 };
1194
1195 let SystemContent { text, .. } = content.first();
1196
1197 Some(text)
1198 }
1199
1200 fn get_prompt(&self) -> Option<String> {
1201 let last_message = self.messages.last()?;
1202
1203 let Message::User { ref content, .. } = last_message.clone() else {
1204 return None;
1205 };
1206
1207 let UserContent::Text { text } = content.first() else {
1208 return None;
1209 };
1210
1211 Some(text)
1212 }
1213
1214 fn get_model_name(&self) -> String {
1215 self.model.clone()
1216 }
1217}
1218
1219impl CompletionModel<reqwest::Client> {
1220 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1221 crate::agent::AgentBuilder::new(self)
1222 }
1223}
1224
1225impl<T> completion::CompletionModel for CompletionModel<T>
1226where
1227 T: HttpClientExt
1228 + Default
1229 + std::fmt::Debug
1230 + Clone
1231 + WasmCompatSend
1232 + WasmCompatSync
1233 + 'static,
1234{
1235 type Response = CompletionResponse;
1236 type StreamingResponse = StreamingCompletionResponse;
1237
1238 type Client = super::CompletionsClient<T>;
1239
1240 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1241 Self::new(client.clone(), model)
1242 }
1243
1244 async fn completion(
1245 &self,
1246 completion_request: CoreCompletionRequest,
1247 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1248 let span = if tracing::Span::current().is_disabled() {
1249 info_span!(
1250 target: "rig::completions",
1251 "chat",
1252 gen_ai.operation.name = "chat",
1253 gen_ai.provider.name = "openai",
1254 gen_ai.request.model = self.model,
1255 gen_ai.system_instructions = &completion_request.preamble,
1256 gen_ai.response.id = tracing::field::Empty,
1257 gen_ai.response.model = tracing::field::Empty,
1258 gen_ai.usage.output_tokens = tracing::field::Empty,
1259 gen_ai.usage.input_tokens = tracing::field::Empty,
1260 gen_ai.usage.cached_tokens = tracing::field::Empty,
1261 )
1262 } else {
1263 tracing::Span::current()
1264 };
1265
1266 let request = CompletionRequest::try_from(OpenAIRequestParams {
1267 model: self.model.to_owned(),
1268 request: completion_request,
1269 strict_tools: self.strict_tools,
1270 tool_result_array_content: self.tool_result_array_content,
1271 })?;
1272
1273 if enabled!(Level::TRACE) {
1274 tracing::trace!(
1275 target: "rig::completions",
1276 "OpenAI Chat Completions completion request: {}",
1277 serde_json::to_string_pretty(&request)?
1278 );
1279 }
1280
1281 let body = serde_json::to_vec(&request)?;
1282
1283 let req = self
1284 .client
1285 .post("/chat/completions")?
1286 .body(body)
1287 .map_err(|e| CompletionError::HttpError(e.into()))?;
1288
1289 async move {
1290 let response = self.client.send(req).await?;
1291
1292 if response.status().is_success() {
1293 let text = http_client::text(response).await?;
1294
1295 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1296 ApiResponse::Ok(response) => {
1297 let span = tracing::Span::current();
1298 span.record_response_metadata(&response);
1299 span.record_token_usage(&response.usage);
1300
1301 if enabled!(Level::TRACE) {
1302 tracing::trace!(
1303 target: "rig::completions",
1304 "OpenAI Chat Completions completion response: {}",
1305 serde_json::to_string_pretty(&response)?
1306 );
1307 }
1308
1309 response.try_into()
1310 }
1311 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1312 }
1313 } else {
1314 let text = http_client::text(response).await?;
1315 Err(CompletionError::ProviderError(text))
1316 }
1317 }
1318 .instrument(span)
1319 .await
1320 }
1321
1322 async fn stream(
1323 &self,
1324 request: CoreCompletionRequest,
1325 ) -> Result<
1326 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1327 CompletionError,
1328 > {
1329 Self::stream(self, request).await
1330 }
1331}
1332
1333fn serialize_assistant_content_vec<S>(
1334 value: &Vec<AssistantContent>,
1335 serializer: S,
1336) -> Result<S::Ok, S::Error>
1337where
1338 S: Serializer,
1339{
1340 if value.is_empty() {
1341 serializer.serialize_str("")
1342 } else {
1343 value.serialize(serializer)
1344 }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350
1351 #[test]
1352 fn test_openai_request_uses_request_model_override() {
1353 let request = crate::completion::CompletionRequest {
1354 model: Some("gpt-4.1".to_string()),
1355 preamble: None,
1356 chat_history: crate::OneOrMany::one("Hello".into()),
1357 documents: vec![],
1358 tools: vec![],
1359 temperature: None,
1360 max_tokens: None,
1361 tool_choice: None,
1362 additional_params: None,
1363 output_schema: None,
1364 };
1365
1366 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1367 model: "gpt-4o-mini".to_string(),
1368 request,
1369 strict_tools: false,
1370 tool_result_array_content: false,
1371 })
1372 .expect("request conversion should succeed");
1373 let serialized =
1374 serde_json::to_value(openai_request).expect("serialization should succeed");
1375
1376 assert_eq!(serialized["model"], "gpt-4.1");
1377 }
1378
1379 #[test]
1380 fn test_openai_request_uses_default_model_when_override_unset() {
1381 let request = crate::completion::CompletionRequest {
1382 model: None,
1383 preamble: None,
1384 chat_history: crate::OneOrMany::one("Hello".into()),
1385 documents: vec![],
1386 tools: vec![],
1387 temperature: None,
1388 max_tokens: None,
1389 tool_choice: None,
1390 additional_params: None,
1391 output_schema: None,
1392 };
1393
1394 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1395 model: "gpt-4o-mini".to_string(),
1396 request,
1397 strict_tools: false,
1398 tool_result_array_content: false,
1399 })
1400 .expect("request conversion should succeed");
1401 let serialized =
1402 serde_json::to_value(openai_request).expect("serialization should succeed");
1403
1404 assert_eq!(serialized["model"], "gpt-4o-mini");
1405 }
1406
1407 #[test]
1408 fn assistant_reasoning_is_silently_skipped() {
1409 let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1410
1411 let converted: Vec<Message> = assistant_content
1412 .try_into()
1413 .expect("conversion should work");
1414
1415 assert!(converted.is_empty());
1416 }
1417
1418 #[test]
1419 fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1420 let assistant_content = OneOrMany::many(vec![
1421 message::AssistantContent::reasoning("hidden"),
1422 message::AssistantContent::text("visible"),
1423 message::AssistantContent::tool_call(
1424 "call_1",
1425 "subtract",
1426 serde_json::json!({"x": 2, "y": 1}),
1427 ),
1428 ])
1429 .expect("non-empty assistant content");
1430
1431 let converted: Vec<Message> = assistant_content
1432 .try_into()
1433 .expect("conversion should work");
1434 assert_eq!(converted.len(), 1);
1435
1436 match &converted[0] {
1437 Message::Assistant {
1438 content,
1439 tool_calls,
1440 ..
1441 } => {
1442 assert_eq!(
1443 content,
1444 &vec![AssistantContent::Text {
1445 text: "visible".to_string()
1446 }]
1447 );
1448 assert_eq!(tool_calls.len(), 1);
1449 assert_eq!(tool_calls[0].id, "call_1");
1450 assert_eq!(tool_calls[0].function.name, "subtract");
1451 assert_eq!(
1452 tool_calls[0].function.arguments,
1453 serde_json::json!({"x": 2, "y": 1})
1454 );
1455 }
1456 _ => panic!("expected assistant message"),
1457 }
1458 }
1459
1460 #[test]
1461 fn test_max_tokens_is_forwarded_to_request() {
1462 let request = crate::completion::CompletionRequest {
1463 model: None,
1464 preamble: None,
1465 chat_history: crate::OneOrMany::one("Hello".into()),
1466 documents: vec![],
1467 tools: vec![],
1468 temperature: None,
1469 max_tokens: Some(4096),
1470 tool_choice: None,
1471 additional_params: None,
1472 output_schema: None,
1473 };
1474
1475 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1476 model: "gpt-4o-mini".to_string(),
1477 request,
1478 strict_tools: false,
1479 tool_result_array_content: false,
1480 })
1481 .expect("request conversion should succeed");
1482 let serialized =
1483 serde_json::to_value(openai_request).expect("serialization should succeed");
1484
1485 assert_eq!(serialized["max_tokens"], 4096);
1486 }
1487
1488 #[test]
1489 fn test_max_tokens_omitted_when_none() {
1490 let request = crate::completion::CompletionRequest {
1491 model: None,
1492 preamble: None,
1493 chat_history: crate::OneOrMany::one("Hello".into()),
1494 documents: vec![],
1495 tools: vec![],
1496 temperature: None,
1497 max_tokens: None,
1498 tool_choice: None,
1499 additional_params: None,
1500 output_schema: None,
1501 };
1502
1503 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1504 model: "gpt-4o-mini".to_string(),
1505 request,
1506 strict_tools: false,
1507 tool_result_array_content: false,
1508 })
1509 .expect("request conversion should succeed");
1510 let serialized =
1511 serde_json::to_value(openai_request).expect("serialization should succeed");
1512
1513 assert!(serialized.get("max_tokens").is_none());
1514 }
1515
1516 #[test]
1517 fn request_conversion_errors_when_all_messages_are_filtered() {
1518 let request = CoreCompletionRequest {
1519 model: None,
1520 preamble: None,
1521 chat_history: OneOrMany::one(message::Message::Assistant {
1522 id: None,
1523 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1524 }),
1525 documents: vec![],
1526 tools: vec![],
1527 temperature: None,
1528 max_tokens: None,
1529 tool_choice: None,
1530 additional_params: None,
1531 output_schema: None,
1532 };
1533
1534 let result = CompletionRequest::try_from(OpenAIRequestParams {
1535 model: "gpt-4o-mini".to_string(),
1536 request,
1537 strict_tools: false,
1538 tool_result_array_content: false,
1539 });
1540
1541 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1542 }
1543
1544 #[test]
1545 fn request_conversion_omits_response_format_on_initial_tool_turn() {
1546 let request = CoreCompletionRequest {
1547 model: None,
1548 preamble: None,
1549 chat_history: OneOrMany::one(message::Message::user(
1550 "Hello, whats the weather in London?",
1551 )),
1552 documents: vec![],
1553 tools: vec![completion::ToolDefinition {
1554 name: "weather".to_string(),
1555 description: "Get the weather".to_string(),
1556 parameters: serde_json::json!({
1557 "type": "object",
1558 "properties": {
1559 "city": { "type": "string" }
1560 },
1561 "required": ["city"]
1562 }),
1563 }],
1564 temperature: None,
1565 max_tokens: None,
1566 tool_choice: None,
1567 additional_params: None,
1568 output_schema: Some(
1569 serde_json::from_value(serde_json::json!({
1570 "title": "WeatherResponse",
1571 "type": "object",
1572 "properties": {
1573 "city": { "type": "string" },
1574 "weather": { "type": "string" }
1575 },
1576 "required": ["city", "weather"]
1577 }))
1578 .expect("schema should deserialize"),
1579 ),
1580 };
1581
1582 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1583 model: "gpt-4o-mini".to_string(),
1584 request,
1585 strict_tools: false,
1586 tool_result_array_content: false,
1587 })
1588 .expect("request conversion should succeed");
1589
1590 let serialized =
1591 serde_json::to_value(openai_request).expect("serialization should succeed");
1592
1593 assert!(
1594 serialized.get("response_format").is_none(),
1595 "initial tool turn should omit response_format: {serialized:?}"
1596 );
1597 }
1598
1599 #[test]
1600 fn request_conversion_restores_response_format_after_tool_result() {
1601 let request = CoreCompletionRequest {
1602 model: None,
1603 preamble: None,
1604 chat_history: OneOrMany::many(vec![
1605 message::Message::user("Hello, whats the weather in London?"),
1606 message::Message::Assistant {
1607 id: None,
1608 content: OneOrMany::one(message::AssistantContent::tool_call(
1609 "call_1",
1610 "weather",
1611 serde_json::json!({ "city": "London" }),
1612 )),
1613 },
1614 message::Message::tool_result(
1615 "call_1",
1616 "The weather in London is all fire and brimstone",
1617 ),
1618 ])
1619 .expect("history should be non-empty"),
1620 documents: vec![],
1621 tools: vec![completion::ToolDefinition {
1622 name: "weather".to_string(),
1623 description: "Get the weather".to_string(),
1624 parameters: serde_json::json!({
1625 "type": "object",
1626 "properties": {
1627 "city": { "type": "string" }
1628 },
1629 "required": ["city"]
1630 }),
1631 }],
1632 temperature: None,
1633 max_tokens: None,
1634 tool_choice: None,
1635 additional_params: None,
1636 output_schema: Some(
1637 serde_json::from_value(serde_json::json!({
1638 "title": "WeatherResponse",
1639 "type": "object",
1640 "properties": {
1641 "city": { "type": "string" },
1642 "weather": { "type": "string" }
1643 },
1644 "required": ["city", "weather"]
1645 }))
1646 .expect("schema should deserialize"),
1647 ),
1648 };
1649
1650 let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1651 model: "gpt-4o-mini".to_string(),
1652 request,
1653 strict_tools: false,
1654 tool_result_array_content: false,
1655 })
1656 .expect("request conversion should succeed");
1657
1658 let serialized =
1659 serde_json::to_value(openai_request).expect("serialization should succeed");
1660
1661 assert!(
1662 serialized.get("response_format").is_some(),
1663 "follow-up turn should restore response_format: {serialized:?}"
1664 );
1665 }
1666
1667 #[test]
1668 fn deserialize_llama_cpp_tool_call() {
1669 let request = r#"{
1670 "choices": [{
1671 "finish_reason": "tool_calls",
1672 "index": 0,
1673 "message": {
1674 "role": "assistant",
1675 "content": "",
1676 "tool_calls": [{ "type": "function", "function": { "name": "hello_world", "arguments": { "city": "Paris" } }, "id": "xxx" }]
1677 }
1678 }],
1679 "created": 0,
1680 "model": "gpt-4o-mini",
1681 "system_fingerprint": "fp_xxx",
1682 "object": "chat.completion",
1683 "usage": { "completion_tokens": 13, "prompt_tokens": 255, "total_tokens": 268 },
1684 "id": "xxx"
1685 }
1686 "#;
1687 let response = serde_json::from_str::<ApiResponse<CompletionResponse>>(request).unwrap();
1688
1689 let ApiResponse::Ok(response) = response else {
1690 panic!("expected successful completion response");
1691 };
1692 assert_eq!(response.choices.len(), 1);
1693
1694 let Message::Assistant { tool_calls, .. } = &response.choices[0].message else {
1695 panic!("expected assistant message");
1696 };
1697 assert_eq!(tool_calls.len(), 1);
1698 assert_eq!(tool_calls[0].id, "xxx");
1699 assert_eq!(tool_calls[0].function.name, "hello_world");
1700 assert_eq!(
1701 tool_calls[0].function.arguments,
1702 serde_json::json!({"city": "Paris"})
1703 );
1704 }
1705
1706 #[test]
1707 fn deserialize_openai_stringified_tool_call() {
1708 let request = r#"{
1709 "choices": [{
1710 "finish_reason": "tool_calls",
1711 "index": 0,
1712 "message": {
1713 "role": "assistant",
1714 "content": "",
1715 "tool_calls": [{ "type": "function", "function": { "name": "hello_world", "arguments": "{\"city\":\"Paris\"}" }, "id": "xxx" }]
1716 }
1717 }],
1718 "created": 0,
1719 "model": "gpt-4o-mini",
1720 "system_fingerprint": "fp_xxx",
1721 "object": "chat.completion",
1722 "usage": { "completion_tokens": 13, "prompt_tokens": 255, "total_tokens": 268 },
1723 "id": "xxx"
1724 }
1725 "#;
1726 let response = serde_json::from_str::<ApiResponse<CompletionResponse>>(request).unwrap();
1727
1728 let ApiResponse::Ok(response) = response else {
1729 panic!("expected successful completion response");
1730 };
1731 assert_eq!(response.choices.len(), 1);
1732
1733 let Message::Assistant { tool_calls, .. } = &response.choices[0].message else {
1734 panic!("expected assistant message");
1735 };
1736 assert_eq!(tool_calls.len(), 1);
1737 assert_eq!(tool_calls[0].id, "xxx");
1738 assert_eq!(tool_calls[0].function.name, "hello_world");
1739 assert_eq!(
1740 tool_calls[0].function.arguments,
1741 serde_json::json!({"city": "Paris"})
1742 );
1743 }
1744}