use serde::{Deserialize, Serialize};
use crate::ExtraMap;
fn extensions_are_absent_or_empty(extensions: &Option<ExtraMap>) -> bool {
extensions.as_ref().is_none_or(ExtraMap::is_empty)
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Tool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value,
#[serde(skip_serializing_if = "extensions_are_absent_or_empty")]
pub extensions: Option<ExtraMap>,
}
impl Tool {
#[must_use]
pub fn new(name: impl Into<String>, parameters: serde_json::Value) -> Self {
Self {
name: name.into(),
description: None,
parameters,
extensions: None,
}
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
#[must_use]
pub fn with_extension(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.extensions
.get_or_insert_with(ExtraMap::new)
.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolChoice {
Auto,
Disabled,
Required,
Specific {
name: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn tool_serde_round_trip() {
let tool = Tool::new("search", json!({"type": "object", "properties": {}}))
.description("Search the web")
.with_extension("cache", json!(true));
let serialized = serde_json::to_string(&tool).unwrap();
let deserialized: Tool = serde_json::from_str(&serialized).unwrap();
assert_eq!(tool, deserialized);
}
#[test]
fn tool_serde_skips_none_fields() {
let tool = Tool::new("bare", json!({}));
let value = serde_json::to_value(&tool).unwrap();
let obj = value.as_object().unwrap();
assert!(!obj.contains_key("description"));
assert!(!obj.contains_key("extensions"));
}
#[test]
fn tool_serde_skips_empty_extensions() {
let tool = Tool {
name: "bare".into(),
description: None,
parameters: json!({}),
extensions: Some(ExtraMap::new()),
};
let value = serde_json::to_value(&tool).unwrap();
let obj = value.as_object().unwrap();
assert!(!obj.contains_key("extensions"));
}
#[test]
fn tool_choice_serde_round_trip() {
let cases = [
ToolChoice::Auto,
ToolChoice::Disabled,
ToolChoice::Required,
ToolChoice::Specific {
name: "search".into(),
},
];
for choice in cases {
let value = serde_json::to_value(&choice).unwrap();
let round_tripped: ToolChoice = serde_json::from_value(value).unwrap();
assert_eq!(round_tripped, choice);
}
}
}