1#![cfg(feature = "llm")]
7
8use std::{sync::Arc, vec};
9
10use im::Vector;
11use serde::{Deserialize, Serialize};
12
13use crate::error::AgentError;
14use crate::value::AgentValue;
15
16#[cfg(feature = "image")]
17use photon_rs::PhotonImage;
18
19#[derive(Debug, Default, Clone)]
47pub struct Message {
48 pub id: Option<String>,
50
51 pub role: String,
53
54 pub content: String,
56
57 pub tokens: Option<usize>,
59
60 pub thinking: Option<String>,
62
63 pub streaming: bool,
65
66 pub tool_calls: Option<Vector<ToolCall>>,
68
69 pub tool_name: Option<String>,
71
72 #[cfg(feature = "image")]
74 pub image: Option<Arc<PhotonImage>>,
75}
76
77impl Message {
78 pub fn new(role: String, content: String) -> Self {
85 Self {
86 id: None,
87 role,
88 content,
89 tokens: None,
90 streaming: false,
91 thinking: None,
92 tool_calls: None,
93 tool_name: None,
94
95 #[cfg(feature = "image")]
96 image: None,
97 }
98 }
99
100 pub fn assistant(content: String) -> Self {
102 Message::new("assistant".to_string(), content)
103 }
104
105 pub fn system(content: String) -> Self {
109 Message::new("system".to_string(), content)
110 }
111
112 pub fn user(content: String) -> Self {
114 Message::new("user".to_string(), content)
115 }
116
117 pub fn tool(tool_name: String, content: String) -> Self {
127 let mut message = Message::new("tool".to_string(), content);
128 message.tool_name = Some(tool_name);
129 message
130 }
131
132 #[cfg(feature = "image")]
136 pub fn with_image(mut self, image: Arc<PhotonImage>) -> Self {
137 self.image = Some(image);
138 self
139 }
140}
141
142impl PartialEq for Message {
143 fn eq(&self, other: &Self) -> bool {
144 self.id == other.id && self.role == other.role && self.content == other.content
145 }
146}
147
148impl Serialize for Message {
149 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
150 where
151 S: serde::Serializer,
152 {
153 let mut map = serde_json::Map::new();
154 if let Some(id) = &self.id {
155 map.insert("id".to_string(), serde_json::Value::String(id.clone()));
156 }
157 map.insert(
158 "role".to_string(),
159 serde_json::Value::String(self.role.clone()),
160 );
161 map.insert(
162 "content".to_string(),
163 serde_json::Value::String(self.content.clone()),
164 );
165 if let Some(tokens) = &self.tokens {
166 map.insert(
167 "tokens".to_string(),
168 serde_json::Value::Number((*tokens).into()),
169 );
170 }
171 if let Some(thinking) = &self.thinking {
172 map.insert(
173 "thinking".to_string(),
174 serde_json::Value::String(thinking.clone()),
175 );
176 }
177 if self.streaming {
178 map.insert("streaming".to_string(), serde_json::Value::Bool(true));
179 }
180 if let Some(tool_calls) = &self.tool_calls {
181 let mut tool_calls_vec = vec![];
182 for call in tool_calls {
183 tool_calls_vec.push(serde_json::to_value(call).map_err(serde::ser::Error::custom)?);
184 }
185 map.insert(
186 "tool_calls".to_string(),
187 serde_json::Value::Array(tool_calls_vec),
188 );
189 }
190 if let Some(tool_name) = &self.tool_name {
191 map.insert(
192 "tool_name".to_string(),
193 serde_json::Value::String(tool_name.clone()),
194 );
195 }
196 #[cfg(feature = "image")]
197 {
198 if let Some(image) = &self.image {
199 map.insert(
200 "image".to_string(),
201 serde_json::Value::String(image.get_base64()),
202 );
203 }
204 }
205 map.serialize(serializer)
206 }
207}
208
209impl<'de> Deserialize<'de> for Message {
210 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
211 where
212 D: serde::Deserializer<'de>,
213 {
214 let mut message = Message::user(String::default());
215 let map = serde_json::Map::deserialize(deserializer)?;
216
217 if let Some(id) = map.get("id") {
218 message.id = id.as_str().map(|s| s.to_string());
219 }
220 if let Some(role) = map.get("role") {
221 message.role = role
222 .as_str()
223 .ok_or_else(|| serde::de::Error::custom("role must be a string"))?
224 .to_string();
225 }
226 if let Some(content) = map.get("content") {
227 message.content = content
228 .as_str()
229 .ok_or_else(|| serde::de::Error::custom("content must be a string"))?
230 .to_string();
231 }
232 if let Some(tokens) = map.get("tokens") {
233 message.tokens = tokens.as_u64().map(|u| u as usize);
234 }
235 if let Some(thinking) = map.get("thinking") {
236 message.thinking = thinking.as_str().map(|s| s.to_string());
237 }
238 if let Some(streaming) = map.get("streaming") {
239 message.streaming = streaming.as_bool().unwrap_or(false);
240 }
241 if let Some(tool_calls) = map.get("tool_calls") {
242 let tool_calls = serde_json::from_value::<Vec<ToolCall>>(tool_calls.clone())
243 .map_err(|e| serde::de::Error::custom(e.to_string()))?;
244 message.tool_calls = Some(tool_calls.into());
245 }
246 if let Some(tool_name) = map.get("tool_name") {
247 message.tool_name = tool_name.as_str().map(|s| s.to_string());
248 }
249 #[cfg(feature = "image")]
250 if let Some(image) = map.get("image") {
251 let image_str = image
252 .as_str()
253 .ok_or_else(|| serde::de::Error::custom("image must be a string"))?;
254 let image = Arc::new(PhotonImage::new_from_base64(image_str));
255 message.image = Some(image);
256 }
257 Ok(message)
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ToolCall {
267 pub function: ToolCallFunction,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct ToolCallFunction {
277 pub name: String,
279
280 pub parameters: serde_json::Value,
282
283 #[serde(skip_serializing_if = "Option::is_none")]
285 pub id: Option<String>,
286}
287
288impl TryFrom<AgentValue> for Message {
289 type Error = AgentError;
290
291 fn try_from(value: AgentValue) -> Result<Self, Self::Error> {
292 match value {
293 AgentValue::Message(msg) => Ok((*msg).clone()),
294 AgentValue::String(s) => Ok(Message::user(s.to_string())),
295
296 #[cfg(feature = "image")]
297 AgentValue::Image(img) => {
298 let mut message = Message::user("".to_string());
299 message.image = Some(img.clone());
300 Ok(message)
301 }
302 AgentValue::Object(obj) => {
303 let role = obj
304 .get("role")
305 .and_then(|r| r.as_str())
306 .unwrap_or("user")
307 .to_string();
308 let content = obj
309 .get("content")
310 .and_then(|c| c.as_str())
311 .ok_or_else(|| {
312 AgentError::InvalidValue(
313 "Message object missing 'content' field".to_string(),
314 )
315 })?
316 .to_string();
317 let mut message = Message::new(role, content);
318
319 let id = obj
320 .get("id")
321 .and_then(|i| i.as_str())
322 .map(|s| s.to_string());
323 message.id = id;
324
325 message.thinking = obj
326 .get("thinking")
327 .and_then(|t| t.as_str())
328 .map(|s| s.to_string());
329
330 message.streaming = obj
331 .get("streaming")
332 .and_then(|st| st.as_bool())
333 .unwrap_or_default();
334
335 if let Some(tool_name) = obj.get("tool_name") {
336 message.tool_name = Some(
337 tool_name
338 .as_str()
339 .ok_or_else(|| {
340 AgentError::InvalidValue(
341 "'tool_name' field must be a string".to_string(),
342 )
343 })?
344 .to_string(),
345 );
346 }
347
348 if let Some(tool_calls) = obj.get("tool_calls") {
349 let mut calls = vec![];
350 for call_value in tool_calls.as_array().ok_or_else(|| {
351 AgentError::InvalidValue("'tool_calls' field must be an array".to_string())
352 })? {
353 let id = call_value
354 .get("id")
355 .and_then(|i| i.as_str())
356 .map(|s| s.to_string());
357 let function = call_value.get("function").ok_or_else(|| {
358 AgentError::InvalidValue(
359 "Tool call missing 'function' field".to_string(),
360 )
361 })?;
362 let tool_name = function.get_str("name").ok_or_else(|| {
363 AgentError::InvalidValue(
364 "Tool call function missing 'name' field".to_string(),
365 )
366 })?;
367 let parameters = function.get("parameters").ok_or_else(|| {
368 AgentError::InvalidValue(
369 "Tool call function missing 'parameters' field".to_string(),
370 )
371 })?;
372 let call = ToolCall {
373 function: ToolCallFunction {
374 id,
375 name: tool_name.to_string(),
376 parameters: parameters.to_json(),
377 },
378 };
379 calls.push(call);
380 }
381 message.tool_calls = Some(calls.into());
382 }
383
384 #[cfg(feature = "image")]
385 {
386 if let Some(image_value) = obj.get("image") {
387 match image_value {
388 AgentValue::String(s) => {
389 message.image = Some(Arc::new(PhotonImage::new_from_base64(
390 s.trim_start_matches("data:image/png;base64,"),
391 )));
392 }
393 AgentValue::Image(img) => {
394 message.image = Some(img.clone());
395 }
396 _ => {}
397 }
398 }
399 }
400
401 Ok(message)
402 }
403 _ => Err(AgentError::InvalidValue(
404 "Cannot convert AgentValue to Message".to_string(),
405 )),
406 }
407 }
408}
409
410impl From<Message> for AgentValue {
411 fn from(msg: Message) -> Self {
412 AgentValue::Message(Arc::new(msg))
413 }
414}
415
416impl From<Vec<Message>> for AgentValue {
417 fn from(msgs: Vec<Message>) -> Self {
418 let agent_msgs: Vector<AgentValue> = msgs.into_iter().map(|m| m.into()).collect();
419 AgentValue::Array(agent_msgs)
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use im::{hashmap, vector};
426
427 use super::*;
428
429 #[test]
432 fn test_message_to_from_agent_value() {
433 let msg = Message::user("What is the weather today?".to_string());
434
435 let value: AgentValue = msg.into();
436 assert!(value.is_message());
437 let msg_ref = value.as_message().unwrap();
438 assert_eq!(msg_ref.role, "user");
439 assert_eq!(msg_ref.content, "What is the weather today?");
440
441 let msg_converted: Message = value.try_into().unwrap();
442 assert_eq!(msg_converted.role, "user");
443 assert_eq!(msg_converted.content, "What is the weather today?");
444 }
445
446 #[test]
447 fn test_message_with_tool_calls_to_from_agent_value() {
448 let mut msg = Message::assistant("".to_string());
449 msg.tool_calls = Some(vector![ToolCall {
450 function: ToolCallFunction {
451 id: Some("call1".to_string()),
452 name: "get_weather".to_string(),
453 parameters: serde_json::json!({"location": "San Francisco"}),
454 },
455 }]);
456
457 let value: AgentValue = msg.into();
458 assert!(value.is_message());
459 let msg_ref = value.as_message().unwrap();
460 assert_eq!(msg_ref.role, "assistant");
461 assert_eq!(msg_ref.content, "");
462 let tool_calls = msg_ref.tool_calls.as_ref().unwrap();
463 assert_eq!(tool_calls.len(), 1);
464 let first_call = &tool_calls[0];
465 assert_eq!(first_call.function.name, "get_weather");
466 assert_eq!(first_call.function.parameters["location"], "San Francisco");
467
468 let msg_converted: Message = value.try_into().unwrap();
469 dbg!(&msg_converted);
470 assert_eq!(msg_converted.role, "assistant");
471 assert_eq!(msg_converted.content, "");
472 let tool_calls = msg_converted.tool_calls.unwrap();
473 assert_eq!(tool_calls.len(), 1);
474 assert_eq!(tool_calls[0].function.name, "get_weather");
475 assert_eq!(
476 tool_calls[0].function.parameters,
477 serde_json::json!({"location": "San Francisco"})
478 );
479 }
480
481 #[test]
482 fn test_tool_message_to_from_agent_value() {
483 let msg = Message::tool("get_time".to_string(), "2025-01-02 03:04:05".to_string());
484
485 let value: AgentValue = msg.clone().into();
486 let msg_ref = value.as_message().unwrap();
487 assert_eq!(msg_ref.role, "tool");
488 assert_eq!(msg_ref.tool_name.as_deref().unwrap(), "get_time");
489 assert_eq!(msg_ref.content, "2025-01-02 03:04:05");
490
491 let msg_converted: Message = value.try_into().unwrap();
492 assert_eq!(msg_converted.role, "tool");
493 assert_eq!(msg_converted.tool_name.unwrap(), "get_time");
494 assert_eq!(msg_converted.content, "2025-01-02 03:04:05");
495 }
496
497 #[test]
498 fn test_message_from_string_value() {
499 let value = AgentValue::string("Just a simple message");
500 let msg: Message = value.try_into().unwrap();
501 assert_eq!(msg.role, "user");
502 assert_eq!(msg.content, "Just a simple message");
503 }
504
505 #[test]
506 fn test_message_from_object_value() {
507 let value = AgentValue::object(hashmap! {
508 "role".into() => AgentValue::string("assistant"),
509 "content".into() =>
510 AgentValue::string("Here is some information."),
511 });
512 let msg: Message = value.try_into().unwrap();
513 assert_eq!(msg.role, "assistant");
514 assert_eq!(msg.content, "Here is some information.");
515 }
516
517 #[test]
518 fn test_message_from_invalid_value() {
519 let value = AgentValue::integer(42);
520 let result: Result<Message, AgentError> = value.try_into();
521 assert!(result.is_err());
522 }
523
524 #[test]
525 fn test_message_invalid_object() {
526 let value =
527 AgentValue::object(hashmap! {"some_key".into() => AgentValue::string("some_value")});
528 let result: Result<Message, AgentError> = value.try_into();
529 assert!(result.is_err());
530 }
531
532 #[test]
533 fn test_message_to_agent_value_with_tool_calls() {
534 let message = Message {
535 role: "assistant".to_string(),
536 content: "".to_string(),
537 tokens: None,
538 thinking: None,
539 streaming: false,
540 tool_calls: Some(vector![ToolCall {
541 function: ToolCallFunction {
542 id: Some("call1".to_string()),
543 name: "active_applications".to_string(),
544 parameters: serde_json::json!({}),
545 },
546 }]),
547 id: None,
548 tool_name: None,
549 #[cfg(feature = "image")]
550 image: None,
551 };
552
553 let value: AgentValue = message.into();
554 let msg_ref = value.as_message().unwrap();
555
556 assert_eq!(msg_ref.role, "assistant");
557 assert_eq!(msg_ref.content, "");
558
559 let tool_calls = msg_ref.tool_calls.as_ref().unwrap();
560 assert_eq!(tool_calls.len(), 1);
561
562 assert_eq!(tool_calls[0].function.name, "active_applications");
563 assert!(
564 tool_calls[0]
565 .function
566 .parameters
567 .as_object()
568 .unwrap()
569 .is_empty()
570 );
571 }
572
573 #[test]
574 fn test_message_partial_eq() {
575 let msg1 = Message::user("hello".to_string());
576 let msg2 = Message::user("hello".to_string());
577 let msg3 = Message::user("world".to_string());
578
579 assert_eq!(msg1, msg2);
580 assert_ne!(msg1, msg3);
581
582 let mut msg4 = Message::user("hello".to_string());
583 msg4.id = Some("123".to_string());
584 assert_ne!(msg1, msg4);
585 }
586}