ai_providers/openai/request/input_models/
input_message.rs

1use crate::openai::errors::ConversionError;
2use crate::openai::request::input_models::common::{Content, Role};
3use crate::openai::request::input_models::input_reference::InputReference;
4use crate::openai::request::input_models::item::Item;
5use serde::{Deserialize, Serialize};
6use std::str::FromStr;
7
8#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9pub struct TextInput {
10    pub role: Role,
11    pub content: String,
12    #[serde(rename = "type")]
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub type_field: Option<String>,
15}
16
17impl TextInput {
18    pub fn new(content: impl Into<String>) -> Self {
19        Self {
20            role: Role::default(),
21            content: content.into(),
22            type_field: None,
23        }
24    }
25
26    pub fn role(mut self, role: impl AsRef<str>) -> Result<Self, ConversionError> {
27        self.role = Role::from_str(role.as_ref())?;
28        Ok(self)
29    }
30
31    pub fn insert_type(mut self) -> Self {
32        self.type_field = Some("message".to_string());
33        self
34    }
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
38pub struct InputItemContentList {
39    pub role: Role,
40    pub content: Vec<Content>,
41    #[serde(rename = "type")]
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub type_field: Option<String>,
44}
45
46impl InputItemContentList {
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    pub fn role(mut self, role: impl AsRef<str>) -> Result<Self, ConversionError> {
52        self.role = Role::from_str(role.as_ref())?;
53        Ok(self)
54    }
55
56    pub fn insert_type(mut self) -> Self {
57        self.type_field = Some("message".to_string());
58        self
59    }
60}
61
62impl From<Item> for InputItemContentList {
63    fn from(_item: Item) -> Self {
64        Self {
65            role: Role::default(),
66            content: Vec::new(),
67            type_field: Some("message".to_string()),
68        }
69    }
70}
71
72impl From<InputReference> for InputItemContentList {
73    fn from(_reference: InputReference) -> Self {
74        Self {
75            role: Role::default(),
76            content: Vec::new(),
77            type_field: Some("message".to_string()),
78        }
79    }
80}
81
82#[derive(Debug, PartialEq, Serialize, Deserialize)]
83#[serde(untagged)]
84pub enum InputMessage {
85    TextInput(TextInput),
86    InputItemContentList(InputItemContentList),
87}
88
89impl From<TextInput> for InputMessage {
90    fn from(text_input: TextInput) -> Self {
91        InputMessage::TextInput(text_input)
92    }
93}
94
95impl From<InputItemContentList> for InputMessage {
96    fn from(content_list: InputItemContentList) -> Self {
97        InputMessage::InputItemContentList(content_list)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use crate::openai::request::input_models::common::TextContent;
104
105    use super::*;
106
107    #[test]
108    fn test_json_values() {
109        let text_input = TextInput::new("Hello, world!");
110        let input_message: InputMessage = text_input.clone().into();
111        assert_eq!(input_message, InputMessage::TextInput(text_input));
112
113        let json_value = serde_json::to_value(&input_message).unwrap();
114        assert_eq!(
115            json_value,
116            serde_json::json!({
117                "role": "user",
118                "content": "Hello, world!"
119            })
120        );
121    }
122
123    #[test]
124    fn test_json_values_input_item_content_list() {
125        let mut input_item_content_list = InputItemContentList::new()
126            .insert_type()
127            .role("developer")
128            .unwrap();
129
130        input_item_content_list
131            .content
132            .push(Content::Text(TextContent::new().text("Hello, world!")));
133
134        let input_message: InputMessage = input_item_content_list.clone().into();
135        assert_eq!(
136            input_message,
137            InputMessage::InputItemContentList(input_item_content_list)
138        );
139
140        let json_value = serde_json::to_value(&input_message).unwrap();
141        assert_eq!(
142            json_value,
143            serde_json::json!({
144                "role": "developer",
145                "content": [
146                    {
147                        "type": "input_text",
148                        "text": "Hello, world!"
149                    }
150                ],
151                "type": "message"
152            })
153        );
154    }
155}