1use std::{ops::Not, sync::Arc, vec};
2
3use im::Vector;
4use serde::{Deserialize, Serialize};
5
6use crate::error::AgentError;
7use crate::value::AgentValue;
8
9#[cfg(feature = "image")]
10use photon_rs::PhotonImage;
11
12#[derive(Debug, Default, Clone, Serialize, Deserialize)]
13pub struct Message {
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub id: Option<String>,
16
17 pub role: String,
18
19 pub content: String,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub tokens: Option<usize>,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub thinking: Option<String>,
26
27 #[serde(skip_serializing_if = "<&bool>::not")]
28 pub streaming: bool,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub tool_calls: Option<Vector<ToolCall>>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub tool_name: Option<String>,
35
36 #[cfg(feature = "image")]
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub image: Option<Arc<PhotonImage>>,
39}
40
41impl Message {
42 pub fn new(role: String, content: String) -> Self {
43 Self {
44 id: None,
45 role,
46 content,
47 tokens: None,
48 streaming: false,
49 thinking: None,
50 tool_calls: None,
51 tool_name: None,
52
53 #[cfg(feature = "image")]
54 image: None,
55 }
56 }
57
58 pub fn assistant(content: String) -> Self {
59 Message::new("assistant".to_string(), content)
60 }
61
62 pub fn system(content: String) -> Self {
63 Message::new("system".to_string(), content)
64 }
65
66 pub fn user(content: String) -> Self {
67 Message::new("user".to_string(), content)
68 }
69
70 pub fn tool(tool_name: String, content: String) -> Self {
71 let mut message = Message::new("tool".to_string(), content);
72 message.tool_name = Some(tool_name);
73 message
74 }
75
76 #[cfg(feature = "image")]
77 pub fn with_image(mut self, image: Arc<PhotonImage>) -> Self {
78 self.image = Some(image);
79 self
80 }
81}
82
83impl PartialEq for Message {
84 fn eq(&self, other: &Self) -> bool {
85 self.id == other.id && self.role == other.role && self.content == other.content
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ToolCall {
91 pub function: ToolCallFunction,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ToolCallFunction {
96 pub name: String,
97 pub parameters: serde_json::Value,
98
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub id: Option<String>,
101}
102
103impl TryFrom<AgentValue> for Message {
104 type Error = AgentError;
105
106 fn try_from(value: AgentValue) -> Result<Self, Self::Error> {
107 match value {
108 AgentValue::Message(msg) => Ok((*msg).clone()),
109 AgentValue::String(s) => Ok(Message::user(s.to_string())),
110
111 #[cfg(feature = "image")]
112 AgentValue::Image(img) => {
113 let mut message = Message::user("".to_string());
114 message.image = Some(img.clone());
115 Ok(message)
116 }
117 AgentValue::Object(obj) => {
118 let role = obj
119 .get("role")
120 .and_then(|r| r.as_str())
121 .unwrap_or("user")
122 .to_string();
123 let content = obj
124 .get("content")
125 .and_then(|c| c.as_str())
126 .ok_or_else(|| {
127 AgentError::InvalidValue(
128 "Message object missing 'content' field".to_string(),
129 )
130 })?
131 .to_string();
132 let mut message = Message::new(role, content);
133
134 let id = obj
135 .get("id")
136 .and_then(|i| i.as_str())
137 .map(|s| s.to_string());
138 message.id = id;
139
140 message.thinking = obj
141 .get("thinking")
142 .and_then(|t| t.as_str())
143 .map(|s| s.to_string());
144
145 message.streaming = obj
146 .get("streaming")
147 .and_then(|st| st.as_bool())
148 .unwrap_or_default();
149
150 if let Some(tool_name) = obj.get("tool_name") {
151 message.tool_name = Some(
152 tool_name
153 .as_str()
154 .ok_or_else(|| {
155 AgentError::InvalidValue(
156 "'tool_name' field must be a string".to_string(),
157 )
158 })?
159 .to_string(),
160 );
161 }
162
163 if let Some(tool_calls) = obj.get("tool_calls") {
164 let mut calls = vec![];
165 for call_value in tool_calls.as_array().ok_or_else(|| {
166 AgentError::InvalidValue("'tool_calls' field must be an array".to_string())
167 })? {
168 let id = call_value
169 .get("id")
170 .and_then(|i| i.as_str())
171 .map(|s| s.to_string());
172 let function = call_value.get("function").ok_or_else(|| {
173 AgentError::InvalidValue(
174 "Tool call missing 'function' field".to_string(),
175 )
176 })?;
177 let tool_name = function.get_str("name").ok_or_else(|| {
178 AgentError::InvalidValue(
179 "Tool call function missing 'name' field".to_string(),
180 )
181 })?;
182 let parameters = function.get("parameters").ok_or_else(|| {
183 AgentError::InvalidValue(
184 "Tool call function missing 'parameters' field".to_string(),
185 )
186 })?;
187 let call = ToolCall {
188 function: ToolCallFunction {
189 id,
190 name: tool_name.to_string(),
191 parameters: parameters.to_json(),
192 },
193 };
194 calls.push(call);
195 }
196 message.tool_calls = Some(calls.into());
197 }
198
199 #[cfg(feature = "image")]
200 {
201 if let Some(image_value) = obj.get("image") {
202 match image_value {
203 AgentValue::String(s) => {
204 message.image = Some(Arc::new(PhotonImage::new_from_base64(
205 s.trim_start_matches("data:image/png;base64,"),
206 )));
207 }
208 AgentValue::Image(img) => {
209 message.image = Some(img.clone());
210 }
211 _ => {}
212 }
213 }
214 }
215
216 Ok(message)
217 }
218 _ => Err(AgentError::InvalidValue(
219 "Cannot convert AgentValue to Message".to_string(),
220 )),
221 }
222 }
223}
224
225impl From<Message> for AgentValue {
226 fn from(msg: Message) -> Self {
227 AgentValue::Message(Arc::new(msg))
228 }
229}
230
231impl From<Vec<Message>> for AgentValue {
232 fn from(msgs: Vec<Message>) -> Self {
233 let agent_msgs: Vector<AgentValue> = msgs.into_iter().map(|m| m.into()).collect();
234 AgentValue::Array(agent_msgs)
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use im::{hashmap, vector};
241
242 use super::*;
243
244 #[test]
247 fn test_message_to_from_agent_value() {
248 let msg = Message::user("What is the weather today?".to_string());
249
250 let value: AgentValue = msg.into();
251 assert!(value.is_message());
252 let msg_ref = value.as_message().unwrap();
253 assert_eq!(msg_ref.role, "user");
254 assert_eq!(msg_ref.content, "What is the weather today?");
255
256 let msg_converted: Message = value.try_into().unwrap();
257 assert_eq!(msg_converted.role, "user");
258 assert_eq!(msg_converted.content, "What is the weather today?");
259 }
260
261 #[test]
262 fn test_message_with_tool_calls_to_from_agent_value() {
263 let mut msg = Message::assistant("".to_string());
264 msg.tool_calls = Some(vector![ToolCall {
265 function: ToolCallFunction {
266 id: Some("call1".to_string()),
267 name: "get_weather".to_string(),
268 parameters: serde_json::json!({"location": "San Francisco"}),
269 },
270 }]);
271
272 let value: AgentValue = msg.into();
273 assert!(value.is_message());
274 let msg_ref = value.as_message().unwrap();
275 assert_eq!(msg_ref.role, "assistant");
276 assert_eq!(msg_ref.content, "");
277 let tool_calls = msg_ref.tool_calls.as_ref().unwrap();
278 assert_eq!(tool_calls.len(), 1);
279 let first_call = &tool_calls[0];
280 assert_eq!(first_call.function.name, "get_weather");
281 assert_eq!(first_call.function.parameters["location"], "San Francisco");
282
283 let msg_converted: Message = value.try_into().unwrap();
284 dbg!(&msg_converted);
285 assert_eq!(msg_converted.role, "assistant");
286 assert_eq!(msg_converted.content, "");
287 let tool_calls = msg_converted.tool_calls.unwrap();
288 assert_eq!(tool_calls.len(), 1);
289 assert_eq!(tool_calls[0].function.name, "get_weather");
290 assert_eq!(
291 tool_calls[0].function.parameters,
292 serde_json::json!({"location": "San Francisco"})
293 );
294 }
295
296 #[test]
297 fn test_tool_message_to_from_agent_value() {
298 let msg = Message::tool("get_time".to_string(), "2025-01-02 03:04:05".to_string());
299
300 let value: AgentValue = msg.clone().into();
301 let msg_ref = value.as_message().unwrap();
302 assert_eq!(msg_ref.role, "tool");
303 assert_eq!(msg_ref.tool_name.as_deref().unwrap(), "get_time");
304 assert_eq!(msg_ref.content, "2025-01-02 03:04:05");
305
306 let msg_converted: Message = value.try_into().unwrap();
307 assert_eq!(msg_converted.role, "tool");
308 assert_eq!(msg_converted.tool_name.unwrap(), "get_time");
309 assert_eq!(msg_converted.content, "2025-01-02 03:04:05");
310 }
311
312 #[test]
313 fn test_message_from_string_value() {
314 let value = AgentValue::string("Just a simple message");
315 let msg: Message = value.try_into().unwrap();
316 assert_eq!(msg.role, "user");
317 assert_eq!(msg.content, "Just a simple message");
318 }
319
320 #[test]
321 fn test_message_from_object_value() {
322 let value = AgentValue::object(hashmap! {
323 "role".into() => AgentValue::string("assistant"),
324 "content".into() =>
325 AgentValue::string("Here is some information."),
326 });
327 let msg: Message = value.try_into().unwrap();
328 assert_eq!(msg.role, "assistant");
329 assert_eq!(msg.content, "Here is some information.");
330 }
331
332 #[test]
333 fn test_message_from_invalid_value() {
334 let value = AgentValue::integer(42);
335 let result: Result<Message, AgentError> = value.try_into();
336 assert!(result.is_err());
337 }
338
339 #[test]
340 fn test_message_invalid_object() {
341 let value =
342 AgentValue::object(hashmap! {"some_key".into() => AgentValue::string("some_value")});
343 let result: Result<Message, AgentError> = value.try_into();
344 assert!(result.is_err());
345 }
346
347 #[test]
348 fn test_message_to_agent_value_with_tool_calls() {
349 let message = Message {
350 role: "assistant".to_string(),
351 content: "".to_string(),
352 tokens: None,
353 thinking: None,
354 streaming: false,
355 tool_calls: Some(vector![ToolCall {
356 function: ToolCallFunction {
357 id: Some("call1".to_string()),
358 name: "active_applications".to_string(),
359 parameters: serde_json::json!({}),
360 },
361 }]),
362 id: None,
363 tool_name: None,
364 #[cfg(feature = "image")]
365 image: None,
366 };
367
368 let value: AgentValue = message.into();
369 let msg_ref = value.as_message().unwrap();
370
371 assert_eq!(msg_ref.role, "assistant");
372 assert_eq!(msg_ref.content, "");
373
374 let tool_calls = msg_ref.tool_calls.as_ref().unwrap();
375 assert_eq!(tool_calls.len(), 1);
376
377 assert_eq!(tool_calls[0].function.name, "active_applications");
378 assert!(
379 tool_calls[0]
380 .function
381 .parameters
382 .as_object()
383 .unwrap()
384 .is_empty()
385 );
386 }
387
388 #[test]
389 fn test_message_partial_eq() {
390 let msg1 = Message::user("hello".to_string());
391 let msg2 = Message::user("hello".to_string());
392 let msg3 = Message::user("world".to_string());
393
394 assert_eq!(msg1, msg2);
395 assert_ne!(msg1, msg3);
396
397 let mut msg4 = Message::user("hello".to_string());
398 msg4.id = Some("123".to_string());
399 assert_ne!(msg1, msg4);
400 }
401}