agent_stream_kit/
llm.rs

1use std::{ops::Not, sync::Arc, vec};
2
3use im::Vector;
4use serde::{Deserialize, Serialize};
5
6use crate::error::AgentError;
7use crate::value::AgentValue;
8
9#[cfg(feature = "image")]
10use photon_rs::PhotonImage;
11
12#[derive(Debug, Default, Clone, Serialize, Deserialize)]
13pub struct Message {
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub id: Option<String>,
16
17    pub role: String,
18
19    pub content: String,
20
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub tokens: Option<usize>,
23
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub thinking: Option<String>,
26
27    #[serde(skip_serializing_if = "<&bool>::not")]
28    pub streaming: bool,
29
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub tool_calls: Option<Vector<ToolCall>>,
32
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tool_name: Option<String>,
35
36    #[cfg(feature = "image")]
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub image: Option<Arc<PhotonImage>>,
39}
40
41impl Message {
42    pub fn new(role: String, content: String) -> Self {
43        Self {
44            id: None,
45            role,
46            content,
47            tokens: None,
48            streaming: false,
49            thinking: None,
50            tool_calls: None,
51            tool_name: None,
52
53            #[cfg(feature = "image")]
54            image: None,
55        }
56    }
57
58    pub fn assistant(content: String) -> Self {
59        Message::new("assistant".to_string(), content)
60    }
61
62    pub fn system(content: String) -> Self {
63        Message::new("system".to_string(), content)
64    }
65
66    pub fn user(content: String) -> Self {
67        Message::new("user".to_string(), content)
68    }
69
70    pub fn tool(tool_name: String, content: String) -> Self {
71        let mut message = Message::new("tool".to_string(), content);
72        message.tool_name = Some(tool_name);
73        message
74    }
75
76    #[cfg(feature = "image")]
77    pub fn with_image(mut self, image: Arc<PhotonImage>) -> Self {
78        self.image = Some(image);
79        self
80    }
81}
82
83impl PartialEq for Message {
84    fn eq(&self, other: &Self) -> bool {
85        self.id == other.id && self.role == other.role && self.content == other.content
86    }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ToolCall {
91    pub function: ToolCallFunction,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ToolCallFunction {
96    pub name: String,
97    pub parameters: serde_json::Value,
98
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub id: Option<String>,
101}
102
103impl TryFrom<AgentValue> for Message {
104    type Error = AgentError;
105
106    fn try_from(value: AgentValue) -> Result<Self, Self::Error> {
107        match value {
108            AgentValue::Message(msg) => Ok((*msg).clone()),
109            AgentValue::String(s) => Ok(Message::user(s.to_string())),
110
111            #[cfg(feature = "image")]
112            AgentValue::Image(img) => {
113                let mut message = Message::user("".to_string());
114                message.image = Some(img.clone());
115                Ok(message)
116            }
117            AgentValue::Object(obj) => {
118                let role = obj
119                    .get("role")
120                    .and_then(|r| r.as_str())
121                    .unwrap_or("user")
122                    .to_string();
123                let content = obj
124                    .get("content")
125                    .and_then(|c| c.as_str())
126                    .ok_or_else(|| {
127                        AgentError::InvalidValue(
128                            "Message object missing 'content' field".to_string(),
129                        )
130                    })?
131                    .to_string();
132                let mut message = Message::new(role, content);
133
134                let id = obj
135                    .get("id")
136                    .and_then(|i| i.as_str())
137                    .map(|s| s.to_string());
138                message.id = id;
139
140                message.thinking = obj
141                    .get("thinking")
142                    .and_then(|t| t.as_str())
143                    .map(|s| s.to_string());
144
145                message.streaming = obj
146                    .get("streaming")
147                    .and_then(|st| st.as_bool())
148                    .unwrap_or_default();
149
150                if let Some(tool_name) = obj.get("tool_name") {
151                    message.tool_name = Some(
152                        tool_name
153                            .as_str()
154                            .ok_or_else(|| {
155                                AgentError::InvalidValue(
156                                    "'tool_name' field must be a string".to_string(),
157                                )
158                            })?
159                            .to_string(),
160                    );
161                }
162
163                if let Some(tool_calls) = obj.get("tool_calls") {
164                    let mut calls = vec![];
165                    for call_value in tool_calls.as_array().ok_or_else(|| {
166                        AgentError::InvalidValue("'tool_calls' field must be an array".to_string())
167                    })? {
168                        let id = call_value
169                            .get("id")
170                            .and_then(|i| i.as_str())
171                            .map(|s| s.to_string());
172                        let function = call_value.get("function").ok_or_else(|| {
173                            AgentError::InvalidValue(
174                                "Tool call missing 'function' field".to_string(),
175                            )
176                        })?;
177                        let tool_name = function.get_str("name").ok_or_else(|| {
178                            AgentError::InvalidValue(
179                                "Tool call function missing 'name' field".to_string(),
180                            )
181                        })?;
182                        let parameters = function.get("parameters").ok_or_else(|| {
183                            AgentError::InvalidValue(
184                                "Tool call function missing 'parameters' field".to_string(),
185                            )
186                        })?;
187                        let call = ToolCall {
188                            function: ToolCallFunction {
189                                id,
190                                name: tool_name.to_string(),
191                                parameters: parameters.to_json(),
192                            },
193                        };
194                        calls.push(call);
195                    }
196                    message.tool_calls = Some(calls.into());
197                }
198
199                #[cfg(feature = "image")]
200                {
201                    if let Some(image_value) = obj.get("image") {
202                        match image_value {
203                            AgentValue::String(s) => {
204                                message.image = Some(Arc::new(PhotonImage::new_from_base64(
205                                    s.trim_start_matches("data:image/png;base64,"),
206                                )));
207                            }
208                            AgentValue::Image(img) => {
209                                message.image = Some(img.clone());
210                            }
211                            _ => {}
212                        }
213                    }
214                }
215
216                Ok(message)
217            }
218            _ => Err(AgentError::InvalidValue(
219                "Cannot convert AgentValue to Message".to_string(),
220            )),
221        }
222    }
223}
224
225impl From<Message> for AgentValue {
226    fn from(msg: Message) -> Self {
227        AgentValue::Message(Arc::new(msg))
228    }
229}
230
231impl From<Vec<Message>> for AgentValue {
232    fn from(msgs: Vec<Message>) -> Self {
233        let agent_msgs: Vector<AgentValue> = msgs.into_iter().map(|m| m.into()).collect();
234        AgentValue::Array(agent_msgs)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use im::{hashmap, vector};
241
242    use super::*;
243
244    // Message tests
245
246    #[test]
247    fn test_message_to_from_agent_value() {
248        let msg = Message::user("What is the weather today?".to_string());
249
250        let value: AgentValue = msg.into();
251        assert!(value.is_message());
252        let msg_ref = value.as_message().unwrap();
253        assert_eq!(msg_ref.role, "user");
254        assert_eq!(msg_ref.content, "What is the weather today?");
255
256        let msg_converted: Message = value.try_into().unwrap();
257        assert_eq!(msg_converted.role, "user");
258        assert_eq!(msg_converted.content, "What is the weather today?");
259    }
260
261    #[test]
262    fn test_message_with_tool_calls_to_from_agent_value() {
263        let mut msg = Message::assistant("".to_string());
264        msg.tool_calls = Some(vector![ToolCall {
265            function: ToolCallFunction {
266                id: Some("call1".to_string()),
267                name: "get_weather".to_string(),
268                parameters: serde_json::json!({"location": "San Francisco"}),
269            },
270        }]);
271
272        let value: AgentValue = msg.into();
273        assert!(value.is_message());
274        let msg_ref = value.as_message().unwrap();
275        assert_eq!(msg_ref.role, "assistant");
276        assert_eq!(msg_ref.content, "");
277        let tool_calls = msg_ref.tool_calls.as_ref().unwrap();
278        assert_eq!(tool_calls.len(), 1);
279        let first_call = &tool_calls[0];
280        assert_eq!(first_call.function.name, "get_weather");
281        assert_eq!(first_call.function.parameters["location"], "San Francisco");
282
283        let msg_converted: Message = value.try_into().unwrap();
284        dbg!(&msg_converted);
285        assert_eq!(msg_converted.role, "assistant");
286        assert_eq!(msg_converted.content, "");
287        let tool_calls = msg_converted.tool_calls.unwrap();
288        assert_eq!(tool_calls.len(), 1);
289        assert_eq!(tool_calls[0].function.name, "get_weather");
290        assert_eq!(
291            tool_calls[0].function.parameters,
292            serde_json::json!({"location": "San Francisco"})
293        );
294    }
295
296    #[test]
297    fn test_tool_message_to_from_agent_value() {
298        let msg = Message::tool("get_time".to_string(), "2025-01-02 03:04:05".to_string());
299
300        let value: AgentValue = msg.clone().into();
301        let msg_ref = value.as_message().unwrap();
302        assert_eq!(msg_ref.role, "tool");
303        assert_eq!(msg_ref.tool_name.as_deref().unwrap(), "get_time");
304        assert_eq!(msg_ref.content, "2025-01-02 03:04:05");
305
306        let msg_converted: Message = value.try_into().unwrap();
307        assert_eq!(msg_converted.role, "tool");
308        assert_eq!(msg_converted.tool_name.unwrap(), "get_time");
309        assert_eq!(msg_converted.content, "2025-01-02 03:04:05");
310    }
311
312    #[test]
313    fn test_message_from_string_value() {
314        let value = AgentValue::string("Just a simple message");
315        let msg: Message = value.try_into().unwrap();
316        assert_eq!(msg.role, "user");
317        assert_eq!(msg.content, "Just a simple message");
318    }
319
320    #[test]
321    fn test_message_from_object_value() {
322        let value = AgentValue::object(hashmap! {
323            "role".into() => AgentValue::string("assistant"),
324                "content".into() =>
325                AgentValue::string("Here is some information."),
326        });
327        let msg: Message = value.try_into().unwrap();
328        assert_eq!(msg.role, "assistant");
329        assert_eq!(msg.content, "Here is some information.");
330    }
331
332    #[test]
333    fn test_message_from_invalid_value() {
334        let value = AgentValue::integer(42);
335        let result: Result<Message, AgentError> = value.try_into();
336        assert!(result.is_err());
337    }
338
339    #[test]
340    fn test_message_invalid_object() {
341        let value =
342            AgentValue::object(hashmap! {"some_key".into() => AgentValue::string("some_value")});
343        let result: Result<Message, AgentError> = value.try_into();
344        assert!(result.is_err());
345    }
346
347    #[test]
348    fn test_message_to_agent_value_with_tool_calls() {
349        let message = Message {
350            role: "assistant".to_string(),
351            content: "".to_string(),
352            tokens: None,
353            thinking: None,
354            streaming: false,
355            tool_calls: Some(vector![ToolCall {
356                function: ToolCallFunction {
357                    id: Some("call1".to_string()),
358                    name: "active_applications".to_string(),
359                    parameters: serde_json::json!({}),
360                },
361            }]),
362            id: None,
363            tool_name: None,
364            #[cfg(feature = "image")]
365            image: None,
366        };
367
368        let value: AgentValue = message.into();
369        let msg_ref = value.as_message().unwrap();
370
371        assert_eq!(msg_ref.role, "assistant");
372        assert_eq!(msg_ref.content, "");
373
374        let tool_calls = msg_ref.tool_calls.as_ref().unwrap();
375        assert_eq!(tool_calls.len(), 1);
376
377        assert_eq!(tool_calls[0].function.name, "active_applications");
378        assert!(
379            tool_calls[0]
380                .function
381                .parameters
382                .as_object()
383                .unwrap()
384                .is_empty()
385        );
386    }
387
388    #[test]
389    fn test_message_partial_eq() {
390        let msg1 = Message::user("hello".to_string());
391        let msg2 = Message::user("hello".to_string());
392        let msg3 = Message::user("world".to_string());
393
394        assert_eq!(msg1, msg2);
395        assert_ne!(msg1, msg3);
396
397        let mut msg4 = Message::user("hello".to_string());
398        msg4.id = Some("123".to_string());
399        assert_ne!(msg1, msg4);
400    }
401}