Skip to main content

erio_core/
message.rs

1//! Message types for LLM conversations.
2
3use serde::{Deserialize, Serialize};
4
5/// Role of a message participant.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    /// User/human message.
10    User,
11    /// Assistant/AI response.
12    Assistant,
13    /// System instructions.
14    System,
15    /// Tool execution result.
16    Tool,
17}
18
19/// A tool call request from the assistant.
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub struct ToolCall {
22    /// Unique identifier for this tool call.
23    pub id: String,
24    /// Name of the tool to invoke.
25    pub name: String,
26    /// Arguments to pass to the tool.
27    pub arguments: serde_json::Value,
28}
29
30/// Content within a message.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32#[serde(tag = "type", rename_all = "snake_case")]
33pub enum Content {
34    /// Plain text content.
35    Text {
36        /// The text value.
37        text: String,
38    },
39    /// A tool call request.
40    ToolCall(ToolCall),
41}
42
43impl Content {
44    /// Returns the text content if this is a text variant.
45    pub fn as_text(&self) -> Option<&str> {
46        match self {
47            Self::Text { text } => Some(text),
48            Self::ToolCall(_) => None,
49        }
50    }
51}
52
53/// A message in a conversation.
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55pub struct Message {
56    /// Role of the message sender.
57    pub role: Role,
58    /// Content blocks in this message.
59    pub content: Vec<Content>,
60    /// ID of the tool call this message responds to (for Tool role).
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub tool_call_id: Option<String>,
63}
64
65impl Message {
66    /// Creates a user message with text content.
67    pub fn user(text: impl Into<String>) -> Self {
68        Self {
69            role: Role::User,
70            content: vec![Content::Text { text: text.into() }],
71            tool_call_id: None,
72        }
73    }
74
75    /// Creates an assistant message with text content.
76    pub fn assistant(text: impl Into<String>) -> Self {
77        Self {
78            role: Role::Assistant,
79            content: vec![Content::Text { text: text.into() }],
80            tool_call_id: None,
81        }
82    }
83
84    /// Creates a system message with text content.
85    pub fn system(text: impl Into<String>) -> Self {
86        Self {
87            role: Role::System,
88            content: vec![Content::Text { text: text.into() }],
89            tool_call_id: None,
90        }
91    }
92
93    /// Creates a tool result message.
94    pub fn tool_result(call_id: impl Into<String>, result: impl Into<String>) -> Self {
95        Self {
96            role: Role::Tool,
97            content: vec![Content::Text {
98                text: result.into(),
99            }],
100            tool_call_id: Some(call_id.into()),
101        }
102    }
103
104    /// Returns the first text content in this message.
105    pub fn text(&self) -> Option<&str> {
106        self.content.iter().find_map(Content::as_text)
107    }
108
109    /// Returns the tool call ID if this is a tool result message.
110    pub fn tool_call_id(&self) -> Option<&str> {
111        self.tool_call_id.as_deref()
112    }
113
114    /// Returns an iterator over tool calls in this message.
115    pub fn tool_calls(&self) -> impl Iterator<Item = &ToolCall> {
116        self.content.iter().filter_map(|c| match c {
117            Content::ToolCall(tc) => Some(tc),
118            Content::Text { .. } => None,
119        })
120    }
121
122    /// Returns true if the message has no content.
123    pub fn is_empty(&self) -> bool {
124        self.content.is_empty()
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    // === Role Tests ===
133
134    #[test]
135    fn role_serializes_to_lowercase() {
136        assert_eq!(serde_json::to_string(&Role::User).unwrap(), "\"user\"");
137        assert_eq!(
138            serde_json::to_string(&Role::Assistant).unwrap(),
139            "\"assistant\""
140        );
141        assert_eq!(serde_json::to_string(&Role::System).unwrap(), "\"system\"");
142        assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), "\"tool\"");
143    }
144
145    #[test]
146    fn role_deserializes_from_lowercase() {
147        assert_eq!(
148            serde_json::from_str::<Role>("\"user\"").unwrap(),
149            Role::User
150        );
151        assert_eq!(
152            serde_json::from_str::<Role>("\"assistant\"").unwrap(),
153            Role::Assistant
154        );
155        assert_eq!(
156            serde_json::from_str::<Role>("\"system\"").unwrap(),
157            Role::System
158        );
159        assert_eq!(
160            serde_json::from_str::<Role>("\"tool\"").unwrap(),
161            Role::Tool
162        );
163    }
164
165    // === Message Construction Tests ===
166
167    #[test]
168    fn message_user_creates_user_message() {
169        let msg = Message::user("Hello");
170
171        assert_eq!(msg.role, Role::User);
172        assert_eq!(msg.text(), Some("Hello"));
173    }
174
175    #[test]
176    fn message_assistant_creates_assistant_message() {
177        let msg = Message::assistant("Hi there");
178
179        assert_eq!(msg.role, Role::Assistant);
180        assert_eq!(msg.text(), Some("Hi there"));
181    }
182
183    #[test]
184    fn message_system_creates_system_message() {
185        let msg = Message::system("You are helpful");
186
187        assert_eq!(msg.role, Role::System);
188        assert_eq!(msg.text(), Some("You are helpful"));
189    }
190
191    #[test]
192    fn message_tool_result_creates_tool_message() {
193        let msg = Message::tool_result("call_123", "result data");
194
195        assert_eq!(msg.role, Role::Tool);
196        assert_eq!(msg.tool_call_id(), Some("call_123"));
197        assert_eq!(msg.text(), Some("result data"));
198    }
199
200    // === Content Tests ===
201
202    #[test]
203    fn content_text_returns_text() {
204        let content = Content::Text {
205            text: "hello".into(),
206        };
207        assert_eq!(content.as_text(), Some("hello"));
208    }
209
210    #[test]
211    fn content_tool_call_returns_none_for_text() {
212        let content = Content::ToolCall(ToolCall {
213            id: "id".into(),
214            name: "shell".into(),
215            arguments: serde_json::json!({}),
216        });
217        assert_eq!(content.as_text(), None);
218    }
219
220    // === ToolCall Tests ===
221
222    #[test]
223    fn tool_call_serializes_correctly() {
224        let call = ToolCall {
225            id: "call_abc123".into(),
226            name: "read_file".into(),
227            arguments: serde_json::json!({"path": "/tmp/test.txt"}),
228        };
229
230        let json = serde_json::to_value(&call).unwrap();
231
232        assert_eq!(json["id"], "call_abc123");
233        assert_eq!(json["name"], "read_file");
234        assert_eq!(json["arguments"]["path"], "/tmp/test.txt");
235    }
236
237    #[test]
238    fn tool_call_deserializes_correctly() {
239        let json = serde_json::json!({
240            "id": "call_xyz",
241            "name": "shell",
242            "arguments": {"command": "ls -la"}
243        });
244
245        let call: ToolCall = serde_json::from_value(json).unwrap();
246
247        assert_eq!(call.id, "call_xyz");
248        assert_eq!(call.name, "shell");
249        assert_eq!(call.arguments["command"], "ls -la");
250    }
251
252    // === Message Serde Tests ===
253
254    #[test]
255    fn message_text_serde_roundtrip() {
256        let original = Message::user("Test message");
257
258        let json = serde_json::to_string(&original).unwrap();
259        let deserialized: Message = serde_json::from_str(&json).unwrap();
260
261        assert_eq!(deserialized.role, original.role);
262        assert_eq!(deserialized.text(), original.text());
263    }
264
265    #[test]
266    fn message_with_tool_calls_serde_roundtrip() {
267        let original = Message {
268            role: Role::Assistant,
269            content: vec![
270                Content::Text {
271                    text: "I'll help you with that.".into(),
272                },
273                Content::ToolCall(ToolCall {
274                    id: "call_1".into(),
275                    name: "shell".into(),
276                    arguments: serde_json::json!({"command": "pwd"}),
277                }),
278            ],
279            tool_call_id: None,
280        };
281
282        let json = serde_json::to_string(&original).unwrap();
283        let deserialized: Message = serde_json::from_str(&json).unwrap();
284
285        assert_eq!(deserialized.role, Role::Assistant);
286        assert_eq!(deserialized.content.len(), 2);
287    }
288
289    // === Message Accessor Tests ===
290
291    #[test]
292    fn message_text_returns_first_text_content() {
293        let msg = Message {
294            role: Role::Assistant,
295            content: vec![
296                Content::Text {
297                    text: "First".into(),
298                },
299                Content::Text {
300                    text: "Second".into(),
301                },
302            ],
303            tool_call_id: None,
304        };
305
306        assert_eq!(msg.text(), Some("First"));
307    }
308
309    #[test]
310    fn message_text_returns_none_when_no_text() {
311        let msg = Message {
312            role: Role::Assistant,
313            content: vec![Content::ToolCall(ToolCall {
314                id: "id".into(),
315                name: "test".into(),
316                arguments: serde_json::json!({}),
317            })],
318            tool_call_id: None,
319        };
320
321        assert_eq!(msg.text(), None);
322    }
323
324    #[test]
325    fn message_tool_calls_returns_all_tool_calls() {
326        let msg = Message {
327            role: Role::Assistant,
328            content: vec![
329                Content::Text {
330                    text: "Let me help".into(),
331                },
332                Content::ToolCall(ToolCall {
333                    id: "call_1".into(),
334                    name: "shell".into(),
335                    arguments: serde_json::json!({}),
336                }),
337                Content::ToolCall(ToolCall {
338                    id: "call_2".into(),
339                    name: "read_file".into(),
340                    arguments: serde_json::json!({}),
341                }),
342            ],
343            tool_call_id: None,
344        };
345
346        let calls: Vec<_> = msg.tool_calls().collect();
347        assert_eq!(calls.len(), 2);
348        assert_eq!(calls[0].name, "shell");
349        assert_eq!(calls[1].name, "read_file");
350    }
351
352    #[test]
353    fn message_is_empty_when_no_content() {
354        let msg = Message {
355            role: Role::User,
356            content: vec![],
357            tool_call_id: None,
358        };
359
360        assert!(msg.is_empty());
361    }
362
363    #[test]
364    fn message_is_not_empty_with_content() {
365        let msg = Message::user("hello");
366        assert!(!msg.is_empty());
367    }
368}