1pub mod schema_based;
4pub mod simple;
5pub mod types;
6pub mod validation;
7
8pub use schema_based::SchemaBasedTool;
9pub use simple::__simple_async_trait;
10pub use types::{ToolInput, ToolOutput};
11pub use validation::{Format, ValidateArgs};
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19use cognis_core::{CognisError, Result};
20
21#[async_trait]
27pub trait Tool: Send + Sync {
28 fn name(&self) -> &str;
30
31 fn description(&self) -> &str;
33
34 fn args_schema(&self) -> Option<serde_json::Value>;
36
37 fn return_direct(&self) -> bool {
40 false
41 }
42
43 async fn _run(&self, input: ToolInput) -> Result<ToolOutput>;
45}
46
47pub use Tool as BaseTool;
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub struct ToolDefinition {
53 pub name: String,
55 pub description: String,
57 pub parameters: Option<serde_json::Value>,
59}
60
61impl ToolDefinition {
62 pub fn from_tool(t: &dyn Tool) -> Self {
64 Self {
65 name: t.name().to_string(),
66 description: t.description().to_string(),
67 parameters: t.args_schema(),
68 }
69 }
70}
71
72#[derive(Default)]
74struct ToolEntry {
75 tool: Option<Arc<dyn Tool>>,
76 enabled: bool,
77 calls: std::sync::atomic::AtomicUsize,
79 #[allow(clippy::type_complexity)]
81 permission: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
82}
83
84impl Clone for ToolEntry {
85 fn clone(&self) -> Self {
86 Self {
87 tool: self.tool.clone(),
88 enabled: self.enabled,
89 calls: std::sync::atomic::AtomicUsize::new(
90 self.calls.load(std::sync::atomic::Ordering::Relaxed),
91 ),
92 permission: self.permission.clone(),
93 }
94 }
95}
96
97#[derive(Default, Clone)]
109pub struct ToolRegistry {
110 entries: HashMap<String, ToolEntry>,
111}
112
113impl ToolRegistry {
114 pub fn new() -> Self {
116 Self::default()
117 }
118
119 pub fn register(&mut self, tool: Arc<dyn Tool>) {
123 let name = tool.name().to_string();
124 self.entries.insert(
125 name,
126 ToolEntry {
127 tool: Some(tool),
128 enabled: true,
129 calls: std::sync::atomic::AtomicUsize::new(0),
130 permission: None,
131 },
132 );
133 }
134
135 pub fn register_alias(&mut self, alias: impl Into<String>, name: &str) {
138 if let Some(t) = self.entries.get(name).and_then(|e| e.tool.clone()) {
139 self.entries.insert(
140 alias.into(),
141 ToolEntry {
142 tool: Some(t),
143 enabled: true,
144 calls: std::sync::atomic::AtomicUsize::new(0),
145 permission: None,
146 },
147 );
148 }
149 }
150
151 pub fn unregister(&mut self, name: &str) -> bool {
153 self.entries.remove(name).is_some()
154 }
155
156 pub fn retain<F>(&mut self, mut predicate: F) -> Vec<String>
159 where
160 F: FnMut(&str) -> bool,
161 {
162 let mut removed = Vec::new();
163 self.entries.retain(|k, _| {
164 let keep = predicate(k);
165 if !keep {
166 removed.push(k.clone());
167 }
168 keep
169 });
170 removed
171 }
172
173 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
176 let e = self.entries.get(name)?;
177 if !e.enabled {
178 return None;
179 }
180 e.tool.as_ref()
181 }
182
183 pub fn contains(&self, name: &str) -> bool {
185 self.entries.contains_key(name)
186 }
187
188 pub fn is_enabled(&self, name: &str) -> bool {
190 self.entries.get(name).is_some_and(|e| e.enabled)
191 }
192
193 pub fn disable(&mut self, name: &str) -> bool {
198 match self.entries.get_mut(name) {
199 Some(e) => {
200 e.enabled = false;
201 true
202 }
203 None => false,
204 }
205 }
206
207 pub fn enable(&mut self, name: &str) -> bool {
209 match self.entries.get_mut(name) {
210 Some(e) => {
211 e.enabled = true;
212 true
213 }
214 None => false,
215 }
216 }
217
218 pub fn set_permission<F>(&mut self, name: &str, predicate: F) -> bool
221 where
222 F: Fn(&str) -> bool + Send + Sync + 'static,
223 {
224 match self.entries.get_mut(name) {
225 Some(e) => {
226 e.permission = Some(Arc::new(predicate));
227 true
228 }
229 None => false,
230 }
231 }
232
233 pub fn clear_permission(&mut self, name: &str) {
235 if let Some(e) = self.entries.get_mut(name) {
236 e.permission = None;
237 }
238 }
239
240 pub fn is_allowed(&self, name: &str, agent_id: &str) -> bool {
244 let Some(e) = self.entries.get(name) else {
245 return false;
246 };
247 if !e.enabled {
248 return false;
249 }
250 match &e.permission {
251 Some(p) => p(agent_id),
252 None => true,
253 }
254 }
255
256 pub fn call_count(&self, name: &str) -> usize {
260 self.entries
261 .get(name)
262 .map(|e| e.calls.load(std::sync::atomic::Ordering::Relaxed))
263 .unwrap_or(0)
264 }
265
266 pub fn tool_names(&self) -> Vec<&str> {
268 self.entries
269 .iter()
270 .filter(|(_, e)| e.enabled)
271 .map(|(k, _)| k.as_str())
272 .collect()
273 }
274
275 pub fn definitions(&self) -> Vec<ToolDefinition> {
277 self.entries
278 .values()
279 .filter(|e| e.enabled)
280 .filter_map(|e| e.tool.as_ref())
281 .map(|t| ToolDefinition::from_tool(t.as_ref()))
282 .collect()
283 }
284
285 pub async fn execute(&self, name: &str, input: ToolInput) -> Result<ToolOutput> {
289 let entry = self.entries.get(name).ok_or_else(|| CognisError::Tool {
290 name: name.to_string(),
291 reason: "not registered".into(),
292 })?;
293 if !entry.enabled {
294 return Err(CognisError::Tool {
295 name: name.to_string(),
296 reason: "disabled".into(),
297 });
298 }
299 let t = entry.tool.as_ref().ok_or_else(|| CognisError::Tool {
300 name: name.to_string(),
301 reason: "no implementation".into(),
302 })?;
303 entry
304 .calls
305 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
306 t._run(input).await
307 }
308
309 pub async fn execute_for(
318 &self,
319 name: &str,
320 agent_id: &str,
321 input: ToolInput,
322 ) -> Result<ToolOutput> {
323 let entry = self.entries.get(name).ok_or_else(|| CognisError::Tool {
324 name: name.to_string(),
325 reason: "not registered".into(),
326 })?;
327 if !entry.enabled {
328 return Err(CognisError::Tool {
329 name: name.to_string(),
330 reason: "disabled".into(),
331 });
332 }
333 let allowed = entry
334 .permission
335 .as_ref()
336 .map(|p| p(agent_id))
337 .unwrap_or(true);
338 if !allowed {
339 return Err(CognisError::Tool {
340 name: name.to_string(),
341 reason: format!("not allowed for agent `{agent_id}`"),
342 });
343 }
344 self.execute(name, input).await
345 }
346
347 pub fn len(&self) -> usize {
349 self.entries.len()
350 }
351
352 pub fn is_empty(&self) -> bool {
354 self.entries.is_empty()
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use serde_json::json;
362
363 struct Echo;
364 #[async_trait]
365 impl Tool for Echo {
366 fn name(&self) -> &str {
367 "echo"
368 }
369 fn description(&self) -> &str {
370 "echoes input"
371 }
372 fn args_schema(&self) -> Option<serde_json::Value> {
373 Some(json!({"type": "object", "properties": {"text": {"type": "string"}}}))
374 }
375 async fn _run(&self, input: ToolInput) -> Result<ToolOutput> {
376 Ok(ToolOutput::Content(input.into_json()))
377 }
378 }
379
380 #[tokio::test]
381 async fn registry_register_get_execute() {
382 let mut reg = ToolRegistry::new();
383 assert!(reg.is_empty());
384 reg.register(Arc::new(Echo));
385 assert_eq!(reg.len(), 1);
386 assert!(reg.contains("echo"));
387
388 let mut m = HashMap::new();
389 m.insert("text".into(), json!("hi"));
390 let out = reg.execute("echo", ToolInput::Structured(m)).await.unwrap();
391 match out {
392 ToolOutput::Content(v) => assert_eq!(v["text"], "hi"),
393 _ => panic!("wrong variant"),
394 }
395 }
396
397 #[tokio::test]
398 async fn unknown_tool_errors() {
399 let reg = ToolRegistry::new();
400 let err = reg
401 .execute("missing", ToolInput::Text("x".into()))
402 .await
403 .unwrap_err();
404 assert_eq!(err.category(), "tool");
405 }
406
407 #[test]
408 fn definition_from_tool() {
409 let d = ToolDefinition::from_tool(&Echo);
410 assert_eq!(d.name, "echo");
411 assert_eq!(d.description, "echoes input");
412 assert!(d.parameters.is_some());
413 }
414
415 #[tokio::test]
416 async fn disable_hides_from_dispatch_and_listing() {
417 let mut reg = ToolRegistry::new();
418 reg.register(Arc::new(Echo));
419 assert!(reg.disable("echo"));
420 assert!(reg.contains("echo"), "still registered");
421 assert!(!reg.is_enabled("echo"));
422 assert!(reg.tool_names().is_empty());
423 assert!(reg.definitions().is_empty());
424 let err = reg
425 .execute("echo", ToolInput::Text("x".into()))
426 .await
427 .unwrap_err();
428 assert!(err.to_string().contains("disabled"), "got: {err}");
429 }
430
431 #[tokio::test]
432 async fn enable_restores() {
433 let mut reg = ToolRegistry::new();
434 reg.register(Arc::new(Echo));
435 reg.disable("echo");
436 reg.enable("echo");
437 assert!(reg.is_enabled("echo"));
438 assert!(reg
439 .execute("echo", ToolInput::Text("x".into()))
440 .await
441 .is_ok());
442 }
443
444 #[tokio::test]
445 async fn call_count_increments_on_execute() {
446 let mut reg = ToolRegistry::new();
447 reg.register(Arc::new(Echo));
448 assert_eq!(reg.call_count("echo"), 0);
449 for _ in 0..3 {
450 reg.execute("echo", ToolInput::Text("hi".into()))
451 .await
452 .unwrap();
453 }
454 assert_eq!(reg.call_count("echo"), 3);
455 assert_eq!(reg.call_count("missing"), 0);
456 }
457
458 #[tokio::test]
459 async fn permission_predicate_blocks_disallowed_agents() {
460 let mut reg = ToolRegistry::new();
461 reg.register(Arc::new(Echo));
462 reg.set_permission("echo", |agent_id: &str| agent_id == "writer");
463 assert!(reg.is_allowed("echo", "writer"));
464 assert!(!reg.is_allowed("echo", "intruder"));
465 let ok = reg
466 .execute_for("echo", "writer", ToolInput::Text("hi".into()))
467 .await;
468 assert!(ok.is_ok());
469 let denied = reg
470 .execute_for("echo", "intruder", ToolInput::Text("hi".into()))
471 .await
472 .unwrap_err();
473 assert!(denied.to_string().contains("not allowed"), "got: {denied}");
474 }
475
476 #[tokio::test]
477 async fn execute_for_reports_not_registered_before_permission() {
478 let reg = ToolRegistry::new();
479 let err = reg
480 .execute_for("ghost", "writer", ToolInput::Text("x".into()))
481 .await
482 .unwrap_err();
483 assert!(
484 err.to_string().contains("not registered"),
485 "wrong error: {err}"
486 );
487 }
488
489 #[tokio::test]
490 async fn execute_for_reports_disabled_before_permission() {
491 let mut reg = ToolRegistry::new();
492 reg.register(Arc::new(Echo));
493 reg.disable("echo");
494 reg.set_permission("echo", |_| false);
496 let err = reg
497 .execute_for("echo", "writer", ToolInput::Text("x".into()))
498 .await
499 .unwrap_err();
500 assert!(err.to_string().contains("disabled"), "wrong error: {err}");
501 }
502
503 #[tokio::test]
504 async fn clear_permission_reopens_dispatch() {
505 let mut reg = ToolRegistry::new();
506 reg.register(Arc::new(Echo));
507 reg.set_permission("echo", |_: &str| false);
508 assert!(!reg.is_allowed("echo", "anyone"));
509 reg.clear_permission("echo");
510 assert!(reg.is_allowed("echo", "anyone"));
511 }
512}