1use super::{Client, responses_api::streaming::StreamingCompletionResponse};
11use super::{ImageUrl, InputAudio, SystemContent};
12use crate::completion::CompletionError;
13use crate::json_utils;
14use crate::message::{AudioMediaType, MessageError, Text};
15use crate::one_or_many::string_or_one_or_many;
16
17use crate::{OneOrMany, completion, message};
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value};
20
21use std::convert::Infallible;
22use std::ops::Add;
23use std::str::FromStr;
24
25pub mod streaming;
26
27#[derive(Debug, Deserialize, Serialize, Clone)]
30pub struct CompletionRequest {
31 pub input: OneOrMany<InputItem>,
33 pub model: String,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub instructions: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
40 pub max_output_tokens: Option<u64>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub stream: Option<bool>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub temperature: Option<f64>,
47 #[serde(skip_serializing_if = "Vec::is_empty")]
51 pub tools: Vec<ResponsesToolDefinition>,
52 #[serde(flatten)]
54 pub additional_parameters: AdditionalParameters,
55}
56
57impl CompletionRequest {
58 pub fn with_structured_outputs<S>(mut self, schema_name: S, schema: serde_json::Value) -> Self
59 where
60 S: Into<String>,
61 {
62 self.additional_parameters.text = Some(TextConfig::structured_output(schema_name, schema));
63
64 self
65 }
66
67 pub fn with_reasoning(mut self, reasoning: Reasoning) -> Self {
68 self.additional_parameters.reasoning = Some(reasoning);
69
70 self
71 }
72}
73
74#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct InputItem {
77 #[serde(skip_serializing_if = "Option::is_none")]
81 role: Option<Role>,
82 #[serde(flatten)]
84 input: InputContent,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
89#[serde(rename_all = "lowercase")]
90pub enum Role {
91 User,
92 Assistant,
93 System,
94}
95
96#[derive(Debug, Deserialize, Serialize, Clone)]
98#[serde(tag = "type", rename_all = "snake_case")]
99enum InputContent {
100 Message(Message),
101 FunctionCall(OutputFunctionCall),
102 FunctionCallOutput(ToolResult),
103}
104
105#[derive(Debug, Deserialize, Serialize, Clone)]
107pub struct ToolResult {
108 call_id: String,
110 output: String,
112 status: ToolStatus,
114}
115
116impl From<Message> for InputItem {
117 fn from(value: Message) -> Self {
118 match value {
119 Message::User { .. } => Self {
120 role: Some(Role::User),
121 input: InputContent::Message(value),
122 },
123 Message::Assistant { .. } => Self {
124 role: Some(Role::Assistant),
125 input: InputContent::Message(value),
126 },
127 Message::System { .. } => Self {
128 role: Some(Role::System),
129 input: InputContent::Message(value),
130 },
131 Message::ToolResult {
132 tool_call_id,
133 output,
134 } => Self {
135 role: None,
136 input: InputContent::FunctionCallOutput(ToolResult {
137 call_id: tool_call_id,
138 output,
139 status: ToolStatus::Completed,
140 }),
141 },
142 }
143 }
144}
145
146impl TryFrom<crate::completion::Message> for Vec<InputItem> {
147 type Error = CompletionError;
148
149 fn try_from(value: crate::completion::Message) -> Result<Self, Self::Error> {
150 match value {
151 crate::completion::Message::User { content } => {
152 let mut items = Vec::new();
153
154 for user_content in content {
155 match user_content {
156 crate::message::UserContent::Text(Text { text }) => {
157 items.push(InputItem {
158 role: Some(Role::User),
159 input: InputContent::Message(Message::User {
160 content: OneOrMany::one(UserContent::InputText { text }),
161 name: None,
162 }),
163 });
164 }
165 crate::message::UserContent::ToolResult(
166 crate::completion::message::ToolResult {
167 call_id,
168 content: tool_content,
169 ..
170 },
171 ) => {
172 for tool_result_content in tool_content {
173 let crate::completion::message::ToolResultContent::Text(Text {
174 text,
175 }) = tool_result_content
176 else {
177 return Err(CompletionError::ProviderError(
178 "This thing only supports text!".to_string(),
179 ));
180 };
181 items.push(InputItem {
183 role: None,
184 input: InputContent::FunctionCallOutput(ToolResult {
185 call_id: call_id
186 .clone()
187 .expect("The call ID of this tool should exist!"),
188 output: text,
189 status: ToolStatus::Completed,
190 }),
191 });
192 }
193 }
194 _ => {
195 return Err(CompletionError::ProviderError(
196 "This API only supports text and tool results at the moment"
197 .to_string(),
198 ));
199 }
200 }
201 }
202
203 Ok(items)
204 }
205 crate::completion::Message::Assistant { id, content } => {
206 let mut items = Vec::new();
207
208 for assistant_content in content {
209 match assistant_content {
210 crate::message::AssistantContent::Text(Text { text }) => {
211 let id = id.as_ref().unwrap_or(&String::default()).clone();
212 items.push(InputItem {
213 role: Some(Role::Assistant),
214 input: InputContent::Message(Message::Assistant {
215 content: OneOrMany::one(AssistantContentType::Text(
216 AssistantContent::OutputText(Text { text }),
217 )),
218 id,
219 name: None,
220 status: ToolStatus::Completed,
221 }),
222 });
223 }
224 crate::message::AssistantContent::ToolCall(crate::message::ToolCall {
225 id: tool_id,
226 call_id,
227 function,
228 }) => {
229 items.push(InputItem {
230 role: None,
231 input: InputContent::FunctionCall(OutputFunctionCall {
232 arguments: function.arguments,
233 call_id: call_id.expect("The tool call ID should exist!"),
234 id: tool_id,
235 name: function.name,
236 status: ToolStatus::Completed,
237 }),
238 });
239 }
240 }
241 }
242
243 Ok(items)
244 }
245 }
246 }
247}
248
249#[derive(Debug, Deserialize, Serialize, Clone)]
251pub struct ResponsesToolDefinition {
252 pub name: String,
254 pub parameters: serde_json::Value,
256 pub strict: bool,
258 #[serde(rename = "type")]
260 pub kind: String,
261 pub description: String,
263}
264
265impl From<completion::ToolDefinition> for ResponsesToolDefinition {
266 fn from(value: completion::ToolDefinition) -> Self {
267 let completion::ToolDefinition {
268 name,
269 mut parameters,
270 description,
271 } = value;
272
273 let parameters = parameters
274 .as_object_mut()
275 .expect("parameters should be a JSON object");
276 parameters.insert(
277 "additionalProperties".to_string(),
278 serde_json::Value::Bool(false),
279 );
280
281 let parameters = serde_json::Value::Object(parameters.clone());
282
283 Self {
284 name,
285 parameters,
286 description,
287 kind: "function".to_string(),
288 strict: true,
289 }
290 }
291}
292
293#[derive(Clone, Debug, Serialize, Deserialize)]
296pub struct ResponsesUsage {
297 pub input_tokens: u64,
299 #[serde(skip_serializing_if = "Option::is_none")]
301 pub input_tokens_details: Option<InputTokensDetails>,
302 pub output_tokens: u64,
304 pub output_tokens_details: OutputTokensDetails,
306 pub total_tokens: u64,
308}
309
310impl ResponsesUsage {
311 pub(crate) fn new() -> Self {
313 Self {
314 input_tokens: 0,
315 input_tokens_details: Some(InputTokensDetails::new()),
316 output_tokens: 0,
317 output_tokens_details: OutputTokensDetails::new(),
318 total_tokens: 0,
319 }
320 }
321}
322
323impl Add for ResponsesUsage {
324 type Output = Self;
325
326 fn add(self, rhs: Self) -> Self::Output {
327 let input_tokens = self.input_tokens + rhs.input_tokens;
328 let input_tokens_details = self.input_tokens_details.map(|lhs| {
329 if let Some(tokens) = rhs.input_tokens_details {
330 lhs + tokens
331 } else {
332 lhs
333 }
334 });
335 let output_tokens = self.output_tokens + rhs.output_tokens;
336 let output_tokens_details = self.output_tokens_details + rhs.output_tokens_details;
337 let total_tokens = self.total_tokens + rhs.total_tokens;
338 Self {
339 input_tokens,
340 input_tokens_details,
341 output_tokens,
342 output_tokens_details,
343 total_tokens,
344 }
345 }
346}
347
348#[derive(Clone, Debug, Serialize, Deserialize)]
350pub struct InputTokensDetails {
351 pub cached_tokens: u64,
353}
354
355impl InputTokensDetails {
356 pub(crate) fn new() -> Self {
357 Self { cached_tokens: 0 }
358 }
359}
360
361impl Add for InputTokensDetails {
362 type Output = Self;
363 fn add(self, rhs: Self) -> Self::Output {
364 Self {
365 cached_tokens: self.cached_tokens + rhs.cached_tokens,
366 }
367 }
368}
369
370#[derive(Clone, Debug, Serialize, Deserialize)]
372pub struct OutputTokensDetails {
373 pub reasoning_tokens: u64,
375}
376
377impl OutputTokensDetails {
378 pub(crate) fn new() -> Self {
379 Self {
380 reasoning_tokens: 0,
381 }
382 }
383}
384
385impl Add for OutputTokensDetails {
386 type Output = Self;
387 fn add(self, rhs: Self) -> Self::Output {
388 Self {
389 reasoning_tokens: self.reasoning_tokens + rhs.reasoning_tokens,
390 }
391 }
392}
393
394#[derive(Clone, Debug, Default, Serialize, Deserialize)]
396pub struct IncompleteDetailsReason {
397 pub reason: String,
399}
400
401#[derive(Clone, Debug, Default, Serialize, Deserialize)]
403pub struct ResponseError {
404 pub code: String,
406 pub message: String,
408}
409
410#[derive(Clone, Debug, Deserialize, Serialize)]
412#[serde(rename_all = "snake_case")]
413pub enum ResponseObject {
414 Response,
415}
416
417#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
419#[serde(rename_all = "snake_case")]
420pub enum ResponseStatus {
421 InProgress,
422 Completed,
423 Failed,
424 Cancelled,
425 Queued,
426 Incomplete,
427}
428
429impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionRequest {
431 type Error = CompletionError;
432 fn try_from(
433 (model, req): (String, crate::completion::CompletionRequest),
434 ) -> Result<Self, Self::Error> {
435 let input = {
436 let mut partial_history = vec![];
437 if let Some(docs) = req.normalized_documents() {
438 partial_history.push(docs);
439 }
440 partial_history.extend(req.chat_history);
441
442 let mut full_history: Vec<InputItem> = Vec::new();
444
445 full_history.extend(
447 partial_history
448 .into_iter()
449 .map(|x| <Vec<InputItem>>::try_from(x).unwrap())
450 .collect::<Vec<Vec<InputItem>>>()
451 .into_iter()
452 .flatten()
453 .collect::<Vec<InputItem>>(),
454 );
455
456 full_history
457 };
458
459 let input = OneOrMany::many(input)
460 .expect("This should never panic - if it does, please file a bug report");
461
462 let stream = req
463 .additional_params
464 .clone()
465 .unwrap_or(Value::Null)
466 .as_bool();
467
468 let additional_parameters = if let Some(map) = req.additional_params {
469 serde_json::from_value::<AdditionalParameters>(map).expect("Converting additional parameters to AdditionalParameters should never fail as every field is an Option")
470 } else {
471 AdditionalParameters::default()
473 };
474
475 Ok(Self {
476 input,
477 model,
478 instructions: req.preamble,
479 max_output_tokens: req.max_tokens,
480 stream,
481 tools: req
482 .tools
483 .into_iter()
484 .map(ResponsesToolDefinition::from)
485 .collect(),
486 temperature: req.temperature,
487 additional_parameters,
488 })
489 }
490}
491
492#[derive(Clone)]
494pub struct ResponsesCompletionModel {
495 pub(crate) client: Client,
497 pub model: String,
499}
500
501impl ResponsesCompletionModel {
502 pub fn new(client: Client, model: &str) -> Self {
504 Self {
505 client,
506 model: model.to_string(),
507 }
508 }
509
510 pub fn completions_api(self) -> crate::providers::openai::completion::CompletionModel {
512 crate::providers::openai::completion::CompletionModel::new(self.client, &self.model)
513 }
514
515 pub(crate) fn create_completion_request(
517 &self,
518 completion_request: crate::completion::CompletionRequest,
519 ) -> Result<CompletionRequest, CompletionError> {
520 let req = CompletionRequest::try_from((self.model.clone(), completion_request))?;
521
522 Ok(req)
523 }
524}
525
526#[derive(Clone, Debug, Serialize, Deserialize)]
528pub struct CompletionResponse {
529 pub id: String,
531 pub object: ResponseObject,
533 pub created_at: u64,
535 pub status: ResponseStatus,
537 pub error: Option<ResponseError>,
539 pub incomplete_details: Option<IncompleteDetailsReason>,
541 pub instructions: Option<String>,
543 pub max_output_tokens: Option<u64>,
545 pub model: String,
547 pub usage: Option<ResponsesUsage>,
549 pub output: Vec<Output>,
551 pub tools: Vec<ResponsesToolDefinition>,
553 #[serde(flatten)]
555 pub additional_parameters: AdditionalParameters,
556}
557
558#[derive(Clone, Debug, Deserialize, Serialize, Default)]
561pub struct AdditionalParameters {
562 #[serde(skip_serializing_if = "Option::is_none")]
564 pub background: Option<bool>,
565 #[serde(skip_serializing_if = "Option::is_none")]
567 pub text: Option<TextConfig>,
568 #[serde(skip_serializing_if = "Option::is_none")]
570 pub include: Option<Vec<Include>>,
571 #[serde(skip_serializing_if = "Option::is_none")]
573 pub top_p: Option<f64>,
574 #[serde(skip_serializing_if = "Option::is_none")]
576 pub truncation: Option<TruncationStrategy>,
577 #[serde(skip_serializing_if = "Option::is_none")]
579 pub user: Option<String>,
580 #[serde(skip_serializing_if = "Map::is_empty")]
582 pub metadata: serde_json::Map<String, serde_json::Value>,
583 #[serde(skip_serializing_if = "Option::is_none")]
585 pub parallel_tool_calls: Option<bool>,
586 #[serde(skip_serializing_if = "Option::is_none")]
588 pub previous_response_id: Option<String>,
589 #[serde(skip_serializing_if = "Option::is_none")]
591 pub reasoning: Option<Reasoning>,
592 #[serde(skip_serializing_if = "Option::is_none")]
594 pub service_tier: Option<OpenAIServiceTier>,
595 #[serde(skip_serializing_if = "Option::is_none")]
597 pub store: Option<bool>,
598}
599
600#[derive(Clone, Debug, Default, Serialize, Deserialize)]
604#[serde(rename_all = "snake_case")]
605pub enum TruncationStrategy {
606 Auto,
607 #[default]
608 Disabled,
609}
610
611#[derive(Clone, Debug, Serialize, Deserialize)]
614pub struct TextConfig {
615 pub format: TextFormat,
616}
617
618impl TextConfig {
619 pub(crate) fn structured_output<S>(name: S, schema: serde_json::Value) -> Self
620 where
621 S: Into<String>,
622 {
623 Self {
624 format: TextFormat::JsonSchema(StructuredOutputsInput {
625 name: name.into(),
626 schema,
627 strict: true,
628 }),
629 }
630 }
631}
632
633#[derive(Clone, Debug, Serialize, Deserialize, Default)]
636#[serde(tag = "type")]
637#[serde(rename_all = "snake_case")]
638pub enum TextFormat {
639 JsonSchema(StructuredOutputsInput),
640 #[default]
641 Text,
642}
643
644#[derive(Clone, Debug, Serialize, Deserialize)]
646pub struct StructuredOutputsInput {
647 pub name: String,
649 pub schema: serde_json::Value,
651 pub strict: bool,
653}
654
655#[derive(Clone, Debug, Default, Serialize, Deserialize)]
657pub struct Reasoning {
658 pub effort: Option<ReasoningEffort>,
660 #[serde(skip_serializing_if = "Option::is_none")]
662 pub summary: Option<ReasoningSummaryLevel>,
663}
664
665impl Reasoning {
666 pub fn new() -> Self {
668 Self {
669 effort: None,
670 summary: None,
671 }
672 }
673
674 pub fn with_effort(mut self, reasoning_effort: ReasoningEffort) -> Self {
676 self.effort = Some(reasoning_effort);
677
678 self
679 }
680
681 pub fn with_summary_level(mut self, reasoning_summary_level: ReasoningSummaryLevel) -> Self {
683 self.summary = Some(reasoning_summary_level);
684
685 self
686 }
687}
688
689#[derive(Clone, Debug, Default, Serialize, Deserialize)]
691#[serde(rename_all = "snake_case")]
692pub enum OpenAIServiceTier {
693 #[default]
694 Auto,
695 Default,
696 Flex,
697}
698
699#[derive(Clone, Debug, Default, Serialize, Deserialize)]
701#[serde(rename_all = "snake_case")]
702pub enum ReasoningEffort {
703 Low,
704 #[default]
705 Medium,
706 High,
707}
708
709#[derive(Clone, Debug, Default, Serialize, Deserialize)]
711#[serde(rename_all = "snake_case")]
712pub enum ReasoningSummaryLevel {
713 #[default]
714 Auto,
715 Concise,
716 Detailed,
717}
718
719#[derive(Clone, Debug, Deserialize, Serialize)]
722pub enum Include {
723 #[serde(rename = "file_search_call.results")]
724 FileSearchCallResults,
725 #[serde(rename = "message.input_image.image_url")]
726 MessageInputImageImageUrl,
727 #[serde(rename = "computer_call.output.image_url")]
728 ComputerCallOutputOutputImageUrl,
729 #[serde(rename = "reasoning.encrypted_content")]
730 ReasoningEncryptedContent,
731 #[serde(rename = "code_interpreter_call.outputs")]
732 CodeInterpreterCallOutputs,
733}
734
735#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
737#[serde(tag = "type")]
738#[serde(rename_all = "snake_case")]
739pub enum Output {
740 Message(OutputMessage),
741 #[serde(alias = "function_call")]
742 FunctionCall(OutputFunctionCall),
743}
744
745impl From<Output> for Vec<completion::AssistantContent> {
746 fn from(value: Output) -> Self {
747 let res: Vec<completion::AssistantContent> = match value {
748 Output::Message(OutputMessage { content, .. }) => content
749 .into_iter()
750 .map(completion::AssistantContent::from)
751 .collect(),
752 Output::FunctionCall(OutputFunctionCall {
753 id,
754 arguments,
755 call_id,
756 name,
757 ..
758 }) => vec![completion::AssistantContent::tool_call_with_call_id(
759 id, call_id, name, arguments,
760 )],
761 };
762
763 res
764 }
765}
766
767#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
769pub struct OutputFunctionCall {
770 pub id: String,
771 #[serde(with = "json_utils::stringified_json")]
772 pub arguments: serde_json::Value,
773 pub call_id: String,
774 pub name: String,
775 pub status: ToolStatus,
776}
777
778#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
780#[serde(rename_all = "snake_case")]
781pub enum ToolStatus {
782 InProgress,
783 Completed,
784 Incomplete,
785}
786
787#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
789pub struct OutputMessage {
790 pub id: String,
792 pub role: OutputRole,
794 pub status: ResponseStatus,
796 pub content: Vec<AssistantContent>,
798}
799
800#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
802#[serde(rename_all = "snake_case")]
803pub enum OutputRole {
804 Assistant,
805}
806
807impl completion::CompletionModel for ResponsesCompletionModel {
808 type Response = CompletionResponse;
809 type StreamingResponse = StreamingCompletionResponse;
810
811 #[cfg_attr(feature = "worker", worker::send)]
812 async fn completion(
813 &self,
814 completion_request: crate::completion::CompletionRequest,
815 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
816 let request = self.create_completion_request(completion_request)?;
817 let request = serde_json::to_value(request)?;
818
819 tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?);
820
821 let response = self.client.post("/responses").json(&request).send().await?;
822
823 if response.status().is_success() {
824 let t = response.text().await?;
825 tracing::debug!(target: "rig", "OpenAI response: {}", t);
826
827 let response = serde_json::from_str::<Self::Response>(&t)?;
828 response.try_into()
829 } else {
830 Err(CompletionError::ProviderError(response.text().await?))
831 }
832 }
833
834 #[cfg_attr(feature = "worker", worker::send)]
835 async fn stream(
836 &self,
837 request: crate::completion::CompletionRequest,
838 ) -> Result<
839 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
840 CompletionError,
841 > {
842 Self::stream(self, request).await
843 }
844}
845
846impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
847 type Error = CompletionError;
848
849 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
850 if response.output.is_empty() {
851 return Err(CompletionError::ResponseError(
852 "Response contained no parts".to_owned(),
853 ));
854 }
855
856 let content: Vec<completion::AssistantContent> = response
857 .output
858 .iter()
859 .cloned()
860 .flat_map(<Vec<completion::AssistantContent>>::from)
861 .collect();
862
863 let choice = OneOrMany::many(content).map_err(|_| {
864 CompletionError::ResponseError(
865 "Response contained no message or tool call (empty)".to_owned(),
866 )
867 })?;
868
869 let usage = response
870 .usage
871 .as_ref()
872 .map(|usage| completion::Usage {
873 input_tokens: usage.input_tokens,
874 output_tokens: usage.output_tokens,
875 total_tokens: usage.total_tokens,
876 })
877 .unwrap_or_default();
878
879 Ok(completion::CompletionResponse {
880 choice,
881 usage,
882 raw_response: response,
883 })
884 }
885}
886
887#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
889#[serde(tag = "role", rename_all = "lowercase")]
890pub enum Message {
891 #[serde(alias = "developer")]
892 System {
893 #[serde(deserialize_with = "string_or_one_or_many")]
894 content: OneOrMany<SystemContent>,
895 #[serde(skip_serializing_if = "Option::is_none")]
896 name: Option<String>,
897 },
898 User {
899 #[serde(deserialize_with = "string_or_one_or_many")]
900 content: OneOrMany<UserContent>,
901 #[serde(skip_serializing_if = "Option::is_none")]
902 name: Option<String>,
903 },
904 Assistant {
905 content: OneOrMany<AssistantContentType>,
906 #[serde(skip_serializing_if = "String::is_empty")]
907 id: String,
908 #[serde(skip_serializing_if = "Option::is_none")]
909 name: Option<String>,
910 status: ToolStatus,
911 },
912 #[serde(rename = "tool")]
913 ToolResult {
914 tool_call_id: String,
915 output: String,
916 },
917}
918
919#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
921#[serde(rename_all = "lowercase")]
922pub enum ToolResultContentType {
923 #[default]
924 Text,
925}
926
927impl Message {
928 pub fn system(content: &str) -> Self {
929 Message::System {
930 content: OneOrMany::one(content.to_owned().into()),
931 name: None,
932 }
933 }
934}
935
936#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
939#[serde(tag = "type", rename_all = "snake_case")]
940pub enum AssistantContent {
941 OutputText(Text),
942 Refusal { refusal: String },
943}
944
945impl From<AssistantContent> for completion::AssistantContent {
946 fn from(value: AssistantContent) -> Self {
947 match value {
948 AssistantContent::Refusal { refusal } => {
949 completion::AssistantContent::Text(Text { text: refusal })
950 }
951 AssistantContent::OutputText(Text { text }) => {
952 completion::AssistantContent::Text(Text { text })
953 }
954 }
955 }
956}
957
958#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
960#[serde(untagged)]
961pub enum AssistantContentType {
962 Text(AssistantContent),
963 ToolCall(OutputFunctionCall),
964}
965
966#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
968#[serde(tag = "type", rename_all = "snake_case")]
969pub enum UserContent {
970 InputText {
971 text: String,
972 },
973 #[serde(rename = "image_url")]
974 Image {
975 image_url: ImageUrl,
976 },
977 Audio {
978 input_audio: InputAudio,
979 },
980 #[serde(rename = "tool")]
981 ToolResult {
982 tool_call_id: String,
983 output: String,
984 },
985}
986
987impl TryFrom<message::Message> for Vec<Message> {
988 type Error = message::MessageError;
989
990 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
991 match message {
992 message::Message::User { content } => {
993 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
994 .into_iter()
995 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
996
997 if !tool_results.is_empty() {
1000 tool_results
1001 .into_iter()
1002 .map(|content| match content {
1003 message::UserContent::ToolResult(message::ToolResult {
1004 call_id,
1005 content,
1006 ..
1007 }) => Ok::<_, message::MessageError>(Message::ToolResult {
1008 tool_call_id: call_id.expect("The tool call ID should exist"),
1009 output: {
1010 let res = content.first();
1011 match res {
1012 completion::message::ToolResultContent::Text(Text {
1013 text,
1014 }) => text,
1015 _ => return Err(MessageError::ConversionError("This API only currently supports text tool results".into()))
1016 }
1017 },
1018 }),
1019 _ => unreachable!(),
1020 })
1021 .collect::<Result<Vec<_>, _>>()
1022 } else {
1023 let other_content = OneOrMany::many(other_content).expect(
1024 "There must be other content here if there were no tool result content",
1025 );
1026
1027 Ok(vec![Message::User {
1028 content: other_content.map(|content| match content {
1029 message::UserContent::Text(message::Text { text }) => {
1030 UserContent::InputText { text }
1031 }
1032 message::UserContent::Image(message::Image {
1033 data, detail, ..
1034 }) => UserContent::Image {
1035 image_url: ImageUrl {
1036 url: data,
1037 detail: detail.unwrap_or_default(),
1038 },
1039 },
1040 message::UserContent::Document(message::Document { data, .. }) => {
1041 UserContent::InputText { text: data }
1042 }
1043 message::UserContent::Audio(message::Audio {
1044 data,
1045 media_type,
1046 ..
1047 }) => UserContent::Audio {
1048 input_audio: InputAudio {
1049 data,
1050 format: match media_type {
1051 Some(media_type) => media_type,
1052 None => AudioMediaType::MP3,
1053 },
1054 },
1055 },
1056 _ => unreachable!(),
1057 }),
1058 name: None,
1059 }])
1060 }
1061 }
1062 message::Message::Assistant { content, id } => {
1063 let assistant_message_id = id;
1064
1065 match content.first() {
1066 crate::message::AssistantContent::Text(Text { text }) => {
1067 Ok(vec![Message::Assistant {
1068 id: assistant_message_id
1069 .expect("The assistant message ID should exist"),
1070 status: ToolStatus::Completed,
1071 content: OneOrMany::one(AssistantContentType::Text(
1072 AssistantContent::OutputText(Text { text }),
1073 )),
1074 name: None,
1075 }])
1076 }
1077 crate::message::AssistantContent::ToolCall(crate::message::ToolCall {
1078 id,
1079 call_id,
1080 function,
1081 }) => Ok(vec![Message::Assistant {
1082 content: OneOrMany::one(AssistantContentType::ToolCall(
1083 OutputFunctionCall {
1084 call_id: call_id.expect("The call ID should exist"),
1085 arguments: function.arguments,
1086 id,
1087 name: function.name,
1088 status: ToolStatus::Completed,
1089 },
1090 )),
1091 id: assistant_message_id.expect("The assistant message ID should exist!"),
1092 name: None,
1093 status: ToolStatus::Completed,
1094 }]),
1095 }
1096 }
1097 }
1098 }
1099}
1100
1101impl FromStr for UserContent {
1102 type Err = Infallible;
1103
1104 fn from_str(s: &str) -> Result<Self, Self::Err> {
1105 Ok(UserContent::InputText {
1106 text: s.to_string(),
1107 })
1108 }
1109}