use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum ToolCaller {
#[default]
Direct,
CodeExecution,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Tool {
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub input_schema: ToolInputSchema,
#[serde(default)]
pub requires_approval: bool,
#[serde(default)]
pub defer_loading: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub allowed_callers: Vec<ToolCaller>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input_examples: Vec<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInputSchema {
#[serde(rename = "type", default = "default_schema_type")]
pub schema_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
fn default_schema_type() -> String {
"object".to_string()
}
impl Default for ToolInputSchema {
fn default() -> Self {
Self {
schema_type: "object".to_string(),
properties: None,
required: None,
}
}
}
impl ToolInputSchema {
pub fn object(properties: HashMap<String, Value>, required: Vec<String>) -> Self {
Self {
schema_type: "object".to_string(),
properties: Some(properties),
required: Some(required),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolUse {
pub id: String,
pub name: String,
pub input: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_use_id: String,
pub content: String,
#[serde(default)]
pub is_error: bool,
}
impl ToolResult {
pub fn success<S: Into<String>>(tool_use_id: S, content: S) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: false,
}
}
pub fn error<S: Into<String>>(tool_use_id: S, error: S) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: error.into(),
is_error: true,
}
}
}
#[derive(Debug, Clone)]
pub struct IdempotencyRecord {
pub executed_at: i64,
pub cached_result: String,
}
#[derive(Debug, Clone, Default)]
pub struct IdempotencyRegistry(Arc<Mutex<HashMap<String, IdempotencyRecord>>>);
impl IdempotencyRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, key: &str) -> Option<IdempotencyRecord> {
self.0
.lock()
.expect("idempotency registry lock poisoned")
.get(key)
.cloned()
}
pub fn record(&self, key: String, result: String) {
let mut map = self.0.lock().expect("idempotency registry lock poisoned");
map.entry(key).or_insert_with(|| {
use chrono::Utc;
IdempotencyRecord {
executed_at: Utc::now().timestamp(),
cached_result: result,
}
});
}
pub fn len(&self) -> usize {
self.0
.lock()
.expect("idempotency registry lock poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct StagedWrite {
pub key: String,
pub target_path: PathBuf,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct CommitResult {
pub committed: usize,
pub paths: Vec<PathBuf>,
}
pub trait StagingBackend: std::fmt::Debug + Send + Sync {
fn stage(&self, write: StagedWrite) -> bool;
fn commit(&self) -> anyhow::Result<CommitResult>;
fn rollback(&self);
fn pending_count(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct ToolContext {
pub working_directory: String,
pub user_id: Option<String>,
pub metadata: HashMap<String, String>,
pub capabilities: Option<Value>,
pub idempotency_registry: Option<IdempotencyRegistry>,
pub staging_backend: Option<Arc<dyn StagingBackend>>,
}
impl ToolContext {
pub fn with_idempotency_registry(mut self) -> Self {
self.idempotency_registry = Some(IdempotencyRegistry::new());
self
}
pub fn with_staging_backend(mut self, backend: Arc<dyn StagingBackend>) -> Self {
self.staging_backend = Some(backend);
self
}
}
impl Default for ToolContext {
fn default() -> Self {
Self {
working_directory: std::env::current_dir()
.ok()
.and_then(|p| p.to_str().map(|s| s.to_string()))
.unwrap_or_else(|| ".".to_string()),
user_id: None,
metadata: HashMap::new(),
capabilities: None,
idempotency_registry: None,
staging_backend: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum ToolMode {
Full,
Explicit(Vec<String>),
#[default]
Smart,
Core,
None,
}
impl ToolMode {
pub fn display_name(&self) -> &'static str {
match self {
ToolMode::Full => "full",
ToolMode::Explicit(_) => "explicit",
ToolMode::Smart => "smart",
ToolMode::Core => "core",
ToolMode::None => "none",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_result_success() {
let result = ToolResult::success("tool-1", "Success!");
assert!(!result.is_error);
}
#[test]
fn test_tool_result_error() {
let result = ToolResult::error("tool-2", "Failed!");
assert!(result.is_error);
}
#[test]
fn test_tool_input_schema_object() {
let mut props = HashMap::new();
props.insert("name".to_string(), json!({"type": "string"}));
let schema = ToolInputSchema::object(props, vec!["name".to_string()]);
assert_eq!(schema.schema_type, "object");
assert!(schema.properties.is_some());
}
#[test]
fn test_idempotency_registry_basic() {
let registry = IdempotencyRegistry::new();
assert!(registry.is_empty());
registry.record("key-1".to_string(), "result-1".to_string());
assert_eq!(registry.len(), 1);
let record = registry.get("key-1").unwrap();
assert_eq!(record.cached_result, "result-1");
assert!(record.executed_at > 0);
registry.record("key-1".to_string(), "result-DIFFERENT".to_string());
assert_eq!(registry.get("key-1").unwrap().cached_result, "result-1");
assert_eq!(registry.len(), 1);
}
#[test]
fn test_idempotency_registry_clone_shares_state() {
let registry = IdempotencyRegistry::new();
let clone = registry.clone();
registry.record("k".to_string(), "v".to_string());
assert!(clone.get("k").is_some());
}
#[test]
fn test_tool_context_default_has_no_registry() {
let ctx = ToolContext::default();
assert!(ctx.idempotency_registry.is_none());
}
#[test]
fn test_tool_context_with_registry() {
let ctx = ToolContext::default().with_idempotency_registry();
assert!(ctx.idempotency_registry.is_some());
assert!(ctx.idempotency_registry.unwrap().is_empty());
}
}