use std::collections::BTreeMap;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use crate::error::{CoreError, Result};
pub type JsonValue = Value;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ParameterSchema {
#[serde(flatten)]
pub fields: BTreeMap<String, Value>,
}
pub struct InvokeContext {
pub tool_call_id: String,
pub cancel: CancellationToken,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub name: String,
pub input: Value,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub label: String,
pub description: String,
pub parameters: ParameterSchema,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn definition(&self) -> ToolDefinition;
async fn execute(&self, ctx: InvokeContext, input: Value) -> Result<ToolResult>;
}
pub fn validate_input(def: &ToolDefinition, input: &Value) -> Result<()> {
let Some(obj) = input.as_object() else {
return Err(CoreError::ToolInputValidation(format!(
"tool `{}` expects an object input",
def.name
)));
};
if let Some(required) = def
.parameters
.fields
.get("required")
.and_then(Value::as_array)
{
for req in required {
if let Some(key) = req.as_str() {
if !obj.contains_key(key) {
return Err(CoreError::ToolInputValidation(format!(
"tool `{}` missing required parameter `{key}`",
def.name
)));
}
}
}
}
Ok(())
}