1use std::fmt;
2
3use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::Value;
5
6use super::function::FunctionCall;
7
8#[derive(Clone)]
14pub enum Message {
15 User {
18 name: Option<String>,
19 content: Vec<MessageContext>
20 },
21 Tool {
23 tool_call_id: String,
24 content: Vec<MessageContext>
25 },
26 Assistant {
29 name: Option<String>,
30 content: Vec<MessageContext>,
31 tool_calls: Option<Vec<FunctionCall>>,
32 },
33 System {
36 name: Option<String>,
37 content: String
38 },
39 Developer {
43 name: Option<String>,
44 content: String
45 },
46}
47
48impl fmt::Debug for Message {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 Message::User { name, content } => {
52 writeln!(f, "User: {}", name.as_deref().unwrap_or("Anonymous"))?;
53 for ctx in content {
54 match ctx {
55 MessageContext::Text(text) => writeln!(f, " {}", text)?,
56 MessageContext::Image(image) => writeln!(f, " [Image URL: {}]", image.url)?,
57 }
58 }
59 Ok(())
60 }
61 Message::Tool { tool_call_id, content } => {
62 writeln!(f, "Tool: {} - Tool Call", tool_call_id)?;
63 for ctx in content {
64 match ctx {
65 MessageContext::Text(text) => writeln!(f, " {}", text)?,
66 MessageContext::Image(image) => writeln!(f, " [Image URL: {}]", image.url)?,
67 }
68 }
69 Ok(())
70 }
71 Message::Assistant { name, content, tool_calls } => {
72 writeln!(f, "Assistant: {}", name.as_deref().unwrap_or("Assistant"))?;
73 for ctx in content {
74 match ctx {
75 MessageContext::Text(text) => writeln!(f, " {}", text)?,
76 MessageContext::Image(image) => writeln!(f, " [Image URL: {}]", image.url)?,
77 }
78 }
79 if let Some(calls) = tool_calls {
80 for call in calls {
81 writeln!(f, " Tool Call: {:?}", call)?;
82 }
83 }
84 Ok(())
85 }
86 Message::System { name, content } => {
87 writeln!(f, "System: {}", name.as_deref().unwrap_or("System"))?;
88 writeln!(f, " {}", content)
89 }
90 Message::Developer { name, content } => {
91 writeln!(f, "Developer: {}", name.as_deref().unwrap_or("Developer"))?;
92 writeln!(f, " {}", content)
93 }
94 }
95 }
96}
97
98impl Serialize for Message {
100 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101 where
102 S: Serializer,
103 {
104 let state = match self {
105 Message::User { name, content } => {
106 let mut s = serializer.serialize_struct("Message", 3)?;
107 s.serialize_field("role", "user")?;
108 if let Some(name) = name {
109 s.serialize_field("name", name)?;
110 }
111 serialize_content_field(&mut s, content)?;
112 s
113 }
114 Message::Tool { tool_call_id, content } => {
115 let mut s = serializer.serialize_struct("Message", 2)?;
116 s.serialize_field("role", "tool")?;
117 s.serialize_field("tool_call_id", tool_call_id)?;
118
119 serialize_content_field(&mut s, content)?;
120 s
121 }
122 Message::Assistant { name, content, tool_calls } => {
123 let mut s = serializer.serialize_struct("Message", 3)?;
124 s.serialize_field("role", "assistant")?;
125 if let Some(name) = name {
126 s.serialize_field("name", name)?;
127 }
128 serialize_content_field(&mut s, content)?;
129 if let Some(tool_calls) = tool_calls {
130 s.serialize_field("tool_calls", tool_calls)?;
131 }
132 s
133 }
134 Message::System { name, content } => {
135 let mut s = serializer.serialize_struct("Message", 3)?;
136 s.serialize_field("role", "system")?;
137 if let Some(name) = name {
138 s.serialize_field("name", name)?;
139 }
140 s.serialize_field("content", content)?;
141 s
142 }
143 Message::Developer { name, content } => {
144 let mut s = serializer.serialize_struct("Message", 3)?;
145 s.serialize_field("role", "developer")?;
146 if let Some(name) = name {
147 s.serialize_field("name", name)?;
148 }
149 s.serialize_field("content", content)?;
150 s
151 }
152 };
153 state.end()
154 }
155}
156
157fn serialize_content_field<S>(
162 state: &mut S,
163 content: &Vec<MessageContext>,
164) -> Result<(), S::Error>
165where
166 S: SerializeStruct,
167{
168 if content.len() == 1 {
169 if let MessageContext::Text(text) = &content[0] {
170 state.serialize_field("content", text)?;
171 } else {
172 state.serialize_field("content", content)?;
173 }
174 } else {
175 state.serialize_field("content", content)?;
176 }
177 Ok(())
178}
179
180impl<'de> Deserialize<'de> for Message {
182 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
183 where
184 D: Deserializer<'de>,
185 {
186 let value: Value = Deserialize::deserialize(deserializer)?;
187
188 let role = value.get("role").and_then(Value::as_str).unwrap_or("");
189
190 match role {
191 "user" => {
192 let name = value.get("name").and_then(Value::as_str).map(String::from);
193 let content = serde_json::from_value(
194 value.get("content").cloned().unwrap_or_default(),
195 )
196 .map_err(serde::de::Error::custom)?;
197 Ok(Message::User { name, content })
198 }
199 "tool" => {
200 let tool_call_id = value
201 .get("tool_call_id")
202 .and_then(Value::as_str)
203 .ok_or_else(|| serde::de::Error::missing_field("tool_call_id"))?
204 .to_string();
205 let content = serde_json::from_value(
206 value.get("content").cloned().unwrap_or_default(),
207 )
208 .map_err(serde::de::Error::custom)?;
209 Ok(Message::Tool { tool_call_id, content })
210 }
211 "assistant" => {
212 let name = value.get("name").and_then(Value::as_str).map(String::from);
213 let content = serde_json::from_value(
214 value.get("content").cloned().unwrap_or_default(),
215 )
216 .map_err(serde::de::Error::custom)?;
217 let tool_calls = value.get("tool_calls").map_or(Ok(None), |v| {
218 serde_json::from_value(v.clone()).map(Some)
219 }).map_err(serde::de::Error::custom)?;
220 Ok(Message::Assistant { name, content, tool_calls })
221 }
222 "system" => {
223 let name = value.get("name").and_then(Value::as_str).map(String::from);
224 let content = value
225 .get("content")
226 .and_then(Value::as_str)
227 .ok_or_else(|| serde::de::Error::missing_field("content"))?
228 .to_string();
229 Ok(Message::System { name, content })
230 }
231 "developer" => {
232 let name = value.get("name").and_then(Value::as_str).map(String::from);
233 let content = value
234 .get("content")
235 .and_then(Value::as_str)
236 .ok_or_else(|| serde::de::Error::missing_field("content"))?
237 .to_string();
238 Ok(Message::Developer { name, content })
239 }
240 _ => Err(serde::de::Error::custom("Invalid message type")),
241 }
242 }
243}
244
245#[derive(Debug, Deserialize, Clone)]
249pub enum MessageContext {
250 Text(String),
252 Image(MessageImage),
254}
255
256impl Serialize for MessageContext {
258 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
259 where
260 S: Serializer,
261 {
262 match self {
263 MessageContext::Text(text) => {
264 let mut state = serializer.serialize_struct("MessageContext", 2)?;
265 state.serialize_field("type", "text")?;
266 state.serialize_field("text", text)?;
267 state.end()
268 }
269 MessageContext::Image(image) => {
270 let mut state = serializer.serialize_struct("MessageContext", 2)?;
271 state.serialize_field("type", "image_url")?;
272 state.serialize_field("image_url", image)?;
273 state.end()
274 }
275 }
276 }
277}
278
279#[derive(Debug, Serialize, Deserialize, Clone)]
283pub struct MessageImage {
284 pub url: String,
290
291 #[serde(skip_serializing_if = "Option::is_none")]
298 pub detail: Option<String>,
299}
300
301#[derive(Debug, Deserialize, Clone)]
305pub struct Choice {
306 pub index: usize,
308
309 pub message: ResponseMessage,
311
312 pub finish_reason: String,
314}
315
316#[derive(Debug, Deserialize, Clone)]
320pub struct ResponseMessage {
321 pub role: String,
323
324 pub content: Option<String>,
326
327 pub tool_calls: Option<Vec<FunctionCall>>,
329
330 pub refusal: Option<String>,
332
333 #[serde(default)]
335 pub annotations: Option<serde_json::Value>
336}