1use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse};
6use crate::completion::{
7 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
8};
9use crate::http_client::{self, HttpClientExt};
10use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
11use crate::one_or_many::string_or_one_or_many;
12use crate::telemetry::{ProviderResponseExt, SpanCombinator};
13use crate::{OneOrMany, completion, json_utils, message};
14use serde::{Deserialize, Serialize};
15use std::convert::Infallible;
16use std::fmt;
17use tracing::{Instrument, info_span};
18
19use std::str::FromStr;
20
21pub mod streaming;
22
23pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
25pub const O4_MINI: &str = "o4-mini";
27pub const O3: &str = "o3";
29pub const O3_MINI: &str = "o3-mini";
31pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
33pub const O1_PRO: &str = "o1-pro";
35pub const O1: &str = "o1";
37pub const O1_2024_12_17: &str = "o1-2024-12-17";
39pub const O1_PREVIEW: &str = "o1-preview";
41pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
43pub const O1_MINI: &str = "o1-mini";
45pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
47
48pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
50pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
52pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
54pub const GPT_4_1: &str = "gpt-4.1";
56pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
58pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
60pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
62pub const GPT_4O: &str = "gpt-4o";
64pub const GPT_4O_MINI: &str = "gpt-4o-mini";
66pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
68pub const GPT_4_TURBO: &str = "gpt-4-turbo";
70pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
72pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
74pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
76pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
78pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
80pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
82pub const GPT_4: &str = "gpt-4";
84pub const GPT_4_0613: &str = "gpt-4-0613";
86pub const GPT_4_32K: &str = "gpt-4-32k";
88pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
90pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
92pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
94pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
96pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
98
99impl From<ApiErrorResponse> for CompletionError {
100 fn from(err: ApiErrorResponse) -> Self {
101 CompletionError::ProviderError(err.message)
102 }
103}
104
105#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
106#[serde(tag = "role", rename_all = "lowercase")]
107pub enum Message {
108 #[serde(alias = "developer")]
109 System {
110 #[serde(deserialize_with = "string_or_one_or_many")]
111 content: OneOrMany<SystemContent>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 name: Option<String>,
114 },
115 User {
116 #[serde(deserialize_with = "string_or_one_or_many")]
117 content: OneOrMany<UserContent>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 name: Option<String>,
120 },
121 Assistant {
122 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
123 content: Vec<AssistantContent>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 refusal: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 audio: Option<AudioAssistant>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 name: Option<String>,
130 #[serde(
131 default,
132 deserialize_with = "json_utils::null_or_vec",
133 skip_serializing_if = "Vec::is_empty"
134 )]
135 tool_calls: Vec<ToolCall>,
136 },
137 #[serde(rename = "tool")]
138 ToolResult {
139 tool_call_id: String,
140 content: OneOrMany<ToolResultContent>,
141 },
142}
143
144impl Message {
145 pub fn system(content: &str) -> Self {
146 Message::System {
147 content: OneOrMany::one(content.to_owned().into()),
148 name: None,
149 }
150 }
151}
152
153#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
154pub struct AudioAssistant {
155 pub id: String,
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
159pub struct SystemContent {
160 #[serde(default)]
161 pub r#type: SystemContentType,
162 pub text: String,
163}
164
165#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
166#[serde(rename_all = "lowercase")]
167pub enum SystemContentType {
168 #[default]
169 Text,
170}
171
172#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
173#[serde(tag = "type", rename_all = "lowercase")]
174pub enum AssistantContent {
175 Text { text: String },
176 Refusal { refusal: String },
177}
178
179impl From<AssistantContent> for completion::AssistantContent {
180 fn from(value: AssistantContent) -> Self {
181 match value {
182 AssistantContent::Text { text } => completion::AssistantContent::text(text),
183 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
184 }
185 }
186}
187
188#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
189#[serde(tag = "type", rename_all = "lowercase")]
190pub enum UserContent {
191 Text {
192 text: String,
193 },
194 #[serde(rename = "image_url")]
195 Image {
196 image_url: ImageUrl,
197 },
198 Audio {
199 input_audio: InputAudio,
200 },
201}
202
203#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
204pub struct ImageUrl {
205 pub url: String,
206 #[serde(default)]
207 pub detail: ImageDetail,
208}
209
210#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
211pub struct InputAudio {
212 pub data: String,
213 pub format: AudioMediaType,
214}
215
216#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
217pub struct ToolResultContent {
218 #[serde(default)]
219 r#type: ToolResultContentType,
220 pub text: String,
221}
222
223#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
224#[serde(rename_all = "lowercase")]
225pub enum ToolResultContentType {
226 #[default]
227 Text,
228}
229
230impl FromStr for ToolResultContent {
231 type Err = Infallible;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 Ok(s.to_owned().into())
235 }
236}
237
238impl From<String> for ToolResultContent {
239 fn from(s: String) -> Self {
240 ToolResultContent {
241 r#type: ToolResultContentType::default(),
242 text: s,
243 }
244 }
245}
246
247#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
248pub struct ToolCall {
249 pub id: String,
250 #[serde(default)]
251 pub r#type: ToolType,
252 pub function: Function,
253}
254
255#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
256#[serde(rename_all = "lowercase")]
257pub enum ToolType {
258 #[default]
259 Function,
260}
261
262#[derive(Debug, Deserialize, Serialize, Clone)]
263pub struct ToolDefinition {
264 pub r#type: String,
265 pub function: completion::ToolDefinition,
266}
267
268impl From<completion::ToolDefinition> for ToolDefinition {
269 fn from(tool: completion::ToolDefinition) -> Self {
270 Self {
271 r#type: "function".into(),
272 function: tool,
273 }
274 }
275}
276
277#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
278#[serde(rename_all = "snake_case")]
279pub enum ToolChoice {
280 #[default]
281 Auto,
282 None,
283 Required,
284}
285
286impl TryFrom<crate::message::ToolChoice> for ToolChoice {
287 type Error = CompletionError;
288 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
289 let res = match value {
290 message::ToolChoice::Specific { .. } => {
291 return Err(CompletionError::ProviderError(
292 "Provider doesn't support only using specific tools".to_string(),
293 ));
294 }
295 message::ToolChoice::Auto => Self::Auto,
296 message::ToolChoice::None => Self::None,
297 message::ToolChoice::Required => Self::Required,
298 };
299
300 Ok(res)
301 }
302}
303
304#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
305pub struct Function {
306 pub name: String,
307 #[serde(with = "json_utils::stringified_json")]
308 pub arguments: serde_json::Value,
309}
310
311impl TryFrom<message::Message> for Vec<Message> {
312 type Error = message::MessageError;
313
314 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
315 match message {
316 message::Message::User { content } => {
317 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
318 .into_iter()
319 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
320
321 if !tool_results.is_empty() {
324 tool_results
325 .into_iter()
326 .map(|content| match content {
327 message::UserContent::ToolResult(message::ToolResult {
328 id,
329 content,
330 ..
331 }) => Ok::<_, message::MessageError>(Message::ToolResult {
332 tool_call_id: id,
333 content: content.try_map(|content| match content {
334 message::ToolResultContent::Text(message::Text { text }) => {
335 Ok(text.into())
336 }
337 _ => Err(message::MessageError::ConversionError(
338 "Tool result content does not support non-text".into(),
339 )),
340 })?,
341 }),
342 _ => unreachable!(),
343 })
344 .collect::<Result<Vec<_>, _>>()
345 } else {
346 let other_content: Vec<UserContent> = other_content
347 .into_iter()
348 .map(|content| match content {
349 message::UserContent::Text(message::Text { text }) => {
350 Ok(UserContent::Text { text })
351 }
352 message::UserContent::Image(message::Image {
353 data,
354 detail,
355 media_type,
356 ..
357 }) => match data {
358 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
359 image_url: ImageUrl {
360 url,
361 detail: detail.unwrap_or_default(),
362 },
363 }),
364 DocumentSourceKind::Base64(data) => {
365 let url = format!(
366 "data:{};base64,{}",
367 media_type.map(|i| i.to_mime_type()).ok_or(
368 message::MessageError::ConversionError(
369 "OpenAI Image URI must have media type".into()
370 )
371 )?,
372 data
373 );
374
375 let detail =
376 detail.ok_or(message::MessageError::ConversionError(
377 "OpenAI image URI must have image detail".into(),
378 ))?;
379
380 Ok(UserContent::Image {
381 image_url: ImageUrl { url, detail },
382 })
383 }
384 DocumentSourceKind::Raw(_) => {
385 Err(message::MessageError::ConversionError(
386 "Raw files not supported, encode as base64 first".into(),
387 ))
388 }
389 DocumentSourceKind::Unknown => {
390 Err(message::MessageError::ConversionError(
391 "Document has no body".into(),
392 ))
393 }
394 doc => Err(message::MessageError::ConversionError(format!(
395 "Unsupported document type: {doc:?}"
396 ))),
397 },
398 message::UserContent::Document(message::Document { data, .. }) => {
399 if let DocumentSourceKind::Base64(text) = data {
400 Ok(UserContent::Text { text })
401 } else {
402 Err(message::MessageError::ConversionError(
403 "Documents must be base64".into(),
404 ))
405 }
406 }
407 message::UserContent::Audio(message::Audio {
408 data: DocumentSourceKind::Base64(data),
409 media_type,
410 ..
411 }) => Ok(UserContent::Audio {
412 input_audio: InputAudio {
413 data,
414 format: match media_type {
415 Some(media_type) => media_type,
416 None => AudioMediaType::MP3,
417 },
418 },
419 }),
420 _ => Err(message::MessageError::ConversionError(
421 "Tool result is in unsupported format".into(),
422 )),
423 })
424 .collect::<Result<Vec<_>, _>>()?;
425
426 let other_content = OneOrMany::many(other_content).expect(
427 "There must be other content here if there were no tool result content",
428 );
429
430 Ok(vec![Message::User {
431 content: other_content,
432 name: None,
433 }])
434 }
435 }
436 message::Message::Assistant { content, .. } => {
437 let (text_content, tool_calls) = content.into_iter().fold(
438 (Vec::new(), Vec::new()),
439 |(mut texts, mut tools), content| {
440 match content {
441 message::AssistantContent::Text(text) => texts.push(text),
442 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
443 message::AssistantContent::Reasoning(_) => {
444 unimplemented!(
445 "The OpenAI Completions API doesn't support reasoning!"
446 );
447 }
448 }
449 (texts, tools)
450 },
451 );
452
453 Ok(vec![Message::Assistant {
456 content: text_content
457 .into_iter()
458 .map(|content| content.text.into())
459 .collect::<Vec<_>>(),
460 refusal: None,
461 audio: None,
462 name: None,
463 tool_calls: tool_calls
464 .into_iter()
465 .map(|tool_call| tool_call.into())
466 .collect::<Vec<_>>(),
467 }])
468 }
469 }
470 }
471}
472
473impl From<message::ToolCall> for ToolCall {
474 fn from(tool_call: message::ToolCall) -> Self {
475 Self {
476 id: tool_call.id,
477 r#type: ToolType::default(),
478 function: Function {
479 name: tool_call.function.name,
480 arguments: tool_call.function.arguments,
481 },
482 }
483 }
484}
485
486impl From<ToolCall> for message::ToolCall {
487 fn from(tool_call: ToolCall) -> Self {
488 Self {
489 id: tool_call.id,
490 call_id: None,
491 function: message::ToolFunction {
492 name: tool_call.function.name,
493 arguments: tool_call.function.arguments,
494 },
495 }
496 }
497}
498
499impl TryFrom<Message> for message::Message {
500 type Error = message::MessageError;
501
502 fn try_from(message: Message) -> Result<Self, Self::Error> {
503 Ok(match message {
504 Message::User { content, .. } => message::Message::User {
505 content: content.map(|content| content.into()),
506 },
507 Message::Assistant {
508 content,
509 tool_calls,
510 ..
511 } => {
512 let mut content = content
513 .into_iter()
514 .map(|content| match content {
515 AssistantContent::Text { text } => message::AssistantContent::text(text),
516
517 AssistantContent::Refusal { refusal } => {
520 message::AssistantContent::text(refusal)
521 }
522 })
523 .collect::<Vec<_>>();
524
525 content.extend(
526 tool_calls
527 .into_iter()
528 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
529 .collect::<Result<Vec<_>, _>>()?,
530 );
531
532 message::Message::Assistant {
533 id: None,
534 content: OneOrMany::many(content).map_err(|_| {
535 message::MessageError::ConversionError(
536 "Neither `content` nor `tool_calls` was provided to the Message"
537 .to_owned(),
538 )
539 })?,
540 }
541 }
542
543 Message::ToolResult {
544 tool_call_id,
545 content,
546 } => message::Message::User {
547 content: OneOrMany::one(message::UserContent::tool_result(
548 tool_call_id,
549 content.map(|content| message::ToolResultContent::text(content.text)),
550 )),
551 },
552
553 Message::System { content, .. } => message::Message::User {
556 content: content.map(|content| message::UserContent::text(content.text)),
557 },
558 })
559 }
560}
561
562impl From<UserContent> for message::UserContent {
563 fn from(content: UserContent) -> Self {
564 match content {
565 UserContent::Text { text } => message::UserContent::text(text),
566 UserContent::Image { image_url } => {
567 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
568 }
569 UserContent::Audio { input_audio } => {
570 message::UserContent::audio(input_audio.data, Some(input_audio.format))
571 }
572 }
573 }
574}
575
576impl From<String> for UserContent {
577 fn from(s: String) -> Self {
578 UserContent::Text { text: s }
579 }
580}
581
582impl FromStr for UserContent {
583 type Err = Infallible;
584
585 fn from_str(s: &str) -> Result<Self, Self::Err> {
586 Ok(UserContent::Text {
587 text: s.to_string(),
588 })
589 }
590}
591
592impl From<String> for AssistantContent {
593 fn from(s: String) -> Self {
594 AssistantContent::Text { text: s }
595 }
596}
597
598impl FromStr for AssistantContent {
599 type Err = Infallible;
600
601 fn from_str(s: &str) -> Result<Self, Self::Err> {
602 Ok(AssistantContent::Text {
603 text: s.to_string(),
604 })
605 }
606}
607impl From<String> for SystemContent {
608 fn from(s: String) -> Self {
609 SystemContent {
610 r#type: SystemContentType::default(),
611 text: s,
612 }
613 }
614}
615
616impl FromStr for SystemContent {
617 type Err = Infallible;
618
619 fn from_str(s: &str) -> Result<Self, Self::Err> {
620 Ok(SystemContent {
621 r#type: SystemContentType::default(),
622 text: s.to_string(),
623 })
624 }
625}
626
627#[derive(Debug, Deserialize, Serialize)]
628pub struct CompletionResponse {
629 pub id: String,
630 pub object: String,
631 pub created: u64,
632 pub model: String,
633 pub system_fingerprint: Option<String>,
634 pub choices: Vec<Choice>,
635 pub usage: Option<Usage>,
636}
637
638impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
639 type Error = CompletionError;
640
641 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
642 let choice = response.choices.first().ok_or_else(|| {
643 CompletionError::ResponseError("Response contained no choices".to_owned())
644 })?;
645
646 let content = match &choice.message {
647 Message::Assistant {
648 content,
649 tool_calls,
650 ..
651 } => {
652 let mut content = content
653 .iter()
654 .filter_map(|c| {
655 let s = match c {
656 AssistantContent::Text { text } => text,
657 AssistantContent::Refusal { refusal } => refusal,
658 };
659 if s.is_empty() {
660 None
661 } else {
662 Some(completion::AssistantContent::text(s))
663 }
664 })
665 .collect::<Vec<_>>();
666
667 content.extend(
668 tool_calls
669 .iter()
670 .map(|call| {
671 completion::AssistantContent::tool_call(
672 &call.id,
673 &call.function.name,
674 call.function.arguments.clone(),
675 )
676 })
677 .collect::<Vec<_>>(),
678 );
679 Ok(content)
680 }
681 _ => Err(CompletionError::ResponseError(
682 "Response did not contain a valid message or tool call".into(),
683 )),
684 }?;
685
686 let choice = OneOrMany::many(content).map_err(|_| {
687 CompletionError::ResponseError(
688 "Response contained no message or tool call (empty)".to_owned(),
689 )
690 })?;
691
692 let usage = response
693 .usage
694 .as_ref()
695 .map(|usage| completion::Usage {
696 input_tokens: usage.prompt_tokens as u64,
697 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
698 total_tokens: usage.total_tokens as u64,
699 })
700 .unwrap_or_default();
701
702 Ok(completion::CompletionResponse {
703 choice,
704 usage,
705 raw_response: response,
706 })
707 }
708}
709
710impl ProviderResponseExt for CompletionResponse {
711 type OutputMessage = Choice;
712 type Usage = Usage;
713
714 fn get_response_id(&self) -> Option<String> {
715 Some(self.id.to_owned())
716 }
717
718 fn get_response_model_name(&self) -> Option<String> {
719 Some(self.model.to_owned())
720 }
721
722 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
723 self.choices.clone()
724 }
725
726 fn get_text_response(&self) -> Option<String> {
727 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
728 return None;
729 };
730
731 let UserContent::Text { text } = content.first() else {
732 return None;
733 };
734
735 Some(text)
736 }
737
738 fn get_usage(&self) -> Option<Self::Usage> {
739 self.usage.clone()
740 }
741}
742
743#[derive(Clone, Debug, Serialize, Deserialize)]
744pub struct Choice {
745 pub index: usize,
746 pub message: Message,
747 pub logprobs: Option<serde_json::Value>,
748 pub finish_reason: String,
749}
750
751#[derive(Clone, Debug, Deserialize, Serialize)]
752pub struct Usage {
753 pub prompt_tokens: usize,
754 pub total_tokens: usize,
755}
756
757impl Usage {
758 pub fn new() -> Self {
759 Self {
760 prompt_tokens: 0,
761 total_tokens: 0,
762 }
763 }
764}
765
766impl Default for Usage {
767 fn default() -> Self {
768 Self::new()
769 }
770}
771
772impl fmt::Display for Usage {
773 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774 let Usage {
775 prompt_tokens,
776 total_tokens,
777 } = self;
778 write!(
779 f,
780 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
781 )
782 }
783}
784
785impl GetTokenUsage for Usage {
786 fn token_usage(&self) -> Option<crate::completion::Usage> {
787 let mut usage = crate::completion::Usage::new();
788 usage.input_tokens = self.prompt_tokens as u64;
789 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
790 usage.total_tokens = self.total_tokens as u64;
791
792 Some(usage)
793 }
794}
795
796#[derive(Clone)]
797pub struct CompletionModel<T = reqwest::Client> {
798 pub(crate) client: Client<T>,
799 pub model: String,
801}
802
803impl<T> CompletionModel<T>
804where
805 T: HttpClientExt + Default + std::fmt::Debug + Clone + 'static,
806{
807 pub fn new(client: Client<T>, model: &str) -> Self {
808 Self {
809 client,
810 model: model.to_string(),
811 }
812 }
813}
814
815#[derive(Debug, Serialize, Deserialize, Clone)]
816pub struct CompletionRequest {
817 model: String,
818 messages: Vec<Message>,
819 tools: Vec<ToolDefinition>,
820 tool_choice: Option<ToolChoice>,
821 temperature: Option<f64>,
822 #[serde(flatten)]
823 additional_params: Option<serde_json::Value>,
824}
825
826impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
827 type Error = CompletionError;
828
829 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
830 let mut partial_history = vec![];
831 if let Some(docs) = req.normalized_documents() {
832 partial_history.push(docs);
833 }
834 let CoreCompletionRequest {
835 preamble,
836 chat_history,
837 tools,
838 temperature,
839 additional_params,
840 tool_choice,
841 ..
842 } = req;
843
844 partial_history.extend(chat_history);
845
846 let mut full_history: Vec<Message> =
847 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
848
849 full_history.extend(
851 partial_history
852 .into_iter()
853 .map(message::Message::try_into)
854 .collect::<Result<Vec<Vec<Message>>, _>>()?
855 .into_iter()
856 .flatten()
857 .collect::<Vec<_>>(),
858 );
859
860 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
861
862 let res = Self {
863 model,
864 messages: full_history,
865 tools: tools
866 .into_iter()
867 .map(ToolDefinition::from)
868 .collect::<Vec<_>>(),
869 tool_choice,
870 temperature,
871 additional_params,
872 };
873
874 Ok(res)
875 }
876}
877
878impl crate::telemetry::ProviderRequestExt for CompletionRequest {
879 type InputMessage = Message;
880
881 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
882 self.messages.clone()
883 }
884
885 fn get_system_prompt(&self) -> Option<String> {
886 let first_message = self.messages.first()?;
887
888 let Message::System { ref content, .. } = first_message.clone() else {
889 return None;
890 };
891
892 let SystemContent { text, .. } = content.first();
893
894 Some(text)
895 }
896
897 fn get_prompt(&self) -> Option<String> {
898 let last_message = self.messages.last()?;
899
900 let Message::User { ref content, .. } = last_message.clone() else {
901 return None;
902 };
903
904 let UserContent::Text { text } = content.first() else {
905 return None;
906 };
907
908 Some(text)
909 }
910
911 fn get_model_name(&self) -> String {
912 self.model.clone()
913 }
914}
915
916impl CompletionModel<reqwest::Client> {
917 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
918 crate::agent::AgentBuilder::new(self)
919 }
920}
921
922impl completion::CompletionModel for CompletionModel<reqwest::Client> {
923 type Response = CompletionResponse;
924 type StreamingResponse = StreamingCompletionResponse;
925
926 #[cfg_attr(feature = "worker", worker::send)]
927 async fn completion(
928 &self,
929 completion_request: CoreCompletionRequest,
930 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
931 let span = if tracing::Span::current().is_disabled() {
932 info_span!(
933 target: "rig::completions",
934 "chat",
935 gen_ai.operation.name = "chat",
936 gen_ai.provider.name = "openai",
937 gen_ai.request.model = self.model,
938 gen_ai.system_instructions = &completion_request.preamble,
939 gen_ai.response.id = tracing::field::Empty,
940 gen_ai.response.model = tracing::field::Empty,
941 gen_ai.usage.output_tokens = tracing::field::Empty,
942 gen_ai.usage.input_tokens = tracing::field::Empty,
943 gen_ai.input.messages = tracing::field::Empty,
944 gen_ai.output.messages = tracing::field::Empty,
945 )
946 } else {
947 tracing::Span::current()
948 };
949
950 let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?;
951
952 span.record_model_input(&request.messages);
953
954 let body = serde_json::to_vec(&request)?;
955
956 let req = self
957 .client
958 .post("/chat/completions")?
959 .header("Content-Type", "application/json")
960 .body(body)
961 .map_err(|e| CompletionError::HttpError(e.into()))?;
962
963 async move {
964 let response = self.client.send(req).await?;
965
966 if response.status().is_success() {
967 let text = http_client::text(response).await?;
968
969 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
970 ApiResponse::Ok(response) => {
971 let span = tracing::Span::current();
972 span.record_model_output(&response.choices);
973 span.record_response_metadata(&response);
974 span.record_token_usage(&response.usage);
975 tracing::debug!("OpenAI response: {response:?}");
976 response.try_into()
977 }
978 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
979 }
980 } else {
981 let text = http_client::text(response).await?;
982 Err(CompletionError::ProviderError(text))
983 }
984 }
985 .instrument(span)
986 .await
987 }
988
989 #[cfg_attr(feature = "worker", worker::send)]
990 async fn stream(
991 &self,
992 request: CoreCompletionRequest,
993 ) -> Result<
994 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
995 CompletionError,
996 > {
997 CompletionModel::stream(self, request).await
998 }
999}