use super::types::{ErrorHandler, ResponseFormat, ToolInput, ToolOutput};
use crate::error::{Result, CognisError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolSchema {
pub name: String,
pub description: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extras: Option<HashMap<String, Value>>,
}
pub trait BaseToolkit: Send + Sync {
fn get_tools(&self) -> Vec<Box<dyn BaseTool>>;
}
#[async_trait]
pub trait BaseTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn args_schema(&self) -> Option<Value> {
None
}
fn tool_call_schema(&self) -> Value {
self.args_schema()
.unwrap_or(Value::Object(Default::default()))
}
fn return_direct(&self) -> bool {
false
}
fn handle_tool_error(&self) -> &ErrorHandler {
&ErrorHandler::Propagate
}
fn handle_validation_error(&self) -> &ErrorHandler {
&ErrorHandler::Propagate
}
fn response_format(&self) -> ResponseFormat {
ResponseFormat::Content
}
fn tags(&self) -> &[String] {
&[]
}
fn metadata(&self) -> Option<&HashMap<String, Value>> {
None
}
fn extras(&self) -> Option<&HashMap<String, Value>> {
None
}
async fn _run(&self, input: ToolInput) -> Result<ToolOutput>;
async fn run(&self, input: ToolInput, _tool_call_id: Option<&str>) -> Result<Value> {
match self._run(input).await {
Ok(output) => {
let content = match output {
ToolOutput::Content(v) => v,
ToolOutput::ContentAndArtifact { content, .. } => content,
};
Ok(content)
}
Err(CognisError::ToolException(msg)) => match self.handle_tool_error() {
ErrorHandler::Propagate => Err(CognisError::ToolException(msg)),
ErrorHandler::DefaultMessage => Ok(Value::String(msg)),
ErrorHandler::StaticMessage(s) => Ok(Value::String(s.clone())),
ErrorHandler::Dynamic(f) => Ok(Value::String(f(&msg))),
},
Err(CognisError::ToolValidationError(msg)) => match self.handle_validation_error() {
ErrorHandler::Propagate => Err(CognisError::ToolValidationError(msg)),
ErrorHandler::DefaultMessage => Ok(Value::String(msg)),
ErrorHandler::StaticMessage(s) => Ok(Value::String(s.clone())),
ErrorHandler::Dynamic(f) => Ok(Value::String(f(&msg))),
},
Err(e) => Err(e),
}
}
async fn run_str(&self, input: &str) -> Result<Value> {
self.run(ToolInput::Text(input.to_string()), None).await
}
async fn run_json(&self, input: &Value) -> Result<Value> {
let map: HashMap<String, Value> = match input {
Value::Object(m) => m.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
Value::String(s) => return self.run(ToolInput::Text(s.clone()), None).await,
_ => return self.run(ToolInput::Text(input.to_string()), None).await,
};
self.run(ToolInput::Structured(map), None).await
}
}