1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28pub const GPT_5_1: &str = "gpt-5.1";
30
31pub const GPT_5: &str = "gpt-5";
33pub const GPT_5_MINI: &str = "gpt-5-mini";
35pub const GPT_5_NANO: &str = "gpt-5-nano";
37
38pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
40pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
42pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
44pub const GPT_4O: &str = "gpt-4o";
46pub const GPT_4O_MINI: &str = "gpt-4o-mini";
48pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
50pub const GPT_4_TURBO: &str = "gpt-4-turbo";
52pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
54pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
56pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
58pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
60pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
62pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
64pub const GPT_4: &str = "gpt-4";
66pub const GPT_4_0613: &str = "gpt-4-0613";
68pub const GPT_4_32K: &str = "gpt-4-32k";
70pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
72
73pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
75pub const O4_MINI: &str = "o4-mini";
77pub const O3: &str = "o3";
79pub const O3_MINI: &str = "o3-mini";
81pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
83pub const O1_PRO: &str = "o1-pro";
85pub const O1: &str = "o1";
87pub const O1_2024_12_17: &str = "o1-2024-12-17";
89pub const O1_PREVIEW: &str = "o1-preview";
91pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
93pub const O1_MINI: &str = "o1-mini";
95pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
97
98pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
100pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
102pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
104pub const GPT_4_1: &str = "gpt-4.1";
106
107impl From<ApiErrorResponse> for CompletionError {
108 fn from(err: ApiErrorResponse) -> Self {
109 CompletionError::ProviderError(err.message)
110 }
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
114#[serde(tag = "role", rename_all = "lowercase")]
115pub enum Message {
116 #[serde(alias = "developer")]
117 System {
118 #[serde(deserialize_with = "string_or_one_or_many")]
119 content: OneOrMany<SystemContent>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 name: Option<String>,
122 },
123 User {
124 #[serde(deserialize_with = "string_or_one_or_many")]
125 content: OneOrMany<UserContent>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 name: Option<String>,
128 },
129 Assistant {
130 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
131 content: Vec<AssistantContent>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 refusal: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 audio: Option<AudioAssistant>,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 name: Option<String>,
138 #[serde(
139 default,
140 deserialize_with = "json_utils::null_or_vec",
141 skip_serializing_if = "Vec::is_empty"
142 )]
143 tool_calls: Vec<ToolCall>,
144 },
145 #[serde(rename = "tool")]
146 ToolResult {
147 tool_call_id: String,
148 content: ToolResultContentValue,
149 },
150}
151
152impl Message {
153 pub fn system(content: &str) -> Self {
154 Message::System {
155 content: OneOrMany::one(content.to_owned().into()),
156 name: None,
157 }
158 }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct AudioAssistant {
163 pub id: String,
164}
165
166#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
167pub struct SystemContent {
168 #[serde(default)]
169 pub r#type: SystemContentType,
170 pub text: String,
171}
172
173#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
174#[serde(rename_all = "lowercase")]
175pub enum SystemContentType {
176 #[default]
177 Text,
178}
179
180#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum AssistantContent {
183 Text { text: String },
184 Refusal { refusal: String },
185}
186
187impl From<AssistantContent> for completion::AssistantContent {
188 fn from(value: AssistantContent) -> Self {
189 match value {
190 AssistantContent::Text { text } => completion::AssistantContent::text(text),
191 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
192 }
193 }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum UserContent {
199 Text {
200 text: String,
201 },
202 #[serde(rename = "image_url")]
203 Image {
204 image_url: ImageUrl,
205 },
206 Audio {
207 input_audio: InputAudio,
208 },
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ImageUrl {
213 pub url: String,
214 #[serde(default)]
215 pub detail: ImageDetail,
216}
217
218#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
219pub struct InputAudio {
220 pub data: String,
221 pub format: AudioMediaType,
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225pub struct ToolResultContent {
226 #[serde(default)]
227 r#type: ToolResultContentType,
228 pub text: String,
229}
230
231#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(rename_all = "lowercase")]
233pub enum ToolResultContentType {
234 #[default]
235 Text,
236}
237
238impl FromStr for ToolResultContent {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(s.to_owned().into())
243 }
244}
245
246impl From<String> for ToolResultContent {
247 fn from(s: String) -> Self {
248 ToolResultContent {
249 r#type: ToolResultContentType::default(),
250 text: s,
251 }
252 }
253}
254
255#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
256#[serde(untagged)]
257pub enum ToolResultContentValue {
258 Array(Vec<ToolResultContent>),
259 String(String),
260}
261
262impl ToolResultContentValue {
263 pub fn from_string(s: String, use_array_format: bool) -> Self {
264 if use_array_format {
265 ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
266 } else {
267 ToolResultContentValue::String(s)
268 }
269 }
270
271 pub fn as_text(&self) -> String {
272 match self {
273 ToolResultContentValue::Array(arr) => arr
274 .iter()
275 .map(|c| c.text.clone())
276 .collect::<Vec<_>>()
277 .join("\n"),
278 ToolResultContentValue::String(s) => s.clone(),
279 }
280 }
281
282 pub fn to_array(&self) -> Self {
283 match self {
284 ToolResultContentValue::Array(_) => self.clone(),
285 ToolResultContentValue::String(s) => {
286 ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
287 }
288 }
289 }
290}
291
292#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
293pub struct ToolCall {
294 pub id: String,
295 #[serde(default)]
296 pub r#type: ToolType,
297 pub function: Function,
298}
299
300#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
301#[serde(rename_all = "lowercase")]
302pub enum ToolType {
303 #[default]
304 Function,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone)]
309pub struct FunctionDefinition {
310 pub name: String,
311 pub description: String,
312 pub parameters: serde_json::Value,
313 #[serde(skip_serializing_if = "Option::is_none")]
314 pub strict: Option<bool>,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone)]
318pub struct ToolDefinition {
319 pub r#type: String,
320 pub function: FunctionDefinition,
321}
322
323impl From<completion::ToolDefinition> for ToolDefinition {
324 fn from(tool: completion::ToolDefinition) -> Self {
325 Self {
326 r#type: "function".into(),
327 function: FunctionDefinition {
328 name: tool.name,
329 description: tool.description,
330 parameters: tool.parameters,
331 strict: None,
332 },
333 }
334 }
335}
336
337impl ToolDefinition {
338 pub fn with_strict(mut self) -> Self {
341 self.function.strict = Some(true);
342 super::sanitize_schema(&mut self.function.parameters);
343 self
344 }
345}
346
347#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
348#[serde(rename_all = "snake_case")]
349pub enum ToolChoice {
350 #[default]
351 Auto,
352 None,
353 Required,
354}
355
356impl TryFrom<crate::message::ToolChoice> for ToolChoice {
357 type Error = CompletionError;
358 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
359 let res = match value {
360 message::ToolChoice::Specific { .. } => {
361 return Err(CompletionError::ProviderError(
362 "Provider doesn't support only using specific tools".to_string(),
363 ));
364 }
365 message::ToolChoice::Auto => Self::Auto,
366 message::ToolChoice::None => Self::None,
367 message::ToolChoice::Required => Self::Required,
368 };
369
370 Ok(res)
371 }
372}
373
374#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
375pub struct Function {
376 pub name: String,
377 #[serde(with = "json_utils::stringified_json")]
378 pub arguments: serde_json::Value,
379}
380
381impl TryFrom<message::ToolResult> for Message {
382 type Error = message::MessageError;
383
384 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
385 let text = value
386 .content
387 .into_iter()
388 .map(|content| {
389 match content {
390 message::ToolResultContent::Text(message::Text { text }) => Ok(text),
391 message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
392 "OpenAI does not support images in tool results. Tool results must be text."
393 .into(),
394 )),
395 }
396 })
397 .collect::<Result<Vec<_>, _>>()?
398 .join("\n");
399
400 Ok(Message::ToolResult {
401 tool_call_id: value.id,
402 content: ToolResultContentValue::String(text),
403 })
404 }
405}
406
407impl TryFrom<message::UserContent> for UserContent {
408 type Error = message::MessageError;
409
410 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
411 match value {
412 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
413 message::UserContent::Image(message::Image {
414 data,
415 detail,
416 media_type,
417 ..
418 }) => match data {
419 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
420 image_url: ImageUrl {
421 url,
422 detail: detail.unwrap_or_default(),
423 },
424 }),
425 DocumentSourceKind::Base64(data) => {
426 let url = format!(
427 "data:{};base64,{}",
428 media_type.map(|i| i.to_mime_type()).ok_or(
429 message::MessageError::ConversionError(
430 "OpenAI Image URI must have media type".into()
431 )
432 )?,
433 data
434 );
435
436 let detail = detail.ok_or(message::MessageError::ConversionError(
437 "OpenAI image URI must have image detail".into(),
438 ))?;
439
440 Ok(UserContent::Image {
441 image_url: ImageUrl { url, detail },
442 })
443 }
444 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
445 "Raw files not supported, encode as base64 first".into(),
446 )),
447 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
448 "Document has no body".into(),
449 )),
450 doc => Err(message::MessageError::ConversionError(format!(
451 "Unsupported document type: {doc:?}"
452 ))),
453 },
454 message::UserContent::Document(message::Document { data, .. }) => {
455 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
456 Ok(UserContent::Text { text })
457 } else {
458 Err(message::MessageError::ConversionError(
459 "Documents must be base64 or a string".into(),
460 ))
461 }
462 }
463 message::UserContent::Audio(message::Audio {
464 data, media_type, ..
465 }) => match data {
466 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
467 input_audio: InputAudio {
468 data,
469 format: match media_type {
470 Some(media_type) => media_type,
471 None => AudioMediaType::MP3,
472 },
473 },
474 }),
475 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
476 "URLs are not supported for audio".into(),
477 )),
478 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
479 "Raw files are not supported for audio".into(),
480 )),
481 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
482 "Audio has no body".into(),
483 )),
484 audio => Err(message::MessageError::ConversionError(format!(
485 "Unsupported audio type: {audio:?}"
486 ))),
487 },
488 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
489 "Tool result is in unsupported format".into(),
490 )),
491 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
492 "Video is in unsupported format".into(),
493 )),
494 }
495 }
496}
497
498impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
499 type Error = message::MessageError;
500
501 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
502 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
503 .into_iter()
504 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
505
506 if !tool_results.is_empty() {
509 tool_results
510 .into_iter()
511 .map(|content| match content {
512 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
513 _ => unreachable!(),
514 })
515 .collect::<Result<Vec<_>, _>>()
516 } else {
517 let other_content: Vec<UserContent> = other_content
518 .into_iter()
519 .map(|content| content.try_into())
520 .collect::<Result<Vec<_>, _>>()?;
521
522 let other_content = OneOrMany::many(other_content)
523 .expect("There must be other content here if there were no tool result content");
524
525 Ok(vec![Message::User {
526 content: other_content,
527 name: None,
528 }])
529 }
530 }
531}
532
533impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
534 type Error = message::MessageError;
535
536 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
537 let (text_content, tool_calls) = value.into_iter().fold(
538 (Vec::new(), Vec::new()),
539 |(mut texts, mut tools), content| {
540 match content {
541 message::AssistantContent::Text(text) => texts.push(text),
542 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
543 message::AssistantContent::Reasoning(_) => {
544 panic!("The OpenAI Completions API doesn't support reasoning!");
545 }
546 message::AssistantContent::Image(_) => {
547 panic!(
548 "The OpenAI Completions API doesn't support image content in assistant messages!"
549 );
550 }
551 }
552 (texts, tools)
553 },
554 );
555
556 Ok(vec![Message::Assistant {
559 content: text_content
560 .into_iter()
561 .map(|content| content.text.into())
562 .collect::<Vec<_>>(),
563 refusal: None,
564 audio: None,
565 name: None,
566 tool_calls: tool_calls
567 .into_iter()
568 .map(|tool_call| tool_call.into())
569 .collect::<Vec<_>>(),
570 }])
571 }
572}
573
574impl TryFrom<message::Message> for Vec<Message> {
575 type Error = message::MessageError;
576
577 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
578 match message {
579 message::Message::User { content } => content.try_into(),
580 message::Message::Assistant { content, .. } => content.try_into(),
581 }
582 }
583}
584
585impl From<message::ToolCall> for ToolCall {
586 fn from(tool_call: message::ToolCall) -> Self {
587 Self {
588 id: tool_call.id,
589 r#type: ToolType::default(),
590 function: Function {
591 name: tool_call.function.name,
592 arguments: tool_call.function.arguments,
593 },
594 }
595 }
596}
597
598impl From<ToolCall> for message::ToolCall {
599 fn from(tool_call: ToolCall) -> Self {
600 Self {
601 id: tool_call.id,
602 call_id: None,
603 function: message::ToolFunction {
604 name: tool_call.function.name,
605 arguments: tool_call.function.arguments,
606 },
607 signature: None,
608 additional_params: None,
609 }
610 }
611}
612
613impl TryFrom<Message> for message::Message {
614 type Error = message::MessageError;
615
616 fn try_from(message: Message) -> Result<Self, Self::Error> {
617 Ok(match message {
618 Message::User { content, .. } => message::Message::User {
619 content: content.map(|content| content.into()),
620 },
621 Message::Assistant {
622 content,
623 tool_calls,
624 ..
625 } => {
626 let mut content = content
627 .into_iter()
628 .map(|content| match content {
629 AssistantContent::Text { text } => message::AssistantContent::text(text),
630
631 AssistantContent::Refusal { refusal } => {
634 message::AssistantContent::text(refusal)
635 }
636 })
637 .collect::<Vec<_>>();
638
639 content.extend(
640 tool_calls
641 .into_iter()
642 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
643 .collect::<Result<Vec<_>, _>>()?,
644 );
645
646 message::Message::Assistant {
647 id: None,
648 content: OneOrMany::many(content).map_err(|_| {
649 message::MessageError::ConversionError(
650 "Neither `content` nor `tool_calls` was provided to the Message"
651 .to_owned(),
652 )
653 })?,
654 }
655 }
656
657 Message::ToolResult {
658 tool_call_id,
659 content,
660 } => message::Message::User {
661 content: OneOrMany::one(message::UserContent::tool_result(
662 tool_call_id,
663 OneOrMany::one(message::ToolResultContent::text(content.as_text())),
664 )),
665 },
666
667 Message::System { content, .. } => message::Message::User {
670 content: content.map(|content| message::UserContent::text(content.text)),
671 },
672 })
673 }
674}
675
676impl From<UserContent> for message::UserContent {
677 fn from(content: UserContent) -> Self {
678 match content {
679 UserContent::Text { text } => message::UserContent::text(text),
680 UserContent::Image { image_url } => {
681 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
682 }
683 UserContent::Audio { input_audio } => {
684 message::UserContent::audio(input_audio.data, Some(input_audio.format))
685 }
686 }
687 }
688}
689
690impl From<String> for UserContent {
691 fn from(s: String) -> Self {
692 UserContent::Text { text: s }
693 }
694}
695
696impl FromStr for UserContent {
697 type Err = Infallible;
698
699 fn from_str(s: &str) -> Result<Self, Self::Err> {
700 Ok(UserContent::Text {
701 text: s.to_string(),
702 })
703 }
704}
705
706impl From<String> for AssistantContent {
707 fn from(s: String) -> Self {
708 AssistantContent::Text { text: s }
709 }
710}
711
712impl FromStr for AssistantContent {
713 type Err = Infallible;
714
715 fn from_str(s: &str) -> Result<Self, Self::Err> {
716 Ok(AssistantContent::Text {
717 text: s.to_string(),
718 })
719 }
720}
721impl From<String> for SystemContent {
722 fn from(s: String) -> Self {
723 SystemContent {
724 r#type: SystemContentType::default(),
725 text: s,
726 }
727 }
728}
729
730impl FromStr for SystemContent {
731 type Err = Infallible;
732
733 fn from_str(s: &str) -> Result<Self, Self::Err> {
734 Ok(SystemContent {
735 r#type: SystemContentType::default(),
736 text: s.to_string(),
737 })
738 }
739}
740
741#[derive(Debug, Deserialize, Serialize)]
742pub struct CompletionResponse {
743 pub id: String,
744 pub object: String,
745 pub created: u64,
746 pub model: String,
747 pub system_fingerprint: Option<String>,
748 pub choices: Vec<Choice>,
749 pub usage: Option<Usage>,
750}
751
752impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
753 type Error = CompletionError;
754
755 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
756 let choice = response.choices.first().ok_or_else(|| {
757 CompletionError::ResponseError("Response contained no choices".to_owned())
758 })?;
759
760 let content = match &choice.message {
761 Message::Assistant {
762 content,
763 tool_calls,
764 ..
765 } => {
766 let mut content = content
767 .iter()
768 .filter_map(|c| {
769 let s = match c {
770 AssistantContent::Text { text } => text,
771 AssistantContent::Refusal { refusal } => refusal,
772 };
773 if s.is_empty() {
774 None
775 } else {
776 Some(completion::AssistantContent::text(s))
777 }
778 })
779 .collect::<Vec<_>>();
780
781 content.extend(
782 tool_calls
783 .iter()
784 .map(|call| {
785 completion::AssistantContent::tool_call(
786 &call.id,
787 &call.function.name,
788 call.function.arguments.clone(),
789 )
790 })
791 .collect::<Vec<_>>(),
792 );
793 Ok(content)
794 }
795 _ => Err(CompletionError::ResponseError(
796 "Response did not contain a valid message or tool call".into(),
797 )),
798 }?;
799
800 let choice = OneOrMany::many(content).map_err(|_| {
801 CompletionError::ResponseError(
802 "Response contained no message or tool call (empty)".to_owned(),
803 )
804 })?;
805
806 let usage = response
807 .usage
808 .as_ref()
809 .map(|usage| completion::Usage {
810 input_tokens: usage.prompt_tokens as u64,
811 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
812 total_tokens: usage.total_tokens as u64,
813 cached_input_tokens: usage
814 .prompt_tokens_details
815 .as_ref()
816 .map(|d| d.cached_tokens as u64)
817 .unwrap_or(0),
818 })
819 .unwrap_or_default();
820
821 Ok(completion::CompletionResponse {
822 choice,
823 usage,
824 raw_response: response,
825 })
826 }
827}
828
829impl ProviderResponseExt for CompletionResponse {
830 type OutputMessage = Choice;
831 type Usage = Usage;
832
833 fn get_response_id(&self) -> Option<String> {
834 Some(self.id.to_owned())
835 }
836
837 fn get_response_model_name(&self) -> Option<String> {
838 Some(self.model.to_owned())
839 }
840
841 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
842 self.choices.clone()
843 }
844
845 fn get_text_response(&self) -> Option<String> {
846 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
847 return None;
848 };
849
850 let UserContent::Text { text } = content.first() else {
851 return None;
852 };
853
854 Some(text)
855 }
856
857 fn get_usage(&self) -> Option<Self::Usage> {
858 self.usage.clone()
859 }
860}
861
862#[derive(Clone, Debug, Serialize, Deserialize)]
863pub struct Choice {
864 pub index: usize,
865 pub message: Message,
866 pub logprobs: Option<serde_json::Value>,
867 pub finish_reason: String,
868}
869
870#[derive(Clone, Debug, Deserialize, Serialize, Default)]
871pub struct PromptTokensDetails {
872 #[serde(default)]
874 pub cached_tokens: usize,
875}
876
877#[derive(Clone, Debug, Deserialize, Serialize)]
878pub struct Usage {
879 pub prompt_tokens: usize,
880 pub total_tokens: usize,
881 #[serde(skip_serializing_if = "Option::is_none")]
882 pub prompt_tokens_details: Option<PromptTokensDetails>,
883}
884
885impl Usage {
886 pub fn new() -> Self {
887 Self {
888 prompt_tokens: 0,
889 total_tokens: 0,
890 prompt_tokens_details: None,
891 }
892 }
893}
894
895impl Default for Usage {
896 fn default() -> Self {
897 Self::new()
898 }
899}
900
901impl fmt::Display for Usage {
902 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
903 let Usage {
904 prompt_tokens,
905 total_tokens,
906 ..
907 } = self;
908 write!(
909 f,
910 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
911 )
912 }
913}
914
915impl GetTokenUsage for Usage {
916 fn token_usage(&self) -> Option<crate::completion::Usage> {
917 let mut usage = crate::completion::Usage::new();
918 usage.input_tokens = self.prompt_tokens as u64;
919 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
920 usage.total_tokens = self.total_tokens as u64;
921 usage.cached_input_tokens = self
922 .prompt_tokens_details
923 .as_ref()
924 .map(|d| d.cached_tokens as u64)
925 .unwrap_or(0);
926
927 Some(usage)
928 }
929}
930
931#[derive(Clone)]
932pub struct CompletionModel<T = reqwest::Client> {
933 pub(crate) client: Client<T>,
934 pub model: String,
935 pub strict_tools: bool,
936 pub tool_result_array_content: bool,
937}
938
939impl<T> CompletionModel<T>
940where
941 T: Default + std::fmt::Debug + Clone + 'static,
942{
943 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
944 Self {
945 client,
946 model: model.into(),
947 strict_tools: false,
948 tool_result_array_content: false,
949 }
950 }
951
952 pub fn with_model(client: Client<T>, model: &str) -> Self {
953 Self {
954 client,
955 model: model.into(),
956 strict_tools: false,
957 tool_result_array_content: false,
958 }
959 }
960
961 pub fn with_strict_tools(mut self) -> Self {
970 self.strict_tools = true;
971 self
972 }
973
974 pub fn with_tool_result_array_content(mut self) -> Self {
975 self.tool_result_array_content = true;
976 self
977 }
978}
979
980#[derive(Debug, Serialize, Deserialize, Clone)]
981pub struct CompletionRequest {
982 model: String,
983 messages: Vec<Message>,
984 #[serde(skip_serializing_if = "Vec::is_empty")]
985 tools: Vec<ToolDefinition>,
986 #[serde(skip_serializing_if = "Option::is_none")]
987 tool_choice: Option<ToolChoice>,
988 #[serde(skip_serializing_if = "Option::is_none")]
989 temperature: Option<f64>,
990 #[serde(flatten)]
991 additional_params: Option<serde_json::Value>,
992}
993
994pub struct OpenAIRequestParams {
995 pub model: String,
996 pub request: CoreCompletionRequest,
997 pub strict_tools: bool,
998 pub tool_result_array_content: bool,
999}
1000
1001impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1002 type Error = CompletionError;
1003
1004 fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1005 let OpenAIRequestParams {
1006 model,
1007 request: req,
1008 strict_tools,
1009 tool_result_array_content,
1010 } = params;
1011
1012 let mut partial_history = vec![];
1013 if let Some(docs) = req.normalized_documents() {
1014 partial_history.push(docs);
1015 }
1016 let CoreCompletionRequest {
1017 preamble,
1018 chat_history,
1019 tools,
1020 temperature,
1021 additional_params,
1022 tool_choice,
1023 ..
1024 } = req;
1025
1026 partial_history.extend(chat_history);
1027
1028 let mut full_history: Vec<Message> =
1029 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1030
1031 full_history.extend(
1032 partial_history
1033 .into_iter()
1034 .map(message::Message::try_into)
1035 .collect::<Result<Vec<Vec<Message>>, _>>()?
1036 .into_iter()
1037 .flatten()
1038 .collect::<Vec<_>>(),
1039 );
1040
1041 if tool_result_array_content {
1042 for msg in &mut full_history {
1043 if let Message::ToolResult { content, .. } = msg {
1044 *content = content.to_array();
1045 }
1046 }
1047 }
1048
1049 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1050
1051 let tools: Vec<ToolDefinition> = tools
1052 .into_iter()
1053 .map(|tool| {
1054 let def = ToolDefinition::from(tool);
1055 if strict_tools { def.with_strict() } else { def }
1056 })
1057 .collect();
1058
1059 let res = Self {
1060 model,
1061 messages: full_history,
1062 tools,
1063 tool_choice,
1064 temperature,
1065 additional_params,
1066 };
1067
1068 Ok(res)
1069 }
1070}
1071
1072impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1073 type Error = CompletionError;
1074
1075 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1076 CompletionRequest::try_from(OpenAIRequestParams {
1077 model,
1078 request: req,
1079 strict_tools: false,
1080 tool_result_array_content: false,
1081 })
1082 }
1083}
1084
1085impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1086 type InputMessage = Message;
1087
1088 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1089 self.messages.clone()
1090 }
1091
1092 fn get_system_prompt(&self) -> Option<String> {
1093 let first_message = self.messages.first()?;
1094
1095 let Message::System { ref content, .. } = first_message.clone() else {
1096 return None;
1097 };
1098
1099 let SystemContent { text, .. } = content.first();
1100
1101 Some(text)
1102 }
1103
1104 fn get_prompt(&self) -> Option<String> {
1105 let last_message = self.messages.last()?;
1106
1107 let Message::User { ref content, .. } = last_message.clone() else {
1108 return None;
1109 };
1110
1111 let UserContent::Text { text } = content.first() else {
1112 return None;
1113 };
1114
1115 Some(text)
1116 }
1117
1118 fn get_model_name(&self) -> String {
1119 self.model.clone()
1120 }
1121}
1122
1123impl CompletionModel<reqwest::Client> {
1124 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1125 crate::agent::AgentBuilder::new(self)
1126 }
1127}
1128
1129impl<T> completion::CompletionModel for CompletionModel<T>
1130where
1131 T: HttpClientExt
1132 + Default
1133 + std::fmt::Debug
1134 + Clone
1135 + WasmCompatSend
1136 + WasmCompatSync
1137 + 'static,
1138{
1139 type Response = CompletionResponse;
1140 type StreamingResponse = StreamingCompletionResponse;
1141
1142 type Client = super::CompletionsClient<T>;
1143
1144 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1145 Self::new(client.clone(), model)
1146 }
1147
1148 async fn completion(
1149 &self,
1150 completion_request: CoreCompletionRequest,
1151 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1152 let span = if tracing::Span::current().is_disabled() {
1153 info_span!(
1154 target: "rig::completions",
1155 "chat",
1156 gen_ai.operation.name = "chat",
1157 gen_ai.provider.name = "openai",
1158 gen_ai.request.model = self.model,
1159 gen_ai.system_instructions = &completion_request.preamble,
1160 gen_ai.response.id = tracing::field::Empty,
1161 gen_ai.response.model = tracing::field::Empty,
1162 gen_ai.usage.output_tokens = tracing::field::Empty,
1163 gen_ai.usage.input_tokens = tracing::field::Empty,
1164 )
1165 } else {
1166 tracing::Span::current()
1167 };
1168
1169 let request = CompletionRequest::try_from(OpenAIRequestParams {
1170 model: self.model.to_owned(),
1171 request: completion_request,
1172 strict_tools: self.strict_tools,
1173 tool_result_array_content: self.tool_result_array_content,
1174 })?;
1175
1176 if enabled!(Level::TRACE) {
1177 tracing::trace!(
1178 target: "rig::completions",
1179 "OpenAI Chat Completions completion request: {}",
1180 serde_json::to_string_pretty(&request)?
1181 );
1182 }
1183
1184 let body = serde_json::to_vec(&request)?;
1185
1186 let req = self
1187 .client
1188 .post("/chat/completions")?
1189 .body(body)
1190 .map_err(|e| CompletionError::HttpError(e.into()))?;
1191
1192 async move {
1193 let response = self.client.send(req).await?;
1194
1195 if response.status().is_success() {
1196 let text = http_client::text(response).await?;
1197
1198 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1199 ApiResponse::Ok(response) => {
1200 let span = tracing::Span::current();
1201 span.record_response_metadata(&response);
1202 span.record_token_usage(&response.usage);
1203
1204 if enabled!(Level::TRACE) {
1205 tracing::trace!(
1206 target: "rig::completions",
1207 "OpenAI Chat Completions completion response: {}",
1208 serde_json::to_string_pretty(&response)?
1209 );
1210 }
1211
1212 response.try_into()
1213 }
1214 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1215 }
1216 } else {
1217 let text = http_client::text(response).await?;
1218 Err(CompletionError::ProviderError(text))
1219 }
1220 }
1221 .instrument(span)
1222 .await
1223 }
1224
1225 async fn stream(
1226 &self,
1227 request: CoreCompletionRequest,
1228 ) -> Result<
1229 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1230 CompletionError,
1231 > {
1232 Self::stream(self, request).await
1233 }
1234}