1use serde::{Deserialize, Serialize};
9
10use crate::messages::cache::CacheControl;
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18#[serde(untagged)]
19pub enum Tool {
20 Custom(CustomTool),
22 Builtin(BuiltinTool),
26}
27
28impl Tool {
29 pub fn custom(name: impl Into<String>, input_schema: serde_json::Value) -> Self {
31 Self::Custom(CustomTool {
32 name: name.into(),
33 description: None,
34 input_schema,
35 cache_control: None,
36 })
37 }
38
39 pub fn builtin(value: serde_json::Value) -> Self {
54 Self::Builtin(BuiltinTool::Other(value))
55 }
56
57 #[must_use]
59 pub fn web_search() -> Self {
60 Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::WebSearch20250305 {
61 name: "web_search".into(),
62 max_uses: None,
63 allowed_domains: None,
64 blocked_domains: None,
65 user_location: None,
66 cache_control: None,
67 }))
68 }
69
70 #[must_use]
73 pub fn computer(display_width_px: u32, display_height_px: u32) -> Self {
74 Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::Computer20250124 {
75 name: "computer".into(),
76 display_width_px,
77 display_height_px,
78 display_number: None,
79 cache_control: None,
80 }))
81 }
82
83 #[must_use]
85 pub fn bash() -> Self {
86 Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::Bash20250124 {
87 name: "bash".into(),
88 cache_control: None,
89 }))
90 }
91
92 #[must_use]
94 pub fn text_editor() -> Self {
95 Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::TextEditor20250124 {
96 name: "str_replace_editor".into(),
97 cache_control: None,
98 }))
99 }
100
101 #[must_use]
103 pub fn code_execution() -> Self {
104 Self::Builtin(BuiltinTool::Known(
105 KnownBuiltinTool::CodeExecution20250825 {
106 name: "code_execution".into(),
107 cache_control: None,
108 },
109 ))
110 }
111
112 #[cfg(feature = "schemars-tools")]
120 #[cfg_attr(docsrs, doc(cfg(feature = "schemars-tools")))]
121 pub fn from_schemars<T: schemars::JsonSchema>(name: impl Into<String>) -> Self {
122 let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<T>();
123 let schema_value =
124 serde_json::to_value(schema).expect("RootSchema is always JSON-serializable");
125 Self::Custom(CustomTool {
126 name: name.into(),
127 description: None,
128 input_schema: schema_value,
129 cache_control: None,
130 })
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
136#[non_exhaustive]
137pub struct CustomTool {
138 pub name: String,
140 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub description: Option<String>,
143 pub input_schema: serde_json::Value,
145 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub cache_control: Option<CacheControl>,
148}
149
150impl CustomTool {
151 pub fn new(name: impl Into<String>, input_schema: serde_json::Value) -> Self {
153 Self {
154 name: name.into(),
155 description: None,
156 input_schema,
157 cache_control: None,
158 }
159 }
160
161 #[must_use]
163 pub fn description(mut self, description: impl Into<String>) -> Self {
164 self.description = Some(description.into());
165 self
166 }
167
168 #[must_use]
170 pub fn cache_control(mut self, cache_control: CacheControl) -> Self {
171 self.cache_control = Some(cache_control);
172 self
173 }
174
175 #[must_use]
178 pub fn with_ephemeral_cache(self) -> Self {
179 self.cache_control(CacheControl::ephemeral())
180 }
181}
182
183#[derive(Debug, Clone, PartialEq)]
191pub enum BuiltinTool {
192 Known(KnownBuiltinTool),
194 Other(serde_json::Value),
196}
197
198#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203#[serde(tag = "type")]
204#[non_exhaustive]
205pub enum KnownBuiltinTool {
206 #[serde(rename = "web_search_20250305")]
211 WebSearch20250305 {
212 name: String,
214 #[serde(default, skip_serializing_if = "Option::is_none")]
216 max_uses: Option<u32>,
217 #[serde(default, skip_serializing_if = "Option::is_none")]
219 allowed_domains: Option<Vec<String>>,
220 #[serde(default, skip_serializing_if = "Option::is_none")]
222 blocked_domains: Option<Vec<String>>,
223 #[serde(default, skip_serializing_if = "Option::is_none")]
225 user_location: Option<UserLocation>,
226 #[serde(default, skip_serializing_if = "Option::is_none")]
228 cache_control: Option<CacheControl>,
229 },
230 #[serde(rename = "computer_20250124")]
233 Computer20250124 {
234 name: String,
236 display_width_px: u32,
238 display_height_px: u32,
240 #[serde(default, skip_serializing_if = "Option::is_none")]
242 display_number: Option<u32>,
243 #[serde(default, skip_serializing_if = "Option::is_none")]
245 cache_control: Option<CacheControl>,
246 },
247 #[serde(rename = "bash_20250124")]
249 Bash20250124 {
250 name: String,
252 #[serde(default, skip_serializing_if = "Option::is_none")]
254 cache_control: Option<CacheControl>,
255 },
256 #[serde(rename = "text_editor_20250124")]
258 TextEditor20250124 {
259 name: String,
261 #[serde(default, skip_serializing_if = "Option::is_none")]
263 cache_control: Option<CacheControl>,
264 },
265 #[serde(rename = "code_execution_20250825")]
268 CodeExecution20250825 {
269 name: String,
271 #[serde(default, skip_serializing_if = "Option::is_none")]
273 cache_control: Option<CacheControl>,
274 },
275}
276
277const KNOWN_BUILTIN_TAGS: &[&str] = &[
278 "web_search_20250305",
279 "computer_20250124",
280 "bash_20250124",
281 "text_editor_20250124",
282 "code_execution_20250825",
283];
284
285impl serde::Serialize for BuiltinTool {
286 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
287 match self {
288 BuiltinTool::Known(k) => k.serialize(s),
289 BuiltinTool::Other(v) => v.serialize(s),
290 }
291 }
292}
293
294impl<'de> serde::Deserialize<'de> for BuiltinTool {
295 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
296 let raw = serde_json::Value::deserialize(d)?;
297 crate::forward_compat::dispatch_known_or_other(
298 raw,
299 KNOWN_BUILTIN_TAGS,
300 BuiltinTool::Known,
301 BuiltinTool::Other,
302 )
303 .map_err(serde::de::Error::custom)
304 }
305}
306
307impl From<KnownBuiltinTool> for BuiltinTool {
308 fn from(k: KnownBuiltinTool) -> Self {
309 BuiltinTool::Known(k)
310 }
311}
312
313#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
315#[non_exhaustive]
316pub struct UserLocation {
317 #[serde(rename = "type", default = "default_user_location_kind")]
319 pub kind: String,
320 #[serde(default, skip_serializing_if = "Option::is_none")]
322 pub city: Option<String>,
323 #[serde(default, skip_serializing_if = "Option::is_none")]
325 pub region: Option<String>,
326 #[serde(default, skip_serializing_if = "Option::is_none")]
328 pub country: Option<String>,
329 #[serde(default, skip_serializing_if = "Option::is_none")]
331 pub timezone: Option<String>,
332}
333
334fn default_user_location_kind() -> String {
335 "approximate".to_owned()
336}
337
338#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
340#[serde(tag = "type", rename_all = "snake_case")]
341#[non_exhaustive]
342pub enum ToolChoice {
343 Auto {
345 #[serde(default, skip_serializing_if = "Option::is_none")]
347 disable_parallel_tool_use: Option<bool>,
348 },
349 Any {
351 #[serde(default, skip_serializing_if = "Option::is_none")]
353 disable_parallel_tool_use: Option<bool>,
354 },
355 Tool {
357 name: String,
359 #[serde(default, skip_serializing_if = "Option::is_none")]
361 disable_parallel_tool_use: Option<bool>,
362 },
363 None,
365}
366
367impl ToolChoice {
368 #[must_use]
370 pub fn auto() -> Self {
371 Self::Auto {
372 disable_parallel_tool_use: None,
373 }
374 }
375
376 #[must_use]
378 pub fn any() -> Self {
379 Self::Any {
380 disable_parallel_tool_use: None,
381 }
382 }
383
384 #[must_use]
386 pub fn tool(name: impl Into<String>) -> Self {
387 Self::Tool {
388 name: name.into(),
389 disable_parallel_tool_use: None,
390 }
391 }
392
393 #[must_use]
395 pub fn none() -> Self {
396 Self::None
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use pretty_assertions::assert_eq;
404 use serde_json::json;
405
406 #[test]
407 fn custom_tool_round_trips() {
408 let t = Tool::Custom(
409 CustomTool::new(
410 "get_weather",
411 json!({"type": "object", "properties": {"city": {"type": "string"}}}),
412 )
413 .description("Look up the weather"),
414 );
415 let v = serde_json::to_value(&t).unwrap();
416 assert_eq!(
417 v,
418 json!({
419 "name": "get_weather",
420 "description": "Look up the weather",
421 "input_schema": {"type": "object", "properties": {"city": {"type": "string"}}}
422 })
423 );
424 let parsed: Tool = serde_json::from_value(v).unwrap();
425 assert_eq!(parsed, t);
426 }
427
428 #[test]
429 fn custom_tool_with_cache_control_round_trips() {
430 let t = Tool::Custom(
431 CustomTool::new("noop", json!({"type": "object"}))
432 .cache_control(CacheControl::ephemeral()),
433 );
434 let v = serde_json::to_value(&t).unwrap();
435 assert_eq!(
436 v,
437 json!({
438 "name": "noop",
439 "input_schema": {"type": "object"},
440 "cache_control": {"type": "ephemeral"}
441 })
442 );
443 let parsed: Tool = serde_json::from_value(v).unwrap();
444 assert_eq!(parsed, t);
445 }
446
447 #[test]
448 fn unknown_builtin_round_trips_through_other() {
449 let raw = json!({"type": "future_builtin_2099", "name": "future_tool"});
451 let t = Tool::builtin(raw.clone());
452 let serialized = serde_json::to_value(&t).unwrap();
453 assert_eq!(serialized, raw, "Other must serialize transparently");
454 let parsed: Tool = serde_json::from_value(serialized).unwrap();
455 assert_eq!(parsed, t);
456 }
457
458 #[test]
459 fn known_builtin_parses_into_typed_variant() {
460 let raw = json!({
461 "type": "web_search_20250305",
462 "name": "web_search",
463 "max_uses": 5
464 });
465 let parsed: Tool = serde_json::from_value(raw).unwrap();
466 match parsed {
467 Tool::Builtin(BuiltinTool::Known(KnownBuiltinTool::WebSearch20250305 {
468 name,
469 max_uses,
470 ..
471 })) => {
472 assert_eq!(name, "web_search");
473 assert_eq!(max_uses, Some(5));
474 }
475 other => panic!("expected typed WebSearch20250305, got {other:?}"),
476 }
477 }
478
479 #[test]
480 fn web_search_default_serializes_to_minimal_wire_form() {
481 let t = Tool::web_search();
482 let v = serde_json::to_value(&t).unwrap();
483 assert_eq!(
484 v,
485 json!({"type": "web_search_20250305", "name": "web_search"})
486 );
487 }
488
489 #[test]
490 fn web_search_with_options_round_trips() {
491 let t = Tool::Builtin(BuiltinTool::Known(KnownBuiltinTool::WebSearch20250305 {
492 name: "web_search".into(),
493 max_uses: Some(3),
494 allowed_domains: Some(vec!["wikipedia.org".into()]),
495 blocked_domains: None,
496 user_location: Some(UserLocation {
497 kind: "approximate".into(),
498 city: Some("Paris".into()),
499 region: None,
500 country: Some("FR".into()),
501 timezone: Some("Europe/Paris".into()),
502 }),
503 cache_control: Some(CacheControl::ephemeral()),
504 }));
505 let v = serde_json::to_value(&t).unwrap();
506 assert_eq!(
507 v,
508 json!({
509 "type": "web_search_20250305",
510 "name": "web_search",
511 "max_uses": 3,
512 "allowed_domains": ["wikipedia.org"],
513 "user_location": {
514 "type": "approximate",
515 "city": "Paris",
516 "country": "FR",
517 "timezone": "Europe/Paris"
518 },
519 "cache_control": {"type": "ephemeral"}
520 })
521 );
522 let parsed: Tool = serde_json::from_value(v).unwrap();
523 assert_eq!(parsed, t);
524 }
525
526 #[test]
527 fn computer_default_serializes_with_required_dims() {
528 let t = Tool::computer(1920, 1080);
529 let v = serde_json::to_value(&t).unwrap();
530 assert_eq!(
531 v,
532 json!({
533 "type": "computer_20250124",
534 "name": "computer",
535 "display_width_px": 1920,
536 "display_height_px": 1080
537 })
538 );
539 }
540
541 #[test]
542 fn bash_text_editor_code_execution_defaults_serialize() {
543 assert_eq!(
544 serde_json::to_value(Tool::bash()).unwrap(),
545 json!({"type": "bash_20250124", "name": "bash"})
546 );
547 assert_eq!(
548 serde_json::to_value(Tool::text_editor()).unwrap(),
549 json!({"type": "text_editor_20250124", "name": "str_replace_editor"})
550 );
551 assert_eq!(
552 serde_json::to_value(Tool::code_execution()).unwrap(),
553 json!({"type": "code_execution_20250825", "name": "code_execution"})
554 );
555 }
556
557 #[test]
558 fn malformed_known_builtin_errors_not_silent_fallthrough() {
559 let raw = json!({
561 "type": "computer_20250124",
562 "name": "computer",
563 "display_width_px": "wide",
564 "display_height_px": 1080
565 });
566 let result: Result<Tool, _> = serde_json::from_value(raw);
567 assert!(
568 result.is_err(),
569 "malformed known builtin must error, not fall through to Other"
570 );
571 }
572
573 #[test]
574 fn untagged_enum_disambiguates_custom_from_builtin() {
575 let custom: Tool = serde_json::from_value(json!({
577 "name": "x",
578 "input_schema": {"type": "object"}
579 }))
580 .unwrap();
581 assert!(matches!(custom, Tool::Custom(_)));
582
583 let builtin: Tool = serde_json::from_value(json!({
585 "type": "web_search_20250305",
586 "name": "web_search"
587 }))
588 .unwrap();
589 assert!(matches!(builtin, Tool::Builtin(_)));
590 }
591
592 #[test]
593 fn tool_choice_auto_round_trips() {
594 let c = ToolChoice::auto();
595 let v = serde_json::to_value(&c).unwrap();
596 assert_eq!(v, json!({"type": "auto"}));
597 let parsed: ToolChoice = serde_json::from_value(v).unwrap();
598 assert_eq!(parsed, c);
599 }
600
601 #[test]
602 fn tool_choice_any_with_no_parallel_round_trips() {
603 let c = ToolChoice::Any {
604 disable_parallel_tool_use: Some(true),
605 };
606 let v = serde_json::to_value(&c).unwrap();
607 assert_eq!(v, json!({"type": "any", "disable_parallel_tool_use": true}));
608 let parsed: ToolChoice = serde_json::from_value(v).unwrap();
609 assert_eq!(parsed, c);
610 }
611
612 #[test]
613 fn tool_choice_specific_tool_round_trips() {
614 let c = ToolChoice::tool("get_weather");
615 let v = serde_json::to_value(&c).unwrap();
616 assert_eq!(v, json!({"type": "tool", "name": "get_weather"}));
617 let parsed: ToolChoice = serde_json::from_value(v).unwrap();
618 assert_eq!(parsed, c);
619 }
620
621 #[test]
622 fn tool_choice_none_round_trips() {
623 let c = ToolChoice::none();
624 let v = serde_json::to_value(&c).unwrap();
625 assert_eq!(v, json!({"type": "none"}));
626 let parsed: ToolChoice = serde_json::from_value(v).unwrap();
627 assert_eq!(parsed, c);
628 }
629
630 #[cfg(feature = "schemars-tools")]
631 #[test]
632 fn from_schemars_builds_custom_tool() {
633 #[derive(schemars::JsonSchema, serde::Deserialize)]
634 #[allow(dead_code)]
635 struct Args {
636 city: String,
637 units: Option<String>,
638 }
639
640 let t = Tool::from_schemars::<Args>("get_weather");
641 match t {
642 Tool::Custom(c) => {
643 assert_eq!(c.name, "get_weather");
644 assert!(c.input_schema.is_object());
646 }
647 Tool::Builtin(_) => panic!("expected Custom"),
648 }
649 }
650}