use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::OpenRouterError;
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
#[builder(build_fn(error = "OpenRouterError"))]
pub struct Tool {
#[serde(rename = "type")]
#[builder(default = r#""function".to_string()"#)]
pub tool_type: String,
pub function: FunctionDefinition,
}
impl Tool {
pub fn builder() -> ToolBuilder {
ToolBuilder::default()
}
pub fn new(name: &str, description: &str, parameters: Value) -> Self {
Self {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: name.to_string(),
description: description.to_string(),
parameters,
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
#[builder(build_fn(error = "OpenRouterError"))]
pub struct FunctionDefinition {
#[builder(setter(into))]
pub name: String,
#[builder(setter(into))]
pub description: String,
#[builder(setter(custom))]
pub parameters: Value,
}
impl FunctionDefinition {
pub fn builder() -> FunctionDefinitionBuilder {
FunctionDefinitionBuilder::default()
}
}
impl ToolBuilder {
pub fn name(&mut self, name: &str) -> &mut Self {
self.function = Some(
FunctionDefinition::builder()
.name(name)
.description("")
.parameters(Value::Null)
.build()
.unwrap(),
);
self
}
pub fn description(&mut self, description: &str) -> &mut Self {
if let Some(ref mut func) = self.function {
func.description = description.to_string();
}
self
}
pub fn parameters(&mut self, parameters: Value) -> &mut Self {
if let Some(ref mut func) = self.function {
func.parameters = parameters;
}
self
}
pub fn parameters_from<T: Serialize>(
&mut self,
params: &T,
) -> Result<&mut Self, OpenRouterError> {
let value = serde_json::to_value(params).map_err(OpenRouterError::Serialization)?;
Ok(self.parameters(value))
}
pub fn parameters_json(&mut self, json: &str) -> Result<&mut Self, OpenRouterError> {
let value: Value = serde_json::from_str(json).map_err(OpenRouterError::Serialization)?;
Ok(self.parameters(value))
}
}
impl FunctionDefinitionBuilder {
pub fn parameters(&mut self, parameters: Value) -> &mut Self {
self.parameters = Some(parameters);
self
}
pub fn parameters_from<T: Serialize>(
&mut self,
params: &T,
) -> Result<&mut Self, OpenRouterError> {
let value = serde_json::to_value(params).map_err(OpenRouterError::Serialization)?;
self.parameters = Some(value);
Ok(self)
}
pub fn parameters_json(&mut self, json: &str) -> Result<&mut Self, OpenRouterError> {
let value: Value = serde_json::from_str(json).map_err(OpenRouterError::Serialization)?;
self.parameters = Some(value);
Ok(self)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum ToolChoice {
String(String),
Specific(SpecificToolChoice),
}
impl ToolChoice {
pub fn none() -> Self {
Self::String("none".to_string())
}
pub fn auto() -> Self {
Self::String("auto".to_string())
}
pub fn required() -> Self {
Self::String("required".to_string())
}
pub fn force_tool(tool_name: &str) -> Self {
Self::Specific(SpecificToolChoice {
tool_type: "function".to_string(),
function: SpecificToolFunction {
name: tool_name.to_string(),
},
})
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SpecificToolChoice {
#[serde(rename = "type")]
pub tool_type: String,
pub function: SpecificToolFunction,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SpecificToolFunction {
pub name: String,
}
pub fn create_tool(name: &str, description: &str, properties: Value, required: &[&str]) -> Tool {
let parameters = serde_json::json!({
"type": "object",
"properties": properties,
"required": required
});
Tool::new(name, description, parameters)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_creation() {
let tool = Tool::builder()
.name("test_function")
.description("A test function")
.parameters(json!({"type": "object"}))
.build()
.unwrap();
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "test_function");
assert_eq!(tool.function.description, "A test function");
}
#[test]
fn test_tool_choice_variants() {
let auto = ToolChoice::auto();
let none = ToolChoice::none();
let required = ToolChoice::required();
let specific = ToolChoice::force_tool("my_function");
assert_eq!(serde_json::to_string(&auto).unwrap(), r#""auto""#);
assert_eq!(serde_json::to_string(&none).unwrap(), r#""none""#);
assert_eq!(serde_json::to_string(&required).unwrap(), r#""required""#);
if let ToolChoice::Specific(spec) = specific {
assert_eq!(spec.function.name, "my_function");
} else {
panic!("Expected specific tool choice");
}
}
#[test]
fn test_create_tool_helper() {
let tool = create_tool(
"weather",
"Get weather",
json!({"location": {"type": "string"}}),
&["location"],
);
assert_eq!(tool.function.name, "weather");
assert_eq!(tool.function.description, "Get weather");
let params = &tool.function.parameters;
assert_eq!(params["type"], "object");
assert_eq!(params["required"], json!(["location"]));
}
}