Skip to main content

crabtalk_core/agent/
tool.rs

1//! Tool registry, dispatcher trait, and handler types.
2//!
3//! [`ToolRegistry`] stores `crabllm_core::Tool` schemas by name — no
4//! handlers, no closures. [`ToolDispatcher`] is the trait Agents call to
5//! execute a tool call; [`ToolHandler`] is the per-tool async closure
6//! type stored in a [`ToolEntry`].
7
8use crate::model::HistoryEntry;
9use crabllm_core::{FunctionDef, Tool, ToolType};
10use heck::ToSnakeCase;
11use schemars::JsonSchema;
12use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc};
13
14/// Boxed future returned by a [`ToolDispatcher::dispatch`] call.
15pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>>;
16
17/// Dynamic tool dispatch surface.
18///
19/// The Agent holds an `Arc<dyn ToolDispatcher>` and calls `dispatch` for
20/// every tool call the model emits. Implementors look the tool up by
21/// name, enforce scope, and invoke the registered handler.
22pub trait ToolDispatcher: Send + Sync + 'static {
23    fn dispatch<'a>(
24        &'a self,
25        name: &'a str,
26        args: &'a str,
27        agent: &'a str,
28        sender: &'a str,
29        conversation_id: Option<u64>,
30    ) -> ToolFuture<'a>;
31}
32
33/// Arguments passed to a tool handler during dispatch.
34#[derive(Clone)]
35pub struct ToolDispatch {
36    /// JSON-encoded arguments string.
37    pub args: String,
38    /// Name of the agent making this call.
39    pub agent: String,
40    /// Sender identity (empty for local/owner conversations).
41    pub sender: String,
42    /// Conversation ID, if running within a conversation.
43    pub conversation_id: Option<u64>,
44}
45
46/// A type-erased async tool handler.
47pub type ToolHandler = Arc<
48    dyn Fn(ToolDispatch) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
49        + Send
50        + Sync,
51>;
52
53/// Callback invoked before each agent run to inject context entries.
54pub type BeforeRunHook = Arc<dyn Fn(&[HistoryEntry]) -> Vec<HistoryEntry> + Send + Sync>;
55
56/// A registered tool: schema + handler + optional lifecycle hooks.
57pub struct ToolEntry {
58    /// Tool schema for the LLM.
59    pub schema: Tool,
60    /// Dispatch handler.
61    pub handler: ToolHandler,
62    /// Appended to agent system prompt at build time.
63    pub system_prompt: Option<String>,
64    /// Injected before each agent turn (auto-recall, context, etc).
65    pub before_run: Option<BeforeRunHook>,
66}
67
68/// Schema-only registry of named tools.
69///
70/// Stores `crabllm_core::Tool` definitions keyed by function name. Used by
71/// `Runtime` to filter tool schemas per agent at `add_agent` time. No
72/// handlers or closures are stored here.
73#[derive(Default, Clone)]
74pub struct ToolRegistry {
75    tools: BTreeMap<String, Tool>,
76}
77
78impl ToolRegistry {
79    /// Create an empty registry.
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Insert a tool schema, keyed by its function name.
85    pub fn insert(&mut self, tool: Tool) {
86        self.tools.insert(tool.function.name.clone(), tool);
87    }
88
89    /// Insert multiple tool schemas.
90    pub fn insert_all(&mut self, tools: Vec<Tool>) {
91        for tool in tools {
92            self.insert(tool);
93        }
94    }
95
96    /// Remove a tool by name. Returns `true` if it existed.
97    pub fn remove(&mut self, name: &str) -> bool {
98        self.tools.remove(name).is_some()
99    }
100
101    /// Check if a tool is registered.
102    pub fn contains(&self, name: &str) -> bool {
103        self.tools.contains_key(name)
104    }
105
106    /// Number of registered tools.
107    pub fn len(&self) -> usize {
108        self.tools.len()
109    }
110
111    /// Whether the registry is empty.
112    pub fn is_empty(&self) -> bool {
113        self.tools.is_empty()
114    }
115
116    /// Return all tool schemas as a `Vec`.
117    pub fn tools(&self) -> Vec<Tool> {
118        self.tools.values().cloned().collect()
119    }
120
121    /// Build a filtered list of tool schemas matching the given names.
122    ///
123    /// If `names` is empty, all tools are returned. Used by `Runtime::add_agent`
124    /// to build the per-agent schema snapshot stored on `Agent`.
125    pub fn filtered_snapshot(&self, names: &[String]) -> Vec<Tool> {
126        if names.is_empty() {
127            return self.tools();
128        }
129        self.tools
130            .iter()
131            .filter(|(k, _)| names.iter().any(|n| n == *k))
132            .map(|(_, v)| v.clone())
133            .collect()
134    }
135}
136
137/// Trait to convert a type into a `crabllm_core::Tool`. The tool's
138/// description is read from the `///` doc comment on the struct —
139/// schemars puts it in the schema's top-level `description` field.
140pub trait AsTool {
141    /// Convert the type into a `crabllm_core::Tool` (the enveloped
142    /// `{kind, function}` wire shape).
143    fn as_tool() -> Tool;
144}
145
146impl<T: JsonSchema> AsTool for T {
147    fn as_tool() -> Tool {
148        let schema = schemars::schema_for!(T);
149        let description = schema
150            .get("description")
151            .and_then(|v| v.as_str())
152            .map(str::to_owned);
153        // `strict: None` matches the prior wire behavior: the wcore
154        // `Tool.strict: bool` field was set to `true` by every `AsTool` impl
155        // but silently dropped by the converter (old convert::to_ct_tool
156        // hard-coded `strict: None`). Turning on strict-mode validation
157        // here would be a behavior change masquerading as a refactor —
158        // leave any opt-in to a separate commit that validates every tool
159        // schema.
160        Tool {
161            kind: ToolType::Function,
162            function: FunctionDef {
163                name: T::schema_name().to_snake_case(),
164                description,
165                parameters: Some(serde_json::to_value(&schema).unwrap_or_default()),
166            },
167            strict: None,
168        }
169    }
170}
171
172impl ToolDispatcher for () {
173    fn dispatch<'a>(
174        &'a self,
175        name: &'a str,
176        _args: &'a str,
177        _agent: &'a str,
178        _sender: &'a str,
179        _conversation_id: Option<u64>,
180    ) -> ToolFuture<'a> {
181        Box::pin(async move { Err(format!("tool not registered: {name}")) })
182    }
183}