ai_providers/openai/request/input_models/
input_message.rs1use crate::openai::errors::ConversionError;
2use crate::openai::request::input_models::common::{Content, Role};
3use crate::openai::request::input_models::input_reference::InputReference;
4use crate::openai::request::input_models::item::Item;
5use serde::{Deserialize, Serialize};
6use std::str::FromStr;
7
8#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9pub struct TextInput {
10 pub role: Role,
11 pub content: String,
12 #[serde(rename = "type")]
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub type_field: Option<String>,
15}
16
17impl TextInput {
18 pub fn new(content: impl Into<String>) -> Self {
19 Self {
20 role: Role::default(),
21 content: content.into(),
22 type_field: None,
23 }
24 }
25
26 pub fn role(mut self, role: impl AsRef<str>) -> Result<Self, ConversionError> {
27 self.role = Role::from_str(role.as_ref())?;
28 Ok(self)
29 }
30
31 pub fn insert_type(mut self) -> Self {
32 self.type_field = Some("message".to_string());
33 self
34 }
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
38pub struct InputItemContentList {
39 pub role: Role,
40 pub content: Vec<Content>,
41 #[serde(rename = "type")]
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub type_field: Option<String>,
44}
45
46impl InputItemContentList {
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn role(mut self, role: impl AsRef<str>) -> Result<Self, ConversionError> {
52 self.role = Role::from_str(role.as_ref())?;
53 Ok(self)
54 }
55
56 pub fn insert_type(mut self) -> Self {
57 self.type_field = Some("message".to_string());
58 self
59 }
60}
61
62impl From<Item> for InputItemContentList {
63 fn from(_item: Item) -> Self {
64 Self {
65 role: Role::default(),
66 content: Vec::new(),
67 type_field: Some("message".to_string()),
68 }
69 }
70}
71
72impl From<InputReference> for InputItemContentList {
73 fn from(_reference: InputReference) -> Self {
74 Self {
75 role: Role::default(),
76 content: Vec::new(),
77 type_field: Some("message".to_string()),
78 }
79 }
80}
81
82#[derive(Debug, PartialEq, Serialize, Deserialize)]
83#[serde(untagged)]
84pub enum InputMessage {
85 TextInput(TextInput),
86 InputItemContentList(InputItemContentList),
87}
88
89impl From<TextInput> for InputMessage {
90 fn from(text_input: TextInput) -> Self {
91 InputMessage::TextInput(text_input)
92 }
93}
94
95impl From<InputItemContentList> for InputMessage {
96 fn from(content_list: InputItemContentList) -> Self {
97 InputMessage::InputItemContentList(content_list)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use crate::openai::request::input_models::common::TextContent;
104
105 use super::*;
106
107 #[test]
108 fn test_json_values() {
109 let text_input = TextInput::new("Hello, world!");
110 let input_message: InputMessage = text_input.clone().into();
111 assert_eq!(input_message, InputMessage::TextInput(text_input));
112
113 let json_value = serde_json::to_value(&input_message).unwrap();
114 assert_eq!(
115 json_value,
116 serde_json::json!({
117 "role": "user",
118 "content": "Hello, world!"
119 })
120 );
121 }
122
123 #[test]
124 fn test_json_values_input_item_content_list() {
125 let mut input_item_content_list = InputItemContentList::new()
126 .insert_type()
127 .role("developer")
128 .unwrap();
129
130 input_item_content_list
131 .content
132 .push(Content::Text(TextContent::new().text("Hello, world!")));
133
134 let input_message: InputMessage = input_item_content_list.clone().into();
135 assert_eq!(
136 input_message,
137 InputMessage::InputItemContentList(input_item_content_list)
138 );
139
140 let json_value = serde_json::to_value(&input_message).unwrap();
141 assert_eq!(
142 json_value,
143 serde_json::json!({
144 "role": "developer",
145 "content": [
146 {
147 "type": "input_text",
148 "text": "Hello, world!"
149 }
150 ],
151 "type": "message"
152 })
153 );
154 }
155}