Skip to main content

astrid_types/
llm.rs

1//! LLM types for messages, tools, and streaming.
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6/// A message in the conversation.
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8pub struct Message {
9    /// Message role.
10    pub role: MessageRole,
11    /// Message content.
12    pub content: MessageContent,
13}
14
15impl Message {
16    /// Create a user message.
17    pub fn user(content: impl Into<String>) -> Self {
18        Self {
19            role: MessageRole::User,
20            content: MessageContent::Text(content.into()),
21        }
22    }
23
24    /// Create an assistant message.
25    pub fn assistant(content: impl Into<String>) -> Self {
26        Self {
27            role: MessageRole::Assistant,
28            content: MessageContent::Text(content.into()),
29        }
30    }
31
32    /// Create a system message.
33    pub fn system(content: impl Into<String>) -> Self {
34        Self {
35            role: MessageRole::System,
36            content: MessageContent::Text(content.into()),
37        }
38    }
39
40    /// Create an assistant message with tool calls.
41    #[must_use]
42    pub fn assistant_with_tools(tool_calls: Vec<ToolCall>) -> Self {
43        Self {
44            role: MessageRole::Assistant,
45            content: MessageContent::ToolCalls(tool_calls),
46        }
47    }
48
49    /// Create a tool result message.
50    #[must_use]
51    pub fn tool_result(result: ToolCallResult) -> Self {
52        Self {
53            role: MessageRole::Tool,
54            content: MessageContent::ToolResult(result),
55        }
56    }
57
58    /// Get text content if this is a text message.
59    #[must_use]
60    pub fn text(&self) -> Option<&str> {
61        match &self.content {
62            MessageContent::Text(s) => Some(s),
63            _ => None,
64        }
65    }
66
67    /// Get tool calls if this is a tool call message.
68    #[must_use]
69    pub fn tool_calls(&self) -> Option<&[ToolCall]> {
70        match &self.content {
71            MessageContent::ToolCalls(calls) => Some(calls),
72            _ => None,
73        }
74    }
75}
76
77/// Message role.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79#[serde(rename_all = "lowercase")]
80pub enum MessageRole {
81    /// System message (instructions).
82    System,
83    /// User message.
84    User,
85    /// Assistant message.
86    Assistant,
87    /// Tool result.
88    Tool,
89}
90
91/// Message content.
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93#[serde(untagged)]
94pub enum MessageContent {
95    /// Plain text content.
96    Text(String),
97    /// Tool calls.
98    ToolCalls(Vec<ToolCall>),
99    /// Tool result.
100    ToolResult(ToolCallResult),
101    /// Multi-part content (text + images).
102    MultiPart(Vec<ContentPart>),
103}
104
105/// A part of multi-part content.
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107#[serde(tag = "type", rename_all = "snake_case")]
108pub enum ContentPart {
109    /// Text content.
110    Text {
111        /// The text.
112        text: String,
113    },
114    /// Image content.
115    Image {
116        /// Base64-encoded image data.
117        data: String,
118        /// MIME type.
119        media_type: String,
120    },
121}
122
123/// A tool call from the assistant.
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
125pub struct ToolCall {
126    /// Unique call ID.
127    pub id: String,
128    /// Tool name.
129    pub name: String,
130    /// Tool arguments (JSON).
131    pub arguments: Value,
132}
133
134impl ToolCall {
135    /// Create a new tool call.
136    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
137        Self {
138            id: id.into(),
139            name: name.into(),
140            arguments: Value::Object(serde_json::Map::new()),
141        }
142    }
143
144    /// Set arguments.
145    #[must_use]
146    pub fn with_arguments(mut self, args: Value) -> Self {
147        self.arguments = args;
148        self
149    }
150
151    /// Parse the server and tool name from "server:tool" format.
152    #[must_use]
153    pub fn parse_name(&self) -> Option<(&str, &str)> {
154        self.name.split_once(':')
155    }
156}
157
158/// Result of a tool call.
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
160pub struct ToolCallResult {
161    /// Tool call ID this is responding to.
162    pub call_id: String,
163    /// Result content.
164    pub content: String,
165    /// Whether this is an error result.
166    #[serde(default)]
167    pub is_error: bool,
168}
169
170impl ToolCallResult {
171    /// Create a successful result.
172    pub fn success(call_id: impl Into<String>, content: impl Into<String>) -> Self {
173        Self {
174            call_id: call_id.into(),
175            content: content.into(),
176            is_error: false,
177        }
178    }
179
180    /// Create an error result.
181    pub fn error(call_id: impl Into<String>, error: impl Into<String>) -> Self {
182        Self {
183            call_id: call_id.into(),
184            content: error.into(),
185            is_error: true,
186        }
187    }
188}
189
190/// Tool definition for the LLM.
191#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
192pub struct LlmToolDefinition {
193    /// Tool name.
194    pub name: String,
195    /// Description.
196    pub description: Option<String>,
197    /// Input JSON schema.
198    pub input_schema: Value,
199}
200
201impl LlmToolDefinition {
202    /// Create a new tool definition.
203    pub fn new(name: impl Into<String>) -> Self {
204        Self {
205            name: name.into(),
206            description: None,
207            input_schema: serde_json::json!({"type": "object"}),
208        }
209    }
210
211    /// Set description.
212    #[must_use]
213    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
214        self.description = Some(desc.into());
215        self
216    }
217
218    /// Set input schema.
219    #[must_use]
220    pub fn with_schema(mut self, schema: Value) -> Self {
221        self.input_schema = schema;
222        self
223    }
224}
225
226/// Streaming event from the LLM.
227#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
228pub enum StreamEvent {
229    /// Partial text output.
230    TextDelta(String),
231    /// Tool call started.
232    ToolCallStart {
233        /// Call ID.
234        id: String,
235        /// Tool name.
236        name: String,
237    },
238    /// Tool call arguments delta.
239    ToolCallDelta {
240        /// Call ID.
241        id: String,
242        /// Partial arguments JSON.
243        args_delta: String,
244    },
245    /// Tool call completed.
246    ToolCallEnd {
247        /// Call ID.
248        id: String,
249    },
250    /// Reasoning/chain-of-thought delta (used by Z.AI, `DeepSeek`, `OpenAI` o-series, etc.).
251    ReasoningDelta(String),
252    /// Usage information.
253    Usage {
254        /// Input tokens.
255        input_tokens: usize,
256        /// Output tokens.
257        output_tokens: usize,
258    },
259    /// Stream completed.
260    Done,
261    /// Error occurred.
262    Error(String),
263}
264
265/// LLM response (non-streaming).
266#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
267pub struct LlmResponse {
268    /// Response message.
269    pub message: Message,
270    /// Whether the response has tool calls.
271    pub has_tool_calls: bool,
272    /// Stop reason.
273    pub stop_reason: StopReason,
274    /// Token usage.
275    pub usage: Usage,
276}
277
278/// Reason the model stopped generating.
279#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
280pub enum StopReason {
281    /// Natural end of response.
282    EndTurn,
283    /// Hit max tokens.
284    MaxTokens,
285    /// Tool use requested.
286    ToolUse,
287    /// Stop sequence hit.
288    StopSequence,
289}
290
291/// Token usage information.
292#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
293pub struct Usage {
294    /// Input tokens.
295    pub input_tokens: usize,
296    /// Output tokens.
297    pub output_tokens: usize,
298}
299
300impl Usage {
301    /// Total tokens.
302    #[must_use]
303    pub fn total(&self) -> usize {
304        self.input_tokens.saturating_add(self.output_tokens)
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_message_creation() {
314        let user = Message::user("Hello");
315        assert_eq!(user.role, MessageRole::User);
316        assert_eq!(user.text(), Some("Hello"));
317
318        let assistant = Message::assistant("Hi there!");
319        assert_eq!(assistant.role, MessageRole::Assistant);
320    }
321
322    #[test]
323    fn test_tool_call() {
324        let call = ToolCall::new("123", "filesystem:read_file")
325            .with_arguments(serde_json::json!({"path": "/tmp/test.txt"}));
326
327        assert_eq!(call.parse_name(), Some(("filesystem", "read_file")));
328    }
329
330    #[test]
331    fn test_tool_result() {
332        let success = ToolCallResult::success("123", "file contents");
333        assert!(!success.is_error);
334
335        let error = ToolCallResult::error("123", "file not found");
336        assert!(error.is_error);
337    }
338}