1use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse};
6use crate::completion::{CompletionError, CompletionRequest};
7use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail};
8use crate::one_or_many::string_or_one_or_many;
9use crate::{OneOrMany, completion, json_utils, message};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::convert::Infallible;
13use std::fmt;
14
15use std::str::FromStr;
16
17pub mod streaming;
18
19pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
21pub const O4_MINI: &str = "o4-mini";
23pub const O3: &str = "o3";
25pub const O3_MINI: &str = "o3-mini";
27pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
29pub const O1_PRO: &str = "o1-pro";
31pub const O1: &str = "o1";
33pub const O1_2024_12_17: &str = "o1-2024-12-17";
35pub const O1_PREVIEW: &str = "o1-preview";
37pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
39pub const O1_MINI: &str = "o1-mini";
41pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
43
44pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
46pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
48pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
50pub const GPT_4_1: &str = "gpt-4.1";
52pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
54pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
56pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
58pub const GPT_4O: &str = "gpt-4o";
60pub const GPT_4O_MINI: &str = "gpt-4o-mini";
62pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
64pub const GPT_4_TURBO: &str = "gpt-4-turbo";
66pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
68pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
70pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
72pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
74pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
76pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
78pub const GPT_4: &str = "gpt-4";
80pub const GPT_4_0613: &str = "gpt-4-0613";
82pub const GPT_4_32K: &str = "gpt-4-32k";
84pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
86pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
88pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
90pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
92pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
94
95impl From<ApiErrorResponse> for CompletionError {
96 fn from(err: ApiErrorResponse) -> Self {
97 CompletionError::ProviderError(err.message)
98 }
99}
100
101#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
102#[serde(tag = "role", rename_all = "lowercase")]
103pub enum Message {
104 #[serde(alias = "developer")]
105 System {
106 #[serde(deserialize_with = "string_or_one_or_many")]
107 content: OneOrMany<SystemContent>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 name: Option<String>,
110 },
111 User {
112 #[serde(deserialize_with = "string_or_one_or_many")]
113 content: OneOrMany<UserContent>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 name: Option<String>,
116 },
117 Assistant {
118 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
119 content: Vec<AssistantContent>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 refusal: Option<String>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 audio: Option<AudioAssistant>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 name: Option<String>,
126 #[serde(
127 default,
128 deserialize_with = "json_utils::null_or_vec",
129 skip_serializing_if = "Vec::is_empty"
130 )]
131 tool_calls: Vec<ToolCall>,
132 },
133 #[serde(rename = "tool")]
134 ToolResult {
135 tool_call_id: String,
136 content: OneOrMany<ToolResultContent>,
137 },
138}
139
140impl Message {
141 pub fn system(content: &str) -> Self {
142 Message::System {
143 content: OneOrMany::one(content.to_owned().into()),
144 name: None,
145 }
146 }
147}
148
149#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
150pub struct AudioAssistant {
151 pub id: String,
152}
153
154#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
155pub struct SystemContent {
156 #[serde(default)]
157 pub r#type: SystemContentType,
158 pub text: String,
159}
160
161#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
162#[serde(rename_all = "lowercase")]
163pub enum SystemContentType {
164 #[default]
165 Text,
166}
167
168#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
169#[serde(tag = "type", rename_all = "lowercase")]
170pub enum AssistantContent {
171 Text { text: String },
172 Refusal { refusal: String },
173}
174
175impl From<AssistantContent> for completion::AssistantContent {
176 fn from(value: AssistantContent) -> Self {
177 match value {
178 AssistantContent::Text { text } => completion::AssistantContent::text(text),
179 AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
180 }
181 }
182}
183
184#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
185#[serde(tag = "type", rename_all = "lowercase")]
186pub enum UserContent {
187 Text {
188 text: String,
189 },
190 #[serde(rename = "image_url")]
191 Image {
192 image_url: ImageUrl,
193 },
194 Audio {
195 input_audio: InputAudio,
196 },
197}
198
199#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
200pub struct ImageUrl {
201 pub url: String,
202 #[serde(default)]
203 pub detail: ImageDetail,
204}
205
206#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
207pub struct InputAudio {
208 pub data: String,
209 pub format: AudioMediaType,
210}
211
212#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
213pub struct ToolResultContent {
214 #[serde(default)]
215 r#type: ToolResultContentType,
216 pub text: String,
217}
218
219#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
220#[serde(rename_all = "lowercase")]
221pub enum ToolResultContentType {
222 #[default]
223 Text,
224}
225
226impl FromStr for ToolResultContent {
227 type Err = Infallible;
228
229 fn from_str(s: &str) -> Result<Self, Self::Err> {
230 Ok(s.to_owned().into())
231 }
232}
233
234impl From<String> for ToolResultContent {
235 fn from(s: String) -> Self {
236 ToolResultContent {
237 r#type: ToolResultContentType::default(),
238 text: s,
239 }
240 }
241}
242
243#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
244pub struct ToolCall {
245 pub id: String,
246 #[serde(default)]
247 pub r#type: ToolType,
248 pub function: Function,
249}
250
251#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
252#[serde(rename_all = "lowercase")]
253pub enum ToolType {
254 #[default]
255 Function,
256}
257
258#[derive(Debug, Deserialize, Serialize, Clone)]
259pub struct ToolDefinition {
260 pub r#type: String,
261 pub function: completion::ToolDefinition,
262}
263
264impl From<completion::ToolDefinition> for ToolDefinition {
265 fn from(tool: completion::ToolDefinition) -> Self {
266 Self {
267 r#type: "function".into(),
268 function: tool,
269 }
270 }
271}
272
273#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
274pub struct Function {
275 pub name: String,
276 #[serde(with = "json_utils::stringified_json")]
277 pub arguments: serde_json::Value,
278}
279
280impl TryFrom<message::Message> for Vec<Message> {
281 type Error = message::MessageError;
282
283 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
284 match message {
285 message::Message::User { content } => {
286 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287 .into_iter()
288 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290 if !tool_results.is_empty() {
293 tool_results
294 .into_iter()
295 .map(|content| match content {
296 message::UserContent::ToolResult(message::ToolResult {
297 id,
298 content,
299 ..
300 }) => Ok::<_, message::MessageError>(Message::ToolResult {
301 tool_call_id: id,
302 content: content.try_map(|content| match content {
303 message::ToolResultContent::Text(message::Text { text }) => {
304 Ok(text.into())
305 }
306 _ => Err(message::MessageError::ConversionError(
307 "Tool result content does not support non-text".into(),
308 )),
309 })?,
310 }),
311 _ => unreachable!(),
312 })
313 .collect::<Result<Vec<_>, _>>()
314 } else {
315 let other_content: Vec<UserContent> = other_content.into_iter().map(|content| match content {
316 message::UserContent::Text(message::Text { text }) => {
317 Ok(UserContent::Text { text })
318 }
319 message::UserContent::Image(message::Image {
320 data, detail, ..
321 }) => {
322 let DocumentSourceKind::Url(url) = data else { return Err(message::MessageError::ConversionError(
323 "Only image URL user content is accepted with OpenAI Chat Completions API".to_string()
324 ))};
325
326 Ok(UserContent::Image {
327 image_url: ImageUrl {
328 url,
329 detail: detail.unwrap_or_default(),
330 }
331 }
332 )
333
334 },
335 message::UserContent::Document(message::Document { data, .. }) => {
336 Ok(UserContent::Text { text: data })
337 }
338 message::UserContent::Audio(message::Audio {
339 data,
340 media_type,
341 ..
342 }) => Ok(UserContent::Audio {
343 input_audio: InputAudio {
344 data,
345 format: match media_type {
346 Some(media_type) => media_type,
347 None => AudioMediaType::MP3,
348 },
349 },
350 }),
351 _ => unreachable!(),
352 }).collect::<Result<Vec<_>, _>>()?;
353
354 let other_content = OneOrMany::many(other_content).expect(
355 "There must be other content here if there were no tool result content",
356 );
357
358 Ok(vec![Message::User {
359 content: other_content,
360 name: None,
361 }])
362 }
363 }
364 message::Message::Assistant { content, .. } => {
365 let (text_content, tool_calls) = content.into_iter().fold(
366 (Vec::new(), Vec::new()),
367 |(mut texts, mut tools), content| {
368 match content {
369 message::AssistantContent::Text(text) => texts.push(text),
370 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
371 message::AssistantContent::Reasoning(_) => {
372 unimplemented!(
373 "The OpenAI Completions API doesn't support reasoning!"
374 );
375 }
376 }
377 (texts, tools)
378 },
379 );
380
381 Ok(vec![Message::Assistant {
384 content: text_content
385 .into_iter()
386 .map(|content| content.text.into())
387 .collect::<Vec<_>>(),
388 refusal: None,
389 audio: None,
390 name: None,
391 tool_calls: tool_calls
392 .into_iter()
393 .map(|tool_call| tool_call.into())
394 .collect::<Vec<_>>(),
395 }])
396 }
397 }
398 }
399}
400
401impl From<message::ToolCall> for ToolCall {
402 fn from(tool_call: message::ToolCall) -> Self {
403 Self {
404 id: tool_call.id,
405 r#type: ToolType::default(),
406 function: Function {
407 name: tool_call.function.name,
408 arguments: tool_call.function.arguments,
409 },
410 }
411 }
412}
413
414impl From<ToolCall> for message::ToolCall {
415 fn from(tool_call: ToolCall) -> Self {
416 Self {
417 id: tool_call.id,
418 call_id: None,
419 function: message::ToolFunction {
420 name: tool_call.function.name,
421 arguments: tool_call.function.arguments,
422 },
423 }
424 }
425}
426
427impl TryFrom<Message> for message::Message {
428 type Error = message::MessageError;
429
430 fn try_from(message: Message) -> Result<Self, Self::Error> {
431 Ok(match message {
432 Message::User { content, .. } => message::Message::User {
433 content: content.map(|content| content.into()),
434 },
435 Message::Assistant {
436 content,
437 tool_calls,
438 ..
439 } => {
440 let mut content = content
441 .into_iter()
442 .map(|content| match content {
443 AssistantContent::Text { text } => message::AssistantContent::text(text),
444
445 AssistantContent::Refusal { refusal } => {
448 message::AssistantContent::text(refusal)
449 }
450 })
451 .collect::<Vec<_>>();
452
453 content.extend(
454 tool_calls
455 .into_iter()
456 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
457 .collect::<Result<Vec<_>, _>>()?,
458 );
459
460 message::Message::Assistant {
461 id: None,
462 content: OneOrMany::many(content).map_err(|_| {
463 message::MessageError::ConversionError(
464 "Neither `content` nor `tool_calls` was provided to the Message"
465 .to_owned(),
466 )
467 })?,
468 }
469 }
470
471 Message::ToolResult {
472 tool_call_id,
473 content,
474 } => message::Message::User {
475 content: OneOrMany::one(message::UserContent::tool_result(
476 tool_call_id,
477 content.map(|content| message::ToolResultContent::text(content.text)),
478 )),
479 },
480
481 Message::System { content, .. } => message::Message::User {
484 content: content.map(|content| message::UserContent::text(content.text)),
485 },
486 })
487 }
488}
489
490impl From<UserContent> for message::UserContent {
491 fn from(content: UserContent) -> Self {
492 match content {
493 UserContent::Text { text } => message::UserContent::text(text),
494 UserContent::Image { image_url } => {
495 message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
496 }
497 UserContent::Audio { input_audio } => message::UserContent::audio(
498 input_audio.data,
499 Some(message::ContentFormat::default()),
500 Some(input_audio.format),
501 ),
502 }
503 }
504}
505
506impl From<String> for UserContent {
507 fn from(s: String) -> Self {
508 UserContent::Text { text: s }
509 }
510}
511
512impl FromStr for UserContent {
513 type Err = Infallible;
514
515 fn from_str(s: &str) -> Result<Self, Self::Err> {
516 Ok(UserContent::Text {
517 text: s.to_string(),
518 })
519 }
520}
521
522impl From<String> for AssistantContent {
523 fn from(s: String) -> Self {
524 AssistantContent::Text { text: s }
525 }
526}
527
528impl FromStr for AssistantContent {
529 type Err = Infallible;
530
531 fn from_str(s: &str) -> Result<Self, Self::Err> {
532 Ok(AssistantContent::Text {
533 text: s.to_string(),
534 })
535 }
536}
537impl From<String> for SystemContent {
538 fn from(s: String) -> Self {
539 SystemContent {
540 r#type: SystemContentType::default(),
541 text: s,
542 }
543 }
544}
545
546impl FromStr for SystemContent {
547 type Err = Infallible;
548
549 fn from_str(s: &str) -> Result<Self, Self::Err> {
550 Ok(SystemContent {
551 r#type: SystemContentType::default(),
552 text: s.to_string(),
553 })
554 }
555}
556
557#[derive(Debug, Deserialize, Serialize)]
558pub struct CompletionResponse {
559 pub id: String,
560 pub object: String,
561 pub created: u64,
562 pub model: String,
563 pub system_fingerprint: Option<String>,
564 pub choices: Vec<Choice>,
565 pub usage: Option<Usage>,
566}
567
568impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
569 type Error = CompletionError;
570
571 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
572 let choice = response.choices.first().ok_or_else(|| {
573 CompletionError::ResponseError("Response contained no choices".to_owned())
574 })?;
575
576 let content = match &choice.message {
577 Message::Assistant {
578 content,
579 tool_calls,
580 ..
581 } => {
582 let mut content = content
583 .iter()
584 .filter_map(|c| {
585 let s = match c {
586 AssistantContent::Text { text } => text,
587 AssistantContent::Refusal { refusal } => refusal,
588 };
589 if s.is_empty() {
590 None
591 } else {
592 Some(completion::AssistantContent::text(s))
593 }
594 })
595 .collect::<Vec<_>>();
596
597 content.extend(
598 tool_calls
599 .iter()
600 .map(|call| {
601 completion::AssistantContent::tool_call(
602 &call.id,
603 &call.function.name,
604 call.function.arguments.clone(),
605 )
606 })
607 .collect::<Vec<_>>(),
608 );
609 Ok(content)
610 }
611 _ => Err(CompletionError::ResponseError(
612 "Response did not contain a valid message or tool call".into(),
613 )),
614 }?;
615
616 let choice = OneOrMany::many(content).map_err(|_| {
617 CompletionError::ResponseError(
618 "Response contained no message or tool call (empty)".to_owned(),
619 )
620 })?;
621
622 let usage = response
623 .usage
624 .as_ref()
625 .map(|usage| completion::Usage {
626 input_tokens: usage.prompt_tokens as u64,
627 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
628 total_tokens: usage.total_tokens as u64,
629 })
630 .unwrap_or_default();
631
632 Ok(completion::CompletionResponse {
633 choice,
634 usage,
635 raw_response: response,
636 })
637 }
638}
639
640#[derive(Debug, Serialize, Deserialize)]
641pub struct Choice {
642 pub index: usize,
643 pub message: Message,
644 pub logprobs: Option<serde_json::Value>,
645 pub finish_reason: String,
646}
647
648#[derive(Clone, Debug, Deserialize, Serialize)]
649pub struct Usage {
650 pub prompt_tokens: usize,
651 pub total_tokens: usize,
652}
653
654impl Usage {
655 pub fn new() -> Self {
656 Self {
657 prompt_tokens: 0,
658 total_tokens: 0,
659 }
660 }
661}
662
663impl Default for Usage {
664 fn default() -> Self {
665 Self::new()
666 }
667}
668
669impl fmt::Display for Usage {
670 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
671 let Usage {
672 prompt_tokens,
673 total_tokens,
674 } = self;
675 write!(
676 f,
677 "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
678 )
679 }
680}
681
682#[derive(Clone)]
683pub struct CompletionModel {
684 pub(crate) client: Client,
685 pub model: String,
687}
688
689impl CompletionModel {
690 pub fn new(client: Client, model: &str) -> Self {
691 Self {
692 client,
693 model: model.to_string(),
694 }
695 }
696
697 pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
698 crate::agent::AgentBuilder::new(self)
699 }
700
701 pub(crate) fn create_completion_request(
702 &self,
703 completion_request: CompletionRequest,
704 ) -> Result<Value, CompletionError> {
705 let mut partial_history = vec![];
707 if let Some(docs) = completion_request.normalized_documents() {
708 partial_history.push(docs);
709 }
710 partial_history.extend(completion_request.chat_history);
711
712 let mut full_history: Vec<Message> = completion_request
714 .preamble
715 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
716
717 full_history.extend(
719 partial_history
720 .into_iter()
721 .map(message::Message::try_into)
722 .collect::<Result<Vec<Vec<Message>>, _>>()?
723 .into_iter()
724 .flatten()
725 .collect::<Vec<_>>(),
726 );
727
728 let request = if completion_request.tools.is_empty() {
729 serde_json::json!({
730 "model": self.model,
731 "messages": full_history,
732
733 })
734 } else {
735 json!({
736 "model": self.model,
737 "messages": full_history,
738 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
739 "tool_choice": "auto",
740 })
741 };
742
743 let request = if let Some(temperature) = completion_request.temperature {
746 json_utils::merge(
747 request,
748 json!({
749 "temperature": temperature,
750 }),
751 )
752 } else {
753 request
754 };
755
756 let request = if let Some(params) = completion_request.additional_params {
757 json_utils::merge(request, params)
758 } else {
759 request
760 };
761
762 Ok(request)
763 }
764}
765
766impl completion::CompletionModel for CompletionModel {
767 type Response = CompletionResponse;
768 type StreamingResponse = StreamingCompletionResponse;
769
770 #[cfg_attr(feature = "worker", worker::send)]
771 async fn completion(
772 &self,
773 completion_request: CompletionRequest,
774 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
775 let request = self.create_completion_request(completion_request)?;
776
777 tracing::debug!(
778 "OpenAI request: {request}",
779 request = serde_json::to_string_pretty(&request).unwrap()
780 );
781
782 let response = self
783 .client
784 .post("/chat/completions")
785 .json(&request)
786 .send()
787 .await?;
788
789 if response.status().is_success() {
790 let t = response.text().await?;
791 tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
792
793 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
794 ApiResponse::Ok(response) => {
795 tracing::info!(target: "rig",
796 "OpenAI completion token usage: {:?}",
797 response.usage.clone().map(|usage| format!("{}", usage.total_tokens)).unwrap_or("N/A".to_string())
798 );
799 response.try_into()
800 }
801 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
802 }
803 } else {
804 Err(CompletionError::ProviderError(response.text().await?))
805 }
806 }
807
808 #[cfg_attr(feature = "worker", worker::send)]
809 async fn stream(
810 &self,
811 request: CompletionRequest,
812 ) -> Result<
813 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
814 CompletionError,
815 > {
816 CompletionModel::stream(self, request).await
817 }
818}