use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: ToolType,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionTool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_keyword: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub force_search: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_location: Option<UserLocation>,
}
impl Tool {
pub fn function(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
tool_type: ToolType::Function,
function: Some(FunctionTool::new(name, description)),
max_keyword: None,
force_search: None,
limit: None,
user_location: None,
}
}
pub fn function_with_params(
name: impl Into<String>,
description: impl Into<String>,
parameters: HashMap<String, Value>,
) -> Self {
Self {
tool_type: ToolType::Function,
function: Some(FunctionTool::with_params(name, description, parameters)),
max_keyword: None,
force_search: None,
limit: None,
user_location: None,
}
}
pub fn web_search() -> Self {
Self {
tool_type: ToolType::WebSearch,
function: None,
max_keyword: None,
force_search: None,
limit: None,
user_location: None,
}
}
pub fn max_keyword(mut self, max: u32) -> Self {
self.max_keyword = Some(max);
self
}
pub fn force_search(mut self, force: bool) -> Self {
self.force_search = Some(force);
self
}
pub fn limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
pub fn user_location(mut self, location: UserLocation) -> Self {
self.user_location = Some(location);
self
}
pub fn strict(mut self, strict: bool) -> Self {
if let Some(ref mut function) = self.function {
function.strict = Some(strict);
}
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ToolType {
Function,
WebSearch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserLocation {
#[serde(rename = "type")]
pub location_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub country: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub city: Option<String>,
}
impl UserLocation {
pub fn new(location_type: impl Into<String>) -> Self {
Self {
location_type: location_type.into(),
country: None,
region: None,
city: None,
}
}
pub fn approximate() -> Self {
Self::new("approximate")
}
pub fn country(mut self, country: impl Into<String>) -> Self {
self.country = Some(country.into());
self
}
pub fn region(mut self, region: impl Into<String>) -> Self {
self.region = Some(region.into());
self
}
pub fn city(mut self, city: impl Into<String>) -> Self {
self.city = Some(city.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionTool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
impl FunctionTool {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: Some(description.into()),
parameters: None,
strict: None,
}
}
pub fn with_params(
name: impl Into<String>,
description: impl Into<String>,
parameters: HashMap<String, Value>,
) -> Self {
Self {
name: name.into(),
description: Some(description.into()),
parameters: Some(parameters),
strict: None,
}
}
pub fn parameters(mut self, parameters: HashMap<String, Value>) -> Self {
self.parameters = Some(parameters);
self
}
pub fn strict(mut self, strict: bool) -> Self {
self.strict = Some(strict);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ParameterBuilder {
params: HashMap<String, Value>,
}
impl ParameterBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn type_object(mut self) -> Self {
self.params
.insert("type".to_string(), Value::String("object".to_string()));
self
}
pub fn required_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
let properties = self
.params
.entry("properties".to_string())
.or_insert_with(|| Value::Object(serde_json::Map::new()));
if let Value::Object(props) = properties {
props.insert(
name.to_string(),
Value::Object(schema.into_iter().collect()),
);
}
let required = self
.params
.entry("required".to_string())
.or_insert_with(|| Value::Array(Vec::new()));
if let Value::Array(req) = required {
req.push(Value::String(name.to_string()));
}
self
}
pub fn optional_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
let properties = self
.params
.entry("properties".to_string())
.or_insert_with(|| Value::Object(serde_json::Map::new()));
if let Value::Object(props) = properties {
props.insert(
name.to_string(),
Value::Object(schema.into_iter().collect()),
);
}
self
}
pub fn build(self) -> HashMap<String, Value> {
self.params
}
}
pub mod schema {
use serde_json::Value;
use std::collections::HashMap;
pub fn string() -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert("type".to_string(), Value::String("string".to_string()));
map
}
pub fn string_with_description(desc: &str) -> HashMap<String, Value> {
let mut map = string();
map.insert("description".to_string(), Value::String(desc.to_string()));
map
}
pub fn number() -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert("type".to_string(), Value::String("number".to_string()));
map
}
pub fn integer() -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert("type".to_string(), Value::String("integer".to_string()));
map
}
pub fn boolean() -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert("type".to_string(), Value::String("boolean".to_string()));
map
}
pub fn array(items: HashMap<String, Value>) -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert("type".to_string(), Value::String("array".to_string()));
map.insert(
"items".to_string(),
Value::Object(items.into_iter().collect()),
);
map
}
pub fn enum_values(values: &[&str]) -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert("type".to_string(), Value::String("string".to_string()));
map.insert(
"enum".to_string(),
Value::Array(
values
.iter()
.map(|v| Value::String(v.to_string()))
.collect(),
),
);
map
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_creation() {
let tool = Tool::function("get_weather", "Get the current weather");
assert_eq!(tool.tool_type, ToolType::Function);
assert!(tool.function.is_some());
}
#[test]
fn test_web_search_tool() {
let tool = Tool::web_search();
assert_eq!(tool.tool_type, ToolType::WebSearch);
assert!(tool.function.is_none());
}
#[test]
fn test_tool_serialization() {
let tool = Tool::function("get_weather", "Get weather info");
let json = serde_json::to_string(&tool).unwrap();
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"get_weather\""));
}
#[test]
fn test_tool_with_parameters() {
let mut params = HashMap::new();
params.insert("type".to_string(), Value::String("object".to_string()));
let tool =
Tool::function_with_params("get_weather", "Get weather for a location", params.clone());
assert!(tool.function.as_ref().unwrap().parameters.is_some());
}
#[test]
fn test_parameter_builder() {
let params = ParameterBuilder::new()
.type_object()
.required_property("location", schema::string())
.build();
assert_eq!(
params.get("type").unwrap(),
&Value::String("object".to_string())
);
}
#[test]
fn test_schema_helpers() {
let s = schema::string();
assert_eq!(s.get("type").unwrap(), &Value::String("string".to_string()));
let n = schema::number();
assert_eq!(n.get("type").unwrap(), &Value::String("number".to_string()));
let e = schema::enum_values(&["a", "b", "c"]);
assert!(e.contains_key("enum"));
}
#[test]
fn test_strict_mode() {
let tool = Tool::function("test", "test function").strict(true);
assert_eq!(tool.function.unwrap().strict, Some(true));
}
#[test]
fn test_web_search_with_options() {
let tool = Tool::web_search()
.max_keyword(3)
.force_search(true)
.limit(5);
assert_eq!(tool.tool_type, ToolType::WebSearch);
assert_eq!(tool.max_keyword, Some(3));
assert_eq!(tool.force_search, Some(true));
assert_eq!(tool.limit, Some(5));
}
#[test]
fn test_web_search_serialization() {
let tool = Tool::web_search().max_keyword(3).force_search(true);
let json = serde_json::to_string(&tool).unwrap();
assert!(json.contains("\"type\":\"web_search\""));
assert!(json.contains("\"max_keyword\":3"));
assert!(json.contains("\"force_search\":true"));
}
#[test]
fn test_user_location() {
let location = UserLocation::approximate()
.country("China")
.region("Hubei")
.city("Wuhan");
assert_eq!(location.location_type, "approximate");
assert_eq!(location.country, Some("China".to_string()));
assert_eq!(location.region, Some("Hubei".to_string()));
assert_eq!(location.city, Some("Wuhan".to_string()));
}
#[test]
fn test_user_location_serialization() {
let location = UserLocation::approximate().country("China").city("Beijing");
let json = serde_json::to_string(&location).unwrap();
assert!(json.contains("\"type\":\"approximate\""));
assert!(json.contains("\"country\":\"China\""));
assert!(json.contains("\"city\":\"Beijing\""));
assert!(!json.contains("region")); }
#[test]
fn test_web_search_with_location() {
let tool = Tool::web_search().user_location(
UserLocation::approximate()
.country("China")
.city("Shanghai"),
);
assert!(tool.user_location.is_some());
let loc = tool.user_location.unwrap();
assert_eq!(loc.country, Some("China".to_string()));
assert_eq!(loc.city, Some("Shanghai".to_string()));
}
}