1use super::{
6 CompletionsClient as Client,
7 client::{ApiErrorResponse, ApiResponse},
8 streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28pub const GPT_5_1: &str = "gpt-5.1";
30
31pub const GPT_5: &str = "gpt-5";
33pub const GPT_5_MINI: &str = "gpt-5-mini";
35pub const GPT_5_NANO: &str = "gpt-5-nano";
37
38pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
40pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
42pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
44pub const GPT_4O: &str = "gpt-4o";
46pub const GPT_4O_MINI: &str = "gpt-4o-mini";
48pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
50pub const GPT_4_TURBO: &str = "gpt-4-turbo";
52pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
54pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
56pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
58pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
60pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
62pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
64pub const GPT_4: &str = "gpt-4";
66pub const GPT_4_0613: &str = "gpt-4-0613";
68pub const GPT_4_32K: &str = "gpt-4-32k";
70pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
72
73pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
75pub const O4_MINI: &str = "o4-mini";
77pub const O3: &str = "o3";
79pub const O3_MINI: &str = "o3-mini";
81pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
83pub const O1_PRO: &str = "o1-pro";
85pub const O1: &str = "o1";
87pub const O1_2024_12_17: &str = "o1-2024-12-17";
89pub const O1_PREVIEW: &str = "o1-preview";
91pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
93pub const O1_MINI: &str = "o1-mini";
95pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
97
98pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
100pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
102pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
104pub const GPT_4_1: &str = "gpt-4.1";
106
107impl From<ApiErrorResponse> for CompletionError {
108 fn from(err: ApiErrorResponse) -> Self {
109 CompletionError::ProviderError(err.message)
110 }
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
114#[serde(tag = "role", rename_all = "lowercase")]
115pub enum Message {
116 #[serde(alias = "developer")]
117 System {
118 #[serde(deserialize_with = "string_or_one_or_many")]
119 content: OneOrMany<SystemContent>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 name: Option<String>,
122 },
123 User {
124 #[serde(deserialize_with = "string_or_one_or_many")]
125 content: OneOrMany<UserContent>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 name: Option<String>,
128 },
129 Assistant {
130 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
131 content: Vec<AssistantContent>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 refusal: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 audio: Option<AudioAssistant>,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 name: Option<String>,
138 #[serde(
139 default,
140 deserialize_with = "json_utils::null_or_vec",
141 skip_serializing_if = "Vec::is_empty"
142 )]
143 tool_calls: Vec<ToolCall>,
144 },
145 #[serde(rename = "tool")]
146 ToolResult {
147 tool_call_id: String,
148 content: OneOrMany<ToolResultContent>,
149 },
150}
151
152impl Message {
153 pub fn system(content: &str) -> Self {
154 Message::System {
155 content: OneOrMany::one(content.to_owned().into()),
156 name: None,
157 }
158 }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct AudioAssistant {
163 pub id: String,
164}
165
166#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
167pub struct SystemContent {
168 #[serde(default)]
169 pub r#type: SystemContentType,
170 pub text: String,
171}
172
173#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
174#[serde(rename_all = "lowercase")]
175pub enum SystemContentType {
176 #[default]
177 Text,
178}
179
180#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum AssistantContent {
183 Text { text: String },
184 Refusal { refusal: String },
185}
186
187impl From<AssistantContent> for completion::AssistantContent {
188 fn from(value: AssistantContent) -> Self {
189 match value {
190 AssistantContent::Text { text } => completion::AssistantContent::text(text),
191 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
192 }
193 }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum UserContent {
199 Text {
200 text: String,
201 },
202 #[serde(rename = "image_url")]
203 Image {
204 image_url: ImageUrl,
205 },
206 Audio {
207 input_audio: InputAudio,
208 },
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ImageUrl {
213 pub url: String,
214 #[serde(default)]
215 pub detail: ImageDetail,
216}
217
218#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
219pub struct InputAudio {
220 pub data: String,
221 pub format: AudioMediaType,
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225pub struct ToolResultContent {
226 #[serde(default)]
227 r#type: ToolResultContentType,
228 pub text: String,
229}
230
231#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(rename_all = "lowercase")]
233pub enum ToolResultContentType {
234 #[default]
235 Text,
236}
237
238impl FromStr for ToolResultContent {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(s.to_owned().into())
243 }
244}
245
246impl From<String> for ToolResultContent {
247 fn from(s: String) -> Self {
248 ToolResultContent {
249 r#type: ToolResultContentType::default(),
250 text: s,
251 }
252 }
253}
254
255#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
256pub struct ToolCall {
257 pub id: String,
258 #[serde(default)]
259 pub r#type: ToolType,
260 pub function: Function,
261}
262
263#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
264#[serde(rename_all = "lowercase")]
265pub enum ToolType {
266 #[default]
267 Function,
268}
269
270#[derive(Debug, Deserialize, Serialize, Clone)]
271pub struct ToolDefinition {
272 pub r#type: String,
273 pub function: completion::ToolDefinition,
274}
275
276impl From<completion::ToolDefinition> for ToolDefinition {
277 fn from(tool: completion::ToolDefinition) -> Self {
278 Self {
279 r#type: "function".into(),
280 function: tool,
281 }
282 }
283}
284
285#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
286#[serde(rename_all = "snake_case")]
287pub enum ToolChoice {
288 #[default]
289 Auto,
290 None,
291 Required,
292}
293
294impl TryFrom<crate::message::ToolChoice> for ToolChoice {
295 type Error = CompletionError;
296 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
297 let res = match value {
298 message::ToolChoice::Specific { .. } => {
299 return Err(CompletionError::ProviderError(
300 "Provider doesn't support only using specific tools".to_string(),
301 ));
302 }
303 message::ToolChoice::Auto => Self::Auto,
304 message::ToolChoice::None => Self::None,
305 message::ToolChoice::Required => Self::Required,
306 };
307
308 Ok(res)
309 }
310}
311
312#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
313pub struct Function {
314 pub name: String,
315 #[serde(with = "json_utils::stringified_json")]
316 pub arguments: serde_json::Value,
317}
318
319impl TryFrom<message::ToolResult> for Message {
320 type Error = message::MessageError;
321
322 fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
323 Ok(Message::ToolResult {
324 tool_call_id: value.id,
325 content: value.content.try_map(|content| match content {
326 message::ToolResultContent::Text(message::Text { text }) => Ok(text.into()),
327 _ => Err(message::MessageError::ConversionError(
328 "Tool result content does not support non-text".into(),
329 )),
330 })?,
331 })
332 }
333}
334
335impl TryFrom<message::UserContent> for UserContent {
336 type Error = message::MessageError;
337
338 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
339 match value {
340 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
341 message::UserContent::Image(message::Image {
342 data,
343 detail,
344 media_type,
345 ..
346 }) => match data {
347 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
348 image_url: ImageUrl {
349 url,
350 detail: detail.unwrap_or_default(),
351 },
352 }),
353 DocumentSourceKind::Base64(data) => {
354 let url = format!(
355 "data:{};base64,{}",
356 media_type.map(|i| i.to_mime_type()).ok_or(
357 message::MessageError::ConversionError(
358 "OpenAI Image URI must have media type".into()
359 )
360 )?,
361 data
362 );
363
364 let detail = detail.ok_or(message::MessageError::ConversionError(
365 "OpenAI image URI must have image detail".into(),
366 ))?;
367
368 Ok(UserContent::Image {
369 image_url: ImageUrl { url, detail },
370 })
371 }
372 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
373 "Raw files not supported, encode as base64 first".into(),
374 )),
375 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
376 "Document has no body".into(),
377 )),
378 doc => Err(message::MessageError::ConversionError(format!(
379 "Unsupported document type: {doc:?}"
380 ))),
381 },
382 message::UserContent::Document(message::Document { data, .. }) => {
383 if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
384 Ok(UserContent::Text { text })
385 } else {
386 Err(message::MessageError::ConversionError(
387 "Documents must be base64 or a string".into(),
388 ))
389 }
390 }
391 message::UserContent::Audio(message::Audio {
392 data, media_type, ..
393 }) => match data {
394 DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
395 input_audio: InputAudio {
396 data,
397 format: match media_type {
398 Some(media_type) => media_type,
399 None => AudioMediaType::MP3,
400 },
401 },
402 }),
403 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
404 "URLs are not supported for audio".into(),
405 )),
406 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
407 "Raw files are not supported for audio".into(),
408 )),
409 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
410 "Audio has no body".into(),
411 )),
412 audio => Err(message::MessageError::ConversionError(format!(
413 "Unsupported audio type: {audio:?}"
414 ))),
415 },
416 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
417 "Tool result is in unsupported format".into(),
418 )),
419 message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
420 "Video is in unsupported format".into(),
421 )),
422 }
423 }
424}
425
426impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
427 type Error = message::MessageError;
428
429 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
430 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
431 .into_iter()
432 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
433
434 if !tool_results.is_empty() {
437 tool_results
438 .into_iter()
439 .map(|content| match content {
440 message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
441 _ => unreachable!(),
442 })
443 .collect::<Result<Vec<_>, _>>()
444 } else {
445 let other_content: Vec<UserContent> = other_content
446 .into_iter()
447 .map(|content| content.try_into())
448 .collect::<Result<Vec<_>, _>>()?;
449
450 let other_content = OneOrMany::many(other_content)
451 .expect("There must be other content here if there were no tool result content");
452
453 Ok(vec![Message::User {
454 content: other_content,
455 name: None,
456 }])
457 }
458 }
459}
460
461impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
462 type Error = message::MessageError;
463
464 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
465 let (text_content, tool_calls) = value.into_iter().fold(
466 (Vec::new(), Vec::new()),
467 |(mut texts, mut tools), content| {
468 match content {
469 message::AssistantContent::Text(text) => texts.push(text),
470 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
471 message::AssistantContent::Reasoning(_) => {
472 unimplemented!("The OpenAI Completions API doesn't support reasoning!");
473 }
474 message::AssistantContent::Image(_) => {
475 unimplemented!(
476 "The OpenAI Completions API doesn't support image content in assistant messages!"
477 );
478 }
479 }
480 (texts, tools)
481 },
482 );
483
484 Ok(vec![Message::Assistant {
487 content: text_content
488 .into_iter()
489 .map(|content| content.text.into())
490 .collect::<Vec<_>>(),
491 refusal: None,
492 audio: None,
493 name: None,
494 tool_calls: tool_calls
495 .into_iter()
496 .map(|tool_call| tool_call.into())
497 .collect::<Vec<_>>(),
498 }])
499 }
500}
501
502impl TryFrom<message::Message> for Vec<Message> {
503 type Error = message::MessageError;
504
505 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
506 match message {
507 message::Message::User { content } => content.try_into(),
508 message::Message::Assistant { content, .. } => content.try_into(),
509 }
510 }
511}
512
513impl From<message::ToolCall> for ToolCall {
514 fn from(tool_call: message::ToolCall) -> Self {
515 Self {
516 id: tool_call.id,
517 r#type: ToolType::default(),
518 function: Function {
519 name: tool_call.function.name,
520 arguments: tool_call.function.arguments,
521 },
522 }
523 }
524}
525
526impl From<ToolCall> for message::ToolCall {
527 fn from(tool_call: ToolCall) -> Self {
528 Self {
529 id: tool_call.id,
530 call_id: None,
531 function: message::ToolFunction {
532 name: tool_call.function.name,
533 arguments: tool_call.function.arguments,
534 },
535 }
536 }
537}
538
539impl TryFrom<Message> for message::Message {
540 type Error = message::MessageError;
541
542 fn try_from(message: Message) -> Result<Self, Self::Error> {
543 Ok(match message {
544 Message::User { content, .. } => message::Message::User {
545 content: content.map(|content| content.into()),
546 },
547 Message::Assistant {
548 content,
549 tool_calls,
550 ..
551 } => {
552 let mut content = content
553 .into_iter()
554 .map(|content| match content {
555 AssistantContent::Text { text } => message::AssistantContent::text(text),
556
557 AssistantContent::Refusal { refusal } => {
560 message::AssistantContent::text(refusal)
561 }
562 })
563 .collect::<Vec<_>>();
564
565 content.extend(
566 tool_calls
567 .into_iter()
568 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
569 .collect::<Result<Vec<_>, _>>()?,
570 );
571
572 message::Message::Assistant {
573 id: None,
574 content: OneOrMany::many(content).map_err(|_| {
575 message::MessageError::ConversionError(
576 "Neither `content` nor `tool_calls` was provided to the Message"
577 .to_owned(),
578 )
579 })?,
580 }
581 }
582
583 Message::ToolResult {
584 tool_call_id,
585 content,
586 } => message::Message::User {
587 content: OneOrMany::one(message::UserContent::tool_result(
588 tool_call_id,
589 content.map(|content| message::ToolResultContent::text(content.text)),
590 )),
591 },
592
593 Message::System { content, .. } => message::Message::User {
596 content: content.map(|content| message::UserContent::text(content.text)),
597 },
598 })
599 }
600}
601
602impl From<UserContent> for message::UserContent {
603 fn from(content: UserContent) -> Self {
604 match content {
605 UserContent::Text { text } => message::UserContent::text(text),
606 UserContent::Image { image_url } => {
607 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
608 }
609 UserContent::Audio { input_audio } => {
610 message::UserContent::audio(input_audio.data, Some(input_audio.format))
611 }
612 }
613 }
614}
615
616impl From<String> for UserContent {
617 fn from(s: String) -> Self {
618 UserContent::Text { text: s }
619 }
620}
621
622impl FromStr for UserContent {
623 type Err = Infallible;
624
625 fn from_str(s: &str) -> Result<Self, Self::Err> {
626 Ok(UserContent::Text {
627 text: s.to_string(),
628 })
629 }
630}
631
632impl From<String> for AssistantContent {
633 fn from(s: String) -> Self {
634 AssistantContent::Text { text: s }
635 }
636}
637
638impl FromStr for AssistantContent {
639 type Err = Infallible;
640
641 fn from_str(s: &str) -> Result<Self, Self::Err> {
642 Ok(AssistantContent::Text {
643 text: s.to_string(),
644 })
645 }
646}
647impl From<String> for SystemContent {
648 fn from(s: String) -> Self {
649 SystemContent {
650 r#type: SystemContentType::default(),
651 text: s,
652 }
653 }
654}
655
656impl FromStr for SystemContent {
657 type Err = Infallible;
658
659 fn from_str(s: &str) -> Result<Self, Self::Err> {
660 Ok(SystemContent {
661 r#type: SystemContentType::default(),
662 text: s.to_string(),
663 })
664 }
665}
666
667#[derive(Debug, Deserialize, Serialize)]
668pub struct CompletionResponse {
669 pub id: String,
670 pub object: String,
671 pub created: u64,
672 pub model: String,
673 pub system_fingerprint: Option<String>,
674 pub choices: Vec<Choice>,
675 pub usage: Option<Usage>,
676}
677
678impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
679 type Error = CompletionError;
680
681 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
682 let choice = response.choices.first().ok_or_else(|| {
683 CompletionError::ResponseError("Response contained no choices".to_owned())
684 })?;
685
686 let content = match &choice.message {
687 Message::Assistant {
688 content,
689 tool_calls,
690 ..
691 } => {
692 let mut content = content
693 .iter()
694 .filter_map(|c| {
695 let s = match c {
696 AssistantContent::Text { text } => text,
697 AssistantContent::Refusal { refusal } => refusal,
698 };
699 if s.is_empty() {
700 None
701 } else {
702 Some(completion::AssistantContent::text(s))
703 }
704 })
705 .collect::<Vec<_>>();
706
707 content.extend(
708 tool_calls
709 .iter()
710 .map(|call| {
711 completion::AssistantContent::tool_call(
712 &call.id,
713 &call.function.name,
714 call.function.arguments.clone(),
715 )
716 })
717 .collect::<Vec<_>>(),
718 );
719 Ok(content)
720 }
721 _ => Err(CompletionError::ResponseError(
722 "Response did not contain a valid message or tool call".into(),
723 )),
724 }?;
725
726 let choice = OneOrMany::many(content).map_err(|_| {
727 CompletionError::ResponseError(
728 "Response contained no message or tool call (empty)".to_owned(),
729 )
730 })?;
731
732 let usage = response
733 .usage
734 .as_ref()
735 .map(|usage| completion::Usage {
736 input_tokens: usage.prompt_tokens as u64,
737 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
738 total_tokens: usage.total_tokens as u64,
739 })
740 .unwrap_or_default();
741
742 Ok(completion::CompletionResponse {
743 choice,
744 usage,
745 raw_response: response,
746 })
747 }
748}
749
750impl ProviderResponseExt for CompletionResponse {
751 type OutputMessage = Choice;
752 type Usage = Usage;
753
754 fn get_response_id(&self) -> Option<String> {
755 Some(self.id.to_owned())
756 }
757
758 fn get_response_model_name(&self) -> Option<String> {
759 Some(self.model.to_owned())
760 }
761
762 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
763 self.choices.clone()
764 }
765
766 fn get_text_response(&self) -> Option<String> {
767 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
768 return None;
769 };
770
771 let UserContent::Text { text } = content.first() else {
772 return None;
773 };
774
775 Some(text)
776 }
777
778 fn get_usage(&self) -> Option<Self::Usage> {
779 self.usage.clone()
780 }
781}
782
783#[derive(Clone, Debug, Serialize, Deserialize)]
784pub struct Choice {
785 pub index: usize,
786 pub message: Message,
787 pub logprobs: Option<serde_json::Value>,
788 pub finish_reason: String,
789}
790
791#[derive(Clone, Debug, Deserialize, Serialize)]
792pub struct Usage {
793 pub prompt_tokens: usize,
794 pub total_tokens: usize,
795}
796
797impl Usage {
798 pub fn new() -> Self {
799 Self {
800 prompt_tokens: 0,
801 total_tokens: 0,
802 }
803 }
804}
805
806impl Default for Usage {
807 fn default() -> Self {
808 Self::new()
809 }
810}
811
812impl fmt::Display for Usage {
813 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
814 let Usage {
815 prompt_tokens,
816 total_tokens,
817 } = self;
818 write!(
819 f,
820 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
821 )
822 }
823}
824
825impl GetTokenUsage for Usage {
826 fn token_usage(&self) -> Option<crate::completion::Usage> {
827 let mut usage = crate::completion::Usage::new();
828 usage.input_tokens = self.prompt_tokens as u64;
829 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
830 usage.total_tokens = self.total_tokens as u64;
831
832 Some(usage)
833 }
834}
835
836#[derive(Clone)]
837pub struct CompletionModel<T = reqwest::Client> {
838 pub(crate) client: Client<T>,
839 pub model: String,
841}
842
843impl<T> CompletionModel<T>
844where
845 T: Default + std::fmt::Debug + Clone + 'static,
846{
847 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
848 Self {
849 client,
850 model: model.into(),
851 }
852 }
853
854 pub fn with_model(client: Client<T>, model: &str) -> Self {
855 Self {
856 client,
857 model: model.into(),
858 }
859 }
860}
861
862#[derive(Debug, Serialize, Deserialize, Clone)]
863pub struct CompletionRequest {
864 model: String,
865 messages: Vec<Message>,
866 #[serde(skip_serializing_if = "Vec::is_empty")]
867 tools: Vec<ToolDefinition>,
868 #[serde(skip_serializing_if = "Option::is_none")]
869 tool_choice: Option<ToolChoice>,
870 #[serde(skip_serializing_if = "Option::is_none")]
871 temperature: Option<f64>,
872 #[serde(flatten)]
873 additional_params: Option<serde_json::Value>,
874}
875
876impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
877 type Error = CompletionError;
878
879 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
880 let mut partial_history = vec![];
881 if let Some(docs) = req.normalized_documents() {
882 partial_history.push(docs);
883 }
884 let CoreCompletionRequest {
885 preamble,
886 chat_history,
887 tools,
888 temperature,
889 additional_params,
890 tool_choice,
891 ..
892 } = req;
893
894 partial_history.extend(chat_history);
895
896 let mut full_history: Vec<Message> =
897 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
898
899 full_history.extend(
901 partial_history
902 .into_iter()
903 .map(message::Message::try_into)
904 .collect::<Result<Vec<Vec<Message>>, _>>()?
905 .into_iter()
906 .flatten()
907 .collect::<Vec<_>>(),
908 );
909
910 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
911
912 let res = Self {
913 model,
914 messages: full_history,
915 tools: tools
916 .into_iter()
917 .map(ToolDefinition::from)
918 .collect::<Vec<_>>(),
919 tool_choice,
920 temperature,
921 additional_params,
922 };
923
924 Ok(res)
925 }
926}
927
928impl crate::telemetry::ProviderRequestExt for CompletionRequest {
929 type InputMessage = Message;
930
931 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
932 self.messages.clone()
933 }
934
935 fn get_system_prompt(&self) -> Option<String> {
936 let first_message = self.messages.first()?;
937
938 let Message::System { ref content, .. } = first_message.clone() else {
939 return None;
940 };
941
942 let SystemContent { text, .. } = content.first();
943
944 Some(text)
945 }
946
947 fn get_prompt(&self) -> Option<String> {
948 let last_message = self.messages.last()?;
949
950 let Message::User { ref content, .. } = last_message.clone() else {
951 return None;
952 };
953
954 let UserContent::Text { text } = content.first() else {
955 return None;
956 };
957
958 Some(text)
959 }
960
961 fn get_model_name(&self) -> String {
962 self.model.clone()
963 }
964}
965
966impl CompletionModel<reqwest::Client> {
967 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
968 crate::agent::AgentBuilder::new(self)
969 }
970}
971
972impl<T> completion::CompletionModel for CompletionModel<T>
973where
974 T: HttpClientExt
975 + Default
976 + std::fmt::Debug
977 + Clone
978 + WasmCompatSend
979 + WasmCompatSync
980 + 'static,
981{
982 type Response = CompletionResponse;
983 type StreamingResponse = StreamingCompletionResponse;
984
985 type Client = super::CompletionsClient<T>;
986
987 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
988 Self::new(client.clone(), model)
989 }
990
991 async fn completion(
992 &self,
993 completion_request: CoreCompletionRequest,
994 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
995 let span = if tracing::Span::current().is_disabled() {
996 info_span!(
997 target: "rig::completions",
998 "chat",
999 gen_ai.operation.name = "chat",
1000 gen_ai.provider.name = "openai",
1001 gen_ai.request.model = self.model,
1002 gen_ai.system_instructions = &completion_request.preamble,
1003 gen_ai.response.id = tracing::field::Empty,
1004 gen_ai.response.model = tracing::field::Empty,
1005 gen_ai.usage.output_tokens = tracing::field::Empty,
1006 gen_ai.usage.input_tokens = tracing::field::Empty,
1007 gen_ai.input.messages = tracing::field::Empty,
1008 gen_ai.output.messages = tracing::field::Empty,
1009 )
1010 } else {
1011 tracing::Span::current()
1012 };
1013
1014 let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?;
1015
1016 span.record_model_input(&request.messages);
1017
1018 let body = serde_json::to_vec(&request)?;
1019
1020 let req = self
1021 .client
1022 .post("/chat/completions")?
1023 .body(body)
1024 .map_err(|e| CompletionError::HttpError(e.into()))?;
1025
1026 async move {
1027 let response = self.client.send(req).await?;
1028
1029 if response.status().is_success() {
1030 let text = http_client::text(response).await?;
1031
1032 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1033 ApiResponse::Ok(response) => {
1034 let span = tracing::Span::current();
1035 span.record_model_output(&response.choices);
1036 span.record_response_metadata(&response);
1037 span.record_token_usage(&response.usage);
1038 tracing::debug!("OpenAI response: {response:?}");
1039 response.try_into()
1040 }
1041 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1042 }
1043 } else {
1044 let text = http_client::text(response).await?;
1045 Err(CompletionError::ProviderError(text))
1046 }
1047 }
1048 .instrument(span)
1049 .await
1050 }
1051
1052 async fn stream(
1053 &self,
1054 request: CoreCompletionRequest,
1055 ) -> Result<
1056 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1057 CompletionError,
1058 > {
1059 Self::stream(self, request).await
1060 }
1061}