1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex};
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11#[derive(Default)]
12pub enum ToolCaller {
13 #[default]
15 Direct,
16 CodeExecution,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22pub struct Tool {
23 #[serde(default)]
25 pub name: String,
26 #[serde(default)]
28 pub description: String,
29 #[serde(default)]
31 pub input_schema: ToolInputSchema,
32 #[serde(default)]
34 pub requires_approval: bool,
35 #[serde(default)]
37 pub defer_loading: bool,
38 #[serde(default, skip_serializing_if = "Vec::is_empty")]
40 pub allowed_callers: Vec<ToolCaller>,
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub input_examples: Vec<Value>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ToolInputSchema {
49 #[serde(rename = "type", default = "default_schema_type")]
51 pub schema_type: String,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub properties: Option<HashMap<String, Value>>,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub required: Option<Vec<String>>,
58}
59
60fn default_schema_type() -> String {
61 "object".to_string()
62}
63
64impl Default for ToolInputSchema {
65 fn default() -> Self {
66 Self {
67 schema_type: "object".to_string(),
68 properties: None,
69 required: None,
70 }
71 }
72}
73
74impl ToolInputSchema {
75 pub fn object(properties: HashMap<String, Value>, required: Vec<String>) -> Self {
77 Self {
78 schema_type: "object".to_string(),
79 properties: Some(properties),
80 required: Some(required),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ToolUse {
88 pub id: String,
90 pub name: String,
92 pub input: Value,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ToolResult {
99 pub tool_use_id: String,
101 pub content: String,
103 #[serde(default)]
105 pub is_error: bool,
106}
107
108impl ToolResult {
109 pub fn success<S: Into<String>>(tool_use_id: S, content: S) -> Self {
111 Self {
112 tool_use_id: tool_use_id.into(),
113 content: content.into(),
114 is_error: false,
115 }
116 }
117
118 pub fn error<S: Into<String>>(tool_use_id: S, error: S) -> Self {
120 Self {
121 tool_use_id: tool_use_id.into(),
122 content: error.into(),
123 is_error: true,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
132pub struct IdempotencyRecord {
133 pub executed_at: i64,
135 pub cached_result: String,
137}
138
139#[derive(Debug, Clone, Default)]
146pub struct IdempotencyRegistry(Arc<Mutex<HashMap<String, IdempotencyRecord>>>);
147
148impl IdempotencyRegistry {
149 pub fn new() -> Self {
151 Self::default()
152 }
153
154 pub fn get(&self, key: &str) -> Option<IdempotencyRecord> {
156 self.0
157 .lock()
158 .expect("idempotency registry lock poisoned")
159 .get(key)
160 .cloned()
161 }
162
163 pub fn record(&self, key: String, result: String) {
167 let mut map = self.0.lock().expect("idempotency registry lock poisoned");
168 map.entry(key).or_insert_with(|| {
169 use chrono::Utc;
170 IdempotencyRecord {
171 executed_at: Utc::now().timestamp(),
172 cached_result: result,
173 }
174 });
175 }
176
177 pub fn len(&self) -> usize {
179 self.0
180 .lock()
181 .expect("idempotency registry lock poisoned")
182 .len()
183 }
184
185 pub fn is_empty(&self) -> bool {
187 self.len() == 0
188 }
189}
190
191#[derive(Debug, Clone)]
195pub struct StagedWrite {
196 pub key: String,
198 pub target_path: PathBuf,
200 pub content: String,
202}
203
204#[derive(Debug, Clone)]
206pub struct CommitResult {
207 pub committed: usize,
209 pub paths: Vec<PathBuf>,
211}
212
213pub trait StagingBackend: std::fmt::Debug + Send + Sync {
221 fn stage(&self, write: StagedWrite) -> bool;
226
227 fn commit(&self) -> anyhow::Result<CommitResult>;
232
233 fn rollback(&self);
235
236 fn pending_count(&self) -> usize;
238}
239
240#[derive(Debug, Clone)]
247pub struct ToolContext {
248 pub working_directory: String,
250 pub user_id: Option<String>,
252 pub metadata: HashMap<String, String>,
254 pub capabilities: Option<Value>,
260 pub idempotency_registry: Option<IdempotencyRegistry>,
267 pub staging_backend: Option<Arc<dyn StagingBackend>>,
278}
279
280impl ToolContext {
281 pub fn with_idempotency_registry(mut self) -> Self {
283 self.idempotency_registry = Some(IdempotencyRegistry::new());
284 self
285 }
286
287 pub fn with_staging_backend(mut self, backend: Arc<dyn StagingBackend>) -> Self {
289 self.staging_backend = Some(backend);
290 self
291 }
292}
293
294impl Default for ToolContext {
295 fn default() -> Self {
296 Self {
297 working_directory: std::env::current_dir()
298 .ok()
299 .and_then(|p| p.to_str().map(|s| s.to_string()))
300 .unwrap_or_else(|| ".".to_string()),
301 user_id: None,
302 metadata: HashMap::new(),
303 capabilities: None,
304 idempotency_registry: None,
305 staging_backend: None,
306 }
307 }
308}
309
310#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
312pub enum ToolMode {
313 Full,
315 Explicit(Vec<String>),
317 #[default]
319 Smart,
320 Core,
322 None,
324}
325
326impl ToolMode {
327 pub fn display_name(&self) -> &'static str {
329 match self {
330 ToolMode::Full => "full",
331 ToolMode::Explicit(_) => "explicit",
332 ToolMode::Smart => "smart",
333 ToolMode::Core => "core",
334 ToolMode::None => "none",
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use serde_json::json;
343
344 #[test]
345 fn test_tool_result_success() {
346 let result = ToolResult::success("tool-1", "Success!");
347 assert!(!result.is_error);
348 }
349
350 #[test]
351 fn test_tool_result_error() {
352 let result = ToolResult::error("tool-2", "Failed!");
353 assert!(result.is_error);
354 }
355
356 #[test]
357 fn test_tool_input_schema_object() {
358 let mut props = HashMap::new();
359 props.insert("name".to_string(), json!({"type": "string"}));
360 let schema = ToolInputSchema::object(props, vec!["name".to_string()]);
361 assert_eq!(schema.schema_type, "object");
362 assert!(schema.properties.is_some());
363 }
364
365 #[test]
366 fn test_idempotency_registry_basic() {
367 let registry = IdempotencyRegistry::new();
368 assert!(registry.is_empty());
369
370 registry.record("key-1".to_string(), "result-1".to_string());
371 assert_eq!(registry.len(), 1);
372
373 let record = registry.get("key-1").unwrap();
374 assert_eq!(record.cached_result, "result-1");
375 assert!(record.executed_at > 0);
376
377 registry.record("key-1".to_string(), "result-DIFFERENT".to_string());
379 assert_eq!(registry.get("key-1").unwrap().cached_result, "result-1");
380 assert_eq!(registry.len(), 1);
381 }
382
383 #[test]
384 fn test_idempotency_registry_clone_shares_state() {
385 let registry = IdempotencyRegistry::new();
386 let clone = registry.clone();
387
388 registry.record("k".to_string(), "v".to_string());
389 assert!(clone.get("k").is_some());
391 }
392
393 #[test]
394 fn test_tool_context_default_has_no_registry() {
395 let ctx = ToolContext::default();
396 assert!(ctx.idempotency_registry.is_none());
397 }
398
399 #[test]
400 fn test_tool_context_with_registry() {
401 let ctx = ToolContext::default().with_idempotency_registry();
402 assert!(ctx.idempotency_registry.is_some());
403 assert!(ctx.idempotency_registry.unwrap().is_empty());
404 }
405}