ai_providers/openai/common/
text.rs

1use serde::{Deserialize, Serialize};
2
3use crate::openai::errors::ConversionError;
4
5#[derive(Debug, PartialEq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7enum ResponseFormatType {
8    Text,
9    JsonSchema,
10    JsonObject,
11}
12
13#[derive(Debug, PartialEq, Serialize, Deserialize)]
14pub struct TextFormat {
15    #[serde(rename = "type")]
16    type_field: ResponseFormatType, // always text
17}
18
19impl TextFormat {
20    pub fn new() -> Self {
21        Self {
22            type_field: ResponseFormatType::Text,
23        }
24    }
25}
26
27impl Default for TextFormat {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33#[derive(Debug, PartialEq, Serialize, Deserialize)]
34pub struct JsonSchemaFormat {
35    #[serde(rename = "type")]
36    type_field: ResponseFormatType, // always json_schema
37    name: String,
38    schema: serde_json::Value,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    description: Option<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    strict: Option<bool>,
43}
44
45impl JsonSchemaFormat {
46    pub fn new(name: impl Into<String>, schema: serde_json::Value) -> Self {
47        Self {
48            type_field: ResponseFormatType::JsonSchema,
49            name: name.into(),
50            schema,
51            description: None,
52            strict: Some(false),
53        }
54    }
55
56    pub fn description(mut self, value: impl Into<String>) -> Self {
57        self.description = Some(value.into());
58        self
59    }
60
61    pub fn strict(mut self) -> Self {
62        self.strict = Some(true);
63        self
64    }
65}
66
67#[derive(Debug, PartialEq, Serialize, Deserialize)]
68pub struct JsonObjectFormat {
69    #[serde(rename = "type")]
70    type_field: ResponseFormatType, // always json_object
71}
72
73impl JsonObjectFormat {
74    pub fn new() -> Self {
75        Self {
76            type_field: ResponseFormatType::JsonObject,
77        }
78    }
79}
80
81impl Default for JsonObjectFormat {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[derive(Debug, PartialEq, Serialize, Deserialize)]
88#[serde(untagged)]
89pub enum ResponseFormat {
90    Text(TextFormat),
91    JsonSchema(JsonSchemaFormat),
92    JsonObject(JsonObjectFormat),
93}
94
95impl std::fmt::Display for ResponseFormat {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self {
98            ResponseFormat::Text(_) => write!(f, "text"),
99            ResponseFormat::JsonSchema(_) => write!(f, "json_schema"),
100            ResponseFormat::JsonObject(_) => write!(f, "json_object"),
101        }
102    }
103}
104
105impl From<TextFormat> for ResponseFormat {
106    fn from(text_format: TextFormat) -> Self {
107        Self::Text(text_format)
108    }
109}
110
111impl From<JsonSchemaFormat> for ResponseFormat {
112    fn from(format: JsonSchemaFormat) -> Self {
113        Self::JsonSchema(format)
114    }
115}
116
117impl From<JsonObjectFormat> for ResponseFormat {
118    fn from(format: JsonObjectFormat) -> Self {
119        Self::JsonObject(format)
120    }
121}
122
123impl TryFrom<ResponseFormat> for TextFormat {
124    type Error = ConversionError;
125
126    fn try_from(format: ResponseFormat) -> Result<Self, Self::Error> {
127        match format {
128            ResponseFormat::Text(inner) => Ok(inner),
129            _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())),
130        }
131    }
132}
133
134impl TryFrom<ResponseFormat> for JsonSchemaFormat {
135    type Error = ConversionError;
136
137    fn try_from(format: ResponseFormat) -> Result<Self, Self::Error> {
138        match format {
139            ResponseFormat::JsonSchema(inner) => Ok(inner),
140            _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())),
141        }
142    }
143}
144
145impl TryFrom<ResponseFormat> for JsonObjectFormat {
146    type Error = ConversionError;
147
148    fn try_from(format: ResponseFormat) -> Result<Self, Self::Error> {
149        match format {
150            ResponseFormat::JsonObject(inner) => Ok(inner),
151            _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())),
152        }
153    }
154}
155
156#[derive(Debug, PartialEq, Serialize, Deserialize)]
157pub struct Text {
158    #[serde(skip_serializing_if = "Option::is_none")]
159    format: Option<ResponseFormat>,
160}
161
162impl Default for Text {
163    fn default() -> Self {
164        Self {
165            format: Some(ResponseFormat::Text(TextFormat::default())),
166        }
167    }
168}
169
170impl Text {
171    pub fn response_format(mut self, value: ResponseFormat) -> Self {
172        self.format = Some(value);
173        self
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use serde_json::json;
181
182    #[test]
183    fn it_builds_text_response_format() {
184        let result = Text::default().response_format(TextFormat::new().into());
185
186        assert_eq!(
187            result,
188            Text {
189                format: Some(ResponseFormat::Text(TextFormat {
190                    type_field: ResponseFormatType::Text
191                }))
192            }
193        );
194    }
195
196    #[test]
197    fn it_builds_json_schema_response_format() {
198        let schema = json!({
199            "name": "Alice",
200            "age": 30,
201            "active": true,
202            "friends": ["Bob", "Charlie"],
203            "address": {
204                "street": "123 Main St",
205                "city": "Somewhere"
206            }
207        });
208
209        let response_format: ResponseFormat = JsonSchemaFormat::new("test", schema.clone())
210            .description("this is a description")
211            .into();
212
213        let result = Text::default().response_format(response_format);
214
215        let expected = Text {
216            format: Some(ResponseFormat::JsonSchema(JsonSchemaFormat {
217                type_field: ResponseFormatType::JsonSchema,
218                name: "test".to_string(),
219                schema: schema,
220                description: Some("this is a description".to_string()),
221                strict: Some(false),
222            })),
223        };
224
225        assert_eq!(result, expected);
226    }
227
228    #[test]
229    fn it_builds_json_object_response_format() {
230        let response_format: ResponseFormat = JsonObjectFormat::new().into();
231        let result = Text::default().response_format(response_format);
232
233        let expected = Text {
234            format: Some(ResponseFormat::JsonObject(JsonObjectFormat {
235                type_field: ResponseFormatType::JsonObject,
236            })),
237        };
238
239        assert_eq!(result, expected);
240    }
241
242    #[test]
243    fn test_json_values() {
244        // Test default text format
245        let text = Text::default();
246        let json_value = serde_json::to_value(&text).unwrap();
247        assert_eq!(
248            json_value,
249            serde_json::json!({
250                "format": {
251                    "type": "text"
252                }
253            })
254        );
255
256        // Test with JSON schema format
257        let schema = json!({
258            "type": "object",
259            "properties": {
260                "name": { "type": "string" },
261                "age": { "type": "number" },
262                "active": { "type": "boolean" }
263            },
264            "required": ["name", "age"]
265        });
266
267        let json_schema_format = JsonSchemaFormat::new("user_data", schema.clone())
268            .description("User information schema")
269            .strict();
270        let text_with_schema = Text::default().response_format(json_schema_format.into());
271        let json_value = serde_json::to_value(&text_with_schema).unwrap();
272        assert_eq!(
273            json_value,
274            serde_json::json!({
275                "format": {
276                    "type": "json_schema",
277                    "name": "user_data",
278                    "schema": {
279                        "type": "object",
280                        "properties": {
281                            "name": { "type": "string" },
282                            "age": { "type": "number" },
283                            "active": { "type": "boolean" }
284                        },
285                        "required": ["name", "age"]
286                    },
287                    "description": "User information schema",
288                    "strict": true
289                }
290            })
291        );
292
293        // Test with JSON object format
294        let text_with_json_object = Text::default().response_format(JsonObjectFormat::new().into());
295        let json_value = serde_json::to_value(&text_with_json_object).unwrap();
296        assert_eq!(
297            json_value,
298            serde_json::json!({
299                "format": {
300                    "type": "json_object"
301                }
302            })
303        );
304    }
305}