1use super::{Client, responses_api::streaming::StreamingCompletionResponse};
11use super::{ImageUrl, InputAudio, SystemContent};
12use crate::completion::CompletionError;
13use crate::json_utils;
14use crate::message::{AudioMediaType, MessageError, Text};
15use crate::one_or_many::string_or_one_or_many;
16
17use crate::{OneOrMany, completion, message};
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value};
20
21use std::convert::Infallible;
22use std::ops::Add;
23use std::str::FromStr;
24
25pub mod streaming;
26
27#[derive(Debug, Deserialize, Serialize, Clone)]
30pub struct CompletionRequest {
31 pub input: OneOrMany<InputItem>,
33 pub model: String,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub instructions: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
40 pub max_output_tokens: Option<u64>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub stream: Option<bool>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub temperature: Option<f64>,
47 #[serde(skip_serializing_if = "Vec::is_empty")]
51 pub tools: Vec<ResponsesToolDefinition>,
52 #[serde(flatten)]
54 pub additional_parameters: AdditionalParameters,
55}
56
57impl CompletionRequest {
58 pub fn with_structured_outputs<S>(mut self, schema_name: S, schema: serde_json::Value) -> Self
59 where
60 S: Into<String>,
61 {
62 self.additional_parameters.text = Some(TextConfig::structured_output(schema_name, schema));
63
64 self
65 }
66
67 pub fn with_reasoning(mut self, reasoning: Reasoning) -> Self {
68 self.additional_parameters.reasoning = Some(reasoning);
69
70 self
71 }
72}
73
74#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct InputItem {
77 #[serde(skip_serializing_if = "Option::is_none")]
81 role: Option<Role>,
82 #[serde(flatten)]
84 input: InputContent,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
89#[serde(rename_all = "lowercase")]
90pub enum Role {
91 User,
92 Assistant,
93 System,
94}
95
96#[derive(Debug, Deserialize, Serialize, Clone)]
98#[serde(tag = "type", rename_all = "snake_case")]
99enum InputContent {
100 Message(Message),
101 OutputMessage(Message),
102 FunctionCall(OutputFunctionCall),
103 FunctionCallOutput(ToolResult),
104}
105
106#[derive(Debug, Deserialize, Serialize, Clone)]
108pub struct ToolResult {
109 call_id: String,
111 output: String,
113 status: ToolStatus,
115}
116
117impl From<Message> for InputItem {
118 fn from(value: Message) -> Self {
119 match value {
120 Message::User { .. } => Self {
121 role: Some(Role::User),
122 input: InputContent::Message(value),
123 },
124 Message::Assistant { .. } => Self {
125 role: Some(Role::Assistant),
126 input: InputContent::OutputMessage(value),
127 },
128 Message::System { .. } => Self {
129 role: Some(Role::System),
130 input: InputContent::Message(value),
131 },
132 Message::ToolResult {
133 tool_call_id,
134 output,
135 } => Self {
136 role: None,
137 input: InputContent::FunctionCallOutput(ToolResult {
138 call_id: tool_call_id,
139 output,
140 status: ToolStatus::Completed,
141 }),
142 },
143 }
144 }
145}
146
147impl TryFrom<crate::completion::Message> for Vec<InputItem> {
148 type Error = CompletionError;
149
150 fn try_from(value: crate::completion::Message) -> Result<Self, Self::Error> {
151 match value {
152 crate::completion::Message::User { content } => {
153 let mut items = Vec::new();
154
155 for user_content in content {
156 match user_content {
157 crate::message::UserContent::Text(Text { text }) => {
158 items.push(InputItem {
159 role: Some(Role::User),
160 input: InputContent::Message(Message::User {
161 content: OneOrMany::one(UserContent::InputText { text }),
162 name: None,
163 }),
164 });
165 }
166 crate::message::UserContent::ToolResult(
167 crate::completion::message::ToolResult {
168 call_id,
169 content: tool_content,
170 ..
171 },
172 ) => {
173 for tool_result_content in tool_content {
174 let crate::completion::message::ToolResultContent::Text(Text {
175 text,
176 }) = tool_result_content
177 else {
178 return Err(CompletionError::ProviderError(
179 "This thing only supports text!".to_string(),
180 ));
181 };
182 items.push(InputItem {
184 role: None,
185 input: InputContent::FunctionCallOutput(ToolResult {
186 call_id: call_id
187 .clone()
188 .expect("The call ID of this tool should exist!"),
189 output: text,
190 status: ToolStatus::Completed,
191 }),
192 });
193 }
194 }
195 _ => {
196 return Err(CompletionError::ProviderError(
197 "This API only supports text and tool results at the moment"
198 .to_string(),
199 ));
200 }
201 }
202 }
203
204 Ok(items)
205 }
206 crate::completion::Message::Assistant { id, content } => {
207 let mut items = Vec::new();
208
209 for assistant_content in content {
210 match assistant_content {
211 crate::message::AssistantContent::Text(Text { text }) => {
212 let id = id.as_ref().unwrap_or(&String::default()).clone();
213 items.push(InputItem {
214 role: Some(Role::Assistant),
215 input: InputContent::OutputMessage(Message::Assistant {
216 content: OneOrMany::one(AssistantContentType::Text(
217 AssistantContent::OutputText(Text { text }),
218 )),
219 id,
220 name: None,
221 status: ToolStatus::Completed,
222 }),
223 });
224 }
225 crate::message::AssistantContent::ToolCall(crate::message::ToolCall {
226 id: tool_id,
227 call_id,
228 function,
229 }) => {
230 items.push(InputItem {
231 role: None,
232 input: InputContent::FunctionCall(OutputFunctionCall {
233 arguments: function.arguments,
234 call_id: call_id.expect("The tool call ID should exist!"),
235 id: tool_id,
236 name: function.name,
237 status: ToolStatus::Completed,
238 }),
239 });
240 }
241 }
242 }
243
244 Ok(items)
245 }
246 }
247 }
248}
249
250#[derive(Debug, Deserialize, Serialize, Clone)]
252pub struct ResponsesToolDefinition {
253 pub name: String,
255 pub parameters: serde_json::Value,
257 pub strict: bool,
259 #[serde(rename = "type")]
261 pub kind: String,
262 pub description: String,
264}
265
266impl From<completion::ToolDefinition> for ResponsesToolDefinition {
267 fn from(value: completion::ToolDefinition) -> Self {
268 let completion::ToolDefinition {
269 name,
270 mut parameters,
271 description,
272 } = value;
273
274 let parameters = parameters
275 .as_object_mut()
276 .expect("parameters should be a JSON object");
277 parameters.insert(
278 "additionalProperties".to_string(),
279 serde_json::Value::Bool(false),
280 );
281
282 let parameters = serde_json::Value::Object(parameters.clone());
283
284 Self {
285 name,
286 parameters,
287 description,
288 kind: "function".to_string(),
289 strict: true,
290 }
291 }
292}
293
294#[derive(Clone, Debug, Serialize, Deserialize)]
297pub struct ResponsesUsage {
298 pub input_tokens: u64,
300 #[serde(skip_serializing_if = "Option::is_none")]
302 pub input_tokens_details: Option<InputTokensDetails>,
303 pub output_tokens: u64,
305 pub output_tokens_details: OutputTokensDetails,
307 pub total_tokens: u64,
309}
310
311impl ResponsesUsage {
312 pub(crate) fn new() -> Self {
314 Self {
315 input_tokens: 0,
316 input_tokens_details: Some(InputTokensDetails::new()),
317 output_tokens: 0,
318 output_tokens_details: OutputTokensDetails::new(),
319 total_tokens: 0,
320 }
321 }
322}
323
324impl Add for ResponsesUsage {
325 type Output = Self;
326
327 fn add(self, rhs: Self) -> Self::Output {
328 let input_tokens = self.input_tokens + rhs.input_tokens;
329 let input_tokens_details = self.input_tokens_details.map(|lhs| {
330 if let Some(tokens) = rhs.input_tokens_details {
331 lhs + tokens
332 } else {
333 lhs
334 }
335 });
336 let output_tokens = self.output_tokens + rhs.output_tokens;
337 let output_tokens_details = self.output_tokens_details + rhs.output_tokens_details;
338 let total_tokens = self.total_tokens + rhs.total_tokens;
339 Self {
340 input_tokens,
341 input_tokens_details,
342 output_tokens,
343 output_tokens_details,
344 total_tokens,
345 }
346 }
347}
348
349#[derive(Clone, Debug, Serialize, Deserialize)]
351pub struct InputTokensDetails {
352 pub cached_tokens: u64,
354}
355
356impl InputTokensDetails {
357 pub(crate) fn new() -> Self {
358 Self { cached_tokens: 0 }
359 }
360}
361
362impl Add for InputTokensDetails {
363 type Output = Self;
364 fn add(self, rhs: Self) -> Self::Output {
365 Self {
366 cached_tokens: self.cached_tokens + rhs.cached_tokens,
367 }
368 }
369}
370
371#[derive(Clone, Debug, Serialize, Deserialize)]
373pub struct OutputTokensDetails {
374 pub reasoning_tokens: u64,
376}
377
378impl OutputTokensDetails {
379 pub(crate) fn new() -> Self {
380 Self {
381 reasoning_tokens: 0,
382 }
383 }
384}
385
386impl Add for OutputTokensDetails {
387 type Output = Self;
388 fn add(self, rhs: Self) -> Self::Output {
389 Self {
390 reasoning_tokens: self.reasoning_tokens + rhs.reasoning_tokens,
391 }
392 }
393}
394
395#[derive(Clone, Debug, Default, Serialize, Deserialize)]
397pub struct IncompleteDetailsReason {
398 pub reason: String,
400}
401
402#[derive(Clone, Debug, Default, Serialize, Deserialize)]
404pub struct ResponseError {
405 pub code: String,
407 pub message: String,
409}
410
411#[derive(Clone, Debug, Deserialize, Serialize)]
413#[serde(rename_all = "snake_case")]
414pub enum ResponseObject {
415 Response,
416}
417
418#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
420#[serde(rename_all = "snake_case")]
421pub enum ResponseStatus {
422 InProgress,
423 Completed,
424 Failed,
425 Cancelled,
426 Queued,
427 Incomplete,
428}
429
430impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionRequest {
432 type Error = CompletionError;
433 fn try_from(
434 (model, req): (String, crate::completion::CompletionRequest),
435 ) -> Result<Self, Self::Error> {
436 let input = {
437 let mut partial_history = vec![];
438 if let Some(docs) = req.normalized_documents() {
439 partial_history.push(docs);
440 }
441 partial_history.extend(req.chat_history);
442
443 let mut full_history: Vec<InputItem> = Vec::new();
445
446 full_history.extend(
448 partial_history
449 .into_iter()
450 .map(|x| <Vec<InputItem>>::try_from(x).unwrap())
451 .collect::<Vec<Vec<InputItem>>>()
452 .into_iter()
453 .flatten()
454 .collect::<Vec<InputItem>>(),
455 );
456
457 full_history
458 };
459
460 let input = OneOrMany::many(input)
461 .expect("This should never panic - if it does, please file a bug report");
462
463 let stream = req
464 .additional_params
465 .clone()
466 .unwrap_or(Value::Null)
467 .as_bool();
468
469 let additional_parameters = if let Some(map) = req.additional_params {
470 serde_json::from_value::<AdditionalParameters>(map).expect("Converting additional parameters to AdditionalParameters should never fail as every field is an Option")
471 } else {
472 AdditionalParameters::default()
474 };
475
476 Ok(Self {
477 input,
478 model,
479 instructions: req.preamble,
480 max_output_tokens: req.max_tokens,
481 stream,
482 tools: req
483 .tools
484 .into_iter()
485 .map(ResponsesToolDefinition::from)
486 .collect(),
487 temperature: req.temperature,
488 additional_parameters,
489 })
490 }
491}
492
493#[derive(Clone)]
495pub struct ResponsesCompletionModel {
496 pub(crate) client: Client,
498 pub model: String,
500}
501
502impl ResponsesCompletionModel {
503 pub fn new(client: Client, model: &str) -> Self {
505 Self {
506 client,
507 model: model.to_string(),
508 }
509 }
510
511 pub fn completions_api(self) -> crate::providers::openai::completion::CompletionModel {
513 crate::providers::openai::completion::CompletionModel::new(self.client, &self.model)
514 }
515
516 pub(crate) fn create_completion_request(
518 &self,
519 completion_request: crate::completion::CompletionRequest,
520 ) -> Result<CompletionRequest, CompletionError> {
521 let req = CompletionRequest::try_from((self.model.clone(), completion_request))?;
522
523 Ok(req)
524 }
525}
526
527#[derive(Clone, Debug, Serialize, Deserialize)]
529pub struct CompletionResponse {
530 pub id: String,
532 pub object: ResponseObject,
534 pub created_at: u64,
536 pub status: ResponseStatus,
538 pub error: Option<ResponseError>,
540 pub incomplete_details: Option<IncompleteDetailsReason>,
542 pub instructions: Option<String>,
544 pub max_output_tokens: Option<u64>,
546 pub model: String,
548 pub usage: Option<ResponsesUsage>,
550 pub output: Vec<Output>,
552 pub tools: Vec<ResponsesToolDefinition>,
554 #[serde(flatten)]
556 pub additional_parameters: AdditionalParameters,
557}
558
559#[derive(Clone, Debug, Deserialize, Serialize, Default)]
562pub struct AdditionalParameters {
563 #[serde(skip_serializing_if = "Option::is_none")]
565 pub background: Option<bool>,
566 #[serde(skip_serializing_if = "Option::is_none")]
568 pub text: Option<TextConfig>,
569 #[serde(skip_serializing_if = "Option::is_none")]
571 pub include: Option<Vec<Include>>,
572 #[serde(skip_serializing_if = "Option::is_none")]
574 pub top_p: Option<f64>,
575 #[serde(skip_serializing_if = "Option::is_none")]
577 pub truncation: Option<TruncationStrategy>,
578 #[serde(skip_serializing_if = "Option::is_none")]
580 pub user: Option<String>,
581 #[serde(skip_serializing_if = "Map::is_empty")]
583 pub metadata: serde_json::Map<String, serde_json::Value>,
584 #[serde(skip_serializing_if = "Option::is_none")]
586 pub parallel_tool_calls: Option<bool>,
587 #[serde(skip_serializing_if = "Option::is_none")]
589 pub previous_response_id: Option<String>,
590 #[serde(skip_serializing_if = "Option::is_none")]
592 pub reasoning: Option<Reasoning>,
593 #[serde(skip_serializing_if = "Option::is_none")]
595 pub service_tier: Option<OpenAIServiceTier>,
596 #[serde(skip_serializing_if = "Option::is_none")]
598 pub store: Option<bool>,
599}
600
601#[derive(Clone, Debug, Default, Serialize, Deserialize)]
605#[serde(rename_all = "snake_case")]
606pub enum TruncationStrategy {
607 Auto,
608 #[default]
609 Disabled,
610}
611
612#[derive(Clone, Debug, Serialize, Deserialize)]
615pub struct TextConfig {
616 pub format: TextFormat,
617}
618
619impl TextConfig {
620 pub(crate) fn structured_output<S>(name: S, schema: serde_json::Value) -> Self
621 where
622 S: Into<String>,
623 {
624 Self {
625 format: TextFormat::JsonSchema(StructuredOutputsInput {
626 name: name.into(),
627 schema,
628 strict: true,
629 }),
630 }
631 }
632}
633
634#[derive(Clone, Debug, Serialize, Deserialize, Default)]
637#[serde(tag = "type")]
638#[serde(rename_all = "snake_case")]
639pub enum TextFormat {
640 JsonSchema(StructuredOutputsInput),
641 #[default]
642 Text,
643}
644
645#[derive(Clone, Debug, Serialize, Deserialize)]
647pub struct StructuredOutputsInput {
648 pub name: String,
650 pub schema: serde_json::Value,
652 pub strict: bool,
654}
655
656#[derive(Clone, Debug, Default, Serialize, Deserialize)]
658pub struct Reasoning {
659 pub effort: Option<ReasoningEffort>,
661 #[serde(skip_serializing_if = "Option::is_none")]
663 pub summary: Option<ReasoningSummaryLevel>,
664}
665
666impl Reasoning {
667 pub fn new() -> Self {
669 Self {
670 effort: None,
671 summary: None,
672 }
673 }
674
675 pub fn with_effort(mut self, reasoning_effort: ReasoningEffort) -> Self {
677 self.effort = Some(reasoning_effort);
678
679 self
680 }
681
682 pub fn with_summary_level(mut self, reasoning_summary_level: ReasoningSummaryLevel) -> Self {
684 self.summary = Some(reasoning_summary_level);
685
686 self
687 }
688}
689
690#[derive(Clone, Debug, Default, Serialize, Deserialize)]
692#[serde(rename_all = "snake_case")]
693pub enum OpenAIServiceTier {
694 #[default]
695 Auto,
696 Default,
697 Flex,
698}
699
700#[derive(Clone, Debug, Default, Serialize, Deserialize)]
702#[serde(rename_all = "snake_case")]
703pub enum ReasoningEffort {
704 Low,
705 #[default]
706 Medium,
707 High,
708}
709
710#[derive(Clone, Debug, Default, Serialize, Deserialize)]
712#[serde(rename_all = "snake_case")]
713pub enum ReasoningSummaryLevel {
714 #[default]
715 Auto,
716 Concise,
717 Detailed,
718}
719
720#[derive(Clone, Debug, Deserialize, Serialize)]
723pub enum Include {
724 #[serde(rename = "file_search_call.results")]
725 FileSearchCallResults,
726 #[serde(rename = "message.input_image.image_url")]
727 MessageInputImageImageUrl,
728 #[serde(rename = "computer_call.output.image_url")]
729 ComputerCallOutputOutputImageUrl,
730 #[serde(rename = "reasoning.encrypted_content")]
731 ReasoningEncryptedContent,
732 #[serde(rename = "code_interpreter_call.outputs")]
733 CodeInterpreterCallOutputs,
734}
735
736#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
738#[serde(tag = "type")]
739#[serde(rename_all = "snake_case")]
740pub enum Output {
741 Message(OutputMessage),
742 #[serde(alias = "function_call")]
743 FunctionCall(OutputFunctionCall),
744}
745
746impl From<Output> for Vec<completion::AssistantContent> {
747 fn from(value: Output) -> Self {
748 let res: Vec<completion::AssistantContent> = match value {
749 Output::Message(OutputMessage { content, .. }) => content
750 .into_iter()
751 .map(completion::AssistantContent::from)
752 .collect(),
753 Output::FunctionCall(OutputFunctionCall {
754 id,
755 arguments,
756 call_id,
757 name,
758 ..
759 }) => vec![completion::AssistantContent::tool_call_with_call_id(
760 id, call_id, name, arguments,
761 )],
762 };
763
764 res
765 }
766}
767
768#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
770pub struct OutputFunctionCall {
771 pub id: String,
772 #[serde(with = "json_utils::stringified_json")]
773 pub arguments: serde_json::Value,
774 pub call_id: String,
775 pub name: String,
776 pub status: ToolStatus,
777}
778
779#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
781#[serde(rename_all = "snake_case")]
782pub enum ToolStatus {
783 InProgress,
784 Completed,
785 Incomplete,
786}
787
788#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
790pub struct OutputMessage {
791 pub id: String,
793 pub role: OutputRole,
795 pub status: ResponseStatus,
797 pub content: Vec<AssistantContent>,
799}
800
801#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
803#[serde(rename_all = "snake_case")]
804pub enum OutputRole {
805 Assistant,
806}
807
808impl completion::CompletionModel for ResponsesCompletionModel {
809 type Response = CompletionResponse;
810 type StreamingResponse = StreamingCompletionResponse;
811
812 #[cfg_attr(feature = "worker", worker::send)]
813 async fn completion(
814 &self,
815 completion_request: crate::completion::CompletionRequest,
816 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
817 let request = self.create_completion_request(completion_request)?;
818 let request = serde_json::to_value(request)?;
819
820 tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?);
821
822 let response = self.client.post("/responses").json(&request).send().await?;
823
824 if response.status().is_success() {
825 let t = response.text().await?;
826 tracing::debug!(target: "rig", "OpenAI response: {}", t);
827
828 let response = serde_json::from_str::<Self::Response>(&t)?;
829 response.try_into()
830 } else {
831 Err(CompletionError::ProviderError(response.text().await?))
832 }
833 }
834
835 #[cfg_attr(feature = "worker", worker::send)]
836 async fn stream(
837 &self,
838 request: crate::completion::CompletionRequest,
839 ) -> Result<
840 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
841 CompletionError,
842 > {
843 Self::stream(self, request).await
844 }
845}
846
847impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
848 type Error = CompletionError;
849
850 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
851 if response.output.is_empty() {
852 return Err(CompletionError::ResponseError(
853 "Response contained no parts".to_owned(),
854 ));
855 }
856
857 let content: Vec<completion::AssistantContent> = response
858 .output
859 .iter()
860 .cloned()
861 .flat_map(<Vec<completion::AssistantContent>>::from)
862 .collect();
863
864 let choice = OneOrMany::many(content).map_err(|_| {
865 CompletionError::ResponseError(
866 "Response contained no message or tool call (empty)".to_owned(),
867 )
868 })?;
869
870 Ok(completion::CompletionResponse {
871 choice,
872 raw_response: response,
873 })
874 }
875}
876
877#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
879#[serde(tag = "role", rename_all = "lowercase")]
880pub enum Message {
881 #[serde(alias = "developer")]
882 System {
883 #[serde(deserialize_with = "string_or_one_or_many")]
884 content: OneOrMany<SystemContent>,
885 #[serde(skip_serializing_if = "Option::is_none")]
886 name: Option<String>,
887 },
888 User {
889 #[serde(deserialize_with = "string_or_one_or_many")]
890 content: OneOrMany<UserContent>,
891 #[serde(skip_serializing_if = "Option::is_none")]
892 name: Option<String>,
893 },
894 Assistant {
895 content: OneOrMany<AssistantContentType>,
896 id: String,
897 #[serde(skip_serializing_if = "Option::is_none")]
898 name: Option<String>,
899 status: ToolStatus,
900 },
901 #[serde(rename = "tool")]
902 ToolResult {
903 tool_call_id: String,
904 output: String,
905 },
906}
907
908#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
910#[serde(rename_all = "lowercase")]
911pub enum ToolResultContentType {
912 #[default]
913 Text,
914}
915
916impl Message {
917 pub fn system(content: &str) -> Self {
918 Message::System {
919 content: OneOrMany::one(content.to_owned().into()),
920 name: None,
921 }
922 }
923}
924
925#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
928#[serde(tag = "type", rename_all = "snake_case")]
929pub enum AssistantContent {
930 OutputText(Text),
931 Refusal { refusal: String },
932}
933
934impl From<AssistantContent> for completion::AssistantContent {
935 fn from(value: AssistantContent) -> Self {
936 match value {
937 AssistantContent::Refusal { refusal } => {
938 completion::AssistantContent::Text(Text { text: refusal })
939 }
940 AssistantContent::OutputText(Text { text }) => {
941 completion::AssistantContent::Text(Text { text })
942 }
943 }
944 }
945}
946
947#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
949#[serde(untagged)]
950pub enum AssistantContentType {
951 Text(AssistantContent),
952 ToolCall(OutputFunctionCall),
953}
954
955#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
957#[serde(tag = "type", rename_all = "snake_case")]
958pub enum UserContent {
959 InputText {
960 text: String,
961 },
962 #[serde(rename = "image_url")]
963 Image {
964 image_url: ImageUrl,
965 },
966 Audio {
967 input_audio: InputAudio,
968 },
969 #[serde(rename = "tool")]
970 ToolResult {
971 tool_call_id: String,
972 output: String,
973 },
974}
975
976impl TryFrom<message::Message> for Vec<Message> {
977 type Error = message::MessageError;
978
979 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
980 match message {
981 message::Message::User { content } => {
982 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
983 .into_iter()
984 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
985
986 if !tool_results.is_empty() {
989 tool_results
990 .into_iter()
991 .map(|content| match content {
992 message::UserContent::ToolResult(message::ToolResult {
993 call_id,
994 content,
995 ..
996 }) => Ok::<_, message::MessageError>(Message::ToolResult {
997 tool_call_id: call_id.expect("The tool call ID should exist"),
998 output: {
999 let res = content.first();
1000 match res {
1001 completion::message::ToolResultContent::Text(Text {
1002 text,
1003 }) => text,
1004 _ => return Err(MessageError::ConversionError("This API only currently supports text tool results".into()))
1005 }
1006 },
1007 }),
1008 _ => unreachable!(),
1009 })
1010 .collect::<Result<Vec<_>, _>>()
1011 } else {
1012 let other_content = OneOrMany::many(other_content).expect(
1013 "There must be other content here if there were no tool result content",
1014 );
1015
1016 Ok(vec![Message::User {
1017 content: other_content.map(|content| match content {
1018 message::UserContent::Text(message::Text { text }) => {
1019 UserContent::InputText { text }
1020 }
1021 message::UserContent::Image(message::Image {
1022 data, detail, ..
1023 }) => UserContent::Image {
1024 image_url: ImageUrl {
1025 url: data,
1026 detail: detail.unwrap_or_default(),
1027 },
1028 },
1029 message::UserContent::Document(message::Document { data, .. }) => {
1030 UserContent::InputText { text: data }
1031 }
1032 message::UserContent::Audio(message::Audio {
1033 data,
1034 media_type,
1035 ..
1036 }) => UserContent::Audio {
1037 input_audio: InputAudio {
1038 data,
1039 format: match media_type {
1040 Some(media_type) => media_type,
1041 None => AudioMediaType::MP3,
1042 },
1043 },
1044 },
1045 _ => unreachable!(),
1046 }),
1047 name: None,
1048 }])
1049 }
1050 }
1051 message::Message::Assistant { content, id } => {
1052 let assistant_message_id = id;
1053
1054 match content.first() {
1055 crate::message::AssistantContent::Text(Text { text }) => {
1056 Ok(vec![Message::Assistant {
1057 id: assistant_message_id
1058 .expect("The assistant message ID should exist"),
1059 status: ToolStatus::Completed,
1060 content: OneOrMany::one(AssistantContentType::Text(
1061 AssistantContent::OutputText(Text { text }),
1062 )),
1063 name: None,
1064 }])
1065 }
1066 crate::message::AssistantContent::ToolCall(crate::message::ToolCall {
1067 id,
1068 call_id,
1069 function,
1070 }) => Ok(vec![Message::Assistant {
1071 content: OneOrMany::one(AssistantContentType::ToolCall(
1072 OutputFunctionCall {
1073 call_id: call_id.expect("The call ID should exist"),
1074 arguments: function.arguments,
1075 id,
1076 name: function.name,
1077 status: ToolStatus::Completed,
1078 },
1079 )),
1080 id: assistant_message_id.expect("The assistant message ID should exist!"),
1081 name: None,
1082 status: ToolStatus::Completed,
1083 }]),
1084 }
1085 }
1086 }
1087 }
1088}
1089
1090impl FromStr for UserContent {
1091 type Err = Infallible;
1092
1093 fn from_str(s: &str) -> Result<Self, Self::Err> {
1094 Ok(UserContent::InputText {
1095 text: s.to_string(),
1096 })
1097 }
1098}