claudius/types/
tool_choice.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(tag = "type")]
12#[serde(rename_all = "lowercase")]
13pub enum ToolChoice {
14 Auto {
16 #[serde(skip_serializing_if = "Option::is_none")]
20 disable_parallel_tool_use: Option<bool>,
21 },
22
23 Any {
25 #[serde(skip_serializing_if = "Option::is_none")]
29 disable_parallel_tool_use: Option<bool>,
30 },
31
32 Tool {
34 name: String,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
41 disable_parallel_tool_use: Option<bool>,
42 },
43
44 None,
46}
47
48impl ToolChoice {
49 pub fn auto() -> Self {
51 Self::Auto {
52 disable_parallel_tool_use: None,
53 }
54 }
55
56 pub fn auto_with_disable_parallel(disable: bool) -> Self {
58 Self::Auto {
59 disable_parallel_tool_use: Some(disable),
60 }
61 }
62
63 pub fn any() -> Self {
65 Self::Any {
66 disable_parallel_tool_use: None,
67 }
68 }
69
70 pub fn any_with_disable_parallel(disable: bool) -> Self {
72 Self::Any {
73 disable_parallel_tool_use: Some(disable),
74 }
75 }
76
77 pub fn tool(name: impl Into<String>) -> Self {
79 Self::Tool {
80 name: name.into(),
81 disable_parallel_tool_use: None,
82 }
83 }
84
85 pub fn tool_with_disable_parallel(name: impl Into<String>, disable: bool) -> Self {
87 Self::Tool {
88 name: name.into(),
89 disable_parallel_tool_use: Some(disable),
90 }
91 }
92
93 pub fn none() -> Self {
95 Self::None
96 }
97}
98
99impl Default for ToolChoice {
100 fn default() -> Self {
101 Self::auto()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use serde_json::{json, to_value};
109
110 #[test]
111 fn tool_choice_auto() {
112 let param = ToolChoice::auto();
113 let json = to_value(¶m).unwrap();
114
115 assert_eq!(
116 json,
117 json!({
118 "type": "auto"
119 })
120 );
121 }
122
123 #[test]
124 fn tool_choice_any() {
125 let param = ToolChoice::any();
126 let json = to_value(¶m).unwrap();
127
128 assert_eq!(
129 json,
130 json!({
131 "type": "any"
132 })
133 );
134 }
135
136 #[test]
137 fn tool_choice_tool() {
138 let param = ToolChoice::tool("my_tool");
139 let json = to_value(¶m).unwrap();
140
141 assert_eq!(
142 json,
143 json!({
144 "name": "my_tool",
145 "type": "tool"
146 })
147 );
148 }
149
150 #[test]
151 fn tool_choice_none() {
152 let param = ToolChoice::none();
153 let json = to_value(¶m).unwrap();
154
155 assert_eq!(
156 json,
157 json!({
158 "type": "none"
159 })
160 );
161 }
162
163 #[test]
164 fn tool_choice_auto_with_disable_parallel() {
165 let param = ToolChoice::auto_with_disable_parallel(true);
166 let json = to_value(¶m).unwrap();
167
168 assert_eq!(
169 json,
170 json!({
171 "type": "auto",
172 "disable_parallel_tool_use": true
173 })
174 );
175 }
176
177 #[test]
178 fn tool_choice_deserialization_auto() {
179 let json = json!({
180 "type": "auto",
181 "disable_parallel_tool_use": true
182 });
183
184 let param: ToolChoice = serde_json::from_value(json).unwrap();
185 match param {
186 ToolChoice::Auto {
187 disable_parallel_tool_use,
188 } => {
189 assert_eq!(disable_parallel_tool_use, Some(true));
190 }
191 _ => panic!("Expected Auto variant"),
192 }
193 }
194
195 #[test]
196 fn tool_choice_deserialization_tool() {
197 let json = json!({
198 "name": "my_tool",
199 "type": "tool",
200 "disable_parallel_tool_use": true
201 });
202
203 let param: ToolChoice = serde_json::from_value(json).unwrap();
204 match param {
205 ToolChoice::Tool {
206 name,
207 disable_parallel_tool_use,
208 } => {
209 assert_eq!(name, "my_tool");
210 assert_eq!(disable_parallel_tool_use, Some(true));
211 }
212 _ => panic!("Expected Tool variant"),
213 }
214 }
215}