Skip to main content

limit_llm/
types.rs

1//! Core types for LLM message passing and tool definitions.
2//!
3//! This module provides the fundamental types used throughout `limit-llm` for
4//! constructing messages, defining tools, and handling responses from LLM providers.
5//!
6//! # Overview
7//!
8//! - [`Message`] — A single message in a conversation with role and content
9//! - [`Role`] — The sender role (User, Assistant, System, or Tool)
10//! - [`Tool`] / [`ToolCall`] — Function calling definitions for LLM tool use
11//! - [`Response`] — Complete response with content, tool calls, and usage
12//! - [`Usage`] — Token counting for prompt and completion
13
14use serde::{Deserialize, Serialize};
15
16/// Content part for multimodal messages
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "type", rename_all = "lowercase")]
19pub enum ContentPart {
20    /// Text content
21    Text { text: String },
22    /// Image content (base64 or URL)
23    #[serde(rename = "image_url")]
24    ImageUrl { image_url: ImageUrl },
25}
26
27/// Image URL or base64 data
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct ImageUrl {
30    /// The URL or base64 data URI
31    pub url: String,
32    /// Optional detail level (low, high, auto)
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub detail: Option<String>,
35}
36
37impl ContentPart {
38    /// Create a text content part
39    pub fn text(text: impl Into<String>) -> Self {
40        ContentPart::Text { text: text.into() }
41    }
42
43    /// Create an image URL content part
44    pub fn image_url(url: impl Into<String>) -> Self {
45        ContentPart::ImageUrl {
46            image_url: ImageUrl {
47                url: url.into(),
48                detail: None,
49            },
50        }
51    }
52
53    /// Create an image from base64 data
54    pub fn image_base64(media_type: &str, base64_data: &str) -> Self {
55        ContentPart::ImageUrl {
56            image_url: ImageUrl {
57                url: format!("data:{};base64,{}", media_type, base64_data),
58                detail: None,
59            },
60        }
61    }
62}
63
64/// A single message in a conversation.
65///
66/// Messages are the fundamental unit of communication with LLM providers.
67/// Each message has a role (who sent it), content (the text), and optionally
68/// tool calls (for function calling).
69///
70/// # Examples
71///
72/// ## User Message
73///
74/// ```
75/// use limit_llm::{Message, Role};
76///
77/// let msg = Message {
78///     role: Role::User,
79///     content: Some(MessageContent::text("What is the capital of France?")),
80///     tool_calls: None,
81///     tool_call_id: None,
82///     cache_control: None,
83/// };
84/// ```
85///
86/// ## Assistant Message with Tool Calls
87///
88/// ```
89/// use limit_llm::{Message, Role, ToolCall, FunctionCall};
90/// use serde_json::json;
91///
92/// let msg = Message {
93///     role: Role::Assistant,
94///     content: None,
95///     tool_calls: Some(vec![ToolCall {
96///         id: "call_123".to_string(),
97///         tool_type: "function".to_string(),
98///         function: FunctionCall {
99///             name: "get_weather".to_string(),
100///             arguments: json!({"location": "Paris"}).to_string(),
101///         },
102///     }]),
103///     tool_call_id: None,
104///     cache_control: None,
105/// };
106/// ```
107///
108/// ## Tool Result Message
109///
110/// ```
111/// use limit_llm::{Message, Role};
112///
113/// let msg = Message {
114///     role: Role::Tool,
115///     content: Some(r#"{"temp": 22, "condition": "sunny"}"#.to_string()),
116///     tool_calls: None,
117///     tool_call_id: Some("call_123".to_string()),
118///     cache_control: None,
119/// };
120/// ```
121///
122/// ## Assistant Message with Tool Call
123///
124/// ```
125/// use limit_llm::{Message, Role, ToolCall, FunctionCall};
126///
127/// let msg = Message {
128///     role: Role::Assistant,
129///     content: None,
130///     tool_calls: Some(vec![ToolCall {
131///         id: "call_123".to_string(),
132///         tool_type: "function".to_string(),
133///         function: FunctionCall {
134///             name: "get_weather".to_string(),
135///             arguments: r#"{"location": "Paris"}"#.to_string(),
136///         },
137///     }]),
138///     tool_call_id: None,
139///     cache_control: None,
140/// };
141/// ```
142///
143/// ## Tool Result Message
144///
145/// ```
146/// use limit_llm::{Message, Role};
147///
148/// let msg = Message {
149///     role: Role::Tool,
150///     content: Some(r#"{"temp": 22, "condition": "sunny"}"#.to_string()),
151///     tool_calls: None,
152///     tool_call_id: Some("call_123".to_string()),
153///     cache_control: None,
154/// };
155/// ```
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Message {
158    /// The role of the message sender.
159    pub role: Role,
160
161    /// The content of the message.
162    ///
163    /// Can be either a simple string or an array of content parts (for multimodal).
164    /// Can be `None` for assistant messages that only contain tool calls.
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub content: Option<MessageContent>,
167
168    /// Tool calls made by the assistant.
169    ///
170    /// Only present in assistant messages when the LLM decides to call tools.
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub tool_calls: Option<Vec<ToolCall>>,
173
174    /// ID of the tool call this message is responding to.
175    ///
176    /// Only present in tool result messages.
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub tool_call_id: Option<String>,
179
180    /// Cache control for prompt caching (Anthropic/OpenAI).
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub cache_control: Option<CacheControl>,
183}
184
185/// Message content - either simple text or multimodal parts
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187#[serde(untagged)]
188pub enum MessageContent {
189    /// Simple text content
190    Text(String),
191    /// Multimodal content parts (text + images)
192    Parts(Vec<ContentPart>),
193}
194
195impl MessageContent {
196    /// Create simple text content
197    pub fn text(text: impl Into<String>) -> Self {
198        MessageContent::Text(text.into())
199    }
200
201    /// Create multimodal content with parts
202    pub fn parts(parts: Vec<ContentPart>) -> Self {
203        MessageContent::Parts(parts)
204    }
205
206    /// Get text if this is simple text content
207    pub fn as_text(&self) -> Option<&str> {
208        match self {
209            MessageContent::Text(text) => Some(text),
210            MessageContent::Parts(_) => None,
211        }
212    }
213
214    /// Get all text content (concatenates text parts)
215    pub fn to_text(&self) -> String {
216        match self {
217            MessageContent::Text(text) => text.clone(),
218            MessageContent::Parts(parts) => parts
219                .iter()
220                .filter_map(|part| match part {
221                    ContentPart::Text { text } => Some(text.clone()),
222                    _ => None,
223                })
224                .collect::<Vec<_>>()
225                .join(""),
226        }
227    }
228}
229
230impl std::fmt::Display for MessageContent {
231    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232        match self {
233            MessageContent::Text(text) => write!(f, "{}", text),
234            MessageContent::Parts(parts) => {
235                let text = parts
236                    .iter()
237                    .filter_map(|p| match p {
238                        ContentPart::Text { text } => Some(text.as_str()),
239                        _ => None,
240                    })
241                    .collect::<Vec<_>>()
242                    .join("");
243                write!(f, "{}", text)
244            }
245        }
246    }
247}
248
249impl From<String> for MessageContent {
250    fn from(text: String) -> Self {
251        MessageContent::Text(text)
252    }
253}
254
255impl From<&str> for MessageContent {
256    fn from(text: &str) -> Self {
257        MessageContent::Text(text.to_string())
258    }
259}
260
261/// The role of a message sender in a conversation.
262///
263/// # Serialization
264///
265/// Roles are serialized as lowercase strings:
266/// - `User` → `"user"`
267/// - `Assistant` → `"assistant"`
268/// - `System` → `"system"`
269/// - `Tool` → `"tool"`
270#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
271#[serde(rename_all = "lowercase")]
272pub enum Role {
273    /// A message from the user.
274    User,
275
276    /// A message from the assistant (LLM).
277    Assistant,
278
279    /// A system message providing instructions or context.
280    System,
281
282    /// A tool result message containing the output of a tool execution.
283    Tool,
284}
285
286/// Cache control settings for prompt caching.
287///
288/// Used to enable API-level caching of messages to reduce input token costs.
289/// Supported by Anthropic Claude and OpenAI models.
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
291pub struct CacheControl {
292    /// The type of cache control. Currently only "ephemeral" is supported.
293    #[serde(rename = "type")]
294    pub cache_type: String,
295
296    /// Time-to-live for the cache entry (Anthropic only).
297    /// Options: "5m" (default), "1h" when long retention is enabled.
298    #[serde(skip_serializing_if = "Option::is_none")]
299    pub ttl: Option<String>,
300}
301
302impl CacheControl {
303    /// Create a new ephemeral cache control with default TTL.
304    pub fn ephemeral() -> Self {
305        Self {
306            cache_type: "ephemeral".to_string(),
307            ttl: None,
308        }
309    }
310
311    /// Create an ephemeral cache control with long TTL (1 hour).
312    pub fn ephemeral_long() -> Self {
313        Self {
314            cache_type: "ephemeral".to_string(),
315            ttl: Some("1h".to_string()),
316        }
317    }
318}
319
320/// A tool call made by the assistant.
321///
322/// When an LLM decides to use a tool, it returns a `ToolCall` containing
323/// the tool ID, type, and the function to call with its arguments.
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct ToolCall {
326    /// Unique identifier for this tool call.
327    pub id: String,
328
329    /// The type of tool (always "function" for now).
330    #[serde(rename = "type")]
331    pub tool_type: String,
332
333    /// The function call details.
334    pub function: FunctionCall,
335}
336
337/// A function call with name and JSON arguments.
338///
339/// The `arguments` field contains a JSON string representing the function
340/// parameters as defined in the tool schema.
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct FunctionCall {
343    /// The name of the function to call.
344    pub name: String,
345
346    /// JSON string representation of the function arguments.
347    ///
348    /// This is a string because LLMs return arguments as JSON strings
349    /// during streaming. Parse with `serde_json::from_str` if needed.
350    pub arguments: String,
351}
352
353/// A tool definition for LLM function calling.
354///
355/// Tools allow LLMs to perform actions by calling functions with structured
356/// parameters. Define tools with JSON Schema for the parameters.
357///
358/// # Example
359///
360/// ```
361/// use limit_llm::{Tool, ToolFunction};
362/// use serde_json::json;
363///
364/// let tool = Tool {
365///     tool_type: "function".to_string(),
366///     function: ToolFunction {
367///         name: "get_weather".to_string(),
368///         description: "Get current weather for a location".to_string(),
369///         parameters: json!({
370///             "type": "object",
371///             "properties": {
372///                 "location": {
373///                     "type": "string",
374///                     "description": "City name"
375///                 }
376///             },
377///             "required": ["location"]
378///         }),
379///     },
380/// };
381/// ```
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct Tool {
384    /// The type of tool (always "function" for now).
385    #[serde(rename = "type")]
386    pub tool_type: String,
387
388    /// The function definition.
389    pub function: ToolFunction,
390}
391
392/// Function definition within a tool.
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct ToolFunction {
395    /// The function name. Must be unique within the tool set.
396    pub name: String,
397
398    /// Human-readable description of what the function does.
399    /// This helps the LLM understand when to use the tool.
400    pub description: String,
401
402    /// JSON Schema defining the function parameters.
403    ///
404    /// Use `serde_json::json!` to construct the schema inline.
405    pub parameters: serde_json::Value,
406}
407
408/// A complete response from an LLM provider.
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct Response {
411    /// The text content of the response.
412    pub content: String,
413
414    /// Tool calls made by the assistant, if any.
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub tool_calls: Option<Vec<ToolCall>>,
417
418    /// Token usage statistics.
419    pub usage: Usage,
420}
421
422/// Token usage statistics for a request.
423///
424/// Tracks the number of tokens used in the prompt (input) and
425/// completion (output). Use with [`TrackingDb`](crate::TrackingDb)
426/// to monitor costs across sessions.
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct Usage {
429    /// Number of tokens in the prompt/input.
430    pub input_tokens: u64,
431
432    /// Number of tokens in the completion/output.
433    pub output_tokens: u64,
434
435    /// Number of tokens read from cache (~10% of input cost).
436    #[serde(default, alias = "cache_read_input_tokens")]
437    pub cache_read_tokens: u64,
438
439    /// Number of tokens written to cache.
440    #[serde(default, alias = "cache_creation_input_tokens")]
441    pub cache_write_tokens: u64,
442}
443
444impl Usage {
445    /// Calculate total tokens including cache operations.
446    pub fn total_tokens(&self) -> u64 {
447        self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_message_serialization() {
457        let msg = Message {
458            role: Role::User,
459            content: Some(MessageContent::text("Hello")),
460            tool_calls: None,
461            tool_call_id: None,
462            cache_control: None,
463        };
464        let json = serde_json::to_string(&msg).unwrap();
465        let deserialized: Message = serde_json::from_str(&json).unwrap();
466        assert_eq!(msg.content, deserialized.content);
467    }
468
469    #[test]
470    fn test_message_with_tool_calls() {
471        let msg = Message {
472            role: Role::Assistant,
473            content: Some(MessageContent::text("")),
474            tool_calls: Some(vec![ToolCall {
475                id: "call_123".to_string(),
476                tool_type: "function".to_string(),
477                function: FunctionCall {
478                    name: "test_tool".to_string(),
479                    arguments: serde_json::json!({"arg": "value"}).to_string(),
480                },
481            }]),
482            tool_call_id: None,
483            cache_control: None,
484        };
485        let json = serde_json::to_string(&msg).unwrap();
486        let deserialized: Message = serde_json::from_str(&json).unwrap();
487        assert!(deserialized.tool_calls.is_some());
488    }
489
490    #[test]
491    fn test_tool_result_message() {
492        let msg = Message {
493            role: Role::Tool,
494            content: Some(MessageContent::text("result output")),
495            tool_calls: None,
496            tool_call_id: Some("call_123".to_string()),
497            cache_control: None,
498        };
499        let json = serde_json::to_string(&msg).unwrap();
500        println!("Tool result message JSON: {}", json);
501        assert!(json.contains("tool_call_id"));
502        let deserialized: Message = serde_json::from_str(&json).unwrap();
503        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
504    }
505
506    #[test]
507    fn test_assistant_with_tool_calls_serialization() {
508        let msg = Message {
509            role: Role::Assistant,
510            content: None,
511            tool_calls: Some(vec![ToolCall {
512                id: "call_123".to_string(),
513                tool_type: "function".to_string(),
514                function: FunctionCall {
515                    name: "test_tool".to_string(),
516                    arguments: serde_json::json!({}).to_string(),
517                },
518            }]),
519            tool_call_id: None,
520            cache_control: None,
521        };
522        let json = serde_json::to_string(&msg).unwrap();
523        println!("Assistant with tool_calls JSON: {}", json);
524        assert!(!json.contains("\"content\":null"));
525        assert!(json.contains("tool_calls"));
526    }
527
528    #[test]
529    fn test_role_serialization() {
530        let role = Role::User;
531        let json = serde_json::to_string(&role).unwrap();
532        assert_eq!(json, "\"user\"");
533    }
534
535    #[test]
536    fn test_tool_serialization() {
537        let tool = Tool {
538            tool_type: "function".to_string(),
539            function: ToolFunction {
540                name: "test_tool".to_string(),
541                description: "A test tool".to_string(),
542                parameters: serde_json::json!({"type": "object"}),
543            },
544        };
545        let json = serde_json::to_string(&tool).unwrap();
546        let deserialized: Tool = serde_json::from_str(&json).unwrap();
547        assert_eq!(tool.function.name, deserialized.function.name);
548    }
549
550    #[test]
551    fn test_response_serialization() {
552        let response = Response {
553            content: "Hello, world!".to_string(),
554            tool_calls: None,
555            usage: Usage {
556                input_tokens: 10,
557                output_tokens: 5,
558                cache_read_tokens: 0,
559                cache_write_tokens: 0,
560            },
561        };
562        let json = serde_json::to_string(&response).unwrap();
563        let deserialized: Response = serde_json::from_str(&json).unwrap();
564        assert_eq!(response.content, deserialized.content);
565        assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
566    }
567
568    #[test]
569    fn test_usage_serialization() {
570        let usage = Usage {
571            input_tokens: 100,
572            output_tokens: 50,
573            cache_read_tokens: 0,
574            cache_write_tokens: 0,
575        };
576        let json = serde_json::to_string(&usage).unwrap();
577        let deserialized: Usage = serde_json::from_str(&json).unwrap();
578        assert_eq!(usage.input_tokens, deserialized.input_tokens);
579        assert_eq!(usage.output_tokens, deserialized.output_tokens);
580    }
581
582    #[test]
583    fn test_cache_control_serialization() {
584        let cache = CacheControl::ephemeral();
585        let json = serde_json::to_string(&cache).unwrap();
586        assert_eq!(json, r#"{"type":"ephemeral"}"#);
587
588        let cache_long = CacheControl::ephemeral_long();
589        let json_long = serde_json::to_string(&cache_long).unwrap();
590        assert!(json_long.contains(r#""ttl":"1h""#));
591    }
592
593    #[test]
594    fn test_message_with_cache_control() {
595        let msg = Message {
596            role: Role::User,
597            content: Some(MessageContent::text("Hello")),
598            tool_calls: None,
599            tool_call_id: None,
600            cache_control: Some(CacheControl::ephemeral()),
601        };
602        let json = serde_json::to_string(&msg).unwrap();
603        assert!(json.contains("cache_control"));
604        let deserialized: Message = serde_json::from_str(&json).unwrap();
605        assert!(deserialized.cache_control.is_some());
606    }
607
608    #[test]
609    fn test_usage_with_cache_fields() {
610        let usage = Usage {
611            input_tokens: 100,
612            output_tokens: 50,
613            cache_read_tokens: 80,
614            cache_write_tokens: 20,
615        };
616        assert_eq!(usage.total_tokens(), 250);
617
618        let json = serde_json::to_string(&usage).unwrap();
619        assert!(json.contains("cache_read_tokens"));
620    }
621
622    #[test]
623    fn test_usage_anthropic_aliases() {
624        let json = r#"{
625            "input_tokens": 100,
626            "output_tokens": 50,
627            "cache_read_input_tokens": 80,
628            "cache_creation_input_tokens": 20
629        }"#;
630        let usage: Usage = serde_json::from_str(json).unwrap();
631        assert_eq!(usage.input_tokens, 100);
632        assert_eq!(usage.output_tokens, 50);
633        assert_eq!(usage.cache_read_tokens, 80);
634        assert_eq!(usage.cache_write_tokens, 20);
635        assert_eq!(usage.total_tokens(), 250);
636    }
637}