ai_providers/openai/common/
text.rs1use serde::{Deserialize, Serialize};
2
3use crate::openai::errors::ConversionError;
4
5#[derive(Debug, PartialEq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7enum ResponseFormatType {
8 Text,
9 JsonSchema,
10 JsonObject,
11}
12
13#[derive(Debug, PartialEq, Serialize, Deserialize)]
14pub struct TextFormat {
15 #[serde(rename = "type")]
16 type_field: ResponseFormatType, }
18
19impl TextFormat {
20 pub fn new() -> Self {
21 Self {
22 type_field: ResponseFormatType::Text,
23 }
24 }
25}
26
27impl Default for TextFormat {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33#[derive(Debug, PartialEq, Serialize, Deserialize)]
34pub struct JsonSchemaFormat {
35 #[serde(rename = "type")]
36 type_field: ResponseFormatType, name: String,
38 schema: serde_json::Value,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 description: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 strict: Option<bool>,
43}
44
45impl JsonSchemaFormat {
46 pub fn new(name: impl Into<String>, schema: serde_json::Value) -> Self {
47 Self {
48 type_field: ResponseFormatType::JsonSchema,
49 name: name.into(),
50 schema,
51 description: None,
52 strict: Some(false),
53 }
54 }
55
56 pub fn description(mut self, value: impl Into<String>) -> Self {
57 self.description = Some(value.into());
58 self
59 }
60
61 pub fn strict(mut self) -> Self {
62 self.strict = Some(true);
63 self
64 }
65}
66
67#[derive(Debug, PartialEq, Serialize, Deserialize)]
68pub struct JsonObjectFormat {
69 #[serde(rename = "type")]
70 type_field: ResponseFormatType, }
72
73impl JsonObjectFormat {
74 pub fn new() -> Self {
75 Self {
76 type_field: ResponseFormatType::JsonObject,
77 }
78 }
79}
80
81impl Default for JsonObjectFormat {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[derive(Debug, PartialEq, Serialize, Deserialize)]
88#[serde(untagged)]
89pub enum ResponseFormat {
90 Text(TextFormat),
91 JsonSchema(JsonSchemaFormat),
92 JsonObject(JsonObjectFormat),
93}
94
95impl std::fmt::Display for ResponseFormat {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 ResponseFormat::Text(_) => write!(f, "text"),
99 ResponseFormat::JsonSchema(_) => write!(f, "json_schema"),
100 ResponseFormat::JsonObject(_) => write!(f, "json_object"),
101 }
102 }
103}
104
105impl From<TextFormat> for ResponseFormat {
106 fn from(text_format: TextFormat) -> Self {
107 Self::Text(text_format)
108 }
109}
110
111impl From<JsonSchemaFormat> for ResponseFormat {
112 fn from(format: JsonSchemaFormat) -> Self {
113 Self::JsonSchema(format)
114 }
115}
116
117impl From<JsonObjectFormat> for ResponseFormat {
118 fn from(format: JsonObjectFormat) -> Self {
119 Self::JsonObject(format)
120 }
121}
122
123impl TryFrom<ResponseFormat> for TextFormat {
124 type Error = ConversionError;
125
126 fn try_from(format: ResponseFormat) -> Result<Self, Self::Error> {
127 match format {
128 ResponseFormat::Text(inner) => Ok(inner),
129 _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())),
130 }
131 }
132}
133
134impl TryFrom<ResponseFormat> for JsonSchemaFormat {
135 type Error = ConversionError;
136
137 fn try_from(format: ResponseFormat) -> Result<Self, Self::Error> {
138 match format {
139 ResponseFormat::JsonSchema(inner) => Ok(inner),
140 _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())),
141 }
142 }
143}
144
145impl TryFrom<ResponseFormat> for JsonObjectFormat {
146 type Error = ConversionError;
147
148 fn try_from(format: ResponseFormat) -> Result<Self, Self::Error> {
149 match format {
150 ResponseFormat::JsonObject(inner) => Ok(inner),
151 _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())),
152 }
153 }
154}
155
156#[derive(Debug, PartialEq, Serialize, Deserialize)]
157pub struct Text {
158 #[serde(skip_serializing_if = "Option::is_none")]
159 format: Option<ResponseFormat>,
160}
161
162impl Default for Text {
163 fn default() -> Self {
164 Self {
165 format: Some(ResponseFormat::Text(TextFormat::default())),
166 }
167 }
168}
169
170impl Text {
171 pub fn response_format(mut self, value: ResponseFormat) -> Self {
172 self.format = Some(value);
173 self
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use serde_json::json;
181
182 #[test]
183 fn it_builds_text_response_format() {
184 let result = Text::default().response_format(TextFormat::new().into());
185
186 assert_eq!(
187 result,
188 Text {
189 format: Some(ResponseFormat::Text(TextFormat {
190 type_field: ResponseFormatType::Text
191 }))
192 }
193 );
194 }
195
196 #[test]
197 fn it_builds_json_schema_response_format() {
198 let schema = json!({
199 "name": "Alice",
200 "age": 30,
201 "active": true,
202 "friends": ["Bob", "Charlie"],
203 "address": {
204 "street": "123 Main St",
205 "city": "Somewhere"
206 }
207 });
208
209 let response_format: ResponseFormat = JsonSchemaFormat::new("test", schema.clone())
210 .description("this is a description")
211 .into();
212
213 let result = Text::default().response_format(response_format);
214
215 let expected = Text {
216 format: Some(ResponseFormat::JsonSchema(JsonSchemaFormat {
217 type_field: ResponseFormatType::JsonSchema,
218 name: "test".to_string(),
219 schema: schema,
220 description: Some("this is a description".to_string()),
221 strict: Some(false),
222 })),
223 };
224
225 assert_eq!(result, expected);
226 }
227
228 #[test]
229 fn it_builds_json_object_response_format() {
230 let response_format: ResponseFormat = JsonObjectFormat::new().into();
231 let result = Text::default().response_format(response_format);
232
233 let expected = Text {
234 format: Some(ResponseFormat::JsonObject(JsonObjectFormat {
235 type_field: ResponseFormatType::JsonObject,
236 })),
237 };
238
239 assert_eq!(result, expected);
240 }
241
242 #[test]
243 fn test_json_values() {
244 let text = Text::default();
246 let json_value = serde_json::to_value(&text).unwrap();
247 assert_eq!(
248 json_value,
249 serde_json::json!({
250 "format": {
251 "type": "text"
252 }
253 })
254 );
255
256 let schema = json!({
258 "type": "object",
259 "properties": {
260 "name": { "type": "string" },
261 "age": { "type": "number" },
262 "active": { "type": "boolean" }
263 },
264 "required": ["name", "age"]
265 });
266
267 let json_schema_format = JsonSchemaFormat::new("user_data", schema.clone())
268 .description("User information schema")
269 .strict();
270 let text_with_schema = Text::default().response_format(json_schema_format.into());
271 let json_value = serde_json::to_value(&text_with_schema).unwrap();
272 assert_eq!(
273 json_value,
274 serde_json::json!({
275 "format": {
276 "type": "json_schema",
277 "name": "user_data",
278 "schema": {
279 "type": "object",
280 "properties": {
281 "name": { "type": "string" },
282 "age": { "type": "number" },
283 "active": { "type": "boolean" }
284 },
285 "required": ["name", "age"]
286 },
287 "description": "User information schema",
288 "strict": true
289 }
290 })
291 );
292
293 let text_with_json_object = Text::default().response_format(JsonObjectFormat::new().into());
295 let json_value = serde_json::to_value(&text_with_json_object).unwrap();
296 assert_eq!(
297 json_value,
298 serde_json::json!({
299 "format": {
300 "type": "json_object"
301 }
302 })
303 );
304 }
305}