Skip to main content

ferro_ai/tools/
mod.rs

1//! Tool calling: `ToolDef`, `ToolError`, `ToolRegistry`, and the bounded dispatch loop.
2//!
3//! ## Safety contract (D-12, SC#5)
4//!
5//! [`ToolRegistry::new`] is the ONLY full constructor. `max_iterations` is required —
6//! there is no `Default` impl and no zero-arg constructor. The dispatch loop returns
7//! [`Error::ToolIterationLimit`] at the hard cap with no override path.
8//!
9//! ## Error surfacing (D-13, SC#6)
10//!
11//! Tool handler failures are surfaced to the LLM as [`ToolError`] messages, never as
12//! raw Rust panics, stack traces, or DB-constraint strings. The Rust caller receives
13//! [`Error::ToolIterationLimit`] when the loop exceeds its cap.
14//!
15//! ## Handler lifetime (D-11)
16//!
17//! Handler closures must satisfy `'static` — all captured state must be owned or
18//! `Arc`-wrapped. Capturing `&references` will not compile.
19
20use crate::client::{
21    CompletionRequest, CompletionResponse, LlmClient, Message, Role, ToolChoice, ToolRequest,
22    ToolUseBlock,
23};
24use crate::error::Error;
25use futures::future::BoxFuture;
26use std::collections::HashMap;
27use tracing::{error, warn};
28
29/// Model-legible tool error.
30///
31/// Surfaced to the LLM as a `tool_result` message carrying only `message`.
32/// Never exposed to Rust callers as a panic or raw DB string (SC#6, T-166-02).
33///
34/// Handler implementations are responsible for mapping domain errors to a
35/// human-readable `message` before returning `Err(ToolError { ... })`.
36#[derive(Debug, Clone)]
37pub struct ToolError {
38    /// The model-legible error message. Must not contain raw Rust panics,
39    /// stack traces, or DB-constraint strings.
40    pub message: String,
41}
42
43impl std::fmt::Display for ToolError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.write_str(&self.message)
46    }
47}
48
49/// A registered tool with its async handler.
50///
51/// `parameters_schema` must already be normalized via `schema::for_structured_output`
52/// before registration. The handler must own all captured state (no `&references` —
53/// wrap shared state in `Arc<T>` to satisfy the `'static` bound).
54pub struct ToolDef {
55    /// The tool name. Must match what the LLM will call.
56    pub name: String,
57    /// Human-readable description of what the tool does.
58    pub description: String,
59    /// JSON Schema for the tool's input parameters.
60    ///
61    /// Must be pre-normalized via `schema::for_structured_output`. The LLM-generated
62    /// input is passed as-is to the handler — handler implementations are responsible
63    /// for validating their own inputs before privileged actions (T-166-03).
64    pub parameters_schema: serde_json::Value,
65    /// The async handler closure.
66    ///
67    /// Receives the LLM-generated `serde_json::Value` and returns either a JSON result
68    /// or a [`ToolError`] with a model-legible message.
69    pub handler: Box<
70        dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<serde_json::Value, ToolError>>
71            + Send
72            + Sync,
73    >,
74}
75
76/// Helper to wrap an `async fn` or closure into the boxed handler type required by [`ToolDef`].
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use ferro_ai::tools::{make_handler, ToolDef, ToolError};
82///
83/// let def = ToolDef {
84///     name: "greet".into(),
85///     description: "Greet a user by name".into(),
86///     parameters_schema: serde_json::json!({"type":"object","properties":{"name":{"type":"string"}},"required":["name"]}),
87///     handler: make_handler(|input| async move {
88///         let name = input["name"].as_str().unwrap_or("world");
89///         Ok(serde_json::json!({"greeting": format!("Hello, {name}!")}))
90///     }),
91/// };
92/// ```
93pub fn make_handler<F, Fut>(
94    f: F,
95) -> Box<
96    dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<serde_json::Value, ToolError>>
97        + Send
98        + Sync,
99>
100where
101    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
102    Fut: std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
103{
104    Box::new(move |input| Box::pin(f(input)))
105}
106
107/// Registry of named tools for the LLM dispatch loop.
108///
109/// ## Construction
110///
111/// `max_iterations` is required at construction — there is no zero-arg constructor
112/// and no way to create an unbounded loop (SC#5, D-12). Suggested default: 10.
113///
114/// ```rust,ignore
115/// let registry = ToolRegistry::new(10);
116/// // or equivalently:
117/// let registry = ToolRegistry::with_default_iterations();
118/// ```
119///
120/// ## Dispatch
121///
122/// [`ToolRegistry::dispatch`] loops until the LLM returns a text response or the
123/// iteration cap is reached. At iteration 5 a warning is logged; at the cap an error
124/// is logged and [`Error::ToolIterationLimit`] is returned.
125pub struct ToolRegistry {
126    tools: HashMap<String, ToolDef>,
127    max_iterations: u32,
128}
129
130impl ToolRegistry {
131    /// Create a new registry with an explicit iteration cap.
132    ///
133    /// There is no `Default` impl and no zero-arg `new()`. Every `ToolRegistry`
134    /// must carry an explicit `max_iterations` to prevent unbounded loops (SC#5).
135    pub fn new(max_iterations: u32) -> Self {
136        Self {
137            tools: HashMap::new(),
138            max_iterations,
139        }
140    }
141
142    /// Convenience constructor with `max_iterations = 10`.
143    pub fn with_default_iterations() -> Self {
144        Self::new(10)
145    }
146
147    /// Register a tool definition.
148    ///
149    /// If a tool with the same name is already registered, it is replaced.
150    pub fn register(&mut self, tool: ToolDef) {
151        self.tools.insert(tool.name.clone(), tool);
152    }
153
154    /// Build a `CompletionRequest` for one dispatch iteration.
155    fn build_request(&self, messages: Vec<Message>) -> CompletionRequest {
156        let tool_requests: Vec<ToolRequest> = self
157            .tools
158            .values()
159            .map(|t| ToolRequest {
160                name: t.name.clone(),
161                description: t.description.clone(),
162                parameters_schema: t.parameters_schema.clone(),
163            })
164            .collect();
165
166        CompletionRequest {
167            system: None,
168            messages,
169            max_tokens: 4096,
170            model_override: None,
171            schema: None,
172            tools: if tool_requests.is_empty() {
173                None
174            } else {
175                Some(tool_requests)
176            },
177            tool_choice: Some(ToolChoice::Auto),
178        }
179    }
180
181    /// Convert a tool handler result into a `Message` to send back to the LLM.
182    ///
183    /// On `Ok(value)` → JSON-serialized result.
184    /// On `Err(ToolError { message })` → the model-legible message (SC#6).
185    ///
186    /// The `block_id` is stored in `tool_call_id` so each provider's `build_body`
187    /// can place it in the correct wire location without string encoding/decoding:
188    /// - Anthropic: `tool_use_id` inside a `tool_result` content block.
189    /// - OpenAI: top-level `tool_call_id` field on the `role: "tool"` message.
190    fn result_to_message(block_id: &str, result: Result<serde_json::Value, ToolError>) -> Message {
191        let content = match result {
192            Ok(value) => value.to_string(),
193            Err(te) => te.message,
194        };
195        Message {
196            role: Role::Tool,
197            content,
198            tool_call_id: Some(block_id.to_string()),
199        }
200    }
201
202    /// Dispatch a tool-calling conversation loop.
203    ///
204    /// Calls `client.complete_with_tools` repeatedly until the LLM returns a text
205    /// response or `max_iterations` is reached. Each `ToolUse` response dispatches
206    /// registered handlers and appends results before the next iteration.
207    ///
208    /// ## Iteration limits (SC#5, T-166-01)
209    ///
210    /// - At iteration 5: `tracing::warn!` (advisory — loop still continues).
211    /// - At `max_iterations`: `tracing::error!` + `Err(Error::ToolIterationLimit)`.
212    ///   This is a hard cap with no override path.
213    ///
214    /// ## Error surfacing (SC#6, T-166-02)
215    ///
216    /// Handler `Err(ToolError { message })` is sent to the LLM as a tool_result
217    /// message carrying only `message`. Unknown tool names are also surfaced to the
218    /// LLM as model-recoverable error strings (not `Error::ToolNotFound`) so the
219    /// model can adapt its tool selection.
220    pub async fn dispatch(
221        &self,
222        mut messages: Vec<Message>,
223        client: &dyn LlmClient,
224    ) -> Result<Vec<Message>, Error> {
225        for iteration in 0..=self.max_iterations {
226            // WR-02: warn fires before the cap check so it is reachable when max_iterations > 5.
227            if iteration == 5 && self.max_iterations > 5 {
228                warn!(
229                    iteration,
230                    max = self.max_iterations,
231                    "tool dispatch at iteration 5"
232                );
233            }
234            if iteration == self.max_iterations {
235                error!(
236                    max_iterations = self.max_iterations,
237                    "tool dispatch hit iteration limit"
238                );
239                return Err(Error::ToolIterationLimit(self.max_iterations));
240            }
241
242            let request = self.build_request(messages.clone());
243            let response = client.complete_with_tools(request).await?;
244
245            match response {
246                CompletionResponse::Text(text) => {
247                    messages.push(Message {
248                        role: Role::Assistant,
249                        content: text,
250                        tool_call_id: None,
251                    });
252                    return Ok(messages);
253                }
254                // CR-02: push the assistant tool-use turn BEFORE the tool result messages.
255                // Both Anthropic and OpenAI require alternating roles with the assistant's
256                // tool_use/tool_calls block present before the corresponding tool_result.
257                CompletionResponse::ToolUse {
258                    blocks,
259                    assistant_content,
260                } => {
261                    messages.push(Message {
262                        role: Role::Assistant,
263                        content: assistant_content,
264                        tool_call_id: None,
265                    });
266                    for block in &blocks {
267                        let result = self.call_tool(block).await;
268                        messages.push(Self::result_to_message(&block.id, result));
269                    }
270                }
271            }
272        }
273        unreachable!()
274    }
275
276    /// Call the handler for one tool-use block.
277    ///
278    /// Unknown tool names are surfaced to the LLM as a model-recoverable error string
279    /// rather than aborting the dispatch loop — the model can select a different tool.
280    async fn call_tool(&self, block: &ToolUseBlock) -> Result<serde_json::Value, ToolError> {
281        match self.tools.get(&block.name) {
282            None => Err(ToolError {
283                message: format!("tool '{}' is not registered", block.name),
284            }),
285            Some(tool) => (tool.handler)(block.input.clone()).await,
286        }
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use crate::client::{CompletionRequest, TokenStream};
294    use async_trait::async_trait;
295    use std::sync::{
296        atomic::{AtomicU32, Ordering},
297        Arc,
298    };
299
300    // ─── SC#4: ToolDef construction ──────────────────────────────────────────
301
302    /// SC#4: ToolDef carries name, description, parameters_schema, and async handler.
303    #[tokio::test]
304    async fn tool_def_construction() {
305        let schema = serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}});
306        let def = ToolDef {
307            name: "my_tool".into(),
308            description: "does a thing".into(),
309            parameters_schema: schema.clone(),
310            handler: make_handler(
311                |_input| async move { Ok(serde_json::json!({"result": "done"})) },
312            ),
313        };
314        assert_eq!(def.name, "my_tool");
315        assert_eq!(def.description, "does a thing");
316        assert_eq!(def.parameters_schema, schema);
317        // Handler must be callable and return Ok.
318        let result = (def.handler)(serde_json::json!({})).await;
319        assert!(result.is_ok());
320    }
321
322    // ─── SC#6: ToolError is model-legible ───────────────────────────────────
323
324    /// SC#6: ToolError Display returns exactly the message, nothing else.
325    #[test]
326    fn tool_error_is_model_legible() {
327        let err = ToolError {
328            message: "domain message".into(),
329        };
330        assert_eq!(format!("{err}"), "domain message");
331        // Debug output contains the struct name and field, but Display is the
332        // model-facing representation — assert Display == message only.
333        let debug_str = format!("{err:?}");
334        assert!(debug_str.contains("domain message"));
335    }
336
337    // ─── No unbounded path ───────────────────────────────────────────────────
338
339    /// Documents that ToolRegistry::new(n) works and with_default_iterations works.
340    /// The absence of Default and a zero-arg new() is enforced by the compiler —
341    /// this test documents the expected construction API.
342    #[test]
343    fn tool_registry_requires_max_iterations() {
344        let r1 = ToolRegistry::new(3);
345        assert_eq!(r1.max_iterations, 3);
346        let r2 = ToolRegistry::with_default_iterations();
347        assert_eq!(r2.max_iterations, 10);
348    }
349
350    // ─── Dispatch loop tests (used in Task 3, defined here for Task 2 GREEN) ─
351
352    /// Mock LlmClient that returns ToolUse for `stop_after` calls then returns Text.
353    struct LoopingClient {
354        calls: Arc<AtomicU32>,
355        stop_after: u32,
356        tool_name: String,
357    }
358
359    #[async_trait]
360    impl LlmClient for LoopingClient {
361        fn default_model(&self) -> &str {
362            "test"
363        }
364
365        async fn complete(&self, _: CompletionRequest) -> Result<String, Error> {
366            Err(Error::Unsupported)
367        }
368
369        async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
370            Err(Error::Unsupported)
371        }
372
373        async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
374            Err(Error::Unsupported)
375        }
376
377        async fn complete_with_tools(
378            &self,
379            _: CompletionRequest,
380        ) -> Result<CompletionResponse, Error> {
381            let n = self.calls.fetch_add(1, Ordering::SeqCst);
382            if n >= self.stop_after {
383                Ok(CompletionResponse::Text("done".into()))
384            } else {
385                Ok(CompletionResponse::ToolUse {
386                    blocks: vec![ToolUseBlock {
387                        id: format!("call_{n}"),
388                        name: self.tool_name.clone(),
389                        input: serde_json::json!({}),
390                    }],
391                    assistant_content: format!(
392                        r#"[{{"type":"tool_use","id":"call_{n}","name":"{}","input":{{}}}}]"#,
393                        self.tool_name
394                    ),
395                })
396            }
397        }
398    }
399
400    /// SC#5: dispatch returns Err(ToolIterationLimit) at the hard cap.
401    #[tokio::test]
402    async fn tool_registry_enforces_max_iterations() {
403        let registry = ToolRegistry::new(3);
404        let calls = Arc::new(AtomicU32::new(0));
405        let client = LoopingClient {
406            calls,
407            stop_after: 99, // never stops on its own
408            tool_name: "no_op".into(),
409        };
410        let result = registry.dispatch(vec![], &client).await;
411        assert!(
412            matches!(result, Err(Error::ToolIterationLimit(3))),
413            "expected ToolIterationLimit(3), got {result:?}"
414        );
415    }
416
417    /// dispatch returns Ok when the client returns Text on the first call.
418    #[tokio::test]
419    async fn dispatch_returns_on_text() {
420        let registry = ToolRegistry::new(5);
421        let calls = Arc::new(AtomicU32::new(0));
422        let client = LoopingClient {
423            calls,
424            stop_after: 0, // returns Text immediately
425            tool_name: "no_op".into(),
426        };
427        let result = registry.dispatch(vec![], &client).await;
428        assert!(result.is_ok());
429        let messages = result.unwrap();
430        assert!(
431            messages
432                .iter()
433                .any(|m| matches!(m.role, Role::Assistant) && m.content == "done"),
434            "expected assistant message with 'done'"
435        );
436    }
437
438    /// SC#6: a handler returning ToolError surfaces only its message to the LLM.
439    ///
440    /// The dispatch loop must complete (not abort) when a registered handler fails,
441    /// and the tool_result message must carry the model-legible ToolError message,
442    /// not a raw panic or Rust debug string.
443    #[tokio::test]
444    async fn dispatch_surfaces_tool_error() {
445        let mut registry = ToolRegistry::new(5);
446
447        // Register a tool that always fails with a model-legible message.
448        registry.register(ToolDef {
449            name: "failing_tool".into(),
450            description: "always fails".into(),
451            parameters_schema: serde_json::json!({}),
452            handler: make_handler(|_| async move {
453                Err(ToolError {
454                    message: "order not found".into(),
455                })
456            }),
457        });
458
459        // Client: first call returns ToolUse for failing_tool, second returns Text.
460        let calls = Arc::new(AtomicU32::new(0));
461        let client = LoopingClient {
462            calls,
463            stop_after: 1, // after 1 ToolUse call → Text
464            tool_name: "failing_tool".into(),
465        };
466
467        let result = registry.dispatch(vec![], &client).await;
468        assert!(
469            result.is_ok(),
470            "dispatch must complete even after tool error"
471        );
472
473        let messages = result.unwrap();
474        // There must be a Role::Tool message carrying the model-legible error.
475        let tool_result = messages.iter().find(|m| matches!(m.role, Role::Tool));
476        assert!(
477            tool_result.is_some(),
478            "expected a Role::Tool result message"
479        );
480        let content = &tool_result.unwrap().content;
481        assert!(
482            content.contains("order not found"),
483            "ToolError message must appear in tool result, got: {content}"
484        );
485        // Must NOT contain raw Rust panic text or debug noise.
486        assert!(
487            !content.contains("panicked at"),
488            "tool result must not contain panic text"
489        );
490    }
491
492    /// CR-02 regression: the dispatch loop must push the assistant tool-use turn into
493    /// history BEFORE the tool result messages. Providers require alternating roles.
494    #[tokio::test]
495    async fn dispatch_includes_assistant_turn_before_tool_results() {
496        let mut registry = ToolRegistry::new(5);
497
498        registry.register(ToolDef {
499            name: "echo".into(),
500            description: "echoes input".into(),
501            parameters_schema: serde_json::json!({}),
502            handler: make_handler(|_| async move { Ok(serde_json::json!({"result": "ok"})) }),
503        });
504
505        // Client: one ToolUse call then Text.
506        let calls = Arc::new(AtomicU32::new(0));
507        let client = LoopingClient {
508            calls,
509            stop_after: 1,
510            tool_name: "echo".into(),
511        };
512
513        let messages = registry.dispatch(vec![], &client).await.unwrap();
514
515        // Find positions of the assistant tool-use turn and the tool result turn.
516        let assistant_pos = messages
517            .iter()
518            .position(|m| matches!(m.role, Role::Assistant) && m.content.contains("tool_use"))
519            .expect("must have an assistant turn with tool_use content");
520        let tool_result_pos = messages
521            .iter()
522            .position(|m| matches!(m.role, Role::Tool))
523            .expect("must have a tool result message");
524
525        assert!(
526            assistant_pos < tool_result_pos,
527            "assistant tool-use turn (pos {assistant_pos}) must precede tool result (pos {tool_result_pos})"
528        );
529
530        // The tool result must carry a real tool_call_id, not embedded in content.
531        let tool_msg = &messages[tool_result_pos];
532        assert!(
533            tool_msg.tool_call_id.is_some(),
534            "tool result message must carry tool_call_id"
535        );
536        assert!(
537            !tool_msg.content.contains("call_"),
538            "tool_call_id must not be embedded in content string, got: {}",
539            tool_msg.content
540        );
541    }
542
543    /// WR-03: call_tool returns a ToolError (not Error::ToolNotFound) for unknown tool names,
544    /// so the dispatch loop can surface it to the LLM as a recoverable message.
545    /// Error::ToolNotFound is reserved as a public API variant for future direct-dispatch helpers.
546    #[tokio::test]
547    async fn dispatch_surfaces_unknown_tool_as_tool_error() {
548        // Registry with no registered tools.
549        let registry = ToolRegistry::new(5);
550
551        // Client returns one ToolUse for an unregistered tool, then Text.
552        let calls = Arc::new(AtomicU32::new(0));
553        let client = LoopingClient {
554            calls,
555            stop_after: 1,
556            tool_name: "nonexistent_tool".into(),
557        };
558
559        let result = registry.dispatch(vec![], &client).await;
560        // Dispatch must complete (not abort with ToolNotFound) — unknown tool is LLM-recoverable.
561        assert!(
562            result.is_ok(),
563            "dispatch must not abort for unknown tool; got {result:?}"
564        );
565        let messages = result.unwrap();
566        let tool_msg = messages
567            .iter()
568            .find(|m| matches!(m.role, Role::Tool))
569            .expect("must have a tool result message for the unknown tool");
570        assert!(
571            tool_msg.content.contains("not registered"),
572            "unknown tool error must surface to LLM as a message, got: {}",
573            tool_msg.content
574        );
575    }
576}