use serde::{Deserialize, Serialize};
use crate::messages::cache::CacheControl;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Tool {
Custom(CustomTool),
Builtin(BuiltinTool),
}
impl Tool {
pub fn custom(name: impl Into<String>, input_schema: serde_json::Value) -> Self {
Self::Custom(CustomTool {
name: name.into(),
description: None,
input_schema,
cache_control: None,
})
}
pub fn builtin(value: serde_json::Value) -> Self {
Self::Builtin(BuiltinTool::Other(value))
}
#[must_use]
pub fn web_search() -> Self {
Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::WebSearch20250305 {
name: "web_search".into(),
max_uses: None,
allowed_domains: None,
blocked_domains: None,
user_location: None,
cache_control: None,
}))
}
#[must_use]
pub fn computer(display_width_px: u32, display_height_px: u32) -> Self {
Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::Computer20250124 {
name: "computer".into(),
display_width_px,
display_height_px,
display_number: None,
cache_control: None,
}))
}
#[must_use]
pub fn bash() -> Self {
Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::Bash20250124 {
name: "bash".into(),
cache_control: None,
}))
}
#[must_use]
pub fn text_editor() -> Self {
Self::Builtin(BuiltinTool::Known(KnownBuiltinTool::TextEditor20250124 {
name: "str_replace_editor".into(),
cache_control: None,
}))
}
#[must_use]
pub fn code_execution() -> Self {
Self::Builtin(BuiltinTool::Known(
KnownBuiltinTool::CodeExecution20250825 {
name: "code_execution".into(),
cache_control: None,
},
))
}
#[cfg(feature = "schemars-tools")]
#[cfg_attr(docsrs, doc(cfg(feature = "schemars-tools")))]
pub fn from_schemars<T: schemars::JsonSchema>(name: impl Into<String>) -> Self {
let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<T>();
let schema_value =
serde_json::to_value(schema).expect("RootSchema is always JSON-serializable");
Self::Custom(CustomTool {
name: name.into(),
description: None,
input_schema: schema_value,
cache_control: None,
})
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct CustomTool {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
impl CustomTool {
pub fn new(name: impl Into<String>, input_schema: serde_json::Value) -> Self {
Self {
name: name.into(),
description: None,
input_schema,
cache_control: None,
}
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
#[must_use]
pub fn cache_control(mut self, cache_control: CacheControl) -> Self {
self.cache_control = Some(cache_control);
self
}
#[must_use]
pub fn with_ephemeral_cache(self) -> Self {
self.cache_control(CacheControl::ephemeral())
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BuiltinTool {
Known(KnownBuiltinTool),
Other(serde_json::Value),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub enum KnownBuiltinTool {
#[serde(rename = "web_search_20250305")]
WebSearch20250305 {
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
max_uses: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
allowed_domains: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
blocked_domains: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
user_location: Option<UserLocation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "computer_20250124")]
Computer20250124 {
name: String,
display_width_px: u32,
display_height_px: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
display_number: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "bash_20250124")]
Bash20250124 {
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "text_editor_20250124")]
TextEditor20250124 {
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "code_execution_20250825")]
CodeExecution20250825 {
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
const KNOWN_BUILTIN_TAGS: &[&str] = &[
"web_search_20250305",
"computer_20250124",
"bash_20250124",
"text_editor_20250124",
"code_execution_20250825",
];
impl serde::Serialize for BuiltinTool {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
match self {
BuiltinTool::Known(k) => k.serialize(s),
BuiltinTool::Other(v) => v.serialize(s),
}
}
}
impl<'de> serde::Deserialize<'de> for BuiltinTool {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let raw = serde_json::Value::deserialize(d)?;
crate::forward_compat::dispatch_known_or_other(
raw,
KNOWN_BUILTIN_TAGS,
BuiltinTool::Known,
BuiltinTool::Other,
)
.map_err(serde::de::Error::custom)
}
}
impl From<KnownBuiltinTool> for BuiltinTool {
fn from(k: KnownBuiltinTool) -> Self {
BuiltinTool::Known(k)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct UserLocation {
#[serde(rename = "type", default = "default_user_location_kind")]
pub kind: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub city: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub country: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timezone: Option<String>,
}
fn default_user_location_kind() -> String {
"approximate".to_owned()
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolChoice {
Auto {
#[serde(default, skip_serializing_if = "Option::is_none")]
disable_parallel_tool_use: Option<bool>,
},
Any {
#[serde(default, skip_serializing_if = "Option::is_none")]
disable_parallel_tool_use: Option<bool>,
},
Tool {
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
disable_parallel_tool_use: Option<bool>,
},
None,
}
impl ToolChoice {
#[must_use]
pub fn auto() -> Self {
Self::Auto {
disable_parallel_tool_use: None,
}
}
#[must_use]
pub fn any() -> Self {
Self::Any {
disable_parallel_tool_use: None,
}
}
#[must_use]
pub fn tool(name: impl Into<String>) -> Self {
Self::Tool {
name: name.into(),
disable_parallel_tool_use: None,
}
}
#[must_use]
pub fn none() -> Self {
Self::None
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn custom_tool_round_trips() {
let t = Tool::Custom(
CustomTool::new(
"get_weather",
json!({"type": "object", "properties": {"city": {"type": "string"}}}),
)
.description("Look up the weather"),
);
let v = serde_json::to_value(&t).unwrap();
assert_eq!(
v,
json!({
"name": "get_weather",
"description": "Look up the weather",
"input_schema": {"type": "object", "properties": {"city": {"type": "string"}}}
})
);
let parsed: Tool = serde_json::from_value(v).unwrap();
assert_eq!(parsed, t);
}
#[test]
fn custom_tool_with_cache_control_round_trips() {
let t = Tool::Custom(
CustomTool::new("noop", json!({"type": "object"}))
.cache_control(CacheControl::ephemeral()),
);
let v = serde_json::to_value(&t).unwrap();
assert_eq!(
v,
json!({
"name": "noop",
"input_schema": {"type": "object"},
"cache_control": {"type": "ephemeral"}
})
);
let parsed: Tool = serde_json::from_value(v).unwrap();
assert_eq!(parsed, t);
}
#[test]
fn unknown_builtin_round_trips_through_other() {
let raw = json!({"type": "future_builtin_2099", "name": "future_tool"});
let t = Tool::builtin(raw.clone());
let serialized = serde_json::to_value(&t).unwrap();
assert_eq!(serialized, raw, "Other must serialize transparently");
let parsed: Tool = serde_json::from_value(serialized).unwrap();
assert_eq!(parsed, t);
}
#[test]
fn known_builtin_parses_into_typed_variant() {
let raw = json!({
"type": "web_search_20250305",
"name": "web_search",
"max_uses": 5
});
let parsed: Tool = serde_json::from_value(raw).unwrap();
match parsed {
Tool::Builtin(BuiltinTool::Known(KnownBuiltinTool::WebSearch20250305 {
name,
max_uses,
..
})) => {
assert_eq!(name, "web_search");
assert_eq!(max_uses, Some(5));
}
other => panic!("expected typed WebSearch20250305, got {other:?}"),
}
}
#[test]
fn web_search_default_serializes_to_minimal_wire_form() {
let t = Tool::web_search();
let v = serde_json::to_value(&t).unwrap();
assert_eq!(
v,
json!({"type": "web_search_20250305", "name": "web_search"})
);
}
#[test]
fn web_search_with_options_round_trips() {
let t = Tool::Builtin(BuiltinTool::Known(KnownBuiltinTool::WebSearch20250305 {
name: "web_search".into(),
max_uses: Some(3),
allowed_domains: Some(vec!["wikipedia.org".into()]),
blocked_domains: None,
user_location: Some(UserLocation {
kind: "approximate".into(),
city: Some("Paris".into()),
region: None,
country: Some("FR".into()),
timezone: Some("Europe/Paris".into()),
}),
cache_control: Some(CacheControl::ephemeral()),
}));
let v = serde_json::to_value(&t).unwrap();
assert_eq!(
v,
json!({
"type": "web_search_20250305",
"name": "web_search",
"max_uses": 3,
"allowed_domains": ["wikipedia.org"],
"user_location": {
"type": "approximate",
"city": "Paris",
"country": "FR",
"timezone": "Europe/Paris"
},
"cache_control": {"type": "ephemeral"}
})
);
let parsed: Tool = serde_json::from_value(v).unwrap();
assert_eq!(parsed, t);
}
#[test]
fn computer_default_serializes_with_required_dims() {
let t = Tool::computer(1920, 1080);
let v = serde_json::to_value(&t).unwrap();
assert_eq!(
v,
json!({
"type": "computer_20250124",
"name": "computer",
"display_width_px": 1920,
"display_height_px": 1080
})
);
}
#[test]
fn bash_text_editor_code_execution_defaults_serialize() {
assert_eq!(
serde_json::to_value(Tool::bash()).unwrap(),
json!({"type": "bash_20250124", "name": "bash"})
);
assert_eq!(
serde_json::to_value(Tool::text_editor()).unwrap(),
json!({"type": "text_editor_20250124", "name": "str_replace_editor"})
);
assert_eq!(
serde_json::to_value(Tool::code_execution()).unwrap(),
json!({"type": "code_execution_20250825", "name": "code_execution"})
);
}
#[test]
fn malformed_known_builtin_errors_not_silent_fallthrough() {
let raw = json!({
"type": "computer_20250124",
"name": "computer",
"display_width_px": "wide",
"display_height_px": 1080
});
let result: Result<Tool, _> = serde_json::from_value(raw);
assert!(
result.is_err(),
"malformed known builtin must error, not fall through to Other"
);
}
#[test]
fn untagged_enum_disambiguates_custom_from_builtin() {
let custom: Tool = serde_json::from_value(json!({
"name": "x",
"input_schema": {"type": "object"}
}))
.unwrap();
assert!(matches!(custom, Tool::Custom(_)));
let builtin: Tool = serde_json::from_value(json!({
"type": "web_search_20250305",
"name": "web_search"
}))
.unwrap();
assert!(matches!(builtin, Tool::Builtin(_)));
}
#[test]
fn tool_choice_auto_round_trips() {
let c = ToolChoice::auto();
let v = serde_json::to_value(&c).unwrap();
assert_eq!(v, json!({"type": "auto"}));
let parsed: ToolChoice = serde_json::from_value(v).unwrap();
assert_eq!(parsed, c);
}
#[test]
fn tool_choice_any_with_no_parallel_round_trips() {
let c = ToolChoice::Any {
disable_parallel_tool_use: Some(true),
};
let v = serde_json::to_value(&c).unwrap();
assert_eq!(v, json!({"type": "any", "disable_parallel_tool_use": true}));
let parsed: ToolChoice = serde_json::from_value(v).unwrap();
assert_eq!(parsed, c);
}
#[test]
fn tool_choice_specific_tool_round_trips() {
let c = ToolChoice::tool("get_weather");
let v = serde_json::to_value(&c).unwrap();
assert_eq!(v, json!({"type": "tool", "name": "get_weather"}));
let parsed: ToolChoice = serde_json::from_value(v).unwrap();
assert_eq!(parsed, c);
}
#[test]
fn tool_choice_none_round_trips() {
let c = ToolChoice::none();
let v = serde_json::to_value(&c).unwrap();
assert_eq!(v, json!({"type": "none"}));
let parsed: ToolChoice = serde_json::from_value(v).unwrap();
assert_eq!(parsed, c);
}
#[cfg(feature = "schemars-tools")]
#[test]
fn from_schemars_builds_custom_tool() {
#[derive(schemars::JsonSchema, serde::Deserialize)]
#[allow(dead_code)]
struct Args {
city: String,
units: Option<String>,
}
let t = Tool::from_schemars::<Args>("get_weather");
match t {
Tool::Custom(c) => {
assert_eq!(c.name, "get_weather");
assert!(c.input_schema.is_object());
}
Tool::Builtin(_) => panic!("expected Custom"),
}
}
}