use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type JsonSchema = serde_json::Value;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolSpec {
pub name: String,
pub description: String,
pub input_schema: InputSchema,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_schema: Option<JsonSchema>,
}
impl ToolSpec {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema: InputSchema::default(),
output_schema: None,
}
}
pub fn with_input_schema(mut self, schema: JsonSchema) -> Self {
self.input_schema = InputSchema { json: schema };
self
}
pub fn with_output_schema(mut self, schema: JsonSchema) -> Self {
self.output_schema = Some(schema);
self
}
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct InputSchema {
pub json: JsonSchema,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub tool_spec: ToolSpec,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolUse {
pub name: String,
pub tool_use_id: String,
pub input: serde_json::Value,
}
impl ToolUse {
pub fn new(name: impl Into<String>, tool_use_id: impl Into<String>, input: serde_json::Value) -> Self {
Self { name: name.into(), tool_use_id: tool_use_id.into(), input }
}
pub fn get_param<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.input.get(key).and_then(|v| T::deserialize(v).ok())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ToolResultContent {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub json: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image: Option<ImageResultContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub document: Option<DocumentResultContent>,
}
impl ToolResultContent {
pub fn text(text: impl Into<String>) -> Self {
Self { text: Some(text.into()), ..Default::default() }
}
pub fn json(value: serde_json::Value) -> Self {
Self { json: Some(value), ..Default::default() }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageResultContent {
pub format: String,
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DocumentResultContent {
pub format: String,
pub name: String,
pub data: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolResultStatus {
Success,
Error,
}
impl ToolResultStatus {
pub fn as_str(&self) -> &'static str {
match self {
ToolResultStatus::Success => "success",
ToolResultStatus::Error => "error",
}
}
}
impl std::fmt::Display for ToolResultStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResult {
pub tool_use_id: String,
pub status: ToolResultStatus,
pub content: Vec<ToolResultContent>,
}
impl ToolResult {
pub fn success(tool_use_id: impl Into<String>, text: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
status: ToolResultStatus::Success,
content: vec![ToolResultContent::text(text)],
}
}
pub fn success_json(tool_use_id: impl Into<String>, json: serde_json::Value) -> Self {
Self {
tool_use_id: tool_use_id.into(),
status: ToolResultStatus::Success,
content: vec![ToolResultContent::json(json)],
}
}
pub fn error(tool_use_id: impl Into<String>, error_message: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
status: ToolResultStatus::Error,
content: vec![ToolResultContent::text(error_message)],
}
}
pub fn is_success(&self) -> bool { self.status == ToolResultStatus::Success }
pub fn is_error(&self) -> bool { self.status == ToolResultStatus::Error }
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceAuto {}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceAny {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceTool {
pub name: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoice {
Auto(ToolChoiceAuto),
Any(ToolChoiceAny),
Tool(ToolChoiceTool),
}
impl Default for ToolChoice {
fn default() -> Self { Self::Auto(ToolChoiceAuto {}) }
}
impl ToolChoice {
pub fn auto() -> Self { Self::Auto(ToolChoiceAuto {}) }
pub fn any() -> Self { Self::Any(ToolChoiceAny {}) }
pub fn tool(name: impl Into<String>) -> Self { Self::Tool(ToolChoiceTool { name: name.into() }) }
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone)]
pub struct ToolContext {
pub tool_use: ToolUse,
pub invocation_state: HashMap<String, serde_json::Value>,
}
impl ToolContext {
pub fn new(tool_use: ToolUse) -> Self {
Self { tool_use, invocation_state: HashMap::new() }
}
pub fn with_state(tool_use: ToolUse, state: HashMap<String, serde_json::Value>) -> Self {
Self { tool_use, invocation_state: state }
}
pub fn get_state<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.invocation_state.get(key).and_then(|v| T::deserialize(v).ok())
}
pub fn interrupt_id(&self, name: &str) -> String {
format!(
"v1:tool_call:{}:{}",
self.tool_use.tool_use_id,
uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, name.as_bytes())
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_spec_creation() {
let spec = ToolSpec::new("get_weather", "Get weather for a location");
assert_eq!(spec.name, "get_weather");
assert_eq!(spec.description, "Get weather for a location");
}
#[test]
fn test_tool_result_success() {
let result = ToolResult::success("123", "Weather is sunny");
assert!(result.is_success());
assert!(!result.is_error());
}
#[test]
fn test_tool_result_error() {
let result = ToolResult::error("123", "Failed to fetch weather");
assert!(result.is_error());
assert!(!result.is_success());
}
#[test]
fn test_tool_choice_variants() {
let auto = ToolChoice::auto();
assert!(matches!(auto, ToolChoice::Auto(_)));
let any = ToolChoice::any();
assert!(matches!(any, ToolChoice::Any(_)));
let specific = ToolChoice::tool("my_tool");
assert!(matches!(specific, ToolChoice::Tool(t) if t.name == "my_tool"));
}
#[test]
fn test_tool_result_content_serialization() {
let content = ToolResultContent::text("hello");
let json = serde_json::to_string(&content).unwrap();
assert_eq!(json, r#"{"text":"hello"}"#);
}
#[test]
fn test_tool_choice_serialization() {
let auto = ToolChoice::auto();
let json = serde_json::to_string(&auto).unwrap();
assert_eq!(json, r#"{"auto":{}}"#);
let any = ToolChoice::any();
let json = serde_json::to_string(&any).unwrap();
assert_eq!(json, r#"{"any":{}}"#);
let tool = ToolChoice::tool("my_tool");
let json = serde_json::to_string(&tool).unwrap();
assert_eq!(json, r#"{"tool":{"name":"my_tool"}}"#);
}
}