Skip to main content

imp_llm/
message.rs

1use serde::{Deserialize, Serialize};
2
3/// A message in the conversation, tagged by role.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(tag = "role")]
6pub enum Message {
7    /// Content from the human user.
8    #[serde(rename = "user")]
9    User(UserMessage),
10    /// Content from the LLM assistant.
11    #[serde(rename = "assistant")]
12    Assistant(AssistantMessage),
13    /// Result of a tool execution returned to the model.
14    #[serde(rename = "tool_result")]
15    ToolResult(ToolResultMessage),
16}
17
18/// A message sent by the user.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct UserMessage {
21    /// One or more content blocks (text, images, etc.).
22    pub content: Vec<ContentBlock>,
23    /// Unix timestamp in seconds when the message was created.
24    pub timestamp: u64,
25}
26
27/// A response from the assistant.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AssistantMessage {
30    /// Content blocks produced by the model.
31    pub content: Vec<ContentBlock>,
32    /// Token usage for this response, if reported by the provider.
33    pub usage: Option<crate::usage::Usage>,
34    /// Why the model stopped generating.
35    pub stop_reason: StopReason,
36    /// Unix timestamp in seconds.
37    pub timestamp: u64,
38}
39
40/// The result of executing a tool, sent back to the model.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ToolResultMessage {
43    /// Provider-assigned call id that pairs this result with its tool call.
44    pub tool_call_id: String,
45    /// Name of the tool that was executed.
46    pub tool_name: String,
47    /// Output content blocks.
48    pub content: Vec<ContentBlock>,
49    /// Whether the tool execution failed.
50    pub is_error: bool,
51    /// Arbitrary metadata about the execution.
52    #[serde(default)]
53    pub details: serde_json::Value,
54    /// Unix timestamp in seconds.
55    pub timestamp: u64,
56}
57
58/// A single block of content within a message.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60#[serde(tag = "type")]
61pub enum ContentBlock {
62    /// Plain text content.
63    #[serde(rename = "text")]
64    Text { text: String },
65    /// Extended thinking / chain-of-thought output.
66    #[serde(rename = "thinking")]
67    Thinking { text: String },
68    /// A request from the model to call a tool.
69    #[serde(rename = "tool_call")]
70    ToolCall {
71        id: String,
72        name: String,
73        arguments: serde_json::Value,
74    },
75    /// Base64-encoded image data.
76    #[serde(rename = "image")]
77    Image { media_type: String, data: String },
78}
79
80/// Reason the model stopped generating tokens.
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
82pub enum StopReason {
83    /// Natural end of response.
84    EndTurn,
85    /// Model wants to call one or more tools.
86    ToolUse,
87    /// Hit the max_tokens limit.
88    MaxTokens,
89    /// An error occurred during generation.
90    Error(String),
91}
92
93impl Message {
94    /// Convenience constructor for a simple text user message.
95    pub fn user(text: impl Into<String>) -> Self {
96        Message::User(UserMessage {
97            content: vec![ContentBlock::Text { text: text.into() }],
98            timestamp: crate::now(),
99        })
100    }
101
102    /// True if this is a user message.
103    pub fn is_user(&self) -> bool {
104        matches!(self, Message::User(_))
105    }
106
107    /// True if this is an assistant message.
108    pub fn is_assistant(&self) -> bool {
109        matches!(self, Message::Assistant(_))
110    }
111
112    /// True if this is a tool result.
113    pub fn is_tool_result(&self) -> bool {
114        matches!(self, Message::ToolResult(_))
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn message_user_round_trip() {
124        let msg = Message::User(UserMessage {
125            content: vec![ContentBlock::Text {
126                text: "Hello".into(),
127            }],
128            timestamp: 1700000000,
129        });
130        let json = serde_json::to_string(&msg).unwrap();
131        let restored: Message = serde_json::from_str(&json).unwrap();
132        assert!(restored.is_user());
133        if let Message::User(u) = &restored {
134            assert_eq!(u.timestamp, 1700000000);
135            assert_eq!(u.content.len(), 1);
136        } else {
137            panic!("expected User variant");
138        }
139    }
140
141    #[test]
142    fn message_assistant_round_trip() {
143        let msg = Message::Assistant(AssistantMessage {
144            content: vec![
145                ContentBlock::Text {
146                    text: "Sure!".into(),
147                },
148                ContentBlock::Thinking {
149                    text: "Let me think...".into(),
150                },
151            ],
152            usage: Some(crate::usage::Usage {
153                input_tokens: 100,
154                output_tokens: 50,
155                cache_read_tokens: 0,
156                cache_write_tokens: 0,
157            }),
158            stop_reason: StopReason::EndTurn,
159            timestamp: 1700000001,
160        });
161        let json = serde_json::to_string(&msg).unwrap();
162        let restored: Message = serde_json::from_str(&json).unwrap();
163        assert!(restored.is_assistant());
164        if let Message::Assistant(a) = &restored {
165            assert_eq!(a.content.len(), 2);
166            assert_eq!(a.stop_reason, StopReason::EndTurn);
167            assert_eq!(a.usage.as_ref().unwrap().input_tokens, 100);
168        } else {
169            panic!("expected Assistant variant");
170        }
171    }
172
173    #[test]
174    fn message_tool_result_round_trip() {
175        let msg = Message::ToolResult(ToolResultMessage {
176            tool_call_id: "call_123".into(),
177            tool_name: "read_file".into(),
178            content: vec![ContentBlock::Text {
179                text: "file contents".into(),
180            }],
181            is_error: false,
182            details: serde_json::json!({"path": "/tmp/test"}),
183            timestamp: 1700000002,
184        });
185        let json = serde_json::to_string(&msg).unwrap();
186        let restored: Message = serde_json::from_str(&json).unwrap();
187        assert!(restored.is_tool_result());
188        if let Message::ToolResult(t) = &restored {
189            assert_eq!(t.tool_call_id, "call_123");
190            assert_eq!(t.tool_name, "read_file");
191            assert!(!t.is_error);
192        } else {
193            panic!("expected ToolResult variant");
194        }
195    }
196
197    #[test]
198    fn tool_call_content_block_round_trip() {
199        let block = ContentBlock::ToolCall {
200            id: "tc_1".into(),
201            name: "bash".into(),
202            arguments: serde_json::json!({"command": "ls"}),
203        };
204        let json = serde_json::to_string(&block).unwrap();
205        let restored: ContentBlock = serde_json::from_str(&json).unwrap();
206        if let ContentBlock::ToolCall {
207            id,
208            name,
209            arguments,
210        } = restored
211        {
212            assert_eq!(id, "tc_1");
213            assert_eq!(name, "bash");
214            assert_eq!(arguments["command"], "ls");
215        } else {
216            panic!("expected ToolCall variant");
217        }
218    }
219
220    #[test]
221    fn image_content_block_round_trip() {
222        let block = ContentBlock::Image {
223            media_type: "image/png".into(),
224            data: "iVBORw0KGgo=".into(),
225        };
226        let json = serde_json::to_string(&block).unwrap();
227        let restored: ContentBlock = serde_json::from_str(&json).unwrap();
228        if let ContentBlock::Image { media_type, data } = restored {
229            assert_eq!(media_type, "image/png");
230            assert_eq!(data, "iVBORw0KGgo=");
231        } else {
232            panic!("expected Image variant");
233        }
234    }
235
236    #[test]
237    fn empty_content_assistant_message_round_trip() {
238        let msg = Message::Assistant(AssistantMessage {
239            content: vec![],
240            usage: None,
241            stop_reason: StopReason::EndTurn,
242            timestamp: 1700000000,
243        });
244        let json = serde_json::to_string(&msg).unwrap();
245        let restored: Message = serde_json::from_str(&json).unwrap();
246        if let Message::Assistant(a) = restored {
247            assert!(a.content.is_empty());
248            assert!(a.usage.is_none());
249            assert_eq!(a.stop_reason, StopReason::EndTurn);
250        } else {
251            panic!("expected Assistant variant");
252        }
253    }
254
255    #[test]
256    fn tool_result_with_is_error_round_trip() {
257        let msg = Message::ToolResult(ToolResultMessage {
258            tool_call_id: "call_err".into(),
259            tool_name: "bash".into(),
260            content: vec![ContentBlock::Text {
261                text: "command not found".into(),
262            }],
263            is_error: true,
264            details: serde_json::Value::Null,
265            timestamp: 1700000003,
266        });
267        let json = serde_json::to_string(&msg).unwrap();
268        let restored: Message = serde_json::from_str(&json).unwrap();
269        if let Message::ToolResult(tr) = restored {
270            assert!(tr.is_error);
271            assert_eq!(tr.tool_call_id, "call_err");
272        } else {
273            panic!("expected ToolResult variant");
274        }
275    }
276
277    #[test]
278    fn message_user_helper() {
279        let msg = Message::user("test prompt");
280        assert!(msg.is_user());
281        assert!(!msg.is_assistant());
282        assert!(!msg.is_tool_result());
283        if let Message::User(u) = msg {
284            assert_eq!(u.content.len(), 1);
285            if let ContentBlock::Text { text } = &u.content[0] {
286                assert_eq!(text, "test prompt");
287            } else {
288                panic!("expected Text block");
289            }
290        }
291    }
292
293    #[test]
294    fn content_block_variant_discrimination() {
295        // All four variants should deserialize to the correct type
296        let text_json = r#"{"type":"text","text":"hello"}"#;
297        let thinking_json = r#"{"type":"thinking","text":"hmm"}"#;
298        let tool_json = r#"{"type":"tool_call","id":"t1","name":"bash","arguments":{}}"#;
299        let image_json = r#"{"type":"image","media_type":"image/jpeg","data":"abc"}"#;
300
301        let text: ContentBlock = serde_json::from_str(text_json).unwrap();
302        assert!(matches!(text, ContentBlock::Text { .. }));
303
304        let thinking: ContentBlock = serde_json::from_str(thinking_json).unwrap();
305        assert!(matches!(thinking, ContentBlock::Thinking { .. }));
306
307        let tool: ContentBlock = serde_json::from_str(tool_json).unwrap();
308        assert!(matches!(tool, ContentBlock::ToolCall { .. }));
309
310        let image: ContentBlock = serde_json::from_str(image_json).unwrap();
311        assert!(matches!(image, ContentBlock::Image { .. }));
312    }
313
314    #[test]
315    fn stop_reason_round_trip() {
316        let reasons = vec![
317            StopReason::EndTurn,
318            StopReason::ToolUse,
319            StopReason::MaxTokens,
320            StopReason::Error("rate_limit".into()),
321        ];
322        for reason in reasons {
323            let json = serde_json::to_string(&reason).unwrap();
324            let restored: StopReason = serde_json::from_str(&json).unwrap();
325            assert_eq!(restored, reason);
326        }
327    }
328}