1use serde::ser::SerializeMap;
2
3use crate::error::Error;
4
5#[derive(serde::Serialize, serde::Deserialize, Debug)]
6pub struct Function {
7 pub name: String,
8 pub arguments: serde_json::Value,
9}
10
11#[derive(serde::Serialize, serde::Deserialize, Debug)]
12pub struct ToolCall {
13 pub id: String,
14 #[serde(rename = "type")]
15 pub ty: String,
16 pub function: Function,
17}
18
19#[derive(Debug)]
20pub struct ToolMessage {
21 pub content: String,
22 pub tool_call_id: String,
23}
24
25#[derive(Debug)]
26pub enum ImageMessage {
27 Text(String),
28 ImageUrl(String),
29}
30
31#[derive(serde::Serialize)]
33struct ImageUrlWrapper<'a> {
34 url: &'a str,
35}
36
37impl serde::Serialize for ImageMessage {
38 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39 where
40 S: serde::Serializer,
41 {
42 let mut map = serializer.serialize_map(Some(2))?;
43 match self {
44 ImageMessage::Text(text) => {
45 map.serialize_entry("type", "text")?;
46 map.serialize_entry("text", text)?;
47 },
48 ImageMessage::ImageUrl(url) => {
49 map.serialize_entry("type", "image_url")?;
50 map.serialize_entry("image_url", &ImageUrlWrapper { url: url })?;
51 },
52 }
53 map.end()
54 }
55
56}
57
58#[derive(Debug)]
59pub enum ChatMessage {
60 System(String),
61 User(String),
62 Image(Vec<ImageMessage>),
63 Assistant(String),
64 ToolCall(Vec<ToolCall>),
65 Tool(ToolMessage),
66}
67
68impl serde::Serialize for ChatMessage {
69 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: serde::Serializer,
72 {
73 let mut map = serializer.serialize_map(Some(2))?;
74 match self {
75 ChatMessage::System(content) => {
76 map.serialize_entry("role", "system")?;
77 map.serialize_entry("content", content)?;
78 },
79 ChatMessage::User(content) => {
80 map.serialize_entry("role", "user")?;
81 map.serialize_entry("content", content)?;
82 },
83 ChatMessage::Image(images) => {
84 map.serialize_entry("role", "user")?;
85 map.serialize_entry("content", images)?;
86 },
87 ChatMessage::Assistant(content) => {
88 map.serialize_entry("role", "assistant")?;
89 map.serialize_entry("content", content)?;
90 },
91 ChatMessage::ToolCall(tool_calls) => {
92 map.serialize_entry("role", "assistant")?;
93 map.serialize_entry("tool_calls", tool_calls)?;
94 },
95 ChatMessage::Tool(tool_message) => {
96 map.serialize_entry("role", "tool")?;
97 map.serialize_entry("tool_call_id", &tool_message.tool_call_id)?;
98 map.serialize_entry("content", &tool_message.content)?;
99 },
100 }
101
102 map.end()
103 }
104}
105
106impl <'de> serde::Deserialize<'de> for ChatMessage {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 struct MessageVisitor;
112
113 impl<'de> serde::de::Visitor<'de> for MessageVisitor {
114 type Value = ChatMessage;
115
116 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
117 formatter.write_str("chat message")
118 }
119
120 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
121 where
122 A: serde::de::MapAccess<'de>,
123 {
124 let mut role = None;
125 let mut content: Option<serde_json::Value> = None;
126 let mut tool_calls: Option<Vec<ToolCall>> = None;
127 let mut tool_call_id: Option<String> = None;
128
129 while let Some(key) = map.next_key()? {
130 match key {
131 "role" => role = map.next_value()?,
132 "content" => content = map.next_value()?,
133 "tool_calls" => tool_calls = map.next_value()?,
134 "tool_call_id" => tool_call_id = map.next_value()?,
135 _ => return Err(serde::de::Error::unknown_field(key, &["role", "content", "tool_calls", "tool_call_id"])),
136 }
137 }
138
139 let role: String = role.ok_or_else(|| serde::de::Error::missing_field("role"))?;
140
141 match (role.as_str(), content, tool_calls, tool_call_id) {
142 ("system", Some(serde_json::Value::String(content)), None, None) => Ok(ChatMessage::System(content)),
143 ("user", Some(serde_json::Value::String(content)), None, None) => Ok(ChatMessage::User(content)),
144 ("user", Some(serde_json::Value::Array(content)), None, None) => {
145 let mut images = Vec::new();
146 for image in content {
147 let image = image.as_object().ok_or(serde::de::Error::custom("invalid image format"))?;
148 match image.get("type") {
149 Some(serde_json::Value::String(ty)) => {
150 match ty.as_str() {
151 "text" => {
152 let content = image.get("text").ok_or(serde::de::Error::custom("missing text field"))?;
153 let content = content.as_str().ok_or(serde::de::Error::custom("invalid text field"))?;
154 images.push(ImageMessage::Text(content.to_string()));
155 },
156 "image_url" => {
157 let image_url = image.get("image_url").ok_or(serde::de::Error::custom("missing image_url field"))?;
158 let image_url = image_url.as_object().ok_or(serde::de::Error::custom("invalid image_url field"))?;
159 let url = image_url.get("url").ok_or(serde::de::Error::custom("missing url field"))?;
160 let url = url.as_str().ok_or(serde::de::Error::custom("invalid url field"))?;
161 images.push(ImageMessage::ImageUrl(url.to_string()));
162 }
163 _ => return Err(serde::de::Error::custom("invalid image type")),
164 }
165 },
166 _ => return Err(serde::de::Error::custom("invalid image type")),
167 }
168 }
169 Ok(ChatMessage::Image(images))
170 },
171 ("assistant", Some(serde_json::Value::String(content)), None, None) => Ok(ChatMessage::Assistant(content)),
172 ("assistant", None, Some(tool_calls), None) => Ok(ChatMessage::ToolCall(tool_calls)),
173 ("tool", Some(serde_json::Value::String(content)), None, Some(tool_call_id)) => Ok(ChatMessage::Tool(ToolMessage { content, tool_call_id })),
174 _ => Err(serde::de::Error::custom("invalid message")),
175 }
176 }
177 }
178
179 deserializer.deserialize_map(MessageVisitor)
180 }
181}
182
183#[derive(Debug)]
184pub enum AssistantMessageDelta {
185 Content(String),
186 ToolCall(Vec<ToolCall>),
187}
188
189impl <'de> serde::Deserialize<'de> for AssistantMessageDelta {
190 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191 where
192 D: serde::Deserializer<'de>
193 {
194 struct DeltaVisitor;
195
196 impl<'de> serde::de::Visitor<'de> for DeltaVisitor {
197 type Value = AssistantMessageDelta;
198
199 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
200 formatter.write_str("assistant message delta")
201 }
202
203 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
204 where
205 A: serde::de::MapAccess<'de>,
206 {
207 let mut role = None;
208 let mut content = None;
209 let mut tool_calls: Option<Vec<ToolCall>> = None;
210
211 while let Some(key) = map.next_key()? {
212 match key {
213 "role" => role = map.next_value()?,
214 "content" => content = map.next_value()?,
215 "tool_calls" => tool_calls = map.next_value()?,
216 _ => return Err(serde::de::Error::unknown_field(key, &["role", "content", "tool_calls"])),
217 }
218 }
219
220 let role: String = role.ok_or_else(|| serde::de::Error::missing_field("role"))?;
221
222 match (role.as_str(), content, tool_calls) {
223 ("assistant", Some(content), None) => Ok(AssistantMessageDelta::Content(content)),
224 ("assistant", None, Some(tool_calls)) => Ok(AssistantMessageDelta::ToolCall(tool_calls)),
225 _ => Err(serde::de::Error::custom("invalid message")),
226 }
227 }
228 }
229
230 deserializer.deserialize_map(DeltaVisitor)
231 }
232}
233
234impl TryFrom<Vec<AssistantMessageDelta>> for ChatMessage {
235 type Error = Error;
236
237 fn try_from(value: Vec<AssistantMessageDelta>) -> Result<Self, Self::Error> {
238 let mut message = None;
239 for delta in value {
240 match delta {
241 AssistantMessageDelta::Content(income) => {
242 match message {
243 Some(ChatMessage::Assistant(ref mut content)) => {
244 content.push_str(&income);
245 },
246 None => {
247 message = Some(ChatMessage::Assistant(income));
248 },
249 _ => return Err(Error::Conflict),
250 }
251 },
252 AssistantMessageDelta::ToolCall(mut income) => {
253 match message {
254 Some(ChatMessage::ToolCall(ref mut tool_calls)) => {
255 tool_calls.append(&mut income);
256 },
257 None => {
258 message = Some(ChatMessage::ToolCall(income));
259 },
260 _ => return Err(Error::Conflict),
261 }
262 }
263 }
264 }
265
266 message.ok_or(Error::EmptyDeltaList)
267 }
268}