use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use crate::messaging::{AgentMessage, MessageContent, MessageMetadata, MessageRole};
use crate::state::AgentStateSnapshot;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameterSchema {
#[serde(rename = "type")]
pub schema_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, ToolParameterSchema>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<ToolParameterSchema>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<Value>,
#[serde(flatten)]
pub additional: HashMap<String, Value>,
}
impl ToolParameterSchema {
pub fn string(description: impl Into<String>) -> Self {
Self {
schema_type: "string".to_string(),
description: Some(description.into()),
properties: None,
required: None,
items: None,
enum_values: None,
default: None,
additional: HashMap::new(),
}
}
pub fn number(description: impl Into<String>) -> Self {
Self {
schema_type: "number".to_string(),
description: Some(description.into()),
properties: None,
required: None,
items: None,
enum_values: None,
default: None,
additional: HashMap::new(),
}
}
pub fn integer(description: impl Into<String>) -> Self {
Self {
schema_type: "integer".to_string(),
description: Some(description.into()),
properties: None,
required: None,
items: None,
enum_values: None,
default: None,
additional: HashMap::new(),
}
}
pub fn boolean(description: impl Into<String>) -> Self {
Self {
schema_type: "boolean".to_string(),
description: Some(description.into()),
properties: None,
required: None,
items: None,
enum_values: None,
default: None,
additional: HashMap::new(),
}
}
pub fn object(
description: impl Into<String>,
properties: HashMap<String, ToolParameterSchema>,
required: Vec<String>,
) -> Self {
Self {
schema_type: "object".to_string(),
description: Some(description.into()),
properties: Some(properties),
required: Some(required),
items: None,
enum_values: None,
default: None,
additional: HashMap::new(),
}
}
pub fn array(description: impl Into<String>, items: ToolParameterSchema) -> Self {
Self {
schema_type: "array".to_string(),
description: Some(description.into()),
properties: None,
required: None,
items: Some(Box::new(items)),
enum_values: None,
default: None,
additional: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: ToolParameterSchema,
}
impl ToolSchema {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: ToolParameterSchema,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
}
}
pub fn no_params(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: ToolParameterSchema {
schema_type: "object".to_string(),
description: None,
properties: Some(HashMap::new()),
required: Some(Vec::new()),
items: None,
enum_values: None,
default: None,
additional: HashMap::new(),
},
}
}
}
#[derive(Clone)]
pub struct ToolContext {
pub state: Arc<AgentStateSnapshot>,
pub state_handle: Option<Arc<std::sync::RwLock<AgentStateSnapshot>>>,
pub tool_call_id: Option<String>,
}
impl ToolContext {
pub fn new(state: Arc<AgentStateSnapshot>) -> Self {
Self {
state,
state_handle: None,
tool_call_id: None,
}
}
pub fn with_mutable_state(
state: Arc<AgentStateSnapshot>,
state_handle: Arc<std::sync::RwLock<AgentStateSnapshot>>,
) -> Self {
Self {
state,
state_handle: Some(state_handle),
tool_call_id: None,
}
}
pub fn with_call_id(mut self, call_id: Option<String>) -> Self {
self.tool_call_id = call_id;
self
}
pub fn text_response(&self, content: impl Into<String>) -> AgentMessage {
AgentMessage {
role: MessageRole::Tool,
content: MessageContent::Text(content.into()),
metadata: self.tool_call_id.as_ref().map(|id| MessageMetadata {
tool_call_id: Some(id.clone()),
cache_control: None,
}),
}
}
pub fn json_response(&self, content: Value) -> AgentMessage {
AgentMessage {
role: MessageRole::Tool,
content: MessageContent::Json(content),
metadata: self.tool_call_id.as_ref().map(|id| MessageMetadata {
tool_call_id: Some(id.clone()),
cache_control: None,
}),
}
}
}
#[derive(Debug, Clone)]
pub enum ToolResult {
Message(AgentMessage),
WithStateUpdate {
message: AgentMessage,
state_diff: crate::command::StateDiff,
},
}
impl ToolResult {
pub fn text(ctx: &ToolContext, content: impl Into<String>) -> Self {
Self::Message(ctx.text_response(content))
}
pub fn json(ctx: &ToolContext, content: Value) -> Self {
Self::Message(ctx.json_response(content))
}
pub fn with_state(message: AgentMessage, state_diff: crate::command::StateDiff) -> Self {
Self::WithStateUpdate {
message,
state_diff,
}
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn schema(&self) -> ToolSchema;
async fn execute(&self, args: Value, ctx: ToolContext) -> anyhow::Result<ToolResult>;
}
pub type ToolBox = Arc<dyn Tool>;
#[derive(Clone, Default)]
pub struct ToolRegistry {
tools: HashMap<String, ToolBox>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: ToolBox) -> &mut Self {
let name = tool.schema().name.clone();
self.tools.insert(name, tool);
self
}
pub fn register_all<I>(&mut self, tools: I) -> &mut Self
where
I: IntoIterator<Item = ToolBox>,
{
for tool in tools {
self.register(tool);
}
self
}
pub fn get(&self, name: &str) -> Option<&ToolBox> {
self.tools.get(name)
}
pub fn all(&self) -> Vec<&ToolBox> {
self.tools.values().collect()
}
pub fn schemas(&self) -> Vec<ToolSchema> {
self.tools.values().map(|t| t.schema()).collect()
}
pub fn names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub fn has(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}