Skip to main content

cognis_llm/tools/
mod.rs

1//! Tool trait + ergonomic tiers + supporting types.
2
3pub mod schema_based;
4pub mod simple;
5pub mod types;
6pub mod validation;
7
8pub use schema_based::SchemaBasedTool;
9pub use simple::__simple_async_trait;
10pub use types::{ToolInput, ToolOutput};
11pub use validation::{Format, ValidateArgs};
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19use cognis_core::{CognisError, Result};
20
21/// Tier-1 tool trait. The most general contract — manual JSON schema,
22/// `serde_json::Value` arg deserialization is the tool's responsibility.
23///
24/// `BaseTool` is a type alias for callers (especially cognis-macros
25/// generated code) that prefer the v1 name.
26#[async_trait]
27pub trait Tool: Send + Sync {
28    /// Tool name as registered with the LLM.
29    fn name(&self) -> &str;
30
31    /// Description shown to the LLM.
32    fn description(&self) -> &str;
33
34    /// Optional JSON Schema for the parameters. None = no parameters.
35    fn args_schema(&self) -> Option<serde_json::Value>;
36
37    /// Hint to the agent: if true, return the tool result directly
38    /// instead of looping back to the LLM.
39    fn return_direct(&self) -> bool {
40        false
41    }
42
43    /// Execute the tool with the given input.
44    async fn _run(&self, input: ToolInput) -> Result<ToolOutput>;
45}
46
47/// Alias for cognis-macros-generated code that emits paths to `BaseTool`.
48pub use Tool as BaseTool;
49
50/// Serializable form of a tool — what gets sent to the LLM API.
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub struct ToolDefinition {
53    /// Tool name.
54    pub name: String,
55    /// Description.
56    pub description: String,
57    /// JSON Schema for parameters (None if the tool takes no params).
58    pub parameters: Option<serde_json::Value>,
59}
60
61impl ToolDefinition {
62    /// Build a `ToolDefinition` from any `&dyn Tool`.
63    pub fn from_tool(t: &dyn Tool) -> Self {
64        Self {
65            name: t.name().to_string(),
66            description: t.description().to_string(),
67            parameters: t.args_schema(),
68        }
69    }
70}
71
72/// Per-tool runtime state — enabled/disabled flag + call counter.
73#[derive(Default)]
74struct ToolEntry {
75    tool: Option<Arc<dyn Tool>>,
76    enabled: bool,
77    /// Total successful + failed calls served via [`ToolRegistry::execute`].
78    calls: std::sync::atomic::AtomicUsize,
79    /// Optional permission predicate: `agent_id → allowed?`.
80    #[allow(clippy::type_complexity)]
81    permission: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
82}
83
84impl Clone for ToolEntry {
85    fn clone(&self) -> Self {
86        Self {
87            tool: self.tool.clone(),
88            enabled: self.enabled,
89            calls: std::sync::atomic::AtomicUsize::new(
90                self.calls.load(std::sync::atomic::Ordering::Relaxed),
91            ),
92            permission: self.permission.clone(),
93        }
94    }
95}
96
97/// HashMap-backed tool registry. The agent layer uses this to dispatch
98/// tool calls returned by the LLM.
99///
100/// Per-tool controls (added in V2):
101/// - [`ToolRegistry::enable`] / [`ToolRegistry::disable`] toggle
102///   availability without unregistering.
103/// - [`ToolRegistry::set_permission`] attaches an `agent_id → allowed`
104///   predicate. [`ToolRegistry::is_allowed`] checks; [`ToolRegistry::execute_for`]
105///   enforces.
106/// - [`ToolRegistry::call_count`] returns the cumulative dispatch count
107///   for any tool (incremented on every `execute*` call).
108#[derive(Default, Clone)]
109pub struct ToolRegistry {
110    entries: HashMap<String, ToolEntry>,
111}
112
113impl ToolRegistry {
114    /// Empty registry.
115    pub fn new() -> Self {
116        Self::default()
117    }
118
119    /// Register a tool. Replaces any existing tool with the same name
120    /// (preserving its enabled flag and call counter? no — replaces
121    /// wholesale; the new tool starts enabled with count = 0).
122    pub fn register(&mut self, tool: Arc<dyn Tool>) {
123        let name = tool.name().to_string();
124        self.entries.insert(
125            name,
126            ToolEntry {
127                tool: Some(tool),
128                enabled: true,
129                calls: std::sync::atomic::AtomicUsize::new(0),
130                permission: None,
131            },
132        );
133    }
134
135    /// Register `alias` to point at the same tool as `name`. No-op when
136    /// `name` is not registered.
137    pub fn register_alias(&mut self, alias: impl Into<String>, name: &str) {
138        if let Some(t) = self.entries.get(name).and_then(|e| e.tool.clone()) {
139            self.entries.insert(
140                alias.into(),
141                ToolEntry {
142                    tool: Some(t),
143                    enabled: true,
144                    calls: std::sync::atomic::AtomicUsize::new(0),
145                    permission: None,
146                },
147            );
148        }
149    }
150
151    /// Remove a tool by name. Returns `true` if it was present.
152    pub fn unregister(&mut self, name: &str) -> bool {
153        self.entries.remove(name).is_some()
154    }
155
156    /// Filter the registry: keep only tools whose name passes `predicate`.
157    /// Returns the names of removed tools.
158    pub fn retain<F>(&mut self, mut predicate: F) -> Vec<String>
159    where
160        F: FnMut(&str) -> bool,
161    {
162        let mut removed = Vec::new();
163        self.entries.retain(|k, _| {
164            let keep = predicate(k);
165            if !keep {
166                removed.push(k.clone());
167            }
168            keep
169        });
170        removed
171    }
172
173    /// Get a tool by name. Returns the inner `Arc` only if the tool is
174    /// **enabled** — disabled tools are invisible to dispatch.
175    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
176        let e = self.entries.get(name)?;
177        if !e.enabled {
178            return None;
179        }
180        e.tool.as_ref()
181    }
182
183    /// True if a tool with this name is registered (regardless of enabled state).
184    pub fn contains(&self, name: &str) -> bool {
185        self.entries.contains_key(name)
186    }
187
188    /// True if a tool with this name is registered AND enabled.
189    pub fn is_enabled(&self, name: &str) -> bool {
190        self.entries.get(name).is_some_and(|e| e.enabled)
191    }
192
193    /// Disable a tool without unregistering it. Disabled tools are
194    /// hidden from `get` / `tool_names` / `definitions` / `execute`,
195    /// but their call counters and permissions persist for when they're
196    /// re-enabled. No-op if the tool isn't registered.
197    pub fn disable(&mut self, name: &str) -> bool {
198        match self.entries.get_mut(name) {
199            Some(e) => {
200                e.enabled = false;
201                true
202            }
203            None => false,
204        }
205    }
206
207    /// Re-enable a previously-disabled tool. No-op if not registered.
208    pub fn enable(&mut self, name: &str) -> bool {
209        match self.entries.get_mut(name) {
210            Some(e) => {
211                e.enabled = true;
212                true
213            }
214            None => false,
215        }
216    }
217
218    /// Attach a permission predicate `agent_id → allowed`. Replace by
219    /// calling again. Pass [`ToolRegistry::clear_permission`] to remove.
220    pub fn set_permission<F>(&mut self, name: &str, predicate: F) -> bool
221    where
222        F: Fn(&str) -> bool + Send + Sync + 'static,
223    {
224        match self.entries.get_mut(name) {
225            Some(e) => {
226                e.permission = Some(Arc::new(predicate));
227                true
228            }
229            None => false,
230        }
231    }
232
233    /// Drop the permission predicate. The tool reverts to "any agent allowed".
234    pub fn clear_permission(&mut self, name: &str) {
235        if let Some(e) = self.entries.get_mut(name) {
236            e.permission = None;
237        }
238    }
239
240    /// True if the tool exists, is enabled, and `agent_id` passes the
241    /// permission predicate (or no predicate is set). Disabled tools
242    /// always return false.
243    pub fn is_allowed(&self, name: &str, agent_id: &str) -> bool {
244        let Some(e) = self.entries.get(name) else {
245            return false;
246        };
247        if !e.enabled {
248            return false;
249        }
250        match &e.permission {
251            Some(p) => p(agent_id),
252            None => true,
253        }
254    }
255
256    /// Number of times the tool was dispatched via `execute` /
257    /// `execute_for` (across all agents, including failed dispatches).
258    /// Returns 0 if the tool isn't registered.
259    pub fn call_count(&self, name: &str) -> usize {
260        self.entries
261            .get(name)
262            .map(|e| e.calls.load(std::sync::atomic::Ordering::Relaxed))
263            .unwrap_or(0)
264    }
265
266    /// All registered + enabled tool names.
267    pub fn tool_names(&self) -> Vec<&str> {
268        self.entries
269            .iter()
270            .filter(|(_, e)| e.enabled)
271            .map(|(k, _)| k.as_str())
272            .collect()
273    }
274
275    /// Build `ToolDefinition`s for every registered + enabled tool.
276    pub fn definitions(&self) -> Vec<ToolDefinition> {
277        self.entries
278            .values()
279            .filter(|e| e.enabled)
280            .filter_map(|e| e.tool.as_ref())
281            .map(|t| ToolDefinition::from_tool(t.as_ref()))
282            .collect()
283    }
284
285    /// Execute a tool by name with the given input. Increments the
286    /// per-tool call counter regardless of success/failure. Errors if
287    /// the tool isn't registered or is disabled.
288    pub async fn execute(&self, name: &str, input: ToolInput) -> Result<ToolOutput> {
289        let entry = self.entries.get(name).ok_or_else(|| CognisError::Tool {
290            name: name.to_string(),
291            reason: "not registered".into(),
292        })?;
293        if !entry.enabled {
294            return Err(CognisError::Tool {
295                name: name.to_string(),
296                reason: "disabled".into(),
297            });
298        }
299        let t = entry.tool.as_ref().ok_or_else(|| CognisError::Tool {
300            name: name.to_string(),
301            reason: "no implementation".into(),
302        })?;
303        entry
304            .calls
305            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
306        t._run(input).await
307    }
308
309    /// Like [`ToolRegistry::execute`] but also enforces the permission
310    /// predicate against `agent_id`.
311    ///
312    /// Error precedence (so the message is honest about the real reason):
313    /// 1. tool not registered  → `"not registered"`
314    /// 2. tool disabled        → `"disabled"`
315    /// 3. agent not allowed    → `"not allowed for agent ..."`
316    /// 4. dispatch errors from the tool itself
317    pub async fn execute_for(
318        &self,
319        name: &str,
320        agent_id: &str,
321        input: ToolInput,
322    ) -> Result<ToolOutput> {
323        let entry = self.entries.get(name).ok_or_else(|| CognisError::Tool {
324            name: name.to_string(),
325            reason: "not registered".into(),
326        })?;
327        if !entry.enabled {
328            return Err(CognisError::Tool {
329                name: name.to_string(),
330                reason: "disabled".into(),
331            });
332        }
333        let allowed = entry
334            .permission
335            .as_ref()
336            .map(|p| p(agent_id))
337            .unwrap_or(true);
338        if !allowed {
339            return Err(CognisError::Tool {
340                name: name.to_string(),
341                reason: format!("not allowed for agent `{agent_id}`"),
342            });
343        }
344        self.execute(name, input).await
345    }
346
347    /// Number of registered tools (including disabled ones).
348    pub fn len(&self) -> usize {
349        self.entries.len()
350    }
351
352    /// True if no tools are registered.
353    pub fn is_empty(&self) -> bool {
354        self.entries.is_empty()
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use serde_json::json;
362
363    struct Echo;
364    #[async_trait]
365    impl Tool for Echo {
366        fn name(&self) -> &str {
367            "echo"
368        }
369        fn description(&self) -> &str {
370            "echoes input"
371        }
372        fn args_schema(&self) -> Option<serde_json::Value> {
373            Some(json!({"type": "object", "properties": {"text": {"type": "string"}}}))
374        }
375        async fn _run(&self, input: ToolInput) -> Result<ToolOutput> {
376            Ok(ToolOutput::Content(input.into_json()))
377        }
378    }
379
380    #[tokio::test]
381    async fn registry_register_get_execute() {
382        let mut reg = ToolRegistry::new();
383        assert!(reg.is_empty());
384        reg.register(Arc::new(Echo));
385        assert_eq!(reg.len(), 1);
386        assert!(reg.contains("echo"));
387
388        let mut m = HashMap::new();
389        m.insert("text".into(), json!("hi"));
390        let out = reg.execute("echo", ToolInput::Structured(m)).await.unwrap();
391        match out {
392            ToolOutput::Content(v) => assert_eq!(v["text"], "hi"),
393            _ => panic!("wrong variant"),
394        }
395    }
396
397    #[tokio::test]
398    async fn unknown_tool_errors() {
399        let reg = ToolRegistry::new();
400        let err = reg
401            .execute("missing", ToolInput::Text("x".into()))
402            .await
403            .unwrap_err();
404        assert_eq!(err.category(), "tool");
405    }
406
407    #[test]
408    fn definition_from_tool() {
409        let d = ToolDefinition::from_tool(&Echo);
410        assert_eq!(d.name, "echo");
411        assert_eq!(d.description, "echoes input");
412        assert!(d.parameters.is_some());
413    }
414
415    #[tokio::test]
416    async fn disable_hides_from_dispatch_and_listing() {
417        let mut reg = ToolRegistry::new();
418        reg.register(Arc::new(Echo));
419        assert!(reg.disable("echo"));
420        assert!(reg.contains("echo"), "still registered");
421        assert!(!reg.is_enabled("echo"));
422        assert!(reg.tool_names().is_empty());
423        assert!(reg.definitions().is_empty());
424        let err = reg
425            .execute("echo", ToolInput::Text("x".into()))
426            .await
427            .unwrap_err();
428        assert!(err.to_string().contains("disabled"), "got: {err}");
429    }
430
431    #[tokio::test]
432    async fn enable_restores() {
433        let mut reg = ToolRegistry::new();
434        reg.register(Arc::new(Echo));
435        reg.disable("echo");
436        reg.enable("echo");
437        assert!(reg.is_enabled("echo"));
438        assert!(reg
439            .execute("echo", ToolInput::Text("x".into()))
440            .await
441            .is_ok());
442    }
443
444    #[tokio::test]
445    async fn call_count_increments_on_execute() {
446        let mut reg = ToolRegistry::new();
447        reg.register(Arc::new(Echo));
448        assert_eq!(reg.call_count("echo"), 0);
449        for _ in 0..3 {
450            reg.execute("echo", ToolInput::Text("hi".into()))
451                .await
452                .unwrap();
453        }
454        assert_eq!(reg.call_count("echo"), 3);
455        assert_eq!(reg.call_count("missing"), 0);
456    }
457
458    #[tokio::test]
459    async fn permission_predicate_blocks_disallowed_agents() {
460        let mut reg = ToolRegistry::new();
461        reg.register(Arc::new(Echo));
462        reg.set_permission("echo", |agent_id: &str| agent_id == "writer");
463        assert!(reg.is_allowed("echo", "writer"));
464        assert!(!reg.is_allowed("echo", "intruder"));
465        let ok = reg
466            .execute_for("echo", "writer", ToolInput::Text("hi".into()))
467            .await;
468        assert!(ok.is_ok());
469        let denied = reg
470            .execute_for("echo", "intruder", ToolInput::Text("hi".into()))
471            .await
472            .unwrap_err();
473        assert!(denied.to_string().contains("not allowed"), "got: {denied}");
474    }
475
476    #[tokio::test]
477    async fn execute_for_reports_not_registered_before_permission() {
478        let reg = ToolRegistry::new();
479        let err = reg
480            .execute_for("ghost", "writer", ToolInput::Text("x".into()))
481            .await
482            .unwrap_err();
483        assert!(
484            err.to_string().contains("not registered"),
485            "wrong error: {err}"
486        );
487    }
488
489    #[tokio::test]
490    async fn execute_for_reports_disabled_before_permission() {
491        let mut reg = ToolRegistry::new();
492        reg.register(Arc::new(Echo));
493        reg.disable("echo");
494        // Permission set to deny — but the disabled state should win.
495        reg.set_permission("echo", |_| false);
496        let err = reg
497            .execute_for("echo", "writer", ToolInput::Text("x".into()))
498            .await
499            .unwrap_err();
500        assert!(err.to_string().contains("disabled"), "wrong error: {err}");
501    }
502
503    #[tokio::test]
504    async fn clear_permission_reopens_dispatch() {
505        let mut reg = ToolRegistry::new();
506        reg.register(Arc::new(Echo));
507        reg.set_permission("echo", |_: &str| false);
508        assert!(!reg.is_allowed("echo", "anyone"));
509        reg.clear_permission("echo");
510        assert!(reg.is_allowed("echo", "anyone"));
511    }
512}