1use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse};
6use crate::completion::{
7 CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
8};
9use crate::http_client::{self, HttpClientExt};
10use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
11use crate::one_or_many::string_or_one_or_many;
12use crate::telemetry::{ProviderResponseExt, SpanCombinator};
13use crate::{OneOrMany, completion, json_utils, message};
14use serde::{Deserialize, Serialize};
15use std::convert::Infallible;
16use std::fmt;
17use tracing::{Instrument, info_span};
18
19use std::str::FromStr;
20
21pub mod streaming;
22
23pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
25pub const O4_MINI: &str = "o4-mini";
27pub const O3: &str = "o3";
29pub const O3_MINI: &str = "o3-mini";
31pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
33pub const O1_PRO: &str = "o1-pro";
35pub const O1: &str = "o1";
37pub const O1_2024_12_17: &str = "o1-2024-12-17";
39pub const O1_PREVIEW: &str = "o1-preview";
41pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
43pub const O1_MINI: &str = "o1-mini";
45pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
47
48pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
50pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
52pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
54pub const GPT_4_1: &str = "gpt-4.1";
56pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
58pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
60pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
62pub const GPT_4O: &str = "gpt-4o";
64pub const GPT_4O_MINI: &str = "gpt-4o-mini";
66pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
68pub const GPT_4_TURBO: &str = "gpt-4-turbo";
70pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
72pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
74pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
76pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
78pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
80pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
82pub const GPT_4: &str = "gpt-4";
84pub const GPT_4_0613: &str = "gpt-4-0613";
86pub const GPT_4_32K: &str = "gpt-4-32k";
88pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
90pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
92pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
94pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
96pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
98
99impl From<ApiErrorResponse> for CompletionError {
100 fn from(err: ApiErrorResponse) -> Self {
101 CompletionError::ProviderError(err.message)
102 }
103}
104
105#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
106#[serde(tag = "role", rename_all = "lowercase")]
107pub enum Message {
108 #[serde(alias = "developer")]
109 System {
110 #[serde(deserialize_with = "string_or_one_or_many")]
111 content: OneOrMany<SystemContent>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 name: Option<String>,
114 },
115 User {
116 #[serde(deserialize_with = "string_or_one_or_many")]
117 content: OneOrMany<UserContent>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 name: Option<String>,
120 },
121 Assistant {
122 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
123 content: Vec<AssistantContent>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 refusal: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 audio: Option<AudioAssistant>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 name: Option<String>,
130 #[serde(
131 default,
132 deserialize_with = "json_utils::null_or_vec",
133 skip_serializing_if = "Vec::is_empty"
134 )]
135 tool_calls: Vec<ToolCall>,
136 },
137 #[serde(rename = "tool")]
138 ToolResult {
139 tool_call_id: String,
140 content: OneOrMany<ToolResultContent>,
141 },
142}
143
144impl Message {
145 pub fn system(content: &str) -> Self {
146 Message::System {
147 content: OneOrMany::one(content.to_owned().into()),
148 name: None,
149 }
150 }
151}
152
153#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
154pub struct AudioAssistant {
155 pub id: String,
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
159pub struct SystemContent {
160 #[serde(default)]
161 pub r#type: SystemContentType,
162 pub text: String,
163}
164
165#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
166#[serde(rename_all = "lowercase")]
167pub enum SystemContentType {
168 #[default]
169 Text,
170}
171
172#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
173#[serde(tag = "type", rename_all = "lowercase")]
174pub enum AssistantContent {
175 Text { text: String },
176 Refusal { refusal: String },
177}
178
179impl From<AssistantContent> for completion::AssistantContent {
180 fn from(value: AssistantContent) -> Self {
181 match value {
182 AssistantContent::Text { text } => completion::AssistantContent::text(text),
183 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
184 }
185 }
186}
187
188#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
189#[serde(tag = "type", rename_all = "lowercase")]
190pub enum UserContent {
191 Text {
192 text: String,
193 },
194 #[serde(rename = "image_url")]
195 Image {
196 image_url: ImageUrl,
197 },
198 Audio {
199 input_audio: InputAudio,
200 },
201}
202
203#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
204pub struct ImageUrl {
205 pub url: String,
206 #[serde(default)]
207 pub detail: ImageDetail,
208}
209
210#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
211pub struct InputAudio {
212 pub data: String,
213 pub format: AudioMediaType,
214}
215
216#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
217pub struct ToolResultContent {
218 #[serde(default)]
219 r#type: ToolResultContentType,
220 pub text: String,
221}
222
223#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
224#[serde(rename_all = "lowercase")]
225pub enum ToolResultContentType {
226 #[default]
227 Text,
228}
229
230impl FromStr for ToolResultContent {
231 type Err = Infallible;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 Ok(s.to_owned().into())
235 }
236}
237
238impl From<String> for ToolResultContent {
239 fn from(s: String) -> Self {
240 ToolResultContent {
241 r#type: ToolResultContentType::default(),
242 text: s,
243 }
244 }
245}
246
247#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
248pub struct ToolCall {
249 pub id: String,
250 #[serde(default)]
251 pub r#type: ToolType,
252 pub function: Function,
253}
254
255#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
256#[serde(rename_all = "lowercase")]
257pub enum ToolType {
258 #[default]
259 Function,
260}
261
262#[derive(Debug, Deserialize, Serialize, Clone)]
263pub struct ToolDefinition {
264 pub r#type: String,
265 pub function: completion::ToolDefinition,
266}
267
268impl From<completion::ToolDefinition> for ToolDefinition {
269 fn from(tool: completion::ToolDefinition) -> Self {
270 Self {
271 r#type: "function".into(),
272 function: tool,
273 }
274 }
275}
276
277#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
278#[serde(rename_all = "snake_case")]
279pub enum ToolChoice {
280 #[default]
281 Auto,
282 None,
283 Required,
284}
285
286impl TryFrom<crate::message::ToolChoice> for ToolChoice {
287 type Error = CompletionError;
288 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
289 let res = match value {
290 message::ToolChoice::Specific { .. } => {
291 return Err(CompletionError::ProviderError(
292 "Provider doesn't support only using specific tools".to_string(),
293 ));
294 }
295 message::ToolChoice::Auto => Self::Auto,
296 message::ToolChoice::None => Self::None,
297 message::ToolChoice::Required => Self::Required,
298 };
299
300 Ok(res)
301 }
302}
303
304#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
305pub struct Function {
306 pub name: String,
307 #[serde(with = "json_utils::stringified_json")]
308 pub arguments: serde_json::Value,
309}
310
311impl TryFrom<message::Message> for Vec<Message> {
312 type Error = message::MessageError;
313
314 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
315 match message {
316 message::Message::User { content } => {
317 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
318 .into_iter()
319 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
320
321 if !tool_results.is_empty() {
324 tool_results
325 .into_iter()
326 .map(|content| match content {
327 message::UserContent::ToolResult(message::ToolResult {
328 id,
329 content,
330 ..
331 }) => Ok::<_, message::MessageError>(Message::ToolResult {
332 tool_call_id: id,
333 content: content.try_map(|content| match content {
334 message::ToolResultContent::Text(message::Text { text }) => {
335 Ok(text.into())
336 }
337 _ => Err(message::MessageError::ConversionError(
338 "Tool result content does not support non-text".into(),
339 )),
340 })?,
341 }),
342 _ => unreachable!(),
343 })
344 .collect::<Result<Vec<_>, _>>()
345 } else {
346 let other_content: Vec<UserContent> = other_content
347 .into_iter()
348 .map(|content| match content {
349 message::UserContent::Text(message::Text { text }) => {
350 Ok(UserContent::Text { text })
351 }
352 message::UserContent::Image(message::Image {
353 data,
354 detail,
355 media_type,
356 ..
357 }) => match data {
358 DocumentSourceKind::Url(url) => Ok(UserContent::Image {
359 image_url: ImageUrl {
360 url,
361 detail: detail.unwrap_or_default(),
362 },
363 }),
364 DocumentSourceKind::Base64(data) => {
365 let url = format!(
366 "data:{};base64,{}",
367 media_type.map(|i| i.to_mime_type()).ok_or(
368 message::MessageError::ConversionError(
369 "OpenAI Image URI must have media type".into()
370 )
371 )?,
372 data
373 );
374
375 let detail =
376 detail.ok_or(message::MessageError::ConversionError(
377 "OpenAI image URI must have image detail".into(),
378 ))?;
379
380 Ok(UserContent::Image {
381 image_url: ImageUrl { url, detail },
382 })
383 }
384 DocumentSourceKind::Raw(_) => {
385 Err(message::MessageError::ConversionError(
386 "Raw files not supported, encode as base64 first".into(),
387 ))
388 }
389 DocumentSourceKind::Unknown => {
390 Err(message::MessageError::ConversionError(
391 "Document has no body".into(),
392 ))
393 }
394 doc => Err(message::MessageError::ConversionError(format!(
395 "Unsupported document type: {doc:?}"
396 ))),
397 },
398 message::UserContent::Document(message::Document { data, .. }) => {
399 if let DocumentSourceKind::Base64(text)
400 | DocumentSourceKind::String(text) = data
401 {
402 Ok(UserContent::Text { text })
403 } else {
404 Err(message::MessageError::ConversionError(
405 "Documents must be base64 or a string".into(),
406 ))
407 }
408 }
409 message::UserContent::Audio(message::Audio {
410 data: DocumentSourceKind::Base64(data),
411 media_type,
412 ..
413 }) => Ok(UserContent::Audio {
414 input_audio: InputAudio {
415 data,
416 format: match media_type {
417 Some(media_type) => media_type,
418 None => AudioMediaType::MP3,
419 },
420 },
421 }),
422 _ => Err(message::MessageError::ConversionError(
423 "Tool result is in unsupported format".into(),
424 )),
425 })
426 .collect::<Result<Vec<_>, _>>()?;
427
428 let other_content = OneOrMany::many(other_content).expect(
429 "There must be other content here if there were no tool result content",
430 );
431
432 Ok(vec![Message::User {
433 content: other_content,
434 name: None,
435 }])
436 }
437 }
438 message::Message::Assistant { content, .. } => {
439 let (text_content, tool_calls) = content.into_iter().fold(
440 (Vec::new(), Vec::new()),
441 |(mut texts, mut tools), content| {
442 match content {
443 message::AssistantContent::Text(text) => texts.push(text),
444 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
445 message::AssistantContent::Reasoning(_) => {
446 unimplemented!(
447 "The OpenAI Completions API doesn't support reasoning!"
448 );
449 }
450 }
451 (texts, tools)
452 },
453 );
454
455 Ok(vec![Message::Assistant {
458 content: text_content
459 .into_iter()
460 .map(|content| content.text.into())
461 .collect::<Vec<_>>(),
462 refusal: None,
463 audio: None,
464 name: None,
465 tool_calls: tool_calls
466 .into_iter()
467 .map(|tool_call| tool_call.into())
468 .collect::<Vec<_>>(),
469 }])
470 }
471 }
472 }
473}
474
475impl From<message::ToolCall> for ToolCall {
476 fn from(tool_call: message::ToolCall) -> Self {
477 Self {
478 id: tool_call.id,
479 r#type: ToolType::default(),
480 function: Function {
481 name: tool_call.function.name,
482 arguments: tool_call.function.arguments,
483 },
484 }
485 }
486}
487
488impl From<ToolCall> for message::ToolCall {
489 fn from(tool_call: ToolCall) -> Self {
490 Self {
491 id: tool_call.id,
492 call_id: None,
493 function: message::ToolFunction {
494 name: tool_call.function.name,
495 arguments: tool_call.function.arguments,
496 },
497 }
498 }
499}
500
501impl TryFrom<Message> for message::Message {
502 type Error = message::MessageError;
503
504 fn try_from(message: Message) -> Result<Self, Self::Error> {
505 Ok(match message {
506 Message::User { content, .. } => message::Message::User {
507 content: content.map(|content| content.into()),
508 },
509 Message::Assistant {
510 content,
511 tool_calls,
512 ..
513 } => {
514 let mut content = content
515 .into_iter()
516 .map(|content| match content {
517 AssistantContent::Text { text } => message::AssistantContent::text(text),
518
519 AssistantContent::Refusal { refusal } => {
522 message::AssistantContent::text(refusal)
523 }
524 })
525 .collect::<Vec<_>>();
526
527 content.extend(
528 tool_calls
529 .into_iter()
530 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
531 .collect::<Result<Vec<_>, _>>()?,
532 );
533
534 message::Message::Assistant {
535 id: None,
536 content: OneOrMany::many(content).map_err(|_| {
537 message::MessageError::ConversionError(
538 "Neither `content` nor `tool_calls` was provided to the Message"
539 .to_owned(),
540 )
541 })?,
542 }
543 }
544
545 Message::ToolResult {
546 tool_call_id,
547 content,
548 } => message::Message::User {
549 content: OneOrMany::one(message::UserContent::tool_result(
550 tool_call_id,
551 content.map(|content| message::ToolResultContent::text(content.text)),
552 )),
553 },
554
555 Message::System { content, .. } => message::Message::User {
558 content: content.map(|content| message::UserContent::text(content.text)),
559 },
560 })
561 }
562}
563
564impl From<UserContent> for message::UserContent {
565 fn from(content: UserContent) -> Self {
566 match content {
567 UserContent::Text { text } => message::UserContent::text(text),
568 UserContent::Image { image_url } => {
569 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
570 }
571 UserContent::Audio { input_audio } => {
572 message::UserContent::audio(input_audio.data, Some(input_audio.format))
573 }
574 }
575 }
576}
577
578impl From<String> for UserContent {
579 fn from(s: String) -> Self {
580 UserContent::Text { text: s }
581 }
582}
583
584impl FromStr for UserContent {
585 type Err = Infallible;
586
587 fn from_str(s: &str) -> Result<Self, Self::Err> {
588 Ok(UserContent::Text {
589 text: s.to_string(),
590 })
591 }
592}
593
594impl From<String> for AssistantContent {
595 fn from(s: String) -> Self {
596 AssistantContent::Text { text: s }
597 }
598}
599
600impl FromStr for AssistantContent {
601 type Err = Infallible;
602
603 fn from_str(s: &str) -> Result<Self, Self::Err> {
604 Ok(AssistantContent::Text {
605 text: s.to_string(),
606 })
607 }
608}
609impl From<String> for SystemContent {
610 fn from(s: String) -> Self {
611 SystemContent {
612 r#type: SystemContentType::default(),
613 text: s,
614 }
615 }
616}
617
618impl FromStr for SystemContent {
619 type Err = Infallible;
620
621 fn from_str(s: &str) -> Result<Self, Self::Err> {
622 Ok(SystemContent {
623 r#type: SystemContentType::default(),
624 text: s.to_string(),
625 })
626 }
627}
628
629#[derive(Debug, Deserialize, Serialize)]
630pub struct CompletionResponse {
631 pub id: String,
632 pub object: String,
633 pub created: u64,
634 pub model: String,
635 pub system_fingerprint: Option<String>,
636 pub choices: Vec<Choice>,
637 pub usage: Option<Usage>,
638}
639
640impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
641 type Error = CompletionError;
642
643 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
644 let choice = response.choices.first().ok_or_else(|| {
645 CompletionError::ResponseError("Response contained no choices".to_owned())
646 })?;
647
648 let content = match &choice.message {
649 Message::Assistant {
650 content,
651 tool_calls,
652 ..
653 } => {
654 let mut content = content
655 .iter()
656 .filter_map(|c| {
657 let s = match c {
658 AssistantContent::Text { text } => text,
659 AssistantContent::Refusal { refusal } => refusal,
660 };
661 if s.is_empty() {
662 None
663 } else {
664 Some(completion::AssistantContent::text(s))
665 }
666 })
667 .collect::<Vec<_>>();
668
669 content.extend(
670 tool_calls
671 .iter()
672 .map(|call| {
673 completion::AssistantContent::tool_call(
674 &call.id,
675 &call.function.name,
676 call.function.arguments.clone(),
677 )
678 })
679 .collect::<Vec<_>>(),
680 );
681 Ok(content)
682 }
683 _ => Err(CompletionError::ResponseError(
684 "Response did not contain a valid message or tool call".into(),
685 )),
686 }?;
687
688 let choice = OneOrMany::many(content).map_err(|_| {
689 CompletionError::ResponseError(
690 "Response contained no message or tool call (empty)".to_owned(),
691 )
692 })?;
693
694 let usage = response
695 .usage
696 .as_ref()
697 .map(|usage| completion::Usage {
698 input_tokens: usage.prompt_tokens as u64,
699 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
700 total_tokens: usage.total_tokens as u64,
701 })
702 .unwrap_or_default();
703
704 Ok(completion::CompletionResponse {
705 choice,
706 usage,
707 raw_response: response,
708 })
709 }
710}
711
712impl ProviderResponseExt for CompletionResponse {
713 type OutputMessage = Choice;
714 type Usage = Usage;
715
716 fn get_response_id(&self) -> Option<String> {
717 Some(self.id.to_owned())
718 }
719
720 fn get_response_model_name(&self) -> Option<String> {
721 Some(self.model.to_owned())
722 }
723
724 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
725 self.choices.clone()
726 }
727
728 fn get_text_response(&self) -> Option<String> {
729 let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
730 return None;
731 };
732
733 let UserContent::Text { text } = content.first() else {
734 return None;
735 };
736
737 Some(text)
738 }
739
740 fn get_usage(&self) -> Option<Self::Usage> {
741 self.usage.clone()
742 }
743}
744
745#[derive(Clone, Debug, Serialize, Deserialize)]
746pub struct Choice {
747 pub index: usize,
748 pub message: Message,
749 pub logprobs: Option<serde_json::Value>,
750 pub finish_reason: String,
751}
752
753#[derive(Clone, Debug, Deserialize, Serialize)]
754pub struct Usage {
755 pub prompt_tokens: usize,
756 pub total_tokens: usize,
757}
758
759impl Usage {
760 pub fn new() -> Self {
761 Self {
762 prompt_tokens: 0,
763 total_tokens: 0,
764 }
765 }
766}
767
768impl Default for Usage {
769 fn default() -> Self {
770 Self::new()
771 }
772}
773
774impl fmt::Display for Usage {
775 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
776 let Usage {
777 prompt_tokens,
778 total_tokens,
779 } = self;
780 write!(
781 f,
782 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
783 )
784 }
785}
786
787impl GetTokenUsage for Usage {
788 fn token_usage(&self) -> Option<crate::completion::Usage> {
789 let mut usage = crate::completion::Usage::new();
790 usage.input_tokens = self.prompt_tokens as u64;
791 usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
792 usage.total_tokens = self.total_tokens as u64;
793
794 Some(usage)
795 }
796}
797
798#[derive(Clone)]
799pub struct CompletionModel<T = reqwest::Client> {
800 pub(crate) client: Client<T>,
801 pub model: String,
803}
804
805impl<T> CompletionModel<T>
806where
807 T: HttpClientExt + Default + std::fmt::Debug + Clone + 'static,
808{
809 pub fn new(client: Client<T>, model: &str) -> Self {
810 Self {
811 client,
812 model: model.to_string(),
813 }
814 }
815}
816
817#[derive(Debug, Serialize, Deserialize, Clone)]
818pub struct CompletionRequest {
819 model: String,
820 messages: Vec<Message>,
821 #[serde(skip_serializing_if = "Vec::is_empty")]
822 tools: Vec<ToolDefinition>,
823 #[serde(skip_serializing_if = "Option::is_none")]
824 tool_choice: Option<ToolChoice>,
825 #[serde(skip_serializing_if = "Option::is_none")]
826 temperature: Option<f64>,
827 #[serde(flatten)]
828 additional_params: Option<serde_json::Value>,
829}
830
831impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
832 type Error = CompletionError;
833
834 fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
835 let mut partial_history = vec![];
836 if let Some(docs) = req.normalized_documents() {
837 partial_history.push(docs);
838 }
839 let CoreCompletionRequest {
840 preamble,
841 chat_history,
842 tools,
843 temperature,
844 additional_params,
845 tool_choice,
846 ..
847 } = req;
848
849 partial_history.extend(chat_history);
850
851 let mut full_history: Vec<Message> =
852 preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
853
854 full_history.extend(
856 partial_history
857 .into_iter()
858 .map(message::Message::try_into)
859 .collect::<Result<Vec<Vec<Message>>, _>>()?
860 .into_iter()
861 .flatten()
862 .collect::<Vec<_>>(),
863 );
864
865 let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
866
867 let res = Self {
868 model,
869 messages: full_history,
870 tools: tools
871 .into_iter()
872 .map(ToolDefinition::from)
873 .collect::<Vec<_>>(),
874 tool_choice,
875 temperature,
876 additional_params,
877 };
878
879 Ok(res)
880 }
881}
882
883impl crate::telemetry::ProviderRequestExt for CompletionRequest {
884 type InputMessage = Message;
885
886 fn get_input_messages(&self) -> Vec<Self::InputMessage> {
887 self.messages.clone()
888 }
889
890 fn get_system_prompt(&self) -> Option<String> {
891 let first_message = self.messages.first()?;
892
893 let Message::System { ref content, .. } = first_message.clone() else {
894 return None;
895 };
896
897 let SystemContent { text, .. } = content.first();
898
899 Some(text)
900 }
901
902 fn get_prompt(&self) -> Option<String> {
903 let last_message = self.messages.last()?;
904
905 let Message::User { ref content, .. } = last_message.clone() else {
906 return None;
907 };
908
909 let UserContent::Text { text } = content.first() else {
910 return None;
911 };
912
913 Some(text)
914 }
915
916 fn get_model_name(&self) -> String {
917 self.model.clone()
918 }
919}
920
921impl CompletionModel<reqwest::Client> {
922 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
923 crate::agent::AgentBuilder::new(self)
924 }
925}
926
927impl completion::CompletionModel for CompletionModel<reqwest::Client> {
928 type Response = CompletionResponse;
929 type StreamingResponse = StreamingCompletionResponse;
930
931 #[cfg_attr(feature = "worker", worker::send)]
932 async fn completion(
933 &self,
934 completion_request: CoreCompletionRequest,
935 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
936 let span = if tracing::Span::current().is_disabled() {
937 info_span!(
938 target: "rig::completions",
939 "chat",
940 gen_ai.operation.name = "chat",
941 gen_ai.provider.name = "openai",
942 gen_ai.request.model = self.model,
943 gen_ai.system_instructions = &completion_request.preamble,
944 gen_ai.response.id = tracing::field::Empty,
945 gen_ai.response.model = tracing::field::Empty,
946 gen_ai.usage.output_tokens = tracing::field::Empty,
947 gen_ai.usage.input_tokens = tracing::field::Empty,
948 gen_ai.input.messages = tracing::field::Empty,
949 gen_ai.output.messages = tracing::field::Empty,
950 )
951 } else {
952 tracing::Span::current()
953 };
954
955 let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?;
956
957 span.record_model_input(&request.messages);
958
959 let body = serde_json::to_vec(&request)?;
960
961 let req = self
962 .client
963 .post("/chat/completions")?
964 .header("Content-Type", "application/json")
965 .body(body)
966 .map_err(|e| CompletionError::HttpError(e.into()))?;
967
968 async move {
969 let response = self.client.send(req).await?;
970
971 if response.status().is_success() {
972 let text = http_client::text(response).await?;
973
974 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
975 ApiResponse::Ok(response) => {
976 let span = tracing::Span::current();
977 span.record_model_output(&response.choices);
978 span.record_response_metadata(&response);
979 span.record_token_usage(&response.usage);
980 tracing::debug!("OpenAI response: {response:?}");
981 response.try_into()
982 }
983 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
984 }
985 } else {
986 let text = http_client::text(response).await?;
987 Err(CompletionError::ProviderError(text))
988 }
989 }
990 .instrument(span)
991 .await
992 }
993
994 #[cfg_attr(feature = "worker", worker::send)]
995 async fn stream(
996 &self,
997 request: CoreCompletionRequest,
998 ) -> Result<
999 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1000 CompletionError,
1001 > {
1002 CompletionModel::stream(self, request).await
1003 }
1004}