Skip to main content

claudius/types/
tool_choice.rs

1use serde::{Deserialize, Serialize};
2
3/// Configuration for Claude's tool choice behavior.
4///
5/// This can be one of the following:
6/// - "auto": Let the model decide if and when to use tools
7/// - "any": Allow the model to use any available tool
8/// - "tool": Force the model to use a specific named tool
9/// - "none": Do not use any tools
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(tag = "type")]
12#[serde(rename_all = "lowercase")]
13pub enum ToolChoice {
14    /// Automatic tool choice
15    Auto {
16        /// Whether to disable parallel tool use.
17        ///
18        /// Defaults to `false`. If set to `true`, the model will output at most one tool use.
19        #[serde(skip_serializing_if = "Option::is_none")]
20        disable_parallel_tool_use: Option<bool>,
21    },
22
23    /// Any tool choice
24    Any {
25        /// Whether to disable parallel tool use.
26        ///
27        /// Defaults to `false`. If set to `true`, the model will output exactly one tool use.
28        #[serde(skip_serializing_if = "Option::is_none")]
29        disable_parallel_tool_use: Option<bool>,
30    },
31
32    /// Specific tool choice
33    Tool {
34        /// The name of the tool to use.
35        name: String,
36
37        /// Whether to disable parallel tool use.
38        ///
39        /// Defaults to `false`. If set to `true`, the model will output exactly one tool use.
40        #[serde(skip_serializing_if = "Option::is_none")]
41        disable_parallel_tool_use: Option<bool>,
42    },
43
44    /// No tools
45    None,
46}
47
48impl ToolChoice {
49    /// Create a new `ToolChoice` with auto mode.
50    pub fn auto() -> Self {
51        Self::Auto {
52            disable_parallel_tool_use: None,
53        }
54    }
55
56    /// Create a new `ToolChoice` with auto mode, specifying whether to disable parallel tool use.
57    pub fn auto_with_disable_parallel(disable: bool) -> Self {
58        Self::Auto {
59            disable_parallel_tool_use: Some(disable),
60        }
61    }
62
63    /// Create a new `ToolChoice` allowing any tool.
64    pub fn any() -> Self {
65        Self::Any {
66            disable_parallel_tool_use: None,
67        }
68    }
69
70    /// Create a new `ToolChoice` allowing any tool, specifying whether to disable parallel tool use.
71    pub fn any_with_disable_parallel(disable: bool) -> Self {
72        Self::Any {
73            disable_parallel_tool_use: Some(disable),
74        }
75    }
76
77    /// Create a new `ToolChoice` with a specific named tool.
78    pub fn tool(name: impl Into<String>) -> Self {
79        Self::Tool {
80            name: name.into(),
81            disable_parallel_tool_use: None,
82        }
83    }
84
85    /// Create a new `ToolChoice` with a specific named tool, specifying whether to disable parallel tool use.
86    pub fn tool_with_disable_parallel(name: impl Into<String>, disable: bool) -> Self {
87        Self::Tool {
88            name: name.into(),
89            disable_parallel_tool_use: Some(disable),
90        }
91    }
92
93    /// Create a new `ToolChoice` with no tools.
94    pub fn none() -> Self {
95        Self::None
96    }
97}
98
99impl Default for ToolChoice {
100    fn default() -> Self {
101        Self::auto()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use serde_json::{json, to_value};
109
110    #[test]
111    fn tool_choice_auto() {
112        let param = ToolChoice::auto();
113        let json = to_value(&param).unwrap();
114
115        assert_eq!(
116            json,
117            json!({
118                "type": "auto"
119            })
120        );
121    }
122
123    #[test]
124    fn tool_choice_any() {
125        let param = ToolChoice::any();
126        let json = to_value(&param).unwrap();
127
128        assert_eq!(
129            json,
130            json!({
131                "type": "any"
132            })
133        );
134    }
135
136    #[test]
137    fn tool_choice_tool() {
138        let param = ToolChoice::tool("my_tool");
139        let json = to_value(&param).unwrap();
140
141        assert_eq!(
142            json,
143            json!({
144                "name": "my_tool",
145                "type": "tool"
146            })
147        );
148    }
149
150    #[test]
151    fn tool_choice_none() {
152        let param = ToolChoice::none();
153        let json = to_value(&param).unwrap();
154
155        assert_eq!(
156            json,
157            json!({
158                "type": "none"
159            })
160        );
161    }
162
163    #[test]
164    fn tool_choice_auto_with_disable_parallel() {
165        let param = ToolChoice::auto_with_disable_parallel(true);
166        let json = to_value(&param).unwrap();
167
168        assert_eq!(
169            json,
170            json!({
171                "type": "auto",
172                "disable_parallel_tool_use": true
173            })
174        );
175    }
176
177    #[test]
178    fn tool_choice_deserialization_auto() {
179        let json = json!({
180            "type": "auto",
181            "disable_parallel_tool_use": true
182        });
183
184        let param: ToolChoice = serde_json::from_value(json).unwrap();
185        match param {
186            ToolChoice::Auto {
187                disable_parallel_tool_use,
188            } => {
189                assert_eq!(disable_parallel_tool_use, Some(true));
190            }
191            _ => panic!("Expected Auto variant"),
192        }
193    }
194
195    #[test]
196    fn tool_choice_deserialization_tool() {
197        let json = json!({
198            "name": "my_tool",
199            "type": "tool",
200            "disable_parallel_tool_use": true
201        });
202
203        let param: ToolChoice = serde_json::from_value(json).unwrap();
204        match param {
205            ToolChoice::Tool {
206                name,
207                disable_parallel_tool_use,
208            } => {
209                assert_eq!(name, "my_tool");
210                assert_eq!(disable_parallel_tool_use, Some(true));
211            }
212            _ => panic!("Expected Tool variant"),
213        }
214    }
215}