Skip to main content

agent_io/llm/types/
tool.rs

1//! Tool definition and function call types
2
3use serde::{Deserialize, Serialize};
4
5use super::content::JsonSchema;
6
7/// Tool choice strategy
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(rename_all = "lowercase")]
10#[derive(Default)]
11pub enum ToolChoice {
12    /// Let the model decide whether to call tools
13    #[default]
14    Auto,
15    /// Force the model to call a tool
16    Required,
17    /// Prevent the model from calling tools
18    None,
19    /// Force a specific tool to be called
20    #[serde(untagged)]
21    Named(String),
22}
23
24impl From<&str> for ToolChoice {
25    fn from(s: &str) -> Self {
26        match s.to_lowercase().as_str() {
27            "auto" => ToolChoice::Auto,
28            "required" => ToolChoice::Required,
29            "none" => ToolChoice::None,
30            name => ToolChoice::Named(name.to_string()),
31        }
32    }
33}
34
35/// Definition of a tool that can be called by the LLM
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ToolDefinition {
38    /// Name of the tool
39    pub name: String,
40    /// Description of what the tool does
41    pub description: String,
42    /// JSON Schema for the tool parameters
43    pub parameters: JsonSchema,
44    /// Whether to use strict schema validation
45    #[serde(default = "default_strict")]
46    pub strict: bool,
47}
48
49fn default_strict() -> bool {
50    true
51}
52
53impl ToolDefinition {
54    /// Create a new tool definition
55    pub fn new(
56        name: impl Into<String>,
57        description: impl Into<String>,
58        parameters: JsonSchema,
59    ) -> Self {
60        Self {
61            name: name.into(),
62            description: description.into(),
63            parameters,
64            strict: true,
65        }
66    }
67
68    /// Set strict mode
69    pub fn with_strict(mut self, strict: bool) -> Self {
70        self.strict = strict;
71        self
72    }
73}
74
75/// Function call from the LLM
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct Function {
78    /// Name of the function to call
79    pub name: String,
80    /// JSON string of arguments
81    pub arguments: String,
82}
83
84impl Function {
85    /// Parse arguments as a specific type
86    pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
87        serde_json::from_str(&self.arguments)
88    }
89}
90
91/// Tool call from the LLM
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ToolCall {
94    /// Unique identifier for the tool call
95    pub id: String,
96    /// The function to call
97    pub function: Function,
98    /// Type of tool (always "function" for now)
99    #[serde(default = "default_tool_type")]
100    #[serde(rename = "type")]
101    pub tool_type: String,
102    /// Thought signature for Gemini thinking models
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub thought_signature: Option<String>,
105}
106
107fn default_tool_type() -> String {
108    "function".to_string()
109}
110
111impl ToolCall {
112    /// Create a new tool call
113    pub fn new(
114        id: impl Into<String>,
115        name: impl Into<String>,
116        arguments: impl Into<String>,
117    ) -> Self {
118        Self {
119            id: id.into(),
120            function: Function {
121                name: name.into(),
122                arguments: arguments.into(),
123            },
124            tool_type: "function".to_string(),
125            thought_signature: None,
126        }
127    }
128
129    /// Parse arguments as a specific type
130    pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
131        self.function.parse_args()
132    }
133}