1use super::{ApiErrorResponse, ApiResponse, Client, Usage};
6use crate::completion::{CompletionError, CompletionRequest};
7use crate::message::{AudioMediaType, ImageDetail};
8use crate::one_or_many::string_or_one_or_many;
9use crate::{completion, json_utils, message, OneOrMany};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::convert::Infallible;
13use std::str::FromStr;
14
15pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
17pub const O4_MINI: &str = "o4-mini";
19pub const O3: &str = "o3";
21pub const O3_MINI: &str = "o3-mini";
23pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
25pub const O1_PRO: &str = "o1-pro";
27pub const O1: &str = "o1";
29pub const O1_2024_12_17: &str = "o1-2024-12-17";
31pub const O1_PREVIEW: &str = "o1-preview";
33pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
35pub const O1_MINI: &str = "o1-mini";
37pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
39
40pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
42pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
44pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
46pub const GPT_4_1: &str = "gpt-4.1";
48pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
50pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
52pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
54pub const GPT_4O: &str = "gpt-4o";
56pub const GPT_4O_MINI: &str = "gpt-4o-mini";
58pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
60pub const GPT_4_TURBO: &str = "gpt-4-turbo";
62pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
64pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
66pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
68pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
70pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
72pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
74pub const GPT_4: &str = "gpt-4";
76pub const GPT_4_0613: &str = "gpt-4-0613";
78pub const GPT_4_32K: &str = "gpt-4-32k";
80pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
82pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
84pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
86pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
88pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
90
91#[derive(Debug, Deserialize)]
92pub struct CompletionResponse {
93 pub id: String,
94 pub object: String,
95 pub created: u64,
96 pub model: String,
97 pub system_fingerprint: Option<String>,
98 pub choices: Vec<Choice>,
99 pub usage: Option<Usage>,
100}
101
102impl From<ApiErrorResponse> for CompletionError {
103 fn from(err: ApiErrorResponse) -> Self {
104 CompletionError::ProviderError(err.message)
105 }
106}
107
108impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
109 type Error = CompletionError;
110
111 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
112 let choice = response.choices.first().ok_or_else(|| {
113 CompletionError::ResponseError("Response contained no choices".to_owned())
114 })?;
115
116 let content = match &choice.message {
117 Message::Assistant {
118 content,
119 tool_calls,
120 ..
121 } => {
122 let mut content = content
123 .iter()
124 .filter_map(|c| {
125 let s = match c {
126 AssistantContent::Text { text } => text,
127 AssistantContent::Refusal { refusal } => refusal,
128 };
129 if s.is_empty() {
130 None
131 } else {
132 Some(completion::AssistantContent::text(s))
133 }
134 })
135 .collect::<Vec<_>>();
136
137 content.extend(
138 tool_calls
139 .iter()
140 .map(|call| {
141 completion::AssistantContent::tool_call(
142 &call.id,
143 &call.function.name,
144 call.function.arguments.clone(),
145 )
146 })
147 .collect::<Vec<_>>(),
148 );
149 Ok(content)
150 }
151 _ => Err(CompletionError::ResponseError(
152 "Response did not contain a valid message or tool call".into(),
153 )),
154 }?;
155
156 let choice = OneOrMany::many(content).map_err(|_| {
157 CompletionError::ResponseError(
158 "Response contained no message or tool call (empty)".to_owned(),
159 )
160 })?;
161
162 Ok(completion::CompletionResponse {
163 choice,
164 raw_response: response,
165 })
166 }
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170pub struct Choice {
171 pub index: usize,
172 pub message: Message,
173 pub logprobs: Option<serde_json::Value>,
174 pub finish_reason: String,
175}
176
177#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
178#[serde(tag = "role", rename_all = "lowercase")]
179pub enum Message {
180 #[serde(alias = "developer")]
181 System {
182 #[serde(deserialize_with = "string_or_one_or_many")]
183 content: OneOrMany<SystemContent>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 name: Option<String>,
186 },
187 User {
188 #[serde(deserialize_with = "string_or_one_or_many")]
189 content: OneOrMany<UserContent>,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 name: Option<String>,
192 },
193 Assistant {
194 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
195 content: Vec<AssistantContent>,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 refusal: Option<String>,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 audio: Option<AudioAssistant>,
200 #[serde(skip_serializing_if = "Option::is_none")]
201 name: Option<String>,
202 #[serde(
203 default,
204 deserialize_with = "json_utils::null_or_vec",
205 skip_serializing_if = "Vec::is_empty"
206 )]
207 tool_calls: Vec<ToolCall>,
208 },
209 #[serde(rename = "tool")]
210 ToolResult {
211 tool_call_id: String,
212 content: OneOrMany<ToolResultContent>,
213 },
214}
215
216impl Message {
217 pub fn system(content: &str) -> Self {
218 Message::System {
219 content: OneOrMany::one(content.to_owned().into()),
220 name: None,
221 }
222 }
223}
224
225#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
226pub struct AudioAssistant {
227 id: String,
228}
229
230#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
231pub struct SystemContent {
232 #[serde(default)]
233 r#type: SystemContentType,
234 text: String,
235}
236
237#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
238#[serde(rename_all = "lowercase")]
239pub enum SystemContentType {
240 #[default]
241 Text,
242}
243
244#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
245#[serde(tag = "type", rename_all = "lowercase")]
246pub enum AssistantContent {
247 Text { text: String },
248 Refusal { refusal: String },
249}
250
251#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
252#[serde(tag = "type", rename_all = "lowercase")]
253pub enum UserContent {
254 Text {
255 text: String,
256 },
257 #[serde(rename = "image_url")]
258 Image {
259 image_url: ImageUrl,
260 },
261 Audio {
262 input_audio: InputAudio,
263 },
264}
265
266#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
267pub struct ImageUrl {
268 pub url: String,
269 #[serde(default)]
270 pub detail: ImageDetail,
271}
272
273#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
274pub struct InputAudio {
275 pub data: String,
276 pub format: AudioMediaType,
277}
278
279#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
280pub struct ToolResultContent {
281 #[serde(default)]
282 r#type: ToolResultContentType,
283 text: String,
284}
285
286#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
287#[serde(rename_all = "lowercase")]
288pub enum ToolResultContentType {
289 #[default]
290 Text,
291}
292
293impl FromStr for ToolResultContent {
294 type Err = Infallible;
295
296 fn from_str(s: &str) -> Result<Self, Self::Err> {
297 Ok(s.to_owned().into())
298 }
299}
300
301impl From<String> for ToolResultContent {
302 fn from(s: String) -> Self {
303 ToolResultContent {
304 r#type: ToolResultContentType::default(),
305 text: s,
306 }
307 }
308}
309
310#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
311pub struct ToolCall {
312 pub id: String,
313 #[serde(default)]
314 pub r#type: ToolType,
315 pub function: Function,
316}
317
318#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
319#[serde(rename_all = "lowercase")]
320pub enum ToolType {
321 #[default]
322 Function,
323}
324
325#[derive(Debug, Deserialize, Serialize, Clone)]
326pub struct ToolDefinition {
327 pub r#type: String,
328 pub function: completion::ToolDefinition,
329}
330
331impl From<completion::ToolDefinition> for ToolDefinition {
332 fn from(tool: completion::ToolDefinition) -> Self {
333 Self {
334 r#type: "function".into(),
335 function: tool,
336 }
337 }
338}
339
340#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
341pub struct Function {
342 pub name: String,
343 #[serde(with = "json_utils::stringified_json")]
344 pub arguments: serde_json::Value,
345}
346
347impl TryFrom<message::Message> for Vec<Message> {
348 type Error = message::MessageError;
349
350 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
351 match message {
352 message::Message::User { content } => {
353 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
354 .into_iter()
355 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
356
357 if !tool_results.is_empty() {
360 tool_results
361 .into_iter()
362 .map(|content| match content {
363 message::UserContent::ToolResult(message::ToolResult {
364 id,
365 content,
366 }) => Ok::<_, message::MessageError>(Message::ToolResult {
367 tool_call_id: id,
368 content: content.try_map(|content| match content {
369 message::ToolResultContent::Text(message::Text { text }) => {
370 Ok(text.into())
371 }
372 _ => Err(message::MessageError::ConversionError(
373 "Tool result content does not support non-text".into(),
374 )),
375 })?,
376 }),
377 _ => unreachable!(),
378 })
379 .collect::<Result<Vec<_>, _>>()
380 } else {
381 let other_content = OneOrMany::many(other_content).expect(
382 "There must be other content here if there were no tool result content",
383 );
384
385 Ok(vec![Message::User {
386 content: other_content.map(|content| match content {
387 message::UserContent::Text(message::Text { text }) => {
388 UserContent::Text { text }
389 }
390 message::UserContent::Image(message::Image {
391 data, detail, ..
392 }) => UserContent::Image {
393 image_url: ImageUrl {
394 url: data,
395 detail: detail.unwrap_or_default(),
396 },
397 },
398 message::UserContent::Document(message::Document { data, .. }) => {
399 UserContent::Text { text: data }
400 }
401 message::UserContent::Audio(message::Audio {
402 data,
403 media_type,
404 ..
405 }) => UserContent::Audio {
406 input_audio: InputAudio {
407 data,
408 format: match media_type {
409 Some(media_type) => media_type,
410 None => AudioMediaType::MP3,
411 },
412 },
413 },
414 _ => unreachable!(),
415 }),
416 name: None,
417 }])
418 }
419 }
420 message::Message::Assistant { content } => {
421 let (text_content, tool_calls) = content.into_iter().fold(
422 (Vec::new(), Vec::new()),
423 |(mut texts, mut tools), content| {
424 match content {
425 message::AssistantContent::Text(text) => texts.push(text),
426 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
427 }
428 (texts, tools)
429 },
430 );
431
432 Ok(vec![Message::Assistant {
435 content: text_content
436 .into_iter()
437 .map(|content| content.text.into())
438 .collect::<Vec<_>>(),
439 refusal: None,
440 audio: None,
441 name: None,
442 tool_calls: tool_calls
443 .into_iter()
444 .map(|tool_call| tool_call.into())
445 .collect::<Vec<_>>(),
446 }])
447 }
448 }
449 }
450}
451
452impl From<message::ToolCall> for ToolCall {
453 fn from(tool_call: message::ToolCall) -> Self {
454 Self {
455 id: tool_call.id,
456 r#type: ToolType::default(),
457 function: Function {
458 name: tool_call.function.name,
459 arguments: tool_call.function.arguments,
460 },
461 }
462 }
463}
464
465impl From<ToolCall> for message::ToolCall {
466 fn from(tool_call: ToolCall) -> Self {
467 Self {
468 id: tool_call.id,
469 function: message::ToolFunction {
470 name: tool_call.function.name,
471 arguments: tool_call.function.arguments,
472 },
473 }
474 }
475}
476
477impl TryFrom<Message> for message::Message {
478 type Error = message::MessageError;
479
480 fn try_from(message: Message) -> Result<Self, Self::Error> {
481 Ok(match message {
482 Message::User { content, .. } => message::Message::User {
483 content: content.map(|content| content.into()),
484 },
485 Message::Assistant {
486 content,
487 tool_calls,
488 ..
489 } => {
490 let mut content = content
491 .into_iter()
492 .map(|content| match content {
493 AssistantContent::Text { text } => message::AssistantContent::text(text),
494
495 AssistantContent::Refusal { refusal } => {
498 message::AssistantContent::text(refusal)
499 }
500 })
501 .collect::<Vec<_>>();
502
503 content.extend(
504 tool_calls
505 .into_iter()
506 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
507 .collect::<Result<Vec<_>, _>>()?,
508 );
509
510 message::Message::Assistant {
511 content: OneOrMany::many(content).map_err(|_| {
512 message::MessageError::ConversionError(
513 "Neither `content` nor `tool_calls` was provided to the Message"
514 .to_owned(),
515 )
516 })?,
517 }
518 }
519
520 Message::ToolResult {
521 tool_call_id,
522 content,
523 } => message::Message::User {
524 content: OneOrMany::one(message::UserContent::tool_result(
525 tool_call_id,
526 content.map(|content| message::ToolResultContent::text(content.text)),
527 )),
528 },
529
530 Message::System { content, .. } => message::Message::User {
533 content: content.map(|content| message::UserContent::text(content.text)),
534 },
535 })
536 }
537}
538
539impl From<UserContent> for message::UserContent {
540 fn from(content: UserContent) -> Self {
541 match content {
542 UserContent::Text { text } => message::UserContent::text(text),
543 UserContent::Image { image_url } => message::UserContent::image(
544 image_url.url,
545 Some(message::ContentFormat::default()),
546 None,
547 Some(image_url.detail),
548 ),
549 UserContent::Audio { input_audio } => message::UserContent::audio(
550 input_audio.data,
551 Some(message::ContentFormat::default()),
552 Some(input_audio.format),
553 ),
554 }
555 }
556}
557
558impl From<String> for UserContent {
559 fn from(s: String) -> Self {
560 UserContent::Text { text: s }
561 }
562}
563
564impl FromStr for UserContent {
565 type Err = Infallible;
566
567 fn from_str(s: &str) -> Result<Self, Self::Err> {
568 Ok(UserContent::Text {
569 text: s.to_string(),
570 })
571 }
572}
573
574impl From<String> for AssistantContent {
575 fn from(s: String) -> Self {
576 AssistantContent::Text { text: s }
577 }
578}
579
580impl FromStr for AssistantContent {
581 type Err = Infallible;
582
583 fn from_str(s: &str) -> Result<Self, Self::Err> {
584 Ok(AssistantContent::Text {
585 text: s.to_string(),
586 })
587 }
588}
589impl From<String> for SystemContent {
590 fn from(s: String) -> Self {
591 SystemContent {
592 r#type: SystemContentType::default(),
593 text: s,
594 }
595 }
596}
597
598impl FromStr for SystemContent {
599 type Err = Infallible;
600
601 fn from_str(s: &str) -> Result<Self, Self::Err> {
602 Ok(SystemContent {
603 r#type: SystemContentType::default(),
604 text: s.to_string(),
605 })
606 }
607}
608
609#[derive(Clone)]
610pub struct CompletionModel {
611 pub(crate) client: Client,
612 pub model: String,
614}
615
616impl CompletionModel {
617 pub fn new(client: Client, model: &str) -> Self {
618 Self {
619 client,
620 model: model.to_string(),
621 }
622 }
623
624 pub(crate) fn create_completion_request(
625 &self,
626 completion_request: CompletionRequest,
627 ) -> Result<Value, CompletionError> {
628 let mut partial_history = vec![];
630 if let Some(docs) = completion_request.normalized_documents() {
631 partial_history.push(docs);
632 }
633 partial_history.extend(completion_request.chat_history);
634
635 let mut full_history: Vec<Message> = completion_request
637 .preamble
638 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
639
640 full_history.extend(
642 partial_history
643 .into_iter()
644 .map(message::Message::try_into)
645 .collect::<Result<Vec<Vec<Message>>, _>>()?
646 .into_iter()
647 .flatten()
648 .collect::<Vec<_>>(),
649 );
650
651 let request = if completion_request.tools.is_empty() {
652 json!({
653 "model": self.model,
654 "messages": full_history,
655
656 })
657 } else {
658 json!({
659 "model": self.model,
660 "messages": full_history,
661 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
662 "tool_choice": "auto",
663 })
664 };
665
666 let request = if let Some(temperature) = completion_request.temperature {
669 json_utils::merge(
670 request,
671 json!({
672 "temperature": temperature,
673 }),
674 )
675 } else {
676 request
677 };
678
679 let request = if let Some(params) = completion_request.additional_params {
680 json_utils::merge(request, params)
681 } else {
682 request
683 };
684
685 Ok(request)
686 }
687}
688
689impl completion::CompletionModel for CompletionModel {
690 type Response = CompletionResponse;
691
692 #[cfg_attr(feature = "worker", worker::send)]
693 async fn completion(
694 &self,
695 completion_request: CompletionRequest,
696 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
697 let request = self.create_completion_request(completion_request)?;
698
699 let response = self
700 .client
701 .post("/chat/completions")
702 .json(&request)
703 .send()
704 .await?;
705
706 if response.status().is_success() {
707 let t = response.text().await?;
708 tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
709
710 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
711 ApiResponse::Ok(response) => {
712 tracing::info!(target: "rig",
713 "OpenAI completion token usage: {:?}",
714 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
715 );
716 response.try_into()
717 }
718 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
719 }
720 } else {
721 Err(CompletionError::ProviderError(response.text().await?))
722 }
723 }
724}