Skip to main content

limit_llm/
types.rs

1//! Core types for LLM message passing and tool definitions.
2//!
3//! This module provides the fundamental types used throughout `limit-llm` for
4//! constructing messages, defining tools, and handling responses from LLM providers.
5//!
6//! # Overview
7//!
8//! - [`Message`] — A single message in a conversation with role and content
9//! - [`Role`] — The sender role (User, Assistant, System, or Tool)
10//! - [`Tool`] / [`ToolCall`] — Function calling definitions for LLM tool use
11//! - [`Response`] — Complete response with content, tool calls, and usage
12//! - [`Usage`] — Token counting for prompt and completion
13
14use serde::{Deserialize, Serialize};
15
16/// A single message in a conversation.
17///
18/// Messages are the fundamental unit of communication with LLM providers.
19/// Each message has a role (who sent it), content (the text), and optionally
20/// tool calls (for function calling).
21///
22/// # Examples
23///
24/// ## User Message
25///
26/// ```
27/// use limit_llm::{Message, Role};
28///
29/// let msg = Message {
30///     role: Role::User,
31///     content: Some("What is the capital of France?".to_string()),
32///     tool_calls: None,
33///     tool_call_id: None,
34///     cache_control: None,
35/// };
36/// ```
37///
38/// ## Assistant Message with Tool Calls
39///
40/// ```
41/// use limit_llm::{Message, Role, ToolCall, FunctionCall};
42/// use serde_json::json;
43///
44/// let msg = Message {
45///     role: Role::Assistant,
46///     content: None,
47///     tool_calls: Some(vec![ToolCall {
48///         id: "call_123".to_string(),
49///         tool_type: "function".to_string(),
50///         function: FunctionCall {
51///             name: "get_weather".to_string(),
52///             arguments: json!({"location": "Paris"}).to_string(),
53///         },
54///     }]),
55///     tool_call_id: None,
56///     cache_control: None,
57/// };
58/// ```
59///
60/// ## Tool Result Message
61///
62/// ```
63/// use limit_llm::{Message, Role};
64///
65/// let msg = Message {
66///     role: Role::Tool,
67///     content: Some(r#"{"temp": 22, "condition": "sunny"}"#.to_string()),
68///     tool_calls: None,
69///     tool_call_id: Some("call_123".to_string()),
70///     cache_control: None,
71/// };
72/// ```
73///
74/// ## Assistant Message with Tool Call
75///
76/// ```
77/// use limit_llm::{Message, Role, ToolCall, FunctionCall};
78///
79/// let msg = Message {
80///     role: Role::Assistant,
81///     content: None,
82///     tool_calls: Some(vec![ToolCall {
83///         id: "call_123".to_string(),
84///         tool_type: "function".to_string(),
85///         function: FunctionCall {
86///             name: "get_weather".to_string(),
87///             arguments: r#"{"location": "Paris"}"#.to_string(),
88///         },
89///     }]),
90///     tool_call_id: None,
91///     cache_control: None,
92/// };
93/// ```
94///
95/// ## Tool Result Message
96///
97/// ```
98/// use limit_llm::{Message, Role};
99///
100/// let msg = Message {
101///     role: Role::Tool,
102///     content: Some(r#"{"temp": 22, "condition": "sunny"}"#.to_string()),
103///     tool_calls: None,
104///     tool_call_id: Some("call_123".to_string()),
105///     cache_control: None,
106/// };
107/// ```
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct Message {
110    /// The role of the message sender.
111    pub role: Role,
112
113    /// The text content of the message.
114    ///
115    /// Can be `None` for assistant messages that only contain tool calls.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub content: Option<String>,
118
119    /// Tool calls made by the assistant.
120    ///
121    /// Only present in assistant messages when the LLM decides to call tools.
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub tool_calls: Option<Vec<ToolCall>>,
124
125    /// ID of the tool call this message is responding to.
126    ///
127    /// Only present in tool result messages.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub tool_call_id: Option<String>,
130
131    /// Cache control for prompt caching (Anthropic/OpenAI).
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub cache_control: Option<CacheControl>,
134}
135
136/// The role of a message sender in a conversation.
137///
138/// # Serialization
139///
140/// Roles are serialized as lowercase strings:
141/// - `User` → `"user"`
142/// - `Assistant` → `"assistant"`
143/// - `System` → `"system"`
144/// - `Tool` → `"tool"`
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
146#[serde(rename_all = "lowercase")]
147pub enum Role {
148    /// A message from the user.
149    User,
150
151    /// A message from the assistant (LLM).
152    Assistant,
153
154    /// A system message providing instructions or context.
155    System,
156
157    /// A tool result message containing the output of a tool execution.
158    Tool,
159}
160
161/// Cache control settings for prompt caching.
162///
163/// Used to enable API-level caching of messages to reduce input token costs.
164/// Supported by Anthropic Claude and OpenAI models.
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct CacheControl {
167    /// The type of cache control. Currently only "ephemeral" is supported.
168    #[serde(rename = "type")]
169    pub cache_type: String,
170
171    /// Time-to-live for the cache entry (Anthropic only).
172    /// Options: "5m" (default), "1h" when long retention is enabled.
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub ttl: Option<String>,
175}
176
177impl CacheControl {
178    /// Create a new ephemeral cache control with default TTL.
179    pub fn ephemeral() -> Self {
180        Self {
181            cache_type: "ephemeral".to_string(),
182            ttl: None,
183        }
184    }
185
186    /// Create an ephemeral cache control with long TTL (1 hour).
187    pub fn ephemeral_long() -> Self {
188        Self {
189            cache_type: "ephemeral".to_string(),
190            ttl: Some("1h".to_string()),
191        }
192    }
193}
194
195/// A tool call made by the assistant.
196///
197/// When an LLM decides to use a tool, it returns a `ToolCall` containing
198/// the tool ID, type, and the function to call with its arguments.
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct ToolCall {
201    /// Unique identifier for this tool call.
202    pub id: String,
203
204    /// The type of tool (always "function" for now).
205    #[serde(rename = "type")]
206    pub tool_type: String,
207
208    /// The function call details.
209    pub function: FunctionCall,
210}
211
212/// A function call with name and JSON arguments.
213///
214/// The `arguments` field contains a JSON string representing the function
215/// parameters as defined in the tool schema.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct FunctionCall {
218    /// The name of the function to call.
219    pub name: String,
220
221    /// JSON string representation of the function arguments.
222    ///
223    /// This is a string because LLMs return arguments as JSON strings
224    /// during streaming. Parse with `serde_json::from_str` if needed.
225    pub arguments: String,
226}
227
228/// A tool definition for LLM function calling.
229///
230/// Tools allow LLMs to perform actions by calling functions with structured
231/// parameters. Define tools with JSON Schema for the parameters.
232///
233/// # Example
234///
235/// ```
236/// use limit_llm::{Tool, ToolFunction};
237/// use serde_json::json;
238///
239/// let tool = Tool {
240///     tool_type: "function".to_string(),
241///     function: ToolFunction {
242///         name: "get_weather".to_string(),
243///         description: "Get current weather for a location".to_string(),
244///         parameters: json!({
245///             "type": "object",
246///             "properties": {
247///                 "location": {
248///                     "type": "string",
249///                     "description": "City name"
250///                 }
251///             },
252///             "required": ["location"]
253///         }),
254///     },
255/// };
256/// ```
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct Tool {
259    /// The type of tool (always "function" for now).
260    #[serde(rename = "type")]
261    pub tool_type: String,
262
263    /// The function definition.
264    pub function: ToolFunction,
265}
266
267/// Function definition within a tool.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ToolFunction {
270    /// The function name. Must be unique within the tool set.
271    pub name: String,
272
273    /// Human-readable description of what the function does.
274    /// This helps the LLM understand when to use the tool.
275    pub description: String,
276
277    /// JSON Schema defining the function parameters.
278    ///
279    /// Use `serde_json::json!` to construct the schema inline.
280    pub parameters: serde_json::Value,
281}
282
283/// A complete response from an LLM provider.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct Response {
286    /// The text content of the response.
287    pub content: String,
288
289    /// Tool calls made by the assistant, if any.
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub tool_calls: Option<Vec<ToolCall>>,
292
293    /// Token usage statistics.
294    pub usage: Usage,
295}
296
297/// Token usage statistics for a request.
298///
299/// Tracks the number of tokens used in the prompt (input) and
300/// completion (output). Use with [`TrackingDb`](crate::TrackingDb)
301/// to monitor costs across sessions.
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct Usage {
304    /// Number of tokens in the prompt/input.
305    pub input_tokens: u64,
306
307    /// Number of tokens in the completion/output.
308    pub output_tokens: u64,
309
310    /// Number of tokens read from cache (~10% of input cost).
311    #[serde(default, alias = "cache_read_input_tokens")]
312    pub cache_read_tokens: u64,
313
314    /// Number of tokens written to cache.
315    #[serde(default, alias = "cache_creation_input_tokens")]
316    pub cache_write_tokens: u64,
317}
318
319impl Usage {
320    /// Calculate total tokens including cache operations.
321    pub fn total_tokens(&self) -> u64 {
322        self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_message_serialization() {
332        let msg = Message {
333            role: Role::User,
334            content: Some("Hello".to_string()),
335            tool_calls: None,
336            tool_call_id: None,
337            cache_control: None,
338        };
339        let json = serde_json::to_string(&msg).unwrap();
340        let deserialized: Message = serde_json::from_str(&json).unwrap();
341        assert_eq!(msg.content, deserialized.content);
342    }
343
344    #[test]
345    fn test_message_with_tool_calls() {
346        let msg = Message {
347            role: Role::Assistant,
348            content: Some("".to_string()),
349            tool_calls: Some(vec![ToolCall {
350                id: "call_123".to_string(),
351                tool_type: "function".to_string(),
352                function: FunctionCall {
353                    name: "test_tool".to_string(),
354                    arguments: serde_json::json!({"arg": "value"}).to_string(),
355                },
356            }]),
357            tool_call_id: None,
358            cache_control: None,
359        };
360        let json = serde_json::to_string(&msg).unwrap();
361        let deserialized: Message = serde_json::from_str(&json).unwrap();
362        assert!(deserialized.tool_calls.is_some());
363    }
364
365    #[test]
366    fn test_tool_result_message() {
367        let msg = Message {
368            role: Role::Tool,
369            content: Some("result output".to_string()),
370            tool_calls: None,
371            tool_call_id: Some("call_123".to_string()),
372            cache_control: None,
373        };
374        let json = serde_json::to_string(&msg).unwrap();
375        println!("Tool result message JSON: {}", json);
376        assert!(json.contains("tool_call_id"));
377        let deserialized: Message = serde_json::from_str(&json).unwrap();
378        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
379    }
380
381    #[test]
382    fn test_assistant_with_tool_calls_serialization() {
383        let msg = Message {
384            role: Role::Assistant,
385            content: None,
386            tool_calls: Some(vec![ToolCall {
387                id: "call_123".to_string(),
388                tool_type: "function".to_string(),
389                function: FunctionCall {
390                    name: "test_tool".to_string(),
391                    arguments: serde_json::json!({}).to_string(),
392                },
393            }]),
394            tool_call_id: None,
395            cache_control: None,
396        };
397        let json = serde_json::to_string(&msg).unwrap();
398        println!("Assistant with tool_calls JSON: {}", json);
399        assert!(!json.contains("\"content\":null"));
400        assert!(json.contains("tool_calls"));
401    }
402
403    #[test]
404    fn test_role_serialization() {
405        let role = Role::User;
406        let json = serde_json::to_string(&role).unwrap();
407        assert_eq!(json, "\"user\"");
408    }
409
410    #[test]
411    fn test_tool_serialization() {
412        let tool = Tool {
413            tool_type: "function".to_string(),
414            function: ToolFunction {
415                name: "test_tool".to_string(),
416                description: "A test tool".to_string(),
417                parameters: serde_json::json!({"type": "object"}),
418            },
419        };
420        let json = serde_json::to_string(&tool).unwrap();
421        let deserialized: Tool = serde_json::from_str(&json).unwrap();
422        assert_eq!(tool.function.name, deserialized.function.name);
423    }
424
425    #[test]
426    fn test_response_serialization() {
427        let response = Response {
428            content: "Hello, world!".to_string(),
429            tool_calls: None,
430            usage: Usage {
431                input_tokens: 10,
432                output_tokens: 5,
433                cache_read_tokens: 0,
434                cache_write_tokens: 0,
435            },
436        };
437        let json = serde_json::to_string(&response).unwrap();
438        let deserialized: Response = serde_json::from_str(&json).unwrap();
439        assert_eq!(response.content, deserialized.content);
440        assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
441    }
442
443    #[test]
444    fn test_usage_serialization() {
445        let usage = Usage {
446            input_tokens: 100,
447            output_tokens: 50,
448            cache_read_tokens: 0,
449            cache_write_tokens: 0,
450        };
451        let json = serde_json::to_string(&usage).unwrap();
452        let deserialized: Usage = serde_json::from_str(&json).unwrap();
453        assert_eq!(usage.input_tokens, deserialized.input_tokens);
454        assert_eq!(usage.output_tokens, deserialized.output_tokens);
455    }
456
457    #[test]
458    fn test_cache_control_serialization() {
459        let cache = CacheControl::ephemeral();
460        let json = serde_json::to_string(&cache).unwrap();
461        assert_eq!(json, r#"{"type":"ephemeral"}"#);
462
463        let cache_long = CacheControl::ephemeral_long();
464        let json_long = serde_json::to_string(&cache_long).unwrap();
465        assert!(json_long.contains(r#""ttl":"1h""#));
466    }
467
468    #[test]
469    fn test_message_with_cache_control() {
470        let msg = Message {
471            role: Role::User,
472            content: Some("Hello".to_string()),
473            tool_calls: None,
474            tool_call_id: None,
475            cache_control: Some(CacheControl::ephemeral()),
476        };
477        let json = serde_json::to_string(&msg).unwrap();
478        assert!(json.contains("cache_control"));
479        let deserialized: Message = serde_json::from_str(&json).unwrap();
480        assert!(deserialized.cache_control.is_some());
481    }
482
483    #[test]
484    fn test_usage_with_cache_fields() {
485        let usage = Usage {
486            input_tokens: 100,
487            output_tokens: 50,
488            cache_read_tokens: 80,
489            cache_write_tokens: 20,
490        };
491        assert_eq!(usage.total_tokens(), 250);
492
493        let json = serde_json::to_string(&usage).unwrap();
494        assert!(json.contains("cache_read_tokens"));
495    }
496
497    #[test]
498    fn test_usage_anthropic_aliases() {
499        let json = r#"{
500            "input_tokens": 100,
501            "output_tokens": 50,
502            "cache_read_input_tokens": 80,
503            "cache_creation_input_tokens": 20
504        }"#;
505        let usage: Usage = serde_json::from_str(json).unwrap();
506        assert_eq!(usage.input_tokens, 100);
507        assert_eq!(usage.output_tokens, 50);
508        assert_eq!(usage.cache_read_tokens, 80);
509        assert_eq!(usage.cache_write_tokens, 20);
510        assert_eq!(usage.total_tokens(), 250);
511    }
512}