Skip to main content

limit_llm/
types.rs

1//! Core types for LLM message passing and tool definitions.
2//!
3//! This module provides the fundamental types used throughout `limit-llm` for
4//! constructing messages, defining tools, and handling responses from LLM providers.
5//!
6//! # Overview
7//!
8//! - [`Message`] — A single message in a conversation with role and content
9//! - [`Role`] — The sender role (User, Assistant, System, or Tool)
10//! - [`Tool`] / [`ToolCall`] — Function calling definitions for LLM tool use
11//! - [`Response`] — Complete response with content, tool calls, and usage
12//! - [`Usage`] — Token counting for prompt and completion
13
14use serde::{Deserialize, Serialize};
15
16/// A single message in a conversation.
17///
18/// Messages are the fundamental unit of communication with LLM providers.
19/// Each message has a role (who sent it), content (the text), and optionally
20/// tool calls (for function calling).
21///
22/// # Examples
23///
24/// ## User Message
25///
26/// ```
27/// use limit_llm::{Message, Role};
28///
29/// let msg = Message {
30///     role: Role::User,
31///     content: Some("What is the capital of France?".to_string()),
32///     tool_calls: None,
33///     tool_call_id: None,
34/// };
35/// ```
36///
37/// ## Assistant Message with Tool Call
38///
39/// ```
40/// use limit_llm::{Message, Role, ToolCall, FunctionCall};
41///
42/// let msg = Message {
43///     role: Role::Assistant,
44///     content: None,
45///     tool_calls: Some(vec![ToolCall {
46///         id: "call_123".to_string(),
47///         tool_type: "function".to_string(),
48///         function: FunctionCall {
49///             name: "get_weather".to_string(),
50///             arguments: r#"{"location": "Paris"}"#.to_string(),
51///         },
52///     }]),
53///     tool_call_id: None,
54/// };
55/// ```
56///
57/// ## Tool Result Message
58///
59/// ```
60/// use limit_llm::{Message, Role};
61///
62/// let msg = Message {
63///     role: Role::Tool,
64///     content: Some(r#"{"temp": 22, "condition": "sunny"}"#.to_string()),
65///     tool_calls: None,
66///     tool_call_id: Some("call_123".to_string()),
67/// };
68/// ```
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Message {
71    /// The role of the message sender.
72    pub role: Role,
73
74    /// The text content of the message.
75    ///
76    /// Can be `None` for assistant messages that only contain tool calls.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub content: Option<String>,
79
80    /// Tool calls made by the assistant.
81    ///
82    /// Only present in assistant messages when the LLM decides to call tools.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub tool_calls: Option<Vec<ToolCall>>,
85
86    /// ID of the tool call this message is responding to.
87    ///
88    /// Only present in tool result messages.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub tool_call_id: Option<String>,
91
92    /// Cache control for prompt caching (Anthropic/OpenAI).
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub cache_control: Option<CacheControl>,
95}
96
97/// The role of a message sender in a conversation.
98///
99/// # Serialization
100///
101/// Roles are serialized as lowercase strings:
102/// - `User` → `"user"`
103/// - `Assistant` → `"assistant"`
104/// - `System` → `"system"`
105/// - `Tool` → `"tool"`
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107#[serde(rename_all = "lowercase")]
108pub enum Role {
109    /// A message from the user.
110    User,
111
112    /// A message from the assistant (LLM).
113    Assistant,
114
115    /// A system message providing instructions or context.
116    System,
117
118    /// A tool result message containing the output of a tool execution.
119    Tool,
120}
121
122/// Cache control settings for prompt caching.
123///
124/// Used to enable API-level caching of messages to reduce input token costs.
125/// Supported by Anthropic Claude and OpenAI models.
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
127pub struct CacheControl {
128    /// The type of cache control. Currently only "ephemeral" is supported.
129    #[serde(rename = "type")]
130    pub cache_type: String,
131
132    /// Time-to-live for the cache entry (Anthropic only).
133    /// Options: "5m" (default), "1h" when long retention is enabled.
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub ttl: Option<String>,
136}
137
138impl CacheControl {
139    /// Create a new ephemeral cache control with default TTL.
140    pub fn ephemeral() -> Self {
141        Self {
142            cache_type: "ephemeral".to_string(),
143            ttl: None,
144        }
145    }
146
147    /// Create an ephemeral cache control with long TTL (1 hour).
148    pub fn ephemeral_long() -> Self {
149        Self {
150            cache_type: "ephemeral".to_string(),
151            ttl: Some("1h".to_string()),
152        }
153    }
154}
155
156/// A tool call made by the assistant.
157///
158/// When an LLM decides to use a tool, it returns a `ToolCall` containing
159/// the tool ID, type, and the function to call with its arguments.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ToolCall {
162    /// Unique identifier for this tool call.
163    pub id: String,
164
165    /// The type of tool (always "function" for now).
166    #[serde(rename = "type")]
167    pub tool_type: String,
168
169    /// The function call details.
170    pub function: FunctionCall,
171}
172
173/// A function call with name and JSON arguments.
174///
175/// The `arguments` field contains a JSON string representing the function
176/// parameters as defined in the tool schema.
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct FunctionCall {
179    /// The name of the function to call.
180    pub name: String,
181
182    /// JSON string representation of the function arguments.
183    ///
184    /// This is a string because LLMs return arguments as JSON strings
185    /// during streaming. Parse with `serde_json::from_str` if needed.
186    pub arguments: String,
187}
188
189/// A tool definition for LLM function calling.
190///
191/// Tools allow LLMs to perform actions by calling functions with structured
192/// parameters. Define tools with JSON Schema for the parameters.
193///
194/// # Example
195///
196/// ```
197/// use limit_llm::{Tool, ToolFunction};
198/// use serde_json::json;
199///
200/// let tool = Tool {
201///     tool_type: "function".to_string(),
202///     function: ToolFunction {
203///         name: "get_weather".to_string(),
204///         description: "Get current weather for a location".to_string(),
205///         parameters: json!({
206///             "type": "object",
207///             "properties": {
208///                 "location": {
209///                     "type": "string",
210///                     "description": "City name"
211///                 }
212///             },
213///             "required": ["location"]
214///         }),
215///     },
216/// };
217/// ```
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct Tool {
220    /// The type of tool (always "function" for now).
221    #[serde(rename = "type")]
222    pub tool_type: String,
223
224    /// The function definition.
225    pub function: ToolFunction,
226}
227
228/// Function definition within a tool.
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ToolFunction {
231    /// The function name. Must be unique within the tool set.
232    pub name: String,
233
234    /// Human-readable description of what the function does.
235    /// This helps the LLM understand when to use the tool.
236    pub description: String,
237
238    /// JSON Schema defining the function parameters.
239    ///
240    /// Use `serde_json::json!` to construct the schema inline.
241    pub parameters: serde_json::Value,
242}
243
244/// A complete response from an LLM provider.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct Response {
247    /// The text content of the response.
248    pub content: String,
249
250    /// Tool calls made by the assistant, if any.
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub tool_calls: Option<Vec<ToolCall>>,
253
254    /// Token usage statistics.
255    pub usage: Usage,
256}
257
258/// Token usage statistics for a request.
259///
260/// Tracks the number of tokens used in the prompt (input) and
261/// completion (output). Use with [`TrackingDb`](crate::TrackingDb)
262/// to monitor costs across sessions.
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct Usage {
265    /// Number of tokens in the prompt/input.
266    pub input_tokens: u64,
267
268    /// Number of tokens in the completion/output.
269    pub output_tokens: u64,
270
271    /// Number of tokens read from cache (~10% of input cost).
272    #[serde(default, alias = "cache_read_input_tokens")]
273    pub cache_read_tokens: u64,
274
275    /// Number of tokens written to cache.
276    #[serde(default, alias = "cache_creation_input_tokens")]
277    pub cache_write_tokens: u64,
278}
279
280impl Usage {
281    /// Calculate total tokens including cache operations.
282    pub fn total_tokens(&self) -> u64 {
283        self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_message_serialization() {
293        let msg = Message {
294            role: Role::User,
295            content: Some("Hello".to_string()),
296            tool_calls: None,
297            tool_call_id: None,
298            cache_control: None,
299        };
300        let json = serde_json::to_string(&msg).unwrap();
301        let deserialized: Message = serde_json::from_str(&json).unwrap();
302        assert_eq!(msg.content, deserialized.content);
303    }
304
305    #[test]
306    fn test_message_with_tool_calls() {
307        let msg = Message {
308            role: Role::Assistant,
309            content: Some("".to_string()),
310            tool_calls: Some(vec![ToolCall {
311                id: "call_123".to_string(),
312                tool_type: "function".to_string(),
313                function: FunctionCall {
314                    name: "test_tool".to_string(),
315                    arguments: serde_json::json!({"arg": "value"}).to_string(),
316                },
317            }]),
318            tool_call_id: None,
319            cache_control: None,
320        };
321        let json = serde_json::to_string(&msg).unwrap();
322        let deserialized: Message = serde_json::from_str(&json).unwrap();
323        assert!(deserialized.tool_calls.is_some());
324    }
325
326    #[test]
327    fn test_tool_result_message() {
328        let msg = Message {
329            role: Role::Tool,
330            content: Some("result output".to_string()),
331            tool_calls: None,
332            tool_call_id: Some("call_123".to_string()),
333            cache_control: None,
334        };
335        let json = serde_json::to_string(&msg).unwrap();
336        println!("Tool result message JSON: {}", json);
337        assert!(json.contains("tool_call_id"));
338        let deserialized: Message = serde_json::from_str(&json).unwrap();
339        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
340    }
341
342    #[test]
343    fn test_assistant_with_tool_calls_serialization() {
344        let msg = Message {
345            role: Role::Assistant,
346            content: None,
347            tool_calls: Some(vec![ToolCall {
348                id: "call_123".to_string(),
349                tool_type: "function".to_string(),
350                function: FunctionCall {
351                    name: "test_tool".to_string(),
352                    arguments: serde_json::json!({}).to_string(),
353                },
354            }]),
355            tool_call_id: None,
356            cache_control: None,
357        };
358        let json = serde_json::to_string(&msg).unwrap();
359        println!("Assistant with tool_calls JSON: {}", json);
360        assert!(!json.contains("\"content\":null"));
361        assert!(json.contains("tool_calls"));
362    }
363
364    #[test]
365    fn test_role_serialization() {
366        let role = Role::User;
367        let json = serde_json::to_string(&role).unwrap();
368        assert_eq!(json, "\"user\"");
369    }
370
371    #[test]
372    fn test_tool_serialization() {
373        let tool = Tool {
374            tool_type: "function".to_string(),
375            function: ToolFunction {
376                name: "test_tool".to_string(),
377                description: "A test tool".to_string(),
378                parameters: serde_json::json!({"type": "object"}),
379            },
380        };
381        let json = serde_json::to_string(&tool).unwrap();
382        let deserialized: Tool = serde_json::from_str(&json).unwrap();
383        assert_eq!(tool.function.name, deserialized.function.name);
384    }
385
386    #[test]
387    fn test_response_serialization() {
388        let response = Response {
389            content: "Hello, world!".to_string(),
390            tool_calls: None,
391            usage: Usage {
392                input_tokens: 10,
393                output_tokens: 5,
394                cache_read_tokens: 0,
395                cache_write_tokens: 0,
396            },
397        };
398        let json = serde_json::to_string(&response).unwrap();
399        let deserialized: Response = serde_json::from_str(&json).unwrap();
400        assert_eq!(response.content, deserialized.content);
401        assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
402    }
403
404    #[test]
405    fn test_usage_serialization() {
406        let usage = Usage {
407            input_tokens: 100,
408            output_tokens: 50,
409            cache_read_tokens: 0,
410            cache_write_tokens: 0,
411        };
412        let json = serde_json::to_string(&usage).unwrap();
413        let deserialized: Usage = serde_json::from_str(&json).unwrap();
414        assert_eq!(usage.input_tokens, deserialized.input_tokens);
415        assert_eq!(usage.output_tokens, deserialized.output_tokens);
416    }
417
418    #[test]
419    fn test_cache_control_serialization() {
420        let cache = CacheControl::ephemeral();
421        let json = serde_json::to_string(&cache).unwrap();
422        assert_eq!(json, r#"{"type":"ephemeral"}"#);
423
424        let cache_long = CacheControl::ephemeral_long();
425        let json_long = serde_json::to_string(&cache_long).unwrap();
426        assert!(json_long.contains(r#""ttl":"1h""#));
427    }
428
429    #[test]
430    fn test_message_with_cache_control() {
431        let msg = Message {
432            role: Role::User,
433            content: Some("Hello".to_string()),
434            tool_calls: None,
435            tool_call_id: None,
436            cache_control: Some(CacheControl::ephemeral()),
437        };
438        let json = serde_json::to_string(&msg).unwrap();
439        assert!(json.contains("cache_control"));
440        let deserialized: Message = serde_json::from_str(&json).unwrap();
441        assert!(deserialized.cache_control.is_some());
442    }
443
444    #[test]
445    fn test_usage_with_cache_fields() {
446        let usage = Usage {
447            input_tokens: 100,
448            output_tokens: 50,
449            cache_read_tokens: 80,
450            cache_write_tokens: 20,
451        };
452        assert_eq!(usage.total_tokens(), 250);
453
454        let json = serde_json::to_string(&usage).unwrap();
455        assert!(json.contains("cache_read_tokens"));
456    }
457
458    #[test]
459    fn test_usage_anthropic_aliases() {
460        let json = r#"{
461            "input_tokens": 100,
462            "output_tokens": 50,
463            "cache_read_input_tokens": 80,
464            "cache_creation_input_tokens": 20
465        }"#;
466        let usage: Usage = serde_json::from_str(json).unwrap();
467        assert_eq!(usage.input_tokens, 100);
468        assert_eq!(usage.output_tokens, 50);
469        assert_eq!(usage.cache_read_tokens, 80);
470        assert_eq!(usage.cache_write_tokens, 20);
471        assert_eq!(usage.total_tokens(), 250);
472    }
473}