1use super::{Client, responses_api::streaming::StreamingCompletionResponse};
11use super::{ImageUrl, InputAudio, SystemContent};
12use crate::completion::CompletionError;
13use crate::json_utils;
14use crate::message::{AudioMediaType, Document, MessageError, Text};
15use crate::one_or_many::string_or_one_or_many;
16
17use crate::{OneOrMany, completion, message};
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value};
20
21use std::convert::Infallible;
22use std::ops::Add;
23use std::str::FromStr;
24
25pub mod streaming;
26
27#[derive(Debug, Deserialize, Serialize, Clone)]
30pub struct CompletionRequest {
31 pub input: OneOrMany<InputItem>,
33 pub model: String,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub instructions: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
40 pub max_output_tokens: Option<u64>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub stream: Option<bool>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub temperature: Option<f64>,
47 #[serde(skip_serializing_if = "Vec::is_empty")]
51 pub tools: Vec<ResponsesToolDefinition>,
52 #[serde(flatten)]
54 pub additional_parameters: AdditionalParameters,
55}
56
57impl CompletionRequest {
58 pub fn with_structured_outputs<S>(mut self, schema_name: S, schema: serde_json::Value) -> Self
59 where
60 S: Into<String>,
61 {
62 self.additional_parameters.text = Some(TextConfig::structured_output(schema_name, schema));
63
64 self
65 }
66
67 pub fn with_reasoning(mut self, reasoning: Reasoning) -> Self {
68 self.additional_parameters.reasoning = Some(reasoning);
69
70 self
71 }
72}
73
74#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct InputItem {
77 #[serde(skip_serializing_if = "Option::is_none")]
81 role: Option<Role>,
82 #[serde(flatten)]
84 input: InputContent,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
89#[serde(rename_all = "lowercase")]
90pub enum Role {
91 User,
92 Assistant,
93 System,
94}
95
96#[derive(Debug, Deserialize, Serialize, Clone)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub enum InputContent {
100 Message(Message),
101 Reasoning(OpenAIReasoning),
102 FunctionCall(OutputFunctionCall),
103 FunctionCallOutput(ToolResult),
104}
105
106#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
107pub struct OpenAIReasoning {
108 pub summary: Vec<ReasoningSummary>,
109 pub encrypted_content: Option<String>,
110 pub status: ToolStatus,
111}
112
113#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
114#[serde(tag = "type", rename_all = "snake_case")]
115pub enum ReasoningSummary {
116 SummaryText { text: String },
117}
118
119impl ReasoningSummary {
120 fn new(input: &str) -> Self {
121 Self::SummaryText {
122 text: input.to_string(),
123 }
124 }
125}
126
127#[derive(Debug, Deserialize, Serialize, Clone)]
129pub struct ToolResult {
130 call_id: String,
132 output: String,
134 status: ToolStatus,
136}
137
138impl From<Message> for InputItem {
139 fn from(value: Message) -> Self {
140 match value {
141 Message::User { .. } => Self {
142 role: Some(Role::User),
143 input: InputContent::Message(value),
144 },
145 Message::Assistant { .. } => Self {
146 role: Some(Role::Assistant),
147 input: InputContent::Message(value),
148 },
149 Message::System { .. } => Self {
150 role: Some(Role::System),
151 input: InputContent::Message(value),
152 },
153 Message::ToolResult {
154 tool_call_id,
155 output,
156 } => Self {
157 role: None,
158 input: InputContent::FunctionCallOutput(ToolResult {
159 call_id: tool_call_id,
160 output,
161 status: ToolStatus::Completed,
162 }),
163 },
164 }
165 }
166}
167
168impl TryFrom<crate::completion::Message> for Vec<InputItem> {
169 type Error = CompletionError;
170
171 fn try_from(value: crate::completion::Message) -> Result<Self, Self::Error> {
172 match value {
173 crate::completion::Message::User { content } => {
174 let mut items = Vec::new();
175
176 for user_content in content {
177 match user_content {
178 crate::message::UserContent::Text(Text { text }) => {
179 items.push(InputItem {
180 role: Some(Role::User),
181 input: InputContent::Message(Message::User {
182 content: OneOrMany::one(UserContent::InputText { text }),
183 name: None,
184 }),
185 });
186 }
187 crate::message::UserContent::ToolResult(
188 crate::completion::message::ToolResult {
189 call_id,
190 content: tool_content,
191 ..
192 },
193 ) => {
194 for tool_result_content in tool_content {
195 let crate::completion::message::ToolResultContent::Text(Text {
196 text,
197 }) = tool_result_content
198 else {
199 return Err(CompletionError::ProviderError(
200 "This thing only supports text!".to_string(),
201 ));
202 };
203 items.push(InputItem {
205 role: None,
206 input: InputContent::FunctionCallOutput(ToolResult {
207 call_id: call_id
208 .clone()
209 .expect("The call ID of this tool should exist!"),
210 output: text,
211 status: ToolStatus::Completed,
212 }),
213 });
214 }
215 }
216 crate::message::UserContent::Document(Document { data, .. }) => {
218 items.push(InputItem {
219 role: Some(Role::User),
220 input: InputContent::Message(Message::User {
221 content: OneOrMany::one(UserContent::InputText { text: data }),
222 name: None,
223 }),
224 })
225 }
226 _ => {
227 return Err(CompletionError::ProviderError(
228 "This API only supports text and tool results at the moment"
229 .to_string(),
230 ));
231 }
232 }
233 }
234
235 Ok(items)
236 }
237 crate::completion::Message::Assistant { id, content } => {
238 let mut items = Vec::new();
239
240 for assistant_content in content {
241 match assistant_content {
242 crate::message::AssistantContent::Text(Text { text }) => {
243 let id = id.as_ref().unwrap_or(&String::default()).clone();
244 items.push(InputItem {
245 role: Some(Role::Assistant),
246 input: InputContent::Message(Message::Assistant {
247 content: OneOrMany::one(AssistantContentType::Text(
248 AssistantContent::OutputText(Text { text }),
249 )),
250 id,
251 name: None,
252 status: ToolStatus::Completed,
253 }),
254 });
255 }
256 crate::message::AssistantContent::ToolCall(crate::message::ToolCall {
257 id: tool_id,
258 call_id,
259 function,
260 }) => {
261 items.push(InputItem {
262 role: None,
263 input: InputContent::FunctionCall(OutputFunctionCall {
264 arguments: function.arguments,
265 call_id: call_id.expect("The tool call ID should exist!"),
266 id: tool_id,
267 name: function.name,
268 status: ToolStatus::Completed,
269 }),
270 });
271 }
272 crate::message::AssistantContent::Reasoning(
273 crate::message::Reasoning { reasoning },
274 ) => {
275 items.push(InputItem {
276 role: Some(Role::Assistant),
277 input: InputContent::Reasoning(OpenAIReasoning {
278 summary: vec![ReasoningSummary::new(&reasoning)],
279 encrypted_content: None,
280 status: ToolStatus::Completed,
281 }),
282 });
283 }
284 }
285 }
286
287 Ok(items)
288 }
289 }
290 }
291}
292
293#[derive(Debug, Deserialize, Serialize, Clone)]
295pub struct ResponsesToolDefinition {
296 pub name: String,
298 pub parameters: serde_json::Value,
300 pub strict: bool,
302 #[serde(rename = "type")]
304 pub kind: String,
305 pub description: String,
307}
308
309impl From<completion::ToolDefinition> for ResponsesToolDefinition {
310 fn from(value: completion::ToolDefinition) -> Self {
311 let completion::ToolDefinition {
312 name,
313 mut parameters,
314 description,
315 } = value;
316
317 let parameters = parameters
318 .as_object_mut()
319 .expect("parameters should be a JSON object");
320 parameters.insert(
321 "additionalProperties".to_string(),
322 serde_json::Value::Bool(false),
323 );
324
325 let parameters = serde_json::Value::Object(parameters.clone());
326
327 Self {
328 name,
329 parameters,
330 description,
331 kind: "function".to_string(),
332 strict: true,
333 }
334 }
335}
336
337#[derive(Clone, Debug, Serialize, Deserialize)]
340pub struct ResponsesUsage {
341 pub input_tokens: u64,
343 #[serde(skip_serializing_if = "Option::is_none")]
345 pub input_tokens_details: Option<InputTokensDetails>,
346 pub output_tokens: u64,
348 pub output_tokens_details: OutputTokensDetails,
350 pub total_tokens: u64,
352}
353
354impl ResponsesUsage {
355 pub(crate) fn new() -> Self {
357 Self {
358 input_tokens: 0,
359 input_tokens_details: Some(InputTokensDetails::new()),
360 output_tokens: 0,
361 output_tokens_details: OutputTokensDetails::new(),
362 total_tokens: 0,
363 }
364 }
365}
366
367impl Add for ResponsesUsage {
368 type Output = Self;
369
370 fn add(self, rhs: Self) -> Self::Output {
371 let input_tokens = self.input_tokens + rhs.input_tokens;
372 let input_tokens_details = self.input_tokens_details.map(|lhs| {
373 if let Some(tokens) = rhs.input_tokens_details {
374 lhs + tokens
375 } else {
376 lhs
377 }
378 });
379 let output_tokens = self.output_tokens + rhs.output_tokens;
380 let output_tokens_details = self.output_tokens_details + rhs.output_tokens_details;
381 let total_tokens = self.total_tokens + rhs.total_tokens;
382 Self {
383 input_tokens,
384 input_tokens_details,
385 output_tokens,
386 output_tokens_details,
387 total_tokens,
388 }
389 }
390}
391
392#[derive(Clone, Debug, Serialize, Deserialize)]
394pub struct InputTokensDetails {
395 pub cached_tokens: u64,
397}
398
399impl InputTokensDetails {
400 pub(crate) fn new() -> Self {
401 Self { cached_tokens: 0 }
402 }
403}
404
405impl Add for InputTokensDetails {
406 type Output = Self;
407 fn add(self, rhs: Self) -> Self::Output {
408 Self {
409 cached_tokens: self.cached_tokens + rhs.cached_tokens,
410 }
411 }
412}
413
414#[derive(Clone, Debug, Serialize, Deserialize)]
416pub struct OutputTokensDetails {
417 pub reasoning_tokens: u64,
419}
420
421impl OutputTokensDetails {
422 pub(crate) fn new() -> Self {
423 Self {
424 reasoning_tokens: 0,
425 }
426 }
427}
428
429impl Add for OutputTokensDetails {
430 type Output = Self;
431 fn add(self, rhs: Self) -> Self::Output {
432 Self {
433 reasoning_tokens: self.reasoning_tokens + rhs.reasoning_tokens,
434 }
435 }
436}
437
438#[derive(Clone, Debug, Default, Serialize, Deserialize)]
440pub struct IncompleteDetailsReason {
441 pub reason: String,
443}
444
445#[derive(Clone, Debug, Default, Serialize, Deserialize)]
447pub struct ResponseError {
448 pub code: String,
450 pub message: String,
452}
453
454#[derive(Clone, Debug, Deserialize, Serialize)]
456#[serde(rename_all = "snake_case")]
457pub enum ResponseObject {
458 Response,
459}
460
461#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
463#[serde(rename_all = "snake_case")]
464pub enum ResponseStatus {
465 InProgress,
466 Completed,
467 Failed,
468 Cancelled,
469 Queued,
470 Incomplete,
471}
472
473impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionRequest {
475 type Error = CompletionError;
476 fn try_from(
477 (model, req): (String, crate::completion::CompletionRequest),
478 ) -> Result<Self, Self::Error> {
479 let input = {
480 let mut partial_history = vec![];
481 if let Some(docs) = req.normalized_documents() {
482 partial_history.push(docs);
483 }
484 partial_history.extend(req.chat_history);
485
486 let mut full_history: Vec<InputItem> = Vec::new();
488
489 full_history.extend(
491 partial_history
492 .into_iter()
493 .map(|x| <Vec<InputItem>>::try_from(x).unwrap())
494 .collect::<Vec<Vec<InputItem>>>()
495 .into_iter()
496 .flatten()
497 .collect::<Vec<InputItem>>(),
498 );
499
500 full_history
501 };
502
503 let input = OneOrMany::many(input)
504 .expect("This should never panic - if it does, please file a bug report");
505
506 let stream = req
507 .additional_params
508 .clone()
509 .unwrap_or(Value::Null)
510 .as_bool();
511
512 let additional_parameters = if let Some(map) = req.additional_params {
513 serde_json::from_value::<AdditionalParameters>(map).expect("Converting additional parameters to AdditionalParameters should never fail as every field is an Option")
514 } else {
515 AdditionalParameters::default()
517 };
518
519 Ok(Self {
520 input,
521 model,
522 instructions: req.preamble,
523 max_output_tokens: req.max_tokens,
524 stream,
525 tools: req
526 .tools
527 .into_iter()
528 .map(ResponsesToolDefinition::from)
529 .collect(),
530 temperature: req.temperature,
531 additional_parameters,
532 })
533 }
534}
535
536#[derive(Clone)]
538pub struct ResponsesCompletionModel {
539 pub(crate) client: Client,
541 pub model: String,
543}
544
545impl ResponsesCompletionModel {
546 pub fn new(client: Client, model: &str) -> Self {
548 Self {
549 client,
550 model: model.to_string(),
551 }
552 }
553
554 pub fn completions_api(self) -> crate::providers::openai::completion::CompletionModel {
556 crate::providers::openai::completion::CompletionModel::new(self.client, &self.model)
557 }
558
559 pub(crate) fn create_completion_request(
561 &self,
562 completion_request: crate::completion::CompletionRequest,
563 ) -> Result<CompletionRequest, CompletionError> {
564 let req = CompletionRequest::try_from((self.model.clone(), completion_request))?;
565
566 Ok(req)
567 }
568}
569
570#[derive(Clone, Debug, Serialize, Deserialize)]
572pub struct CompletionResponse {
573 pub id: String,
575 pub object: ResponseObject,
577 pub created_at: u64,
579 pub status: ResponseStatus,
581 pub error: Option<ResponseError>,
583 pub incomplete_details: Option<IncompleteDetailsReason>,
585 pub instructions: Option<String>,
587 pub max_output_tokens: Option<u64>,
589 pub model: String,
591 pub usage: Option<ResponsesUsage>,
593 pub output: Vec<Output>,
595 pub tools: Vec<ResponsesToolDefinition>,
597 #[serde(flatten)]
599 pub additional_parameters: AdditionalParameters,
600}
601
602#[derive(Clone, Debug, Deserialize, Serialize, Default)]
605pub struct AdditionalParameters {
606 #[serde(skip_serializing_if = "Option::is_none")]
608 pub background: Option<bool>,
609 #[serde(skip_serializing_if = "Option::is_none")]
611 pub text: Option<TextConfig>,
612 #[serde(skip_serializing_if = "Option::is_none")]
614 pub include: Option<Vec<Include>>,
615 #[serde(skip_serializing_if = "Option::is_none")]
617 pub top_p: Option<f64>,
618 #[serde(skip_serializing_if = "Option::is_none")]
620 pub truncation: Option<TruncationStrategy>,
621 #[serde(skip_serializing_if = "Option::is_none")]
623 pub user: Option<String>,
624 #[serde(skip_serializing_if = "Map::is_empty", default)]
626 pub metadata: serde_json::Map<String, serde_json::Value>,
627 #[serde(skip_serializing_if = "Option::is_none")]
629 pub parallel_tool_calls: Option<bool>,
630 #[serde(skip_serializing_if = "Option::is_none")]
632 pub previous_response_id: Option<String>,
633 #[serde(skip_serializing_if = "Option::is_none")]
635 pub reasoning: Option<Reasoning>,
636 #[serde(skip_serializing_if = "Option::is_none")]
638 pub service_tier: Option<OpenAIServiceTier>,
639 #[serde(skip_serializing_if = "Option::is_none")]
641 pub store: Option<bool>,
642}
643
644impl AdditionalParameters {
645 pub fn to_json(self) -> serde_json::Value {
646 serde_json::to_value(self).expect("this should never fail since a struct that impls Deserialize will always be valid JSON")
647 }
648}
649
650#[derive(Clone, Debug, Default, Serialize, Deserialize)]
654#[serde(rename_all = "snake_case")]
655pub enum TruncationStrategy {
656 Auto,
657 #[default]
658 Disabled,
659}
660
661#[derive(Clone, Debug, Serialize, Deserialize)]
664pub struct TextConfig {
665 pub format: TextFormat,
666}
667
668impl TextConfig {
669 pub(crate) fn structured_output<S>(name: S, schema: serde_json::Value) -> Self
670 where
671 S: Into<String>,
672 {
673 Self {
674 format: TextFormat::JsonSchema(StructuredOutputsInput {
675 name: name.into(),
676 schema,
677 strict: true,
678 }),
679 }
680 }
681}
682
683#[derive(Clone, Debug, Serialize, Deserialize, Default)]
686#[serde(tag = "type")]
687#[serde(rename_all = "snake_case")]
688pub enum TextFormat {
689 JsonSchema(StructuredOutputsInput),
690 #[default]
691 Text,
692}
693
694#[derive(Clone, Debug, Serialize, Deserialize)]
696pub struct StructuredOutputsInput {
697 pub name: String,
699 pub schema: serde_json::Value,
701 pub strict: bool,
703}
704
705#[derive(Clone, Debug, Default, Serialize, Deserialize)]
707pub struct Reasoning {
708 pub effort: Option<ReasoningEffort>,
710 #[serde(skip_serializing_if = "Option::is_none")]
712 pub summary: Option<ReasoningSummaryLevel>,
713}
714
715impl Reasoning {
716 pub fn new() -> Self {
718 Self {
719 effort: None,
720 summary: None,
721 }
722 }
723
724 pub fn with_effort(mut self, reasoning_effort: ReasoningEffort) -> Self {
726 self.effort = Some(reasoning_effort);
727
728 self
729 }
730
731 pub fn with_summary_level(mut self, reasoning_summary_level: ReasoningSummaryLevel) -> Self {
733 self.summary = Some(reasoning_summary_level);
734
735 self
736 }
737}
738
739#[derive(Clone, Debug, Default, Serialize, Deserialize)]
741#[serde(rename_all = "snake_case")]
742pub enum OpenAIServiceTier {
743 #[default]
744 Auto,
745 Default,
746 Flex,
747}
748
749#[derive(Clone, Debug, Default, Serialize, Deserialize)]
751#[serde(rename_all = "snake_case")]
752pub enum ReasoningEffort {
753 Low,
754 #[default]
755 Medium,
756 High,
757}
758
759#[derive(Clone, Debug, Default, Serialize, Deserialize)]
761#[serde(rename_all = "snake_case")]
762pub enum ReasoningSummaryLevel {
763 #[default]
764 Auto,
765 Concise,
766 Detailed,
767}
768
769#[derive(Clone, Debug, Deserialize, Serialize)]
772pub enum Include {
773 #[serde(rename = "file_search_call.results")]
774 FileSearchCallResults,
775 #[serde(rename = "message.input_image.image_url")]
776 MessageInputImageImageUrl,
777 #[serde(rename = "computer_call.output.image_url")]
778 ComputerCallOutputOutputImageUrl,
779 #[serde(rename = "reasoning.encrypted_content")]
780 ReasoningEncryptedContent,
781 #[serde(rename = "code_interpreter_call.outputs")]
782 CodeInterpreterCallOutputs,
783}
784
785#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
787#[serde(tag = "type")]
788#[serde(rename_all = "snake_case")]
789pub enum Output {
790 Message(OutputMessage),
791 #[serde(alias = "function_call")]
792 FunctionCall(OutputFunctionCall),
793 Reasoning {
794 summary: Vec<ReasoningSummary>,
795 },
796}
797
798impl From<Output> for Vec<completion::AssistantContent> {
799 fn from(value: Output) -> Self {
800 let res: Vec<completion::AssistantContent> = match value {
801 Output::Message(OutputMessage { content, .. }) => content
802 .into_iter()
803 .map(completion::AssistantContent::from)
804 .collect(),
805 Output::FunctionCall(OutputFunctionCall {
806 id,
807 arguments,
808 call_id,
809 name,
810 ..
811 }) => vec![completion::AssistantContent::tool_call_with_call_id(
812 id, call_id, name, arguments,
813 )],
814 Output::Reasoning { summary } => {
815 let text_joined = summary
816 .into_iter()
817 .map(|x| {
818 let ReasoningSummary::SummaryText { text } = x;
819 text
820 })
821 .collect::<Vec<String>>()
822 .join("\n");
823 vec![completion::AssistantContent::Reasoning(
824 crate::message::Reasoning {
825 reasoning: text_joined,
826 },
827 )]
828 }
829 };
830
831 res
832 }
833}
834
835#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
836pub struct OutputReasoning {
837 id: String,
838 summary: Vec<ReasoningSummary>,
839 status: ToolStatus,
840}
841
842#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
844pub struct OutputFunctionCall {
845 pub id: String,
846 #[serde(with = "json_utils::stringified_json")]
847 pub arguments: serde_json::Value,
848 pub call_id: String,
849 pub name: String,
850 pub status: ToolStatus,
851}
852
853#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
855#[serde(rename_all = "snake_case")]
856pub enum ToolStatus {
857 InProgress,
858 Completed,
859 Incomplete,
860}
861
862#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
864pub struct OutputMessage {
865 pub id: String,
867 pub role: OutputRole,
869 pub status: ResponseStatus,
871 pub content: Vec<AssistantContent>,
873}
874
875#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
877#[serde(rename_all = "snake_case")]
878pub enum OutputRole {
879 Assistant,
880}
881
882impl completion::CompletionModel for ResponsesCompletionModel {
883 type Response = CompletionResponse;
884 type StreamingResponse = StreamingCompletionResponse;
885
886 #[cfg_attr(feature = "worker", worker::send)]
887 async fn completion(
888 &self,
889 completion_request: crate::completion::CompletionRequest,
890 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
891 let request = self.create_completion_request(completion_request)?;
892 let request = serde_json::to_value(request)?;
893
894 tracing::debug!("OpenAI input: {}", serde_json::to_string_pretty(&request)?);
895
896 let response = self.client.post("/responses").json(&request).send().await?;
897
898 if response.status().is_success() {
899 let t = response.text().await?;
900 tracing::debug!(target: "rig", "OpenAI response: {}", t);
901
902 let response = serde_json::from_str::<Self::Response>(&t)?;
903 response.try_into()
904 } else {
905 Err(CompletionError::ProviderError(response.text().await?))
906 }
907 }
908
909 #[cfg_attr(feature = "worker", worker::send)]
910 async fn stream(
911 &self,
912 request: crate::completion::CompletionRequest,
913 ) -> Result<
914 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
915 CompletionError,
916 > {
917 Self::stream(self, request).await
918 }
919}
920
921impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
922 type Error = CompletionError;
923
924 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
925 if response.output.is_empty() {
926 return Err(CompletionError::ResponseError(
927 "Response contained no parts".to_owned(),
928 ));
929 }
930
931 let content: Vec<completion::AssistantContent> = response
932 .output
933 .iter()
934 .cloned()
935 .flat_map(<Vec<completion::AssistantContent>>::from)
936 .collect();
937
938 let choice = OneOrMany::many(content).map_err(|_| {
939 CompletionError::ResponseError(
940 "Response contained no message or tool call (empty)".to_owned(),
941 )
942 })?;
943
944 let usage = response
945 .usage
946 .as_ref()
947 .map(|usage| completion::Usage {
948 input_tokens: usage.input_tokens,
949 output_tokens: usage.output_tokens,
950 total_tokens: usage.total_tokens,
951 })
952 .unwrap_or_default();
953
954 Ok(completion::CompletionResponse {
955 choice,
956 usage,
957 raw_response: response,
958 })
959 }
960}
961
962#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
964#[serde(tag = "role", rename_all = "lowercase")]
965pub enum Message {
966 #[serde(alias = "developer")]
967 System {
968 #[serde(deserialize_with = "string_or_one_or_many")]
969 content: OneOrMany<SystemContent>,
970 #[serde(skip_serializing_if = "Option::is_none")]
971 name: Option<String>,
972 },
973 User {
974 #[serde(deserialize_with = "string_or_one_or_many")]
975 content: OneOrMany<UserContent>,
976 #[serde(skip_serializing_if = "Option::is_none")]
977 name: Option<String>,
978 },
979 Assistant {
980 content: OneOrMany<AssistantContentType>,
981 #[serde(skip_serializing_if = "String::is_empty")]
982 id: String,
983 #[serde(skip_serializing_if = "Option::is_none")]
984 name: Option<String>,
985 status: ToolStatus,
986 },
987 #[serde(rename = "tool")]
988 ToolResult {
989 tool_call_id: String,
990 output: String,
991 },
992}
993
994#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
996#[serde(rename_all = "lowercase")]
997pub enum ToolResultContentType {
998 #[default]
999 Text,
1000}
1001
1002impl Message {
1003 pub fn system(content: &str) -> Self {
1004 Message::System {
1005 content: OneOrMany::one(content.to_owned().into()),
1006 name: None,
1007 }
1008 }
1009}
1010
1011#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1014#[serde(tag = "type", rename_all = "snake_case")]
1015pub enum AssistantContent {
1016 OutputText(Text),
1017 Refusal { refusal: String },
1018}
1019
1020impl From<AssistantContent> for completion::AssistantContent {
1021 fn from(value: AssistantContent) -> Self {
1022 match value {
1023 AssistantContent::Refusal { refusal } => {
1024 completion::AssistantContent::Text(Text { text: refusal })
1025 }
1026 AssistantContent::OutputText(Text { text }) => {
1027 completion::AssistantContent::Text(Text { text })
1028 }
1029 }
1030 }
1031}
1032
1033#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1035#[serde(untagged)]
1036pub enum AssistantContentType {
1037 Text(AssistantContent),
1038 ToolCall(OutputFunctionCall),
1039 Reasoning(OpenAIReasoning),
1040}
1041
1042#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1044#[serde(tag = "type", rename_all = "snake_case")]
1045pub enum UserContent {
1046 InputText {
1047 text: String,
1048 },
1049 #[serde(rename = "image_url")]
1050 Image {
1051 image_url: ImageUrl,
1052 },
1053 Audio {
1054 input_audio: InputAudio,
1055 },
1056 #[serde(rename = "tool")]
1057 ToolResult {
1058 tool_call_id: String,
1059 output: String,
1060 },
1061}
1062
1063impl TryFrom<message::Message> for Vec<Message> {
1064 type Error = message::MessageError;
1065
1066 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
1067 match message {
1068 message::Message::User { content } => {
1069 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
1070 .into_iter()
1071 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
1072
1073 if !tool_results.is_empty() {
1076 tool_results
1077 .into_iter()
1078 .map(|content| match content {
1079 message::UserContent::ToolResult(message::ToolResult {
1080 call_id,
1081 content,
1082 ..
1083 }) => Ok::<_, message::MessageError>(Message::ToolResult {
1084 tool_call_id: call_id.expect("The tool call ID should exist"),
1085 output: {
1086 let res = content.first();
1087 match res {
1088 completion::message::ToolResultContent::Text(Text {
1089 text,
1090 }) => text,
1091 _ => return Err(MessageError::ConversionError("This API only currently supports text tool results".into()))
1092 }
1093 },
1094 }),
1095 _ => unreachable!(),
1096 })
1097 .collect::<Result<Vec<_>, _>>()
1098 } else {
1099 let other_content = OneOrMany::many(other_content).expect(
1100 "There must be other content here if there were no tool result content",
1101 );
1102
1103 Ok(vec![Message::User {
1104 content: other_content.map(|content| match content {
1105 message::UserContent::Text(message::Text { text }) => {
1106 UserContent::InputText { text }
1107 }
1108 message::UserContent::Image(message::Image {
1109 data, detail, ..
1110 }) => UserContent::Image {
1111 image_url: ImageUrl {
1112 url: data,
1113 detail: detail.unwrap_or_default(),
1114 },
1115 },
1116 message::UserContent::Document(message::Document { data, .. }) => {
1117 UserContent::InputText { text: data }
1118 }
1119 message::UserContent::Audio(message::Audio {
1120 data,
1121 media_type,
1122 ..
1123 }) => UserContent::Audio {
1124 input_audio: InputAudio {
1125 data,
1126 format: match media_type {
1127 Some(media_type) => media_type,
1128 None => AudioMediaType::MP3,
1129 },
1130 },
1131 },
1132 _ => unreachable!(),
1133 }),
1134 name: None,
1135 }])
1136 }
1137 }
1138 message::Message::Assistant { content, id } => {
1139 let assistant_message_id = id;
1140
1141 match content.first() {
1142 crate::message::AssistantContent::Text(Text { text }) => {
1143 Ok(vec![Message::Assistant {
1144 id: assistant_message_id
1145 .expect("The assistant message ID should exist"),
1146 status: ToolStatus::Completed,
1147 content: OneOrMany::one(AssistantContentType::Text(
1148 AssistantContent::OutputText(Text { text }),
1149 )),
1150 name: None,
1151 }])
1152 }
1153 crate::message::AssistantContent::ToolCall(crate::message::ToolCall {
1154 id,
1155 call_id,
1156 function,
1157 }) => Ok(vec![Message::Assistant {
1158 content: OneOrMany::one(AssistantContentType::ToolCall(
1159 OutputFunctionCall {
1160 call_id: call_id.expect("The call ID should exist"),
1161 arguments: function.arguments,
1162 id,
1163 name: function.name,
1164 status: ToolStatus::Completed,
1165 },
1166 )),
1167 id: assistant_message_id.expect("The assistant message ID should exist!"),
1168 name: None,
1169 status: ToolStatus::Completed,
1170 }]),
1171 crate::message::AssistantContent::Reasoning(crate::message::Reasoning {
1172 reasoning,
1173 }) => Ok(vec![Message::Assistant {
1174 content: OneOrMany::one(AssistantContentType::Reasoning(OpenAIReasoning {
1175 summary: vec![ReasoningSummary::new(&reasoning)],
1176 encrypted_content: None,
1177 status: ToolStatus::Completed,
1178 })),
1179 id: assistant_message_id.expect("The assistant message ID should exist!"),
1180 name: None,
1181 status: (ToolStatus::Completed),
1182 }]),
1183 }
1184 }
1185 }
1186 }
1187}
1188
1189impl FromStr for UserContent {
1190 type Err = Infallible;
1191
1192 fn from_str(s: &str) -> Result<Self, Self::Err> {
1193 Ok(UserContent::InputText {
1194 text: s.to_string(),
1195 })
1196 }
1197}