Skip to main content

agent_io/llm/
types.rs

1//! Core types for LLM interactions
2
3use serde::{Deserialize, Serialize};
4
5// =============================================================================
6// Tool Definition
7// =============================================================================
8
9/// Tool choice strategy
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(rename_all = "lowercase")]
12#[derive(Default)]
13pub enum ToolChoice {
14    /// Let the model decide whether to call tools
15    #[default]
16    Auto,
17    /// Force the model to call a tool
18    Required,
19    /// Prevent the model from calling tools
20    None,
21    /// Force a specific tool to be called
22    #[serde(untagged)]
23    Named(String),
24}
25
26impl From<&str> for ToolChoice {
27    fn from(s: &str) -> Self {
28        match s.to_lowercase().as_str() {
29            "auto" => ToolChoice::Auto,
30            "required" => ToolChoice::Required,
31            "none" => ToolChoice::None,
32            name => ToolChoice::Named(name.to_string()),
33        }
34    }
35}
36
37/// JSON Schema for tool parameters
38pub type JsonSchema = serde_json::Map<String, serde_json::Value>;
39
40/// Definition of a tool that can be called by the LLM
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ToolDefinition {
43    /// Name of the tool
44    pub name: String,
45    /// Description of what the tool does
46    pub description: String,
47    /// JSON Schema for the tool parameters
48    pub parameters: JsonSchema,
49    /// Whether to use strict schema validation
50    #[serde(default = "default_strict")]
51    pub strict: bool,
52}
53
54fn default_strict() -> bool {
55    true
56}
57
58impl ToolDefinition {
59    /// Create a new tool definition
60    pub fn new(
61        name: impl Into<String>,
62        description: impl Into<String>,
63        parameters: JsonSchema,
64    ) -> Self {
65        Self {
66            name: name.into(),
67            description: description.into(),
68            parameters,
69            strict: true,
70        }
71    }
72
73    /// Set strict mode
74    pub fn with_strict(mut self, strict: bool) -> Self {
75        self.strict = strict;
76        self
77    }
78}
79
80// =============================================================================
81// Function Call
82// =============================================================================
83
84/// Function call from the LLM
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct Function {
87    /// Name of the function to call
88    pub name: String,
89    /// JSON string of arguments
90    pub arguments: String,
91}
92
93impl Function {
94    /// Parse arguments as a specific type
95    pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
96        serde_json::from_str(&self.arguments)
97    }
98}
99
100/// Tool call from the LLM
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
102pub struct ToolCall {
103    /// Unique identifier for the tool call
104    pub id: String,
105    /// The function to call
106    pub function: Function,
107    /// Type of tool (always "function" for now)
108    #[serde(default = "default_tool_type")]
109    #[serde(rename = "type")]
110    pub tool_type: String,
111    /// Thought signature for Gemini thinking models
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub thought_signature: Option<String>,
114}
115
116fn default_tool_type() -> String {
117    "function".to_string()
118}
119
120impl ToolCall {
121    /// Create a new tool call
122    pub fn new(
123        id: impl Into<String>,
124        name: impl Into<String>,
125        arguments: impl Into<String>,
126    ) -> Self {
127        Self {
128            id: id.into(),
129            function: Function {
130                name: name.into(),
131                arguments: arguments.into(),
132            },
133            tool_type: "function".to_string(),
134            thought_signature: None,
135        }
136    }
137
138    /// Parse arguments as a specific type
139    pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
140        self.function.parse_args()
141    }
142}
143
144// =============================================================================
145// Content Parts
146// =============================================================================
147
148/// Text content part
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150pub struct ContentPartText {
151    #[serde(rename = "type")]
152    pub content_type: String,
153    pub text: String,
154}
155
156impl ContentPartText {
157    pub fn new(text: impl Into<String>) -> Self {
158        Self {
159            content_type: "text".to_string(),
160            text: text.into(),
161        }
162    }
163}
164
165impl From<String> for ContentPartText {
166    fn from(text: String) -> Self {
167        Self::new(text)
168    }
169}
170
171impl From<&str> for ContentPartText {
172    fn from(text: &str) -> Self {
173        Self::new(text)
174    }
175}
176
177/// Image content part
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ContentPartImage {
180    #[serde(rename = "type")]
181    pub content_type: String,
182    pub image_url: ImageUrl,
183}
184
185/// Image URL structure
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct ImageUrl {
188    pub url: String,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub detail: Option<String>,
191}
192
193impl ContentPartImage {
194    /// Create from URL
195    pub fn from_url(url: impl Into<String>) -> Self {
196        Self {
197            content_type: "image_url".to_string(),
198            image_url: ImageUrl {
199                url: url.into(),
200                detail: None,
201            },
202        }
203    }
204
205    /// Create from base64 data
206    pub fn from_base64(media_type: &str, data: &str) -> Self {
207        Self {
208            content_type: "image_url".to_string(),
209            image_url: ImageUrl {
210                url: format!("data:{};base64,{}", media_type, data),
211                detail: None,
212            },
213        }
214    }
215}
216
217/// Document content part (for PDFs etc.)
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct ContentPartDocument {
220    #[serde(rename = "type")]
221    pub content_type: String,
222    pub source: DocumentSource,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct DocumentSource {
227    #[serde(rename = "type")]
228    pub source_type: String,
229    pub media_type: String,
230    pub data: String,
231}
232
233impl ContentPartDocument {
234    pub fn from_base64(media_type: impl Into<String>, data: impl Into<String>) -> Self {
235        Self {
236            content_type: "document".to_string(),
237            source: DocumentSource {
238                source_type: "base64".to_string(),
239                media_type: media_type.into(),
240                data: data.into(),
241            },
242        }
243    }
244}
245
246/// Thinking content part
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ContentPartThinking {
249    #[serde(rename = "type")]
250    pub content_type: String,
251    pub thinking: String,
252}
253
254impl ContentPartThinking {
255    pub fn new(thinking: impl Into<String>) -> Self {
256        Self {
257            content_type: "thinking".to_string(),
258            thinking: thinking.into(),
259        }
260    }
261}
262
263/// Redacted thinking content
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct ContentPartRedactedThinking {
266    #[serde(rename = "type")]
267    pub content_type: String,
268    pub data: String,
269}
270
271/// Refusal content part
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct ContentPartRefusal {
274    #[serde(rename = "type")]
275    pub content_type: String,
276    pub refusal: String,
277}
278
279/// Union type for all content parts
280#[derive(Debug, Clone, Serialize, Deserialize)]
281#[serde(untagged)]
282pub enum ContentPart {
283    Text(ContentPartText),
284    Image(ContentPartImage),
285    Document(ContentPartDocument),
286    Thinking(ContentPartThinking),
287    RedactedThinking(ContentPartRedactedThinking),
288    Refusal(ContentPartRefusal),
289}
290
291impl ContentPart {
292    pub fn text(content: impl Into<String>) -> Self {
293        ContentPart::Text(ContentPartText::new(content))
294    }
295
296    pub fn is_text(&self) -> bool {
297        matches!(self, ContentPart::Text(_))
298    }
299
300    pub fn as_text(&self) -> Option<&str> {
301        match self {
302            ContentPart::Text(t) => Some(&t.text),
303            _ => None,
304        }
305    }
306}
307
308impl From<String> for ContentPart {
309    fn from(text: String) -> Self {
310        ContentPart::text(text)
311    }
312}
313
314impl From<&str> for ContentPart {
315    fn from(text: &str) -> Self {
316        ContentPart::text(text)
317    }
318}
319
320// =============================================================================
321// Messages
322// =============================================================================
323
324/// User message
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct UserMessage {
327    pub role: String,
328    pub content: Vec<ContentPart>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub name: Option<String>,
331}
332
333impl UserMessage {
334    pub fn new(content: impl Into<String>) -> Self {
335        Self {
336            role: "user".to_string(),
337            content: vec![ContentPart::text(content)],
338            name: None,
339        }
340    }
341
342    pub fn with_parts(content: Vec<ContentPart>) -> Self {
343        Self {
344            role: "user".to_string(),
345            content,
346            name: None,
347        }
348    }
349
350    pub fn with_name(mut self, name: impl Into<String>) -> Self {
351        self.name = Some(name.into());
352        self
353    }
354}
355
356/// System message
357#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct SystemMessage {
359    pub role: String,
360    pub content: String,
361}
362
363impl SystemMessage {
364    pub fn new(content: impl Into<String>) -> Self {
365        Self {
366            role: "system".to_string(),
367            content: content.into(),
368        }
369    }
370}
371
372/// Developer message (for o1+ models)
373#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct DeveloperMessage {
375    pub role: String,
376    pub content: String,
377}
378
379impl DeveloperMessage {
380    pub fn new(content: impl Into<String>) -> Self {
381        Self {
382            role: "developer".to_string(),
383            content: content.into(),
384        }
385    }
386}
387
388/// Assistant message
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct AssistantMessage {
391    pub role: String,
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub content: Option<String>,
394    #[serde(skip_serializing_if = "Option::is_none")]
395    pub thinking: Option<String>,
396    #[serde(skip_serializing_if = "Option::is_none")]
397    pub redacted_thinking: Option<String>,
398    #[serde(default)]
399    pub tool_calls: Vec<ToolCall>,
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub refusal: Option<String>,
402}
403
404impl AssistantMessage {
405    pub fn new(content: impl Into<String>) -> Self {
406        Self {
407            role: "assistant".to_string(),
408            content: Some(content.into()),
409            thinking: None,
410            redacted_thinking: None,
411            tool_calls: Vec::new(),
412            refusal: None,
413        }
414    }
415
416    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
417        self.tool_calls = tool_calls;
418        self
419    }
420
421    pub fn with_thinking(mut self, thinking: impl Into<String>) -> Self {
422        self.thinking = Some(thinking.into());
423        self
424    }
425
426    pub fn is_empty(&self) -> bool {
427        self.content.is_none() && self.thinking.is_none() && self.tool_calls.is_empty()
428    }
429}
430
431/// Tool result message
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct ToolMessage {
434    pub role: String,
435    pub content: String,
436    pub tool_call_id: String,
437    /// Tool name that produced this result
438    #[serde(skip_serializing_if = "Option::is_none")]
439    pub tool_name: Option<String>,
440    /// Whether this is an ephemeral message
441    #[serde(default)]
442    pub ephemeral: bool,
443    /// Whether this ephemeral message has been destroyed
444    #[serde(default)]
445    pub destroyed: bool,
446}
447
448impl ToolMessage {
449    pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
450        Self {
451            role: "tool".to_string(),
452            content: content.into(),
453            tool_call_id: tool_call_id.into(),
454            tool_name: None,
455            ephemeral: false,
456            destroyed: false,
457        }
458    }
459
460    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
461        self.tool_name = Some(name.into());
462        self
463    }
464
465    pub fn with_ephemeral(mut self, ephemeral: bool) -> Self {
466        self.ephemeral = ephemeral;
467        self
468    }
469
470    pub fn destroy(&mut self) {
471        self.destroyed = true;
472        self.content = "<removed to save context>".to_string();
473    }
474}
475
476/// Union type for all messages
477#[derive(Debug, Clone, Serialize, Deserialize)]
478#[serde(tag = "role")]
479#[serde(rename_all = "lowercase")]
480pub enum Message {
481    User(UserMessage),
482    Assistant(AssistantMessage),
483    System(SystemMessage),
484    Developer(DeveloperMessage),
485    Tool(ToolMessage),
486}
487
488impl Message {
489    pub fn user(content: impl Into<String>) -> Self {
490        Message::User(UserMessage::new(content))
491    }
492
493    pub fn assistant(content: impl Into<String>) -> Self {
494        Message::Assistant(AssistantMessage::new(content))
495    }
496
497    pub fn system(content: impl Into<String>) -> Self {
498        Message::System(SystemMessage::new(content))
499    }
500
501    pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
502        Message::Tool(ToolMessage::new(tool_call_id, content))
503    }
504
505    pub fn role(&self) -> &str {
506        match self {
507            Message::User(_) => "user",
508            Message::Assistant(_) => "assistant",
509            Message::System(_) => "system",
510            Message::Developer(_) => "developer",
511            Message::Tool(_) => "tool",
512        }
513    }
514}
515
516impl From<UserMessage> for Message {
517    fn from(msg: UserMessage) -> Self {
518        Message::User(msg)
519    }
520}
521
522impl From<AssistantMessage> for Message {
523    fn from(msg: AssistantMessage) -> Self {
524        Message::Assistant(msg)
525    }
526}
527
528impl From<SystemMessage> for Message {
529    fn from(msg: SystemMessage) -> Self {
530        Message::System(msg)
531    }
532}
533
534impl From<ToolMessage> for Message {
535    fn from(msg: ToolMessage) -> Self {
536        Message::Tool(msg)
537    }
538}
539
540// =============================================================================
541// Response Types
542// =============================================================================
543
544/// Token usage information
545#[derive(Debug, Clone, Default, Serialize, Deserialize)]
546pub struct Usage {
547    pub prompt_tokens: u64,
548    pub completion_tokens: u64,
549    pub total_tokens: u64,
550    #[serde(skip_serializing_if = "Option::is_none")]
551    pub prompt_cached_tokens: Option<u64>,
552    #[serde(skip_serializing_if = "Option::is_none")]
553    pub prompt_cache_creation_tokens: Option<u64>,
554    #[serde(skip_serializing_if = "Option::is_none")]
555    pub prompt_image_tokens: Option<u64>,
556}
557
558impl Usage {
559    pub fn new(prompt_tokens: u64, completion_tokens: u64) -> Self {
560        Self {
561            prompt_tokens,
562            completion_tokens,
563            total_tokens: prompt_tokens + completion_tokens,
564            ..Default::default()
565        }
566    }
567}
568
569/// Stop reason for completion
570#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
571#[serde(rename_all = "snake_case")]
572pub enum StopReason {
573    EndTurn,
574    StopSequence,
575    ToolUse,
576    MaxTokens,
577    #[serde(other)]
578    Unknown,
579}
580
581/// Chat completion response
582#[derive(Debug, Clone)]
583pub struct ChatCompletion {
584    /// Text content (if any)
585    pub content: Option<String>,
586    /// Thinking content (for extended thinking models)
587    pub thinking: Option<String>,
588    /// Redacted thinking content
589    pub redacted_thinking: Option<String>,
590    /// Tool calls (if any)
591    pub tool_calls: Vec<ToolCall>,
592    /// Token usage
593    pub usage: Option<Usage>,
594    /// Why the completion stopped
595    pub stop_reason: Option<StopReason>,
596}
597
598impl ChatCompletion {
599    pub fn text(content: impl Into<String>) -> Self {
600        Self {
601            content: Some(content.into()),
602            thinking: None,
603            redacted_thinking: None,
604            tool_calls: Vec::new(),
605            usage: None,
606            stop_reason: None,
607        }
608    }
609
610    pub fn with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
611        Self {
612            content: None,
613            thinking: None,
614            redacted_thinking: None,
615            tool_calls,
616            usage: None,
617            stop_reason: Some(StopReason::ToolUse),
618        }
619    }
620
621    pub fn has_tool_calls(&self) -> bool {
622        !self.tool_calls.is_empty()
623    }
624
625    pub fn has_content(&self) -> bool {
626        self.content.is_some() && self.content.as_ref().is_some_and(|c| !c.is_empty())
627    }
628}
629
630// =============================================================================
631// Cache Control
632// =============================================================================
633
634/// Cache control for prompt caching
635#[derive(Debug, Clone, Serialize, Deserialize)]
636pub struct CacheControl {
637    #[serde(rename = "type")]
638    pub control_type: CacheControlType,
639}
640
641#[derive(Debug, Clone, Serialize, Deserialize)]
642#[serde(rename_all = "snake_case")]
643pub enum CacheControlType {
644    Ephemeral,
645}
646
647impl CacheControl {
648    pub fn ephemeral() -> Self {
649        Self {
650            control_type: CacheControlType::Ephemeral,
651        }
652    }
653}