ferro_ai/tools/mod.rs
1//! Tool calling: `ToolDef`, `ToolError`, `ToolRegistry`, and the bounded dispatch loop.
2//!
3//! ## Safety contract (D-12, SC#5)
4//!
5//! [`ToolRegistry::new`] is the ONLY full constructor. `max_iterations` is required —
6//! there is no `Default` impl and no zero-arg constructor. The dispatch loop returns
7//! [`Error::ToolIterationLimit`] at the hard cap with no override path.
8//!
9//! ## Error surfacing (D-13, SC#6)
10//!
11//! Tool handler failures are surfaced to the LLM as [`ToolError`] messages, never as
12//! raw Rust panics, stack traces, or DB-constraint strings. The Rust caller receives
13//! [`Error::ToolIterationLimit`] when the loop exceeds its cap.
14//!
15//! ## Handler lifetime (D-11)
16//!
17//! Handler closures must satisfy `'static` — all captured state must be owned or
18//! `Arc`-wrapped. Capturing `&references` will not compile.
19
20use crate::client::{
21 CompletionRequest, CompletionResponse, LlmClient, Message, Role, ToolChoice, ToolRequest,
22 ToolUseBlock,
23};
24use crate::error::Error;
25use futures::future::BoxFuture;
26use std::collections::HashMap;
27use tracing::{error, warn};
28
29/// Model-legible tool error.
30///
31/// Surfaced to the LLM as a `tool_result` message carrying only `message`.
32/// Never exposed to Rust callers as a panic or raw DB string (SC#6, T-166-02).
33///
34/// Handler implementations are responsible for mapping domain errors to a
35/// human-readable `message` before returning `Err(ToolError { ... })`.
36#[derive(Debug, Clone)]
37pub struct ToolError {
38 /// The model-legible error message. Must not contain raw Rust panics,
39 /// stack traces, or DB-constraint strings.
40 pub message: String,
41}
42
43impl std::fmt::Display for ToolError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.write_str(&self.message)
46 }
47}
48
49/// A registered tool with its async handler.
50///
51/// `parameters_schema` must already be normalized via `schema::for_structured_output`
52/// before registration. The handler must own all captured state (no `&references` —
53/// wrap shared state in `Arc<T>` to satisfy the `'static` bound).
54pub struct ToolDef {
55 /// The tool name. Must match what the LLM will call.
56 pub name: String,
57 /// Human-readable description of what the tool does.
58 pub description: String,
59 /// JSON Schema for the tool's input parameters.
60 ///
61 /// Must be pre-normalized via `schema::for_structured_output`. The LLM-generated
62 /// input is passed as-is to the handler — handler implementations are responsible
63 /// for validating their own inputs before privileged actions (T-166-03).
64 pub parameters_schema: serde_json::Value,
65 /// The async handler closure.
66 ///
67 /// Receives the LLM-generated `serde_json::Value` and returns either a JSON result
68 /// or a [`ToolError`] with a model-legible message.
69 pub handler: Box<
70 dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<serde_json::Value, ToolError>>
71 + Send
72 + Sync,
73 >,
74}
75
76/// Helper to wrap an `async fn` or closure into the boxed handler type required by [`ToolDef`].
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use ferro_ai::tools::{make_handler, ToolDef, ToolError};
82///
83/// let def = ToolDef {
84/// name: "greet".into(),
85/// description: "Greet a user by name".into(),
86/// parameters_schema: serde_json::json!({"type":"object","properties":{"name":{"type":"string"}},"required":["name"]}),
87/// handler: make_handler(|input| async move {
88/// let name = input["name"].as_str().unwrap_or("world");
89/// Ok(serde_json::json!({"greeting": format!("Hello, {name}!")}))
90/// }),
91/// };
92/// ```
93pub fn make_handler<F, Fut>(
94 f: F,
95) -> Box<
96 dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<serde_json::Value, ToolError>>
97 + Send
98 + Sync,
99>
100where
101 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
102 Fut: std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
103{
104 Box::new(move |input| Box::pin(f(input)))
105}
106
107/// Registry of named tools for the LLM dispatch loop.
108///
109/// ## Construction
110///
111/// `max_iterations` is required at construction — there is no zero-arg constructor
112/// and no way to create an unbounded loop (SC#5, D-12). Suggested default: 10.
113///
114/// ```rust,ignore
115/// let registry = ToolRegistry::new(10);
116/// // or equivalently:
117/// let registry = ToolRegistry::with_default_iterations();
118/// ```
119///
120/// ## Dispatch
121///
122/// [`ToolRegistry::dispatch`] loops until the LLM returns a text response or the
123/// iteration cap is reached. At iteration 5 a warning is logged; at the cap an error
124/// is logged and [`Error::ToolIterationLimit`] is returned.
125pub struct ToolRegistry {
126 tools: HashMap<String, ToolDef>,
127 max_iterations: u32,
128}
129
130impl ToolRegistry {
131 /// Create a new registry with an explicit iteration cap.
132 ///
133 /// There is no `Default` impl and no zero-arg `new()`. Every `ToolRegistry`
134 /// must carry an explicit `max_iterations` to prevent unbounded loops (SC#5).
135 pub fn new(max_iterations: u32) -> Self {
136 Self {
137 tools: HashMap::new(),
138 max_iterations,
139 }
140 }
141
142 /// Convenience constructor with `max_iterations = 10`.
143 pub fn with_default_iterations() -> Self {
144 Self::new(10)
145 }
146
147 /// Register a tool definition.
148 ///
149 /// If a tool with the same name is already registered, it is replaced.
150 pub fn register(&mut self, tool: ToolDef) {
151 self.tools.insert(tool.name.clone(), tool);
152 }
153
154 /// Build a `CompletionRequest` for one dispatch iteration.
155 fn build_request(&self, messages: Vec<Message>) -> CompletionRequest {
156 let tool_requests: Vec<ToolRequest> = self
157 .tools
158 .values()
159 .map(|t| ToolRequest {
160 name: t.name.clone(),
161 description: t.description.clone(),
162 parameters_schema: t.parameters_schema.clone(),
163 })
164 .collect();
165
166 CompletionRequest {
167 system: None,
168 messages,
169 max_tokens: 4096,
170 model_override: None,
171 schema: None,
172 tools: if tool_requests.is_empty() {
173 None
174 } else {
175 Some(tool_requests)
176 },
177 tool_choice: Some(ToolChoice::Auto),
178 }
179 }
180
181 /// Convert a tool handler result into a `Message` to send back to the LLM.
182 ///
183 /// On `Ok(value)` → JSON-serialized result.
184 /// On `Err(ToolError { message })` → the model-legible message (SC#6).
185 ///
186 /// The `block_id` is stored in `tool_call_id` so each provider's `build_body`
187 /// can place it in the correct wire location without string encoding/decoding:
188 /// - Anthropic: `tool_use_id` inside a `tool_result` content block.
189 /// - OpenAI: top-level `tool_call_id` field on the `role: "tool"` message.
190 fn result_to_message(block_id: &str, result: Result<serde_json::Value, ToolError>) -> Message {
191 let content = match result {
192 Ok(value) => value.to_string(),
193 Err(te) => te.message,
194 };
195 Message {
196 role: Role::Tool,
197 content,
198 tool_call_id: Some(block_id.to_string()),
199 }
200 }
201
202 /// Dispatch a tool-calling conversation loop.
203 ///
204 /// Calls `client.complete_with_tools` repeatedly until the LLM returns a text
205 /// response or `max_iterations` is reached. Each `ToolUse` response dispatches
206 /// registered handlers and appends results before the next iteration.
207 ///
208 /// ## Iteration limits (SC#5, T-166-01)
209 ///
210 /// - At iteration 5: `tracing::warn!` (advisory — loop still continues).
211 /// - At `max_iterations`: `tracing::error!` + `Err(Error::ToolIterationLimit)`.
212 /// This is a hard cap with no override path.
213 ///
214 /// ## Error surfacing (SC#6, T-166-02)
215 ///
216 /// Handler `Err(ToolError { message })` is sent to the LLM as a tool_result
217 /// message carrying only `message`. Unknown tool names are also surfaced to the
218 /// LLM as model-recoverable error strings (not `Error::ToolNotFound`) so the
219 /// model can adapt its tool selection.
220 pub async fn dispatch(
221 &self,
222 mut messages: Vec<Message>,
223 client: &dyn LlmClient,
224 ) -> Result<Vec<Message>, Error> {
225 for iteration in 0..=self.max_iterations {
226 // WR-02: warn fires before the cap check so it is reachable when max_iterations > 5.
227 if iteration == 5 && self.max_iterations > 5 {
228 warn!(
229 iteration,
230 max = self.max_iterations,
231 "tool dispatch at iteration 5"
232 );
233 }
234 if iteration == self.max_iterations {
235 error!(
236 max_iterations = self.max_iterations,
237 "tool dispatch hit iteration limit"
238 );
239 return Err(Error::ToolIterationLimit(self.max_iterations));
240 }
241
242 let request = self.build_request(messages.clone());
243 let response = client.complete_with_tools(request).await?;
244
245 match response {
246 CompletionResponse::Text(text) => {
247 messages.push(Message {
248 role: Role::Assistant,
249 content: text,
250 tool_call_id: None,
251 });
252 return Ok(messages);
253 }
254 // CR-02: push the assistant tool-use turn BEFORE the tool result messages.
255 // Both Anthropic and OpenAI require alternating roles with the assistant's
256 // tool_use/tool_calls block present before the corresponding tool_result.
257 CompletionResponse::ToolUse {
258 blocks,
259 assistant_content,
260 } => {
261 messages.push(Message {
262 role: Role::Assistant,
263 content: assistant_content,
264 tool_call_id: None,
265 });
266 for block in &blocks {
267 let result = self.call_tool(block).await;
268 messages.push(Self::result_to_message(&block.id, result));
269 }
270 }
271 }
272 }
273 unreachable!()
274 }
275
276 /// Call the handler for one tool-use block.
277 ///
278 /// Unknown tool names are surfaced to the LLM as a model-recoverable error string
279 /// rather than aborting the dispatch loop — the model can select a different tool.
280 async fn call_tool(&self, block: &ToolUseBlock) -> Result<serde_json::Value, ToolError> {
281 match self.tools.get(&block.name) {
282 None => Err(ToolError {
283 message: format!("tool '{}' is not registered", block.name),
284 }),
285 Some(tool) => (tool.handler)(block.input.clone()).await,
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use crate::client::{CompletionRequest, TokenStream};
294 use async_trait::async_trait;
295 use std::sync::{
296 atomic::{AtomicU32, Ordering},
297 Arc,
298 };
299
300 // ─── SC#4: ToolDef construction ──────────────────────────────────────────
301
302 /// SC#4: ToolDef carries name, description, parameters_schema, and async handler.
303 #[tokio::test]
304 async fn tool_def_construction() {
305 let schema = serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}});
306 let def = ToolDef {
307 name: "my_tool".into(),
308 description: "does a thing".into(),
309 parameters_schema: schema.clone(),
310 handler: make_handler(
311 |_input| async move { Ok(serde_json::json!({"result": "done"})) },
312 ),
313 };
314 assert_eq!(def.name, "my_tool");
315 assert_eq!(def.description, "does a thing");
316 assert_eq!(def.parameters_schema, schema);
317 // Handler must be callable and return Ok.
318 let result = (def.handler)(serde_json::json!({})).await;
319 assert!(result.is_ok());
320 }
321
322 // ─── SC#6: ToolError is model-legible ───────────────────────────────────
323
324 /// SC#6: ToolError Display returns exactly the message, nothing else.
325 #[test]
326 fn tool_error_is_model_legible() {
327 let err = ToolError {
328 message: "domain message".into(),
329 };
330 assert_eq!(format!("{err}"), "domain message");
331 // Debug output contains the struct name and field, but Display is the
332 // model-facing representation — assert Display == message only.
333 let debug_str = format!("{err:?}");
334 assert!(debug_str.contains("domain message"));
335 }
336
337 // ─── No unbounded path ───────────────────────────────────────────────────
338
339 /// Documents that ToolRegistry::new(n) works and with_default_iterations works.
340 /// The absence of Default and a zero-arg new() is enforced by the compiler —
341 /// this test documents the expected construction API.
342 #[test]
343 fn tool_registry_requires_max_iterations() {
344 let r1 = ToolRegistry::new(3);
345 assert_eq!(r1.max_iterations, 3);
346 let r2 = ToolRegistry::with_default_iterations();
347 assert_eq!(r2.max_iterations, 10);
348 }
349
350 // ─── Dispatch loop tests (used in Task 3, defined here for Task 2 GREEN) ─
351
352 /// Mock LlmClient that returns ToolUse for `stop_after` calls then returns Text.
353 struct LoopingClient {
354 calls: Arc<AtomicU32>,
355 stop_after: u32,
356 tool_name: String,
357 }
358
359 #[async_trait]
360 impl LlmClient for LoopingClient {
361 fn default_model(&self) -> &str {
362 "test"
363 }
364
365 async fn complete(&self, _: CompletionRequest) -> Result<String, Error> {
366 Err(Error::Unsupported)
367 }
368
369 async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
370 Err(Error::Unsupported)
371 }
372
373 async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
374 Err(Error::Unsupported)
375 }
376
377 async fn complete_with_tools(
378 &self,
379 _: CompletionRequest,
380 ) -> Result<CompletionResponse, Error> {
381 let n = self.calls.fetch_add(1, Ordering::SeqCst);
382 if n >= self.stop_after {
383 Ok(CompletionResponse::Text("done".into()))
384 } else {
385 Ok(CompletionResponse::ToolUse {
386 blocks: vec![ToolUseBlock {
387 id: format!("call_{n}"),
388 name: self.tool_name.clone(),
389 input: serde_json::json!({}),
390 }],
391 assistant_content: format!(
392 r#"[{{"type":"tool_use","id":"call_{n}","name":"{}","input":{{}}}}]"#,
393 self.tool_name
394 ),
395 })
396 }
397 }
398 }
399
400 /// SC#5: dispatch returns Err(ToolIterationLimit) at the hard cap.
401 #[tokio::test]
402 async fn tool_registry_enforces_max_iterations() {
403 let registry = ToolRegistry::new(3);
404 let calls = Arc::new(AtomicU32::new(0));
405 let client = LoopingClient {
406 calls,
407 stop_after: 99, // never stops on its own
408 tool_name: "no_op".into(),
409 };
410 let result = registry.dispatch(vec![], &client).await;
411 assert!(
412 matches!(result, Err(Error::ToolIterationLimit(3))),
413 "expected ToolIterationLimit(3), got {result:?}"
414 );
415 }
416
417 /// dispatch returns Ok when the client returns Text on the first call.
418 #[tokio::test]
419 async fn dispatch_returns_on_text() {
420 let registry = ToolRegistry::new(5);
421 let calls = Arc::new(AtomicU32::new(0));
422 let client = LoopingClient {
423 calls,
424 stop_after: 0, // returns Text immediately
425 tool_name: "no_op".into(),
426 };
427 let result = registry.dispatch(vec![], &client).await;
428 assert!(result.is_ok());
429 let messages = result.unwrap();
430 assert!(
431 messages
432 .iter()
433 .any(|m| matches!(m.role, Role::Assistant) && m.content == "done"),
434 "expected assistant message with 'done'"
435 );
436 }
437
438 /// SC#6: a handler returning ToolError surfaces only its message to the LLM.
439 ///
440 /// The dispatch loop must complete (not abort) when a registered handler fails,
441 /// and the tool_result message must carry the model-legible ToolError message,
442 /// not a raw panic or Rust debug string.
443 #[tokio::test]
444 async fn dispatch_surfaces_tool_error() {
445 let mut registry = ToolRegistry::new(5);
446
447 // Register a tool that always fails with a model-legible message.
448 registry.register(ToolDef {
449 name: "failing_tool".into(),
450 description: "always fails".into(),
451 parameters_schema: serde_json::json!({}),
452 handler: make_handler(|_| async move {
453 Err(ToolError {
454 message: "order not found".into(),
455 })
456 }),
457 });
458
459 // Client: first call returns ToolUse for failing_tool, second returns Text.
460 let calls = Arc::new(AtomicU32::new(0));
461 let client = LoopingClient {
462 calls,
463 stop_after: 1, // after 1 ToolUse call → Text
464 tool_name: "failing_tool".into(),
465 };
466
467 let result = registry.dispatch(vec![], &client).await;
468 assert!(
469 result.is_ok(),
470 "dispatch must complete even after tool error"
471 );
472
473 let messages = result.unwrap();
474 // There must be a Role::Tool message carrying the model-legible error.
475 let tool_result = messages.iter().find(|m| matches!(m.role, Role::Tool));
476 assert!(
477 tool_result.is_some(),
478 "expected a Role::Tool result message"
479 );
480 let content = &tool_result.unwrap().content;
481 assert!(
482 content.contains("order not found"),
483 "ToolError message must appear in tool result, got: {content}"
484 );
485 // Must NOT contain raw Rust panic text or debug noise.
486 assert!(
487 !content.contains("panicked at"),
488 "tool result must not contain panic text"
489 );
490 }
491
492 /// CR-02 regression: the dispatch loop must push the assistant tool-use turn into
493 /// history BEFORE the tool result messages. Providers require alternating roles.
494 #[tokio::test]
495 async fn dispatch_includes_assistant_turn_before_tool_results() {
496 let mut registry = ToolRegistry::new(5);
497
498 registry.register(ToolDef {
499 name: "echo".into(),
500 description: "echoes input".into(),
501 parameters_schema: serde_json::json!({}),
502 handler: make_handler(|_| async move { Ok(serde_json::json!({"result": "ok"})) }),
503 });
504
505 // Client: one ToolUse call then Text.
506 let calls = Arc::new(AtomicU32::new(0));
507 let client = LoopingClient {
508 calls,
509 stop_after: 1,
510 tool_name: "echo".into(),
511 };
512
513 let messages = registry.dispatch(vec![], &client).await.unwrap();
514
515 // Find positions of the assistant tool-use turn and the tool result turn.
516 let assistant_pos = messages
517 .iter()
518 .position(|m| matches!(m.role, Role::Assistant) && m.content.contains("tool_use"))
519 .expect("must have an assistant turn with tool_use content");
520 let tool_result_pos = messages
521 .iter()
522 .position(|m| matches!(m.role, Role::Tool))
523 .expect("must have a tool result message");
524
525 assert!(
526 assistant_pos < tool_result_pos,
527 "assistant tool-use turn (pos {assistant_pos}) must precede tool result (pos {tool_result_pos})"
528 );
529
530 // The tool result must carry a real tool_call_id, not embedded in content.
531 let tool_msg = &messages[tool_result_pos];
532 assert!(
533 tool_msg.tool_call_id.is_some(),
534 "tool result message must carry tool_call_id"
535 );
536 assert!(
537 !tool_msg.content.contains("call_"),
538 "tool_call_id must not be embedded in content string, got: {}",
539 tool_msg.content
540 );
541 }
542
543 /// WR-03: call_tool returns a ToolError (not Error::ToolNotFound) for unknown tool names,
544 /// so the dispatch loop can surface it to the LLM as a recoverable message.
545 /// Error::ToolNotFound is reserved as a public API variant for future direct-dispatch helpers.
546 #[tokio::test]
547 async fn dispatch_surfaces_unknown_tool_as_tool_error() {
548 // Registry with no registered tools.
549 let registry = ToolRegistry::new(5);
550
551 // Client returns one ToolUse for an unregistered tool, then Text.
552 let calls = Arc::new(AtomicU32::new(0));
553 let client = LoopingClient {
554 calls,
555 stop_after: 1,
556 tool_name: "nonexistent_tool".into(),
557 };
558
559 let result = registry.dispatch(vec![], &client).await;
560 // Dispatch must complete (not abort with ToolNotFound) — unknown tool is LLM-recoverable.
561 assert!(
562 result.is_ok(),
563 "dispatch must not abort for unknown tool; got {result:?}"
564 );
565 let messages = result.unwrap();
566 let tool_msg = messages
567 .iter()
568 .find(|m| matches!(m.role, Role::Tool))
569 .expect("must have a tool result message for the unknown tool");
570 assert!(
571 tool_msg.content.contains("not registered"),
572 "unknown tool error must surface to LLM as a message, got: {}",
573 tool_msg.content
574 );
575 }
576}