ai_providers/openai/request/input_models/
common.rs

1use crate::openai::errors::ConversionError;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Default, PartialEq, Copy, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    #[default]
10    User,
11    Assistant,
12    System,
13    Developer,
14}
15
16impl FromStr for Role {
17    type Err = ConversionError;
18
19    fn from_str(s: &str) -> Result<Self, Self::Err> {
20        match s {
21            "user" => Ok(Role::User),
22            "assistant" => Ok(Role::Assistant),
23            "system" => Ok(Role::System),
24            "developer" => Ok(Role::Developer),
25            _ => Err(ConversionError::FromStr(s.to_string())),
26        }
27    }
28}
29
30#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ContentType {
33    InputText,
34    InputImage,
35    InputFile,
36}
37
38impl FromStr for ContentType {
39    type Err = ConversionError;
40
41    fn from_str(s: &str) -> Result<Self, Self::Err> {
42        match s {
43            "input_text" => Ok(ContentType::InputText),
44            "input_image" => Ok(ContentType::InputImage),
45            "input_file" => Ok(ContentType::InputFile),
46            _ => Err(ConversionError::FromStr(s.to_string())),
47        }
48    }
49}
50
51#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
52pub struct TextContent {
53    #[serde(rename = "type")]
54    pub type_field: ContentType, // always InputText
55    pub text: String,
56}
57
58impl Default for TextContent {
59    fn default() -> Self {
60        Self {
61            type_field: ContentType::InputText,
62            text: String::new(),
63        }
64    }
65}
66
67impl TextContent {
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    pub fn text(mut self, text: impl Into<String>) -> Self {
73        self.text = text.into();
74        self
75    }
76}
77
78#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
79#[serde(rename_all = "lowercase")]
80pub enum ImageDetail {
81    High,
82    Low,
83    #[default]
84    Auto,
85}
86
87impl FromStr for ImageDetail {
88    type Err = ConversionError;
89
90    fn from_str(s: &str) -> Result<Self, Self::Err> {
91        match s {
92            "high" => Ok(ImageDetail::High),
93            "low" => Ok(ImageDetail::Low),
94            "auto" => Ok(ImageDetail::Auto),
95            _ => Err(ConversionError::FromStr(s.to_string())),
96        }
97    }
98}
99
100#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
101pub struct ImageContent {
102    #[serde(rename = "type")]
103    pub type_field: ContentType, // always InputImage
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub image_url: Option<String>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub file_id: Option<String>,
108    pub detail: ImageDetail,
109}
110
111impl Default for ImageContent {
112    fn default() -> Self {
113        Self {
114            type_field: ContentType::InputImage,
115            image_url: None,
116            file_id: None,
117            detail: ImageDetail::Auto,
118        }
119    }
120}
121
122impl ImageContent {
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    pub fn image_url(mut self, value: impl Into<String>) -> Self {
128        self.image_url = Some(value.into());
129        self
130    }
131
132    pub fn file_id(mut self, value: impl Into<String>) -> Self {
133        self.file_id = Some(value.into());
134        self
135    }
136
137    pub fn detail(mut self, value: impl AsRef<str>) -> Result<Self, ConversionError> {
138        self.detail = ImageDetail::from_str(value.as_ref())?;
139        Ok(self)
140    }
141}
142
143#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
144pub struct FileContent {
145    #[serde(rename = "type")]
146    pub type_field: ContentType, // always InputFile,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub file_id: Option<String>,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub file_data: Option<String>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub filename: Option<String>,
153}
154
155impl Default for FileContent {
156    fn default() -> Self {
157        Self {
158            type_field: ContentType::InputFile,
159            file_id: None,
160            file_data: None,
161            filename: None,
162        }
163    }
164}
165
166impl FileContent {
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    pub fn file_id(mut self, value: impl Into<String>) -> Self {
172        self.file_id = Some(value.into());
173        self
174    }
175
176    pub fn file_data(mut self, value: impl Into<String>) -> Self {
177        self.file_data = Some(value.into());
178        self
179    }
180
181    pub fn filename(mut self, value: impl Into<String>) -> Self {
182        self.filename = Some(value.into());
183        self
184    }
185}
186
187#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
188#[serde(untagged)]
189pub enum Content {
190    Text(TextContent),
191    Image(ImageContent),
192    File(FileContent),
193}
194
195impl From<TextContent> for Content {
196    fn from(text_content: TextContent) -> Self {
197        Self::Text(text_content)
198    }
199}
200
201impl From<ImageContent> for Content {
202    fn from(image_content: ImageContent) -> Self {
203        Self::Image(image_content)
204    }
205}
206
207impl From<FileContent> for Content {
208    fn from(file_content: FileContent) -> Self {
209        Self::File(file_content)
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use serde_json::json;
217
218    // let's compare the json output of the default values
219    #[test]
220    fn test_default_values() {
221        let text_content = TextContent::default();
222        let image_content = ImageContent::default();
223        let file_content = FileContent::default();
224
225        let text_content_json = serde_json::to_value(&text_content).unwrap();
226        let image_content_json = serde_json::to_value(&image_content).unwrap();
227        let file_content_json = serde_json::to_value(&file_content).unwrap();
228
229        assert_eq!(text_content_json, json!({"type": "input_text", "text": ""}));
230        assert_eq!(
231            image_content_json,
232            json!({"type": "input_image", "detail": "auto"})
233        );
234        assert_eq!(file_content_json, json!({"type": "input_file"}));
235    }
236
237    #[test]
238    fn test_text_content() {
239        let text = "Hello, world!";
240        let text_content = TextContent::new().text(text);
241        let text_content_json = serde_json::to_value(&text_content).unwrap();
242        assert_eq!(
243            text_content_json,
244            json!({"type": "input_text", "text": text})
245        );
246    }
247
248    #[test]
249    fn test_image_content() {
250        let image_url = "https://example.com/image.png";
251        let file_id = "1234567890";
252        let detail = "auto";
253
254        let image_content = ImageContent::new()
255            .image_url(image_url)
256            .file_id(file_id)
257            .detail(detail)
258            .unwrap();
259
260        let image_content_json = serde_json::to_value(&image_content).unwrap();
261        assert_eq!(
262            image_content_json,
263            json!({"type": "input_image", "image_url": image_url, "file_id": file_id, "detail": detail})
264        );
265    }
266
267    #[test]
268    fn test_file_content() {
269        let file_id = "1234567890";
270        let file_data = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=";
271        let filename = "image.png";
272
273        let file_content = FileContent::new()
274            .file_id(file_id)
275            .file_data(file_data)
276            .filename(filename);
277
278        let file_content_json = serde_json::to_value(&file_content).unwrap();
279        assert_eq!(
280            file_content_json,
281            json!({"type": "input_file", "file_id": file_id, "file_data": file_data, "filename": filename})
282        );
283    }
284
285    #[test]
286    fn test_role_from_str() {
287        assert_eq!(Role::from_str("user").unwrap(), Role::User);
288        assert_eq!(Role::from_str("assistant").unwrap(), Role::Assistant);
289        assert_eq!(Role::from_str("system").unwrap(), Role::System);
290        assert_eq!(Role::from_str("developer").unwrap(), Role::Developer);
291    }
292
293    #[test]
294    fn test_image_detail_from_str() {
295        assert_eq!(ImageDetail::from_str("high").unwrap(), ImageDetail::High);
296        assert_eq!(ImageDetail::from_str("low").unwrap(), ImageDetail::Low);
297        assert_eq!(ImageDetail::from_str("auto").unwrap(), ImageDetail::Auto);
298    }
299
300    #[test]
301    fn test_from_specific_content_to_content() {
302        let text = "Hello, world!";
303        let image_url = "https://example.com/image.png";
304        let file_id = "1234567890";
305
306        let text_content_builder = TextContent::new().text(text);
307        let text_content: Content = text_content_builder.into();
308
309        let image_content_builder = ImageContent::new().image_url(image_url);
310        let image_content: Content = image_content_builder.into();
311
312        let file_content_builder = FileContent::new().file_id(file_id);
313        let file_content: Content = file_content_builder.into();
314
315        assert_eq!(text_content, Content::Text(TextContent::new().text(text)));
316        assert_eq!(
317            image_content,
318            Content::Image(ImageContent::new().image_url(image_url))
319        );
320        assert_eq!(
321            file_content,
322            Content::File(FileContent::new().file_id(file_id))
323        );
324    }
325}