1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28pub const GPT_5_1: &str = "gpt-5.1";
30
31pub const GPT_5: &str = "gpt-5";
33pub const GPT_5_MINI: &str = "gpt-5-mini";
35pub const GPT_5_NANO: &str = "gpt-5-nano";
37
38pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
40pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
42pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
44pub const GPT_4O: &str = "gpt-4o";
46pub const GPT_4O_MINI: &str = "gpt-4o-mini";
48pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
50pub const GPT_4_TURBO: &str = "gpt-4-turbo";
52pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
54pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
56pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
58pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
60pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
62pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
64pub const GPT_4: &str = "gpt-4";
66pub const GPT_4_0613: &str = "gpt-4-0613";
68pub const GPT_4_32K: &str = "gpt-4-32k";
70pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
72
73pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
75pub const O4_MINI: &str = "o4-mini";
77pub const O3: &str = "o3";
79pub const O3_MINI: &str = "o3-mini";
81pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
83pub const O1_PRO: &str = "o1-pro";
85pub const O1: &str = "o1";
87pub const O1_2024_12_17: &str = "o1-2024-12-17";
89pub const O1_PREVIEW: &str = "o1-preview";
91pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
93pub const O1_MINI: &str = "o1-mini";
95pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
97
98pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
100pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
102pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
104pub const GPT_4_1: &str = "gpt-4.1";
106
107impl From<ApiErrorResponse> for CompletionError {
108 fn from(err: ApiErrorResponse) -> Self {
109 CompletionError::ProviderError(err.message)
110 }
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
114#[serde(tag = "role", rename_all = "lowercase")]
115pub enum Message {
116 #[serde(alias = "developer")]
117 System {
118 #[serde(deserialize_with = "string_or_one_or_many")]
119 content: OneOrMany<SystemContent>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 name: Option<String>,
122 },
123 User {
124 #[serde(deserialize_with = "string_or_one_or_many")]
125 content: OneOrMany<UserContent>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 name: Option<String>,
128 },
129 Assistant {
130 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
131 content: Vec<AssistantContent>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 refusal: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 audio: Option<AudioAssistant>,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 name: Option<String>,
138 #[serde(
139 default,
140 deserialize_with = "json_utils::null_or_vec",
141 skip_serializing_if = "Vec::is_empty"
142 )]
143 tool_calls: Vec<ToolCall>,
144 },
145 #[serde(rename = "tool")]
146 ToolResult {
147 tool_call_id: String,
148 content: ToolResultContentValue,
149 },
150}
151
152impl Message {
153 pub fn system(content: &str) -> Self {
154 Message::System {
155 content: OneOrMany::one(content.to_owned().into()),
156 name: None,
157 }
158 }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct AudioAssistant {
163 pub id: String,
164}
165
166#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
167pub struct SystemContent {
168 #[serde(default)]
169 pub r#type: SystemContentType,
170 pub text: String,
171}
172
173#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
174#[serde(rename_all = "lowercase")]
175pub enum SystemContentType {
176 #[default]
177 Text,
178}
179
180#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum AssistantContent {
183 Text { text: String },
184 Refusal { refusal: String },
185}
186
187impl From<AssistantContent> for completion::AssistantContent {
188 fn from(value: AssistantContent) -> Self {
189 match value {
190 AssistantContent::Text { text } => completion::AssistantContent::text(text),
191 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
192 }
193 }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum UserContent {
199 Text {
200 text: String,
201 },
202 #[serde(rename = "image_url")]
203 Image {
204 image_url: ImageUrl,
205 },
206 Audio {
207 input_audio: InputAudio,
208 },
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ImageUrl {
213 pub url: String,
214 #[serde(default)]
215 pub detail: ImageDetail,
216}
217
218#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
219pub struct InputAudio {
220 pub data: String,
221 pub format: AudioMediaType,
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225pub struct ToolResultContent {
226 #[serde(default)]
227 r#type: ToolResultContentType,
228 pub text: String,
229}
230
231#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(rename_all = "lowercase")]
233pub enum ToolResultContentType {
234 #[default]
235 Text,
236}
237
238impl FromStr for ToolResultContent {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(s.to_owned().into())
243 }
244}
245
246impl From<String> for ToolResultContent {
247 fn from(s: String) -> Self {
248 ToolResultContent {
249 r#type: ToolResultContentType::default(),
250 text: s,
251 }
252 }
253}
254
255#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
256#[serde(untagged)]
257pub enum ToolResultContentValue {
258 Array(Vec<ToolResultContent>),
259 String(String),
260}
261
262impl ToolResultContentValue {
263 pub fn from_string(s: String, use_array_format: bool) -> Self {
264 if use_array_format {
265 ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
266 } else {
267 ToolResultContentValue::String(s)
268 }
269 }
270
271 pub fn as_text(&self) -> String {
272 match self {
273 ToolResultContentValue::Array(arr) => arr
274 .iter()
275 .map(|c| c.text.clone())
276 .collect::<Vec<_>>()
277 .join("\n"),
278 ToolResultContentValue::String(s) => s.clone(),
279 }
280 }
281
282 pub fn to_array(&self) -> Self {
283 match self {
284 ToolResultContentValue::Array(_) => self.clone(),
285 ToolResultContentValue::String(s) => {
286 ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
287 }
288 }
289 }
290}
291
292#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
293pub struct ToolCall {
294 pub id: String,
295 #[serde(default)]
296 pub r#type: ToolType,
297 pub function: Function,
298}
299
300#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
301#[serde(rename_all = "lowercase")]
302pub enum ToolType {
303 #[default]
304 Function,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone)]
309pub struct FunctionDefinition {
310 pub name: String,
311 pub description: String,
312 pub parameters: serde_json::Value,
313 #[serde(skip_serializing_if = "Option::is_none")]
314 pub strict: Option<bool>,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone)]
318pub struct ToolDefinition {
319 pub r#type: String,
320 pub function: FunctionDefinition,
321}
322
323impl From<completion::ToolDefinition> for ToolDefinition {
324 fn from(tool: completion::ToolDefinition) -> Self {
325 Self {
326 r#type: "function".into(),
327 function: FunctionDefinition {
328 name: tool.name,
329 description: tool.description,
330 parameters: tool.parameters,
331 strict: None,
332 },
333 }
334 }
335}
336
337impl ToolDefinition {
338 pub fn with_strict(mut self) -> Self {
341 self.function.strict = Some(true);
342 super::sanitize_schema(&mut self.function.parameters);
343 self
344 }
345}
346
347#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
348#[serde(rename_all = "snake_case")]
349pub enum ToolChoice {
350 #[default]
351 Auto,
352 None,
353 Required,
354}
355
356impl TryFrom<crate::message::ToolChoice> for ToolChoice {
357 type Error = CompletionError;
358 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
359 let res = match value {
360 message::ToolChoice::Specific { .. } => {
361 return Err(CompletionError::ProviderError(
362 "Provider doesn't support only using specific tools".to_string(),
363 ));
364 }
365 message::ToolChoice::Auto => Self::Auto,
366 message::ToolChoice::None => Self::None,
367 message::ToolChoice::Required => Self::Required,
368 };
369
370 Ok(res)
371 }
372}
373
374#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
375pub struct Function {
376 pub name: String,
377 #[serde(with = "json_utils::stringified_json")]
378 pub arguments: serde_json::Value,
379}
380
381impl TryFrom<message::ToolResult> for Message {
382 type Error = message::MessageError;
383
384 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
385 let text = value
386 .content
387 .into_iter()
388 .map(|content| match content {
389 message::ToolResultContent::Text(message::Text { text }) => Ok(text),
390 _ => Err(message::MessageError::ConversionError(
391 "Tool result content does not support non-text".into(),
392 )),
393 })
394 .collect::<Result<Vec<_>, _>>()?
395 .join("\n");
396
397 Ok(Message::ToolResult {
398 tool_call_id: value.id,
399 content: ToolResultContentValue::String(text),
400 })
401 }
402}
403
404impl TryFrom<message::UserContent> for UserContent {
405 type Error = message::MessageError;
406
407 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
408 match value {
409 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
410 message::UserContent::Image(message::Image {
411 data,
412 detail,
413 media_type,
414 ..
415 }) => match data {
416 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
417 image_url: ImageUrl {
418 url,
419 detail: detail.unwrap_or_default(),
420 },
421 }),
422 DocumentSourceKind::Base64(data) => {
423 let url = format!(
424 "data:{};base64,{}",
425 media_type.map(|i| i.to_mime_type()).ok_or(
426 message::MessageError::ConversionError(
427 "OpenAI Image URI must have media type".into()
428 )
429 )?,
430 data
431 );
432
433 let detail = detail.ok_or(message::MessageError::ConversionError(
434 "OpenAI image URI must have image detail".into(),
435 ))?;
436
437 Ok(UserContent::Image {
438 image_url: ImageUrl { url, detail },
439 })
440 }
441 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
442 "Raw files not supported, encode as base64 first".into(),
443 )),
444 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
445 "Document has no body".into(),
446 )),
447 doc => Err(message::MessageError::ConversionError(format!(
448 "Unsupported document type: {doc:?}"
449 ))),
450 },
451 message::UserContent::Document(message::Document { data, .. }) => {
452 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
453 Ok(UserContent::Text { text })
454 } else {
455 Err(message::MessageError::ConversionError(
456 "Documents must be base64 or a string".into(),
457 ))
458 }
459 }
460 message::UserContent::Audio(message::Audio {
461 data, media_type, ..
462 }) => match data {
463 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
464 input_audio: InputAudio {
465 data,
466 format: match media_type {
467 Some(media_type) => media_type,
468 None => AudioMediaType::MP3,
469 },
470 },
471 }),
472 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
473 "URLs are not supported for audio".into(),
474 )),
475 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
476 "Raw files are not supported for audio".into(),
477 )),
478 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
479 "Audio has no body".into(),
480 )),
481 audio => Err(message::MessageError::ConversionError(format!(
482 "Unsupported audio type: {audio:?}"
483 ))),
484 },
485 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
486 "Tool result is in unsupported format".into(),
487 )),
488 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
489 "Video is in unsupported format".into(),
490 )),
491 }
492 }
493}
494
495impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
496 type Error = message::MessageError;
497
498 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
499 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
500 .into_iter()
501 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
502
503 if !tool_results.is_empty() {
506 tool_results
507 .into_iter()
508 .map(|content| match content {
509 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
510 _ => unreachable!(),
511 })
512 .collect::<Result<Vec<_>, _>>()
513 } else {
514 let other_content: Vec<UserContent> = other_content
515 .into_iter()
516 .map(|content| content.try_into())
517 .collect::<Result<Vec<_>, _>>()?;
518
519 let other_content = OneOrMany::many(other_content)
520 .expect("There must be other content here if there were no tool result content");
521
522 Ok(vec![Message::User {
523 content: other_content,
524 name: None,
525 }])
526 }
527 }
528}
529
530impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
531 type Error = message::MessageError;
532
533 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
534 let (text_content, tool_calls) = value.into_iter().fold(
535 (Vec::new(), Vec::new()),
536 |(mut texts, mut tools), content| {
537 match content {
538 message::AssistantContent::Text(text) => texts.push(text),
539 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
540 message::AssistantContent::Reasoning(_) => {
541 panic!("The OpenAI Completions API doesn't support reasoning!");
542 }
543 message::AssistantContent::Image(_) => {
544 panic!(
545 "The OpenAI Completions API doesn't support image content in assistant messages!"
546 );
547 }
548 }
549 (texts, tools)
550 },
551 );
552
553 Ok(vec![Message::Assistant {
556 content: text_content
557 .into_iter()
558 .map(|content| content.text.into())
559 .collect::<Vec<_>>(),
560 refusal: None,
561 audio: None,
562 name: None,
563 tool_calls: tool_calls
564 .into_iter()
565 .map(|tool_call| tool_call.into())
566 .collect::<Vec<_>>(),
567 }])
568 }
569}
570
571impl TryFrom<message::Message> for Vec<Message> {
572 type Error = message::MessageError;
573
574 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
575 match message {
576 message::Message::User { content } => content.try_into(),
577 message::Message::Assistant { content, .. } => content.try_into(),
578 }
579 }
580}
581
582impl From<message::ToolCall> for ToolCall {
583 fn from(tool_call: message::ToolCall) -> Self {
584 Self {
585 id: tool_call.id,
586 r#type: ToolType::default(),
587 function: Function {
588 name: tool_call.function.name,
589 arguments: tool_call.function.arguments,
590 },
591 }
592 }
593}
594
595impl From<ToolCall> for message::ToolCall {
596 fn from(tool_call: ToolCall) -> Self {
597 Self {
598 id: tool_call.id,
599 call_id: None,
600 function: message::ToolFunction {
601 name: tool_call.function.name,
602 arguments: tool_call.function.arguments,
603 },
604 signature: None,
605 additional_params: None,
606 }
607 }
608}
609
610impl TryFrom<Message> for message::Message {
611 type Error = message::MessageError;
612
613 fn try_from(message: Message) -> Result<Self, Self::Error> {
614 Ok(match message {
615 Message::User { content, .. } => message::Message::User {
616 content: content.map(|content| content.into()),
617 },
618 Message::Assistant {
619 content,
620 tool_calls,
621 ..
622 } => {
623 let mut content = content
624 .into_iter()
625 .map(|content| match content {
626 AssistantContent::Text { text } => message::AssistantContent::text(text),
627
628 AssistantContent::Refusal { refusal } => {
631 message::AssistantContent::text(refusal)
632 }
633 })
634 .collect::<Vec<_>>();
635
636 content.extend(
637 tool_calls
638 .into_iter()
639 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
640 .collect::<Result<Vec<_>, _>>()?,
641 );
642
643 message::Message::Assistant {
644 id: None,
645 content: OneOrMany::many(content).map_err(|_| {
646 message::MessageError::ConversionError(
647 "Neither `content` nor `tool_calls` was provided to the Message"
648 .to_owned(),
649 )
650 })?,
651 }
652 }
653
654 Message::ToolResult {
655 tool_call_id,
656 content,
657 } => message::Message::User {
658 content: OneOrMany::one(message::UserContent::tool_result(
659 tool_call_id,
660 OneOrMany::one(message::ToolResultContent::text(content.as_text())),
661 )),
662 },
663
664 Message::System { content, .. } => message::Message::User {
667 content: content.map(|content| message::UserContent::text(content.text)),
668 },
669 })
670 }
671}
672
673impl From<UserContent> for message::UserContent {
674 fn from(content: UserContent) -> Self {
675 match content {
676 UserContent::Text { text } => message::UserContent::text(text),
677 UserContent::Image { image_url } => {
678 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
679 }
680 UserContent::Audio { input_audio } => {
681 message::UserContent::audio(input_audio.data, Some(input_audio.format))
682 }
683 }
684 }
685}
686
687impl From<String> for UserContent {
688 fn from(s: String) -> Self {
689 UserContent::Text { text: s }
690 }
691}
692
693impl FromStr for UserContent {
694 type Err = Infallible;
695
696 fn from_str(s: &str) -> Result<Self, Self::Err> {
697 Ok(UserContent::Text {
698 text: s.to_string(),
699 })
700 }
701}
702
703impl From<String> for AssistantContent {
704 fn from(s: String) -> Self {
705 AssistantContent::Text { text: s }
706 }
707}
708
709impl FromStr for AssistantContent {
710 type Err = Infallible;
711
712 fn from_str(s: &str) -> Result<Self, Self::Err> {
713 Ok(AssistantContent::Text {
714 text: s.to_string(),
715 })
716 }
717}
718impl From<String> for SystemContent {
719 fn from(s: String) -> Self {
720 SystemContent {
721 r#type: SystemContentType::default(),
722 text: s,
723 }
724 }
725}
726
727impl FromStr for SystemContent {
728 type Err = Infallible;
729
730 fn from_str(s: &str) -> Result<Self, Self::Err> {
731 Ok(SystemContent {
732 r#type: SystemContentType::default(),
733 text: s.to_string(),
734 })
735 }
736}
737
738#[derive(Debug, Deserialize, Serialize)]
739pub struct CompletionResponse {
740 pub id: String,
741 pub object: String,
742 pub created: u64,
743 pub model: String,
744 pub system_fingerprint: Option<String>,
745 pub choices: Vec<Choice>,
746 pub usage: Option<Usage>,
747}
748
749impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
750 type Error = CompletionError;
751
752 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
753 let choice = response.choices.first().ok_or_else(|| {
754 CompletionError::ResponseError("Response contained no choices".to_owned())
755 })?;
756
757 let content = match &choice.message {
758 Message::Assistant {
759 content,
760 tool_calls,
761 ..
762 } => {
763 let mut content = content
764 .iter()
765 .filter_map(|c| {
766 let s = match c {
767 AssistantContent::Text { text } => text,
768 AssistantContent::Refusal { refusal } => refusal,
769 };
770 if s.is_empty() {
771 None
772 } else {
773 Some(completion::AssistantContent::text(s))
774 }
775 })
776 .collect::<Vec<_>>();
777
778 content.extend(
779 tool_calls
780 .iter()
781 .map(|call| {
782 completion::AssistantContent::tool_call(
783 &call.id,
784 &call.function.name,
785 call.function.arguments.clone(),
786 )
787 })
788 .collect::<Vec<_>>(),
789 );
790 Ok(content)
791 }
792 _ => Err(CompletionError::ResponseError(
793 "Response did not contain a valid message or tool call".into(),
794 )),
795 }?;
796
797 let choice = OneOrMany::many(content).map_err(|_| {
798 CompletionError::ResponseError(
799 "Response contained no message or tool call (empty)".to_owned(),
800 )
801 })?;
802
803 let usage = response
804 .usage
805 .as_ref()
806 .map(|usage| completion::Usage {
807 input_tokens: usage.prompt_tokens as u64,
808 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
809 total_tokens: usage.total_tokens as u64,
810 })
811 .unwrap_or_default();
812
813 Ok(completion::CompletionResponse {
814 choice,
815 usage,
816 raw_response: response,
817 })
818 }
819}
820
821impl ProviderResponseExt for CompletionResponse {
822 type OutputMessage = Choice;
823 type Usage = Usage;
824
825 fn get_response_id(&self) -> Option<String> {
826 Some(self.id.to_owned())
827 }
828
829 fn get_response_model_name(&self) -> Option<String> {
830 Some(self.model.to_owned())
831 }
832
833 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
834 self.choices.clone()
835 }
836
837 fn get_text_response(&self) -> Option<String> {
838 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
839 return None;
840 };
841
842 let UserContent::Text { text } = content.first() else {
843 return None;
844 };
845
846 Some(text)
847 }
848
849 fn get_usage(&self) -> Option<Self::Usage> {
850 self.usage.clone()
851 }
852}
853
854#[derive(Clone, Debug, Serialize, Deserialize)]
855pub struct Choice {
856 pub index: usize,
857 pub message: Message,
858 pub logprobs: Option<serde_json::Value>,
859 pub finish_reason: String,
860}
861
862#[derive(Clone, Debug, Deserialize, Serialize)]
863pub struct Usage {
864 pub prompt_tokens: usize,
865 pub total_tokens: usize,
866}
867
868impl Usage {
869 pub fn new() -> Self {
870 Self {
871 prompt_tokens: 0,
872 total_tokens: 0,
873 }
874 }
875}
876
877impl Default for Usage {
878 fn default() -> Self {
879 Self::new()
880 }
881}
882
883impl fmt::Display for Usage {
884 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
885 let Usage {
886 prompt_tokens,
887 total_tokens,
888 } = self;
889 write!(
890 f,
891 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
892 )
893 }
894}
895
896impl GetTokenUsage for Usage {
897 fn token_usage(&self) -> Option<crate::completion::Usage> {
898 let mut usage = crate::completion::Usage::new();
899 usage.input_tokens = self.prompt_tokens as u64;
900 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
901 usage.total_tokens = self.total_tokens as u64;
902
903 Some(usage)
904 }
905}
906
907#[derive(Clone)]
908pub struct CompletionModel<T = reqwest::Client> {
909 pub(crate) client: Client<T>,
910 pub model: String,
911 pub strict_tools: bool,
912 pub tool_result_array_content: bool,
913}
914
915impl<T> CompletionModel<T>
916where
917 T: Default + std::fmt::Debug + Clone + 'static,
918{
919 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
920 Self {
921 client,
922 model: model.into(),
923 strict_tools: false,
924 tool_result_array_content: false,
925 }
926 }
927
928 pub fn with_model(client: Client<T>, model: &str) -> Self {
929 Self {
930 client,
931 model: model.into(),
932 strict_tools: false,
933 tool_result_array_content: false,
934 }
935 }
936
937 pub fn with_strict_tools(mut self) -> Self {
946 self.strict_tools = true;
947 self
948 }
949
950 pub fn with_tool_result_array_content(mut self) -> Self {
951 self.tool_result_array_content = true;
952 self
953 }
954}
955
956#[derive(Debug, Serialize, Deserialize, Clone)]
957pub struct CompletionRequest {
958 model: String,
959 messages: Vec<Message>,
960 #[serde(skip_serializing_if = "Vec::is_empty")]
961 tools: Vec<ToolDefinition>,
962 #[serde(skip_serializing_if = "Option::is_none")]
963 tool_choice: Option<ToolChoice>,
964 #[serde(skip_serializing_if = "Option::is_none")]
965 temperature: Option<f64>,
966 #[serde(flatten)]
967 additional_params: Option<serde_json::Value>,
968}
969
970pub struct OpenAIRequestParams {
971 pub model: String,
972 pub request: CoreCompletionRequest,
973 pub strict_tools: bool,
974 pub tool_result_array_content: bool,
975}
976
977impl TryFrom<OpenAIRequestParams> for CompletionRequest {
978 type Error = CompletionError;
979
980 fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
981 let OpenAIRequestParams {
982 model,
983 request: req,
984 strict_tools,
985 tool_result_array_content,
986 } = params;
987
988 let mut partial_history = vec![];
989 if let Some(docs) = req.normalized_documents() {
990 partial_history.push(docs);
991 }
992 let CoreCompletionRequest {
993 preamble,
994 chat_history,
995 tools,
996 temperature,
997 additional_params,
998 tool_choice,
999 ..
1000 } = req;
1001
1002 partial_history.extend(chat_history);
1003
1004 let mut full_history: Vec<Message> =
1005 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1006
1007 full_history.extend(
1008 partial_history
1009 .into_iter()
1010 .map(message::Message::try_into)
1011 .collect::<Result<Vec<Vec<Message>>, _>>()?
1012 .into_iter()
1013 .flatten()
1014 .collect::<Vec<_>>(),
1015 );
1016
1017 if tool_result_array_content {
1018 for msg in &mut full_history {
1019 if let Message::ToolResult { content, .. } = msg {
1020 *content = content.to_array();
1021 }
1022 }
1023 }
1024
1025 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1026
1027 let tools: Vec<ToolDefinition> = tools
1028 .into_iter()
1029 .map(|tool| {
1030 let def = ToolDefinition::from(tool);
1031 if strict_tools { def.with_strict() } else { def }
1032 })
1033 .collect();
1034
1035 let res = Self {
1036 model,
1037 messages: full_history,
1038 tools,
1039 tool_choice,
1040 temperature,
1041 additional_params,
1042 };
1043
1044 Ok(res)
1045 }
1046}
1047
1048impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1049 type Error = CompletionError;
1050
1051 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1052 CompletionRequest::try_from(OpenAIRequestParams {
1053 model,
1054 request: req,
1055 strict_tools: false,
1056 tool_result_array_content: false,
1057 })
1058 }
1059}
1060
1061impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1062 type InputMessage = Message;
1063
1064 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1065 self.messages.clone()
1066 }
1067
1068 fn get_system_prompt(&self) -> Option<String> {
1069 let first_message = self.messages.first()?;
1070
1071 let Message::System { ref content, .. } = first_message.clone() else {
1072 return None;
1073 };
1074
1075 let SystemContent { text, .. } = content.first();
1076
1077 Some(text)
1078 }
1079
1080 fn get_prompt(&self) -> Option<String> {
1081 let last_message = self.messages.last()?;
1082
1083 let Message::User { ref content, .. } = last_message.clone() else {
1084 return None;
1085 };
1086
1087 let UserContent::Text { text } = content.first() else {
1088 return None;
1089 };
1090
1091 Some(text)
1092 }
1093
1094 fn get_model_name(&self) -> String {
1095 self.model.clone()
1096 }
1097}
1098
1099impl CompletionModel<reqwest::Client> {
1100 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1101 crate::agent::AgentBuilder::new(self)
1102 }
1103}
1104
1105impl<T> completion::CompletionModel for CompletionModel<T>
1106where
1107 T: HttpClientExt
1108 + Default
1109 + std::fmt::Debug
1110 + Clone
1111 + WasmCompatSend
1112 + WasmCompatSync
1113 + 'static,
1114{
1115 type Response = CompletionResponse;
1116 type StreamingResponse = StreamingCompletionResponse;
1117
1118 type Client = super::CompletionsClient<T>;
1119
1120 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1121 Self::new(client.clone(), model)
1122 }
1123
1124 async fn completion(
1125 &self,
1126 completion_request: CoreCompletionRequest,
1127 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1128 let span = if tracing::Span::current().is_disabled() {
1129 info_span!(
1130 target: "rig::completions",
1131 "chat",
1132 gen_ai.operation.name = "chat",
1133 gen_ai.provider.name = "openai",
1134 gen_ai.request.model = self.model,
1135 gen_ai.system_instructions = &completion_request.preamble,
1136 gen_ai.response.id = tracing::field::Empty,
1137 gen_ai.response.model = tracing::field::Empty,
1138 gen_ai.usage.output_tokens = tracing::field::Empty,
1139 gen_ai.usage.input_tokens = tracing::field::Empty,
1140 )
1141 } else {
1142 tracing::Span::current()
1143 };
1144
1145 let request = CompletionRequest::try_from(OpenAIRequestParams {
1146 model: self.model.to_owned(),
1147 request: completion_request,
1148 strict_tools: self.strict_tools,
1149 tool_result_array_content: self.tool_result_array_content,
1150 })?;
1151
1152 if enabled!(Level::TRACE) {
1153 tracing::trace!(
1154 target: "rig::completions",
1155 "OpenAI Chat Completions completion request: {}",
1156 serde_json::to_string_pretty(&request)?
1157 );
1158 }
1159
1160 let body = serde_json::to_vec(&request)?;
1161
1162 let req = self
1163 .client
1164 .post("/chat/completions")?
1165 .body(body)
1166 .map_err(|e| CompletionError::HttpError(e.into()))?;
1167
1168 async move {
1169 let response = self.client.send(req).await?;
1170
1171 if response.status().is_success() {
1172 let text = http_client::text(response).await?;
1173
1174 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1175 ApiResponse::Ok(response) => {
1176 let span = tracing::Span::current();
1177 span.record_response_metadata(&response);
1178 span.record_token_usage(&response.usage);
1179
1180 if enabled!(Level::TRACE) {
1181 tracing::trace!(
1182 target: "rig::completions",
1183 "OpenAI Chat Completions completion response: {}",
1184 serde_json::to_string_pretty(&response)?
1185 );
1186 }
1187
1188 response.try_into()
1189 }
1190 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1191 }
1192 } else {
1193 let text = http_client::text(response).await?;
1194 Err(CompletionError::ProviderError(text))
1195 }
1196 }
1197 .instrument(span)
1198 .await
1199 }
1200
1201 async fn stream(
1202 &self,
1203 request: CoreCompletionRequest,
1204 ) -> Result<
1205 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1206 CompletionError,
1207 > {
1208 Self::stream(self, request).await
1209 }
1210}