Skip to main content

llm_agent_runtime/
agent.rs

1//! # Module: Agent
2//!
3//! ## Responsibility
4//! Provides a ReAct (Thought-Action-Observation) agent loop with pluggable tools.
5//! Mirrors the public API of `wasm-agent`.
6//!
7//! ## Guarantees
8//! - Deterministic: the loop terminates after at most `max_iterations` cycles
9//! - Non-panicking: all operations return `Result`
10//! - Tool handlers support both sync and async `Fn` closures
11//!
12//! ## NOT Responsible For
13//! - Actual LLM inference (callers supply a mock/stub inference fn)
14//! - WASM compilation or browser execution
15//! - Streaming partial responses
16
17use crate::error::AgentRuntimeError;
18use crate::metrics::RuntimeMetrics;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::collections::HashMap;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25
26// ── Types ─────────────────────────────────────────────────────────────────────
27
28/// A pinned, boxed future returning a `Value`. Used for async tool handlers.
29pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
30
31/// A pinned, boxed future returning `Result<Value, String>`. Used for fallible async tool handlers.
32pub type AsyncToolResultFuture = Pin<Box<dyn Future<Output = Result<Value, String>> + Send>>;
33
34/// An async tool handler closure.
35pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
36
37/// Role of a message in a conversation.
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
39pub enum Role {
40    /// System-level instruction injected before the user turn.
41    System,
42    /// Message from the human user.
43    User,
44    /// Message produced by the language model.
45    Assistant,
46    /// Message produced by a tool invocation.
47    Tool,
48}
49
50/// A single message in the conversation history.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct Message {
53    /// The role of the speaker who produced this message.
54    pub role: Role,
55    /// The textual content of the message.
56    pub content: String,
57}
58
59impl Message {
60    /// Create a new `Message` with the given role and content.
61    ///
62    /// # Panics
63    ///
64    /// This function does not panic.
65    pub fn new(role: Role, content: impl Into<String>) -> Self {
66        Self {
67            role,
68            content: content.into(),
69        }
70    }
71
72    /// Return a reference to the message role.
73    pub fn role(&self) -> &Role {
74        &self.role
75    }
76
77    /// Return the message content as a `&str`.
78    pub fn content(&self) -> &str {
79        &self.content
80    }
81}
82
83impl std::fmt::Display for Role {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            Role::System => write!(f, "system"),
87            Role::User => write!(f, "user"),
88            Role::Assistant => write!(f, "assistant"),
89            Role::Tool => write!(f, "tool"),
90        }
91    }
92}
93
94/// A single ReAct step: Thought → Action → Observation.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ReActStep {
97    /// Agent's reasoning about the current state.
98    pub thought: String,
99    /// The action taken (tool name + JSON arguments, or "FINAL_ANSWER").
100    pub action: String,
101    /// The result of the action.
102    pub observation: String,
103    /// Wall-clock duration of this individual step in milliseconds.
104    /// Covers the time from the start of the inference call to the end of the
105    /// tool invocation (or FINAL_ANSWER detection).  Zero for steps that were
106    /// constructed outside the loop (e.g., in tests).
107    #[serde(default)]
108    pub step_duration_ms: u64,
109}
110
111impl ReActStep {
112    /// Returns `true` if this step's action is a `FINAL_ANSWER`.
113    pub fn is_final_answer(&self) -> bool {
114        self.action.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER")
115    }
116
117    /// Returns `true` if this step's action is a tool call (not a FINAL_ANSWER).
118    pub fn is_tool_call(&self) -> bool {
119        !self.is_final_answer() && !self.action.trim().is_empty()
120    }
121}
122
123/// Configuration for the ReAct agent loop.
124#[derive(Debug, Clone)]
125pub struct AgentConfig {
126    /// Maximum number of Thought-Action-Observation cycles.
127    pub max_iterations: usize,
128    /// Model identifier passed to the inference function.
129    pub model: String,
130    /// System prompt injected at the start of the conversation.
131    pub system_prompt: String,
132    /// Maximum number of episodic memories to inject into the prompt.
133    /// Keeping this small prevents silent token-budget overruns.  Default: 3.
134    pub max_memory_recalls: usize,
135    /// Maximum approximate token budget for injected memories.
136    /// Uses ~4 chars/token heuristic. None means no limit.
137    pub max_memory_tokens: Option<usize>,
138    /// Optional wall-clock timeout for the entire loop.
139    /// If the loop runs longer than this duration, it returns
140    /// `Err(AgentRuntimeError::AgentLoop("loop timeout ..."))`.
141    pub loop_timeout: Option<std::time::Duration>,
142    /// Model sampling temperature.
143    pub temperature: Option<f32>,
144    /// Maximum output tokens.
145    pub max_tokens: Option<usize>,
146    /// Per-inference timeout.
147    pub request_timeout: Option<std::time::Duration>,
148}
149
150impl AgentConfig {
151    /// Create a new config with sensible defaults.
152    pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
153        Self {
154            max_iterations,
155            model: model.into(),
156            system_prompt: "You are a helpful AI agent.".into(),
157            max_memory_recalls: 3,
158            max_memory_tokens: None,
159            loop_timeout: None,
160            temperature: None,
161            max_tokens: None,
162            request_timeout: None,
163        }
164    }
165
166    /// Override the system prompt.
167    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
168        self.system_prompt = prompt.into();
169        self
170    }
171
172    /// Set the maximum number of episodic memories injected per run.
173    pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
174        self.max_memory_recalls = n;
175        self
176    }
177
178    /// Set a maximum token budget for injected memories (~4 chars/token heuristic).
179    pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
180        self.max_memory_tokens = Some(n);
181        self
182    }
183
184    /// Set a wall-clock timeout for the entire ReAct loop.
185    ///
186    /// If the loop has not reached `FINAL_ANSWER` within this duration,
187    /// [`ReActLoop::run`] returns `Err(AgentRuntimeError::AgentLoop(...))`.
188    pub fn with_loop_timeout(mut self, d: std::time::Duration) -> Self {
189        self.loop_timeout = Some(d);
190        self
191    }
192
193    /// Set the model sampling temperature.
194    pub fn with_temperature(mut self, t: f32) -> Self {
195        self.temperature = Some(t);
196        self
197    }
198
199    /// Set the maximum output tokens.
200    pub fn with_max_tokens(mut self, n: usize) -> Self {
201        self.max_tokens = Some(n);
202        self
203    }
204
205    /// Set the per-inference timeout.
206    pub fn with_request_timeout(mut self, d: std::time::Duration) -> Self {
207        self.request_timeout = Some(d);
208        self
209    }
210}
211
212// ── ToolSpec ──────────────────────────────────────────────────────────────────
213
214/// Describes and implements a single callable tool.
215pub struct ToolSpec {
216    /// Short identifier used in action strings (e.g. "search").
217    pub name: String,
218    /// Human-readable description passed to the model as part of the system prompt.
219    pub description: String,
220    /// Async handler: receives JSON arguments, returns a future resolving to a JSON result.
221    pub(crate) handler: AsyncToolHandler,
222    /// Field names that must be present in the JSON args object.
223    /// Empty means no validation is performed.
224    pub required_fields: Vec<String>,
225    /// Custom argument validators run after `required_fields` and before the handler.
226    /// All validators must pass; the first failure short-circuits execution.
227    pub validators: Vec<Box<dyn ToolValidator>>,
228    /// Optional per-tool circuit breaker.
229    #[cfg(feature = "orchestrator")]
230    pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
231}
232
233impl std::fmt::Debug for ToolSpec {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        let mut s = f.debug_struct("ToolSpec");
236        s.field("name", &self.name)
237            .field("description", &self.description)
238            .field("required_fields", &self.required_fields);
239        #[cfg(feature = "orchestrator")]
240        s.field("has_circuit_breaker", &self.circuit_breaker.is_some());
241        s.finish()
242    }
243}
244
245impl ToolSpec {
246    /// Construct a new `ToolSpec` from a synchronous handler closure.
247    /// The closure is wrapped in an `async move` block automatically.
248    pub fn new(
249        name: impl Into<String>,
250        description: impl Into<String>,
251        handler: impl Fn(Value) -> Value + Send + Sync + 'static,
252    ) -> Self {
253        Self {
254            name: name.into(),
255            description: description.into(),
256            handler: Box::new(move |args| {
257                let result = handler(args);
258                Box::pin(async move { result })
259            }),
260            required_fields: Vec::new(),
261            validators: Vec::new(),
262            #[cfg(feature = "orchestrator")]
263            circuit_breaker: None,
264        }
265    }
266
267    /// Construct a new `ToolSpec` from an async handler closure.
268    pub fn new_async(
269        name: impl Into<String>,
270        description: impl Into<String>,
271        handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
272    ) -> Self {
273        Self {
274            name: name.into(),
275            description: description.into(),
276            handler: Box::new(handler),
277            required_fields: Vec::new(),
278            validators: Vec::new(),
279            #[cfg(feature = "orchestrator")]
280            circuit_breaker: None,
281        }
282    }
283
284    /// Construct a new `ToolSpec` from a synchronous fallible handler closure.
285    /// `Err(msg)` is converted to `{"error": msg, "ok": false}`.
286    pub fn new_fallible(
287        name: impl Into<String>,
288        description: impl Into<String>,
289        handler: impl Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
290    ) -> Self {
291        Self {
292            name: name.into(),
293            description: description.into(),
294            handler: Box::new(move |args| {
295                let result = handler(args);
296                let value = match result {
297                    Ok(v) => v,
298                    Err(msg) => serde_json::json!({"error": msg, "ok": false}),
299                };
300                Box::pin(async move { value })
301            }),
302            required_fields: Vec::new(),
303            validators: Vec::new(),
304            #[cfg(feature = "orchestrator")]
305            circuit_breaker: None,
306        }
307    }
308
309    /// Construct a new `ToolSpec` from an async fallible handler closure.
310    /// `Err(msg)` is converted to `{"error": msg, "ok": false}`.
311    pub fn new_async_fallible(
312        name: impl Into<String>,
313        description: impl Into<String>,
314        handler: impl Fn(Value) -> AsyncToolResultFuture + Send + Sync + 'static,
315    ) -> Self {
316        Self {
317            name: name.into(),
318            description: description.into(),
319            handler: Box::new(move |args| {
320                let fut = handler(args);
321                Box::pin(async move {
322                    match fut.await {
323                        Ok(v) => v,
324                        Err(msg) => serde_json::json!({"error": msg, "ok": false}),
325                    }
326                })
327            }),
328            required_fields: Vec::new(),
329            validators: Vec::new(),
330            #[cfg(feature = "orchestrator")]
331            circuit_breaker: None,
332        }
333    }
334
335    /// Set the required fields that must be present in the JSON args object.
336    pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
337        self.required_fields = fields;
338        self
339    }
340
341    /// Attach custom argument validators.
342    ///
343    /// Validators run after `required_fields` checks and before the handler.
344    /// The first failing validator short-circuits execution.
345    pub fn with_validators(mut self, validators: Vec<Box<dyn ToolValidator>>) -> Self {
346        self.validators = validators;
347        self
348    }
349
350    /// Attach a circuit breaker to this tool spec.
351    #[cfg(feature = "orchestrator")]
352    pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
353        self.circuit_breaker = Some(cb);
354        self
355    }
356
357    /// Invoke the tool with the given JSON arguments.
358    pub async fn call(&self, args: Value) -> Value {
359        (self.handler)(args).await
360    }
361}
362
363// ── ToolCache ─────────────────────────────────────────────────────────────────
364
365/// Cache for tool call results.
366///
367/// Implement to deduplicate repeated identical tool calls within a single
368/// [`ReActLoop::run`] invocation.
369///
370/// ## Cache key
371/// Implementations should key on `(tool_name, args)`.  The `args` value is the
372/// full parsed JSON object passed to the tool.
373///
374/// ## Thread safety
375/// The trait is `Send + Sync`, so implementations must be safe to share across
376/// threads.  Wrap mutable state in a `Mutex` or use lock-free atomics.
377///
378/// ## TTL
379/// TTL semantics are implementation-defined.  A simple in-memory cache may
380/// keep results for the lifetime of the [`ReActLoop::run`] call; a distributed
381/// cache may use Redis with explicit expiry.
382///
383/// ## Lifetime
384/// A cache instance is attached to a `ToolRegistry` and lives for the lifetime
385/// of that registry.  Results are **not** automatically cleared between
386/// `ReActLoop::run` calls — clear the cache explicitly if cross-run dedup is
387/// not desired.
388pub trait ToolCache: Send + Sync {
389    /// Look up a cached result for `(tool_name, args)`.
390    fn get(&self, tool_name: &str, args: &serde_json::Value) -> Option<serde_json::Value>;
391    /// Store a result for `(tool_name, args)`.
392    fn set(&self, tool_name: &str, args: &serde_json::Value, result: serde_json::Value);
393}
394
395// ── ToolRegistry ──────────────────────────────────────────────────────────────
396
397/// Registry of available tools for the agent loop.
398pub struct ToolRegistry {
399    tools: HashMap<String, ToolSpec>,
400    /// Optional tool result cache.
401    cache: Option<Arc<dyn ToolCache>>,
402}
403
404impl std::fmt::Debug for ToolRegistry {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        f.debug_struct("ToolRegistry")
407            .field("tools", &self.tools.keys().collect::<Vec<_>>())
408            .field("has_cache", &self.cache.is_some())
409            .finish()
410    }
411}
412
413impl Default for ToolRegistry {
414    fn default() -> Self {
415        Self::new()
416    }
417}
418
419impl ToolRegistry {
420    /// Create a new empty registry.
421    pub fn new() -> Self {
422        Self {
423            tools: HashMap::new(),
424            cache: None,
425        }
426    }
427
428    /// Attach a tool result cache.
429    pub fn with_cache(mut self, cache: Arc<dyn ToolCache>) -> Self {
430        self.cache = Some(cache);
431        self
432    }
433
434    /// Register a tool. Overwrites any existing tool with the same name.
435    pub fn register(&mut self, spec: ToolSpec) {
436        self.tools.insert(spec.name.clone(), spec);
437    }
438
439    /// Register multiple tools at once.
440    ///
441    /// Equivalent to calling [`register`] for each spec in order.  Duplicate
442    /// names overwrite earlier entries.
443    ///
444    /// [`register`]: ToolRegistry::register
445    pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
446        for spec in specs {
447            self.register(spec);
448        }
449    }
450
451    /// Fluent builder: register a tool and return `self`.
452    ///
453    /// Allows chaining multiple registrations:
454    /// ```no_run
455    /// use llm_agent_runtime::agent::{ToolRegistry, ToolSpec};
456    ///
457    /// let registry = ToolRegistry::new()
458    ///     .with_tool(ToolSpec::new("search", "Search", |args| args.clone()))
459    ///     .with_tool(ToolSpec::new("calc", "Calculate", |args| args.clone()));
460    /// ```
461    pub fn with_tool(mut self, spec: ToolSpec) -> Self {
462        self.register(spec);
463        self
464    }
465
466    /// Call a tool by name.
467    ///
468    /// # Errors
469    /// - `AgentRuntimeError::AgentLoop` — tool not found, required field missing, or
470    ///   custom validator rejected the arguments
471    /// - `AgentRuntimeError::CircuitOpen` — the tool's circuit breaker is open
472    ///   (only possible when the `orchestrator` feature is enabled)
473    #[tracing::instrument(skip_all, fields(tool_name = %name))]
474    pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
475        let spec = self.tools.get(name).ok_or_else(|| {
476            let mut suggestion = String::new();
477            let names = self.tool_names();
478            if !names.is_empty() {
479                if let Some((closest, dist)) = names
480                    .iter()
481                    .map(|n| (n, levenshtein(name, n)))
482                    .min_by_key(|(_, d)| *d)
483                {
484                    if dist <= 3 {
485                        suggestion = format!(" (did you mean '{closest}'?)");
486                    }
487                }
488            }
489            AgentRuntimeError::AgentLoop(format!("tool '{name}' not found{suggestion}"))
490        })?;
491
492        // Item 3 — required field validation
493        if !spec.required_fields.is_empty() {
494            if let Some(obj) = args.as_object() {
495                for field in &spec.required_fields {
496                    if !obj.contains_key(field) {
497                        return Err(AgentRuntimeError::AgentLoop(format!(
498                            "tool '{}' missing required field '{}'",
499                            name, field
500                        )));
501                    }
502                }
503            } else {
504                return Err(AgentRuntimeError::AgentLoop(format!(
505                    "tool '{}' requires JSON object args, got {}",
506                    name, args
507                )));
508            }
509        }
510
511        // Custom validators.
512        for validator in &spec.validators {
513            validator.validate(&args)?;
514        }
515
516        // Per-tool circuit breaker check.
517        #[cfg(feature = "orchestrator")]
518        if let Some(ref cb) = spec.circuit_breaker {
519            use crate::orchestrator::CircuitState;
520            if let Ok(CircuitState::Open { .. }) = cb.state() {
521                return Err(AgentRuntimeError::CircuitOpen {
522                    service: format!("tool:{}", name),
523                });
524            }
525        }
526
527        // Check cache before invoking handler.
528        if let Some(ref cache) = self.cache {
529            if let Some(cached) = cache.get(name, &args) {
530                return Ok(cached);
531            }
532        }
533
534        let result = spec.call(args.clone()).await;
535
536        // Store result in cache.
537        if let Some(ref cache) = self.cache {
538            cache.set(name, &args, result.clone());
539        }
540
541        Ok(result)
542    }
543
544    /// Return the list of registered tool names.
545    pub fn tool_names(&self) -> Vec<&str> {
546        self.tools.keys().map(|s| s.as_str()).collect()
547    }
548}
549
550// ── ReActLoop ─────────────────────────────────────────────────────────────────
551
552/// Parses a ReAct response string into a `ReActStep`.
553///
554/// Case-insensitive; tolerates spaces around the colon.
555/// e.g. `Thought :`, `thought:`, `THOUGHT :` all match.
556pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
557    let mut thought = String::new();
558    let mut action = String::new();
559
560    for line in text.lines() {
561        let trimmed = line.trim();
562        let lower = trimmed.to_ascii_lowercase();
563        if lower.starts_with("thought") {
564            if let Some(colon_pos) = trimmed.find(':') {
565                thought = trimmed[colon_pos + 1..].trim().to_owned();
566            }
567        } else if lower.starts_with("action") {
568            if let Some(colon_pos) = trimmed.find(':') {
569                action = trimmed[colon_pos + 1..].trim().to_owned();
570            }
571        }
572    }
573
574    if thought.is_empty() && action.is_empty() {
575        return Err(AgentRuntimeError::AgentLoop(
576            "could not parse ReAct step from response".into(),
577        ));
578    }
579
580    Ok(ReActStep {
581        thought,
582        action,
583        observation: String::new(),
584        step_duration_ms: 0,
585    })
586}
587
588/// The ReAct agent loop.
589pub struct ReActLoop {
590    config: AgentConfig,
591    registry: ToolRegistry,
592    /// Optional metrics sink; increments `total_tool_calls` / `failed_tool_calls`.
593    metrics: Option<Arc<RuntimeMetrics>>,
594    /// Optional persistence backend for per-step checkpointing during the loop.
595    #[cfg(feature = "persistence")]
596    checkpoint_backend: Option<(Arc<dyn crate::persistence::PersistenceBackend>, String)>,
597    /// Optional observer for agent loop events.
598    observer: Option<Arc<dyn Observer>>,
599    /// Optional action hook called before each tool dispatch.
600    action_hook: Option<ActionHook>,
601}
602
603impl std::fmt::Debug for ReActLoop {
604    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605        let mut s = f.debug_struct("ReActLoop");
606        s.field("config", &self.config)
607            .field("registry", &self.registry)
608            .field("has_metrics", &self.metrics.is_some())
609            .field("has_observer", &self.observer.is_some())
610            .field("has_action_hook", &self.action_hook.is_some());
611        #[cfg(feature = "persistence")]
612        s.field("has_checkpoint_backend", &self.checkpoint_backend.is_some());
613        s.finish()
614    }
615}
616
617impl ReActLoop {
618    /// Create a new `ReActLoop` with the given configuration and an empty tool registry.
619    pub fn new(config: AgentConfig) -> Self {
620        Self {
621            config,
622            registry: ToolRegistry::new(),
623            metrics: None,
624            #[cfg(feature = "persistence")]
625            checkpoint_backend: None,
626            observer: None,
627            action_hook: None,
628        }
629    }
630
631    /// Attach an observer for agent loop events.
632    pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
633        self.observer = Some(observer);
634        self
635    }
636
637    /// Attach an action hook called before each tool dispatch.
638    pub fn with_action_hook(mut self, hook: ActionHook) -> Self {
639        self.action_hook = Some(hook);
640        self
641    }
642
643    /// Attach a shared `RuntimeMetrics` instance.
644    ///
645    /// When set, the loop increments `total_tool_calls` on every tool dispatch
646    /// and `failed_tool_calls` whenever a tool returns an error observation.
647    pub fn with_metrics(mut self, metrics: Arc<RuntimeMetrics>) -> Self {
648        self.metrics = Some(metrics);
649        self
650    }
651
652    /// Attach a persistence backend for per-step loop checkpointing.
653    ///
654    /// After each completed step the current `Vec<ReActStep>` is serialised and
655    /// saved under the key `"loop:<session_id>:step:<n>"`.  Checkpoint errors
656    /// are logged but never abort the loop.
657    #[cfg(feature = "persistence")]
658    pub fn with_step_checkpoint(
659        mut self,
660        backend: Arc<dyn crate::persistence::PersistenceBackend>,
661        session_id: impl Into<String>,
662    ) -> Self {
663        self.checkpoint_backend = Some((backend, session_id.into()));
664        self
665    }
666
667    /// Register a tool that the agent loop can invoke.
668    pub fn register_tool(&mut self, spec: ToolSpec) {
669        self.registry.register(spec);
670    }
671
672    /// Register multiple tools at once.
673    ///
674    /// Equivalent to calling [`register_tool`] for each spec.
675    ///
676    /// [`register_tool`]: ReActLoop::register_tool
677    pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
678        for spec in specs {
679            self.registry.register(spec);
680        }
681    }
682
683    /// Emit a blocked-action observation string.
684    fn blocked_observation() -> String {
685        serde_json::json!({
686            "ok": false,
687            "error": "action blocked by reviewer",
688            "kind": "blocked"
689        })
690        .to_string()
691    }
692
693    /// Build the error observation JSON for a failed tool call.
694    fn error_observation(tool_name: &str, e: &AgentRuntimeError) -> String {
695        let _ = tool_name;
696        let kind = match e {
697            AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => "not_found",
698            #[cfg(feature = "orchestrator")]
699            AgentRuntimeError::CircuitOpen { .. } => "transient",
700            _ => "permanent",
701        };
702        serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind }).to_string()
703    }
704
705    /// Execute the ReAct loop for the given prompt.
706    ///
707    /// # Arguments
708    /// * `prompt` — user input passed as the initial context
709    /// * `infer`  — async inference function: receives context string, returns response string
710    ///
711    /// # Errors
712    /// - `AgentRuntimeError::AgentLoop("loop timeout …")` — if `loop_timeout` is configured
713    ///   and the loop runs past the deadline
714    /// - `AgentRuntimeError::AgentLoop("max iterations … reached")` — if the loop exhausts
715    ///   `max_iterations` without emitting `FINAL_ANSWER`
716    /// - `AgentRuntimeError::AgentLoop("could not parse …")` — if the model response cannot
717    ///   be parsed into a `ReActStep`
718    #[tracing::instrument(skip(infer))]
719    pub async fn run<F, Fut>(
720        &self,
721        prompt: &str,
722        mut infer: F,
723    ) -> Result<Vec<ReActStep>, AgentRuntimeError>
724    where
725        F: FnMut(String) -> Fut,
726        Fut: Future<Output = String>,
727    {
728        let mut steps: Vec<ReActStep> = Vec::new();
729        let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
730
731        // Pre-compute optional deadline once so that each iteration is O(1).
732        let deadline = self
733            .config
734            .loop_timeout
735            .map(|d| std::time::Instant::now() + d);
736
737        // Observer: on_loop_start
738        if let Some(ref obs) = self.observer {
739            obs.on_loop_start(prompt);
740        }
741
742        for iteration in 0..self.config.max_iterations {
743            // Wall-clock timeout check.
744            if let Some(dl) = deadline {
745                if std::time::Instant::now() >= dl {
746                    let ms = self
747                        .config
748                        .loop_timeout
749                        .map(|d| d.as_millis())
750                        .unwrap_or(0);
751                    if let Some(ref obs) = self.observer {
752                        obs.on_loop_end(steps.len());
753                    }
754                    return Err(AgentRuntimeError::AgentLoop(format!(
755                        "loop timeout after {ms} ms"
756                    )));
757                }
758            }
759
760            let step_start = std::time::Instant::now();
761            let response = infer(context.clone()).await;
762            let mut step = parse_react_step(&response)?;
763
764            tracing::debug!(
765                step = iteration,
766                thought = %step.thought,
767                action = %step.action,
768                "ReAct iteration"
769            );
770
771            if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
772                step.observation = step.action.clone();
773                step.step_duration_ms = step_start.elapsed().as_millis() as u64;
774                if let Some(ref m) = self.metrics {
775                    m.record_step_latency(step.step_duration_ms);
776                }
777                if let Some(ref obs) = self.observer {
778                    obs.on_step(iteration, &step);
779                }
780                steps.push(step);
781                tracing::info!(step = iteration, "FINAL_ANSWER reached");
782                if let Some(ref obs) = self.observer {
783                    obs.on_loop_end(steps.len());
784                }
785                return Ok(steps);
786            }
787
788            // Item 3 — propagate parse errors rather than silently falling back.
789            let (tool_name, args) = parse_tool_call(&step.action)?;
790
791            tracing::debug!(
792                step = iteration,
793                tool_name = %tool_name,
794                "dispatching tool call"
795            );
796
797            // Action hook check.
798            if let Some(ref hook) = self.action_hook {
799                if !hook(tool_name.clone(), args.clone()).await {
800                    if let Some(ref obs) = self.observer {
801                        obs.on_action_blocked(&tool_name, &args);
802                    }
803                    if let Some(ref m) = self.metrics {
804                        m.record_tool_call(&tool_name);
805                        m.record_tool_failure(&tool_name);
806                    }
807                    step.observation = Self::blocked_observation();
808                    step.step_duration_ms = step_start.elapsed().as_millis() as u64;
809                    if let Some(ref m) = self.metrics {
810                        m.record_step_latency(step.step_duration_ms);
811                    }
812                    context.push_str(&format!(
813                        "\nThought: {}\nAction: {}\nObservation: {}\n",
814                        step.thought, step.action, step.observation
815                    ));
816                    if let Some(ref obs) = self.observer {
817                        obs.on_step(iteration, &step);
818                    }
819                    steps.push(step);
820                    continue;
821                }
822            }
823
824            // Observer: on_tool_call
825            if let Some(ref obs) = self.observer {
826                obs.on_tool_call(&tool_name, &args);
827            }
828
829            // Count every tool dispatch (global + per-tool).
830            if let Some(ref m) = self.metrics {
831                m.record_tool_call(&tool_name);
832            }
833
834            // Structured error categorization in observation.
835            let observation = match self.registry.call(&tool_name, args).await {
836                Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
837                Err(e) => {
838                    // Count failed tool calls (global + per-tool).
839                    if let Some(ref m) = self.metrics {
840                        m.record_tool_failure(&tool_name);
841                    }
842                    Self::error_observation(&tool_name, &e)
843                }
844            };
845
846            step.observation = observation.clone();
847            step.step_duration_ms = step_start.elapsed().as_millis() as u64;
848            if let Some(ref m) = self.metrics {
849                m.record_step_latency(step.step_duration_ms);
850            }
851            context.push_str(&format!(
852                "\nThought: {}\nAction: {}\nObservation: {}\n",
853                step.thought, step.action, observation
854            ));
855            if let Some(ref obs) = self.observer {
856                obs.on_step(iteration, &step);
857            }
858            steps.push(step);
859
860            // Item 11 — per-step loop checkpoint (behind feature flag).
861            #[cfg(feature = "persistence")]
862            if let Some((ref backend, ref session_id)) = self.checkpoint_backend {
863                let step_idx = steps.len();
864                let key = format!("loop:{session_id}:step:{step_idx}");
865                match serde_json::to_vec(&steps) {
866                    Ok(bytes) => {
867                        if let Err(e) = backend.save(&key, &bytes).await {
868                            tracing::warn!(
869                                key = %key,
870                                error = %e,
871                                "loop step checkpoint save failed"
872                            );
873                        }
874                    }
875                    Err(e) => {
876                        tracing::warn!(
877                            step = step_idx,
878                            error = %e,
879                            "loop step checkpoint serialisation failed"
880                        );
881                    }
882                }
883            }
884        }
885
886        let err = AgentRuntimeError::AgentLoop(format!(
887            "max iterations ({}) reached without final answer",
888            self.config.max_iterations
889        ));
890        tracing::warn!(
891            max_iterations = self.config.max_iterations,
892            "ReAct loop exhausted max iterations without FINAL_ANSWER"
893        );
894        if let Some(ref obs) = self.observer {
895            obs.on_loop_end(steps.len());
896        }
897        Err(err)
898    }
899
900    /// Execute the ReAct loop using a streaming inference function.
901    ///
902    /// Identical to [`run`] except the inference closure returns a
903    /// `tokio::sync::mpsc::Receiver` that streams token chunks.  All chunks
904    /// are collected into a single `String` before each iteration's response
905    /// is parsed.  Stream errors result in an empty partial response (the
906    /// erroring chunk is silently dropped and a warning is logged).
907    ///
908    /// [`run`]: ReActLoop::run
909    #[tracing::instrument(skip(infer_stream))]
910    pub async fn run_streaming<F, Fut>(
911        &self,
912        prompt: &str,
913        mut infer_stream: F,
914    ) -> Result<Vec<ReActStep>, AgentRuntimeError>
915    where
916        F: FnMut(String) -> Fut,
917        Fut: Future<
918            Output = tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>,
919        >,
920    {
921        self.run(prompt, move |ctx| {
922            let rx_fut = infer_stream(ctx);
923            async move {
924                let mut rx = rx_fut.await;
925                let mut out = String::new();
926                while let Some(chunk) = rx.recv().await {
927                    match chunk {
928                        Ok(s) => out.push_str(&s),
929                        Err(e) => {
930                            tracing::warn!(error = %e, "streaming chunk error; skipping");
931                        }
932                    }
933                }
934                out
935            }
936        })
937        .await
938    }
939}
940
941/// Declarative argument validator for a `ToolSpec`.
942///
943/// Implement this trait to enforce custom argument constraints (type ranges,
944/// string patterns, etc.) before the handler is invoked.
945///
946/// Validators run **after** `required_fields` checks and **before** the handler.
947/// The first failing validator short-circuits execution.
948///
949/// # Basic Example
950/// ```no_run
951/// use llm_agent_runtime::agent::ToolValidator;
952/// use llm_agent_runtime::AgentRuntimeError;
953/// use serde_json::Value;
954///
955/// struct NonEmptyQuery;
956/// impl ToolValidator for NonEmptyQuery {
957///     fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError> {
958///         let q = args.get("q").and_then(|v| v.as_str()).unwrap_or("");
959///         if q.is_empty() {
960///             return Err(AgentRuntimeError::AgentLoop(
961///                 "tool 'search': q must not be empty".into(),
962///             ));
963///         }
964///         Ok(())
965///     }
966/// }
967/// ```
968///
969/// # Advanced Example — Parameterised validator
970/// ```no_run
971/// use llm_agent_runtime::agent::{ToolSpec, ToolValidator};
972/// use llm_agent_runtime::AgentRuntimeError;
973/// use serde_json::Value;
974///
975/// /// Validates that a named integer field is within [min, max].
976/// struct RangeValidator { field: &'static str, min: i64, max: i64 }
977///
978/// impl ToolValidator for RangeValidator {
979///     fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError> {
980///         let n = args
981///             .get(self.field)
982///             .and_then(|v| v.as_i64())
983///             .ok_or_else(|| {
984///                 AgentRuntimeError::AgentLoop(format!(
985///                     "field '{}' must be an integer", self.field
986///                 ))
987///             })?;
988///         if n < self.min || n > self.max {
989///             return Err(AgentRuntimeError::AgentLoop(format!(
990///                 "field '{}' = {n} is outside [{}, {}]",
991///                 self.field, self.min, self.max,
992///             )));
993///         }
994///         Ok(())
995///     }
996/// }
997///
998/// // Attach to a tool spec:
999/// let spec = ToolSpec::new("roll_dice", "Roll n dice", |args| {
1000///     serde_json::json!({ "result": args })
1001/// })
1002/// .with_validators(vec![
1003///     Box::new(RangeValidator { field: "n", min: 1, max: 100 }),
1004/// ]);
1005/// ```
1006pub trait ToolValidator: Send + Sync {
1007    /// Validate `args` before the tool handler is invoked.
1008    ///
1009    /// Return `Ok(())` if the arguments are valid, or
1010    /// `Err(AgentRuntimeError::AgentLoop(...))` with a human-readable message.
1011    fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError>;
1012}
1013
1014/// Compute the Levenshtein edit distance between two strings.
1015///
1016/// Used to suggest close matches when a tool name is not found.
1017fn levenshtein(a: &str, b: &str) -> usize {
1018    let a: Vec<char> = a.chars().collect();
1019    let b: Vec<char> = b.chars().collect();
1020    let (m, n) = (a.len(), b.len());
1021    let mut dp = vec![vec![0usize; n + 1]; m + 1];
1022    for i in 0..=m {
1023        dp[i][0] = i;
1024    }
1025    for j in 0..=n {
1026        dp[0][j] = j;
1027    }
1028    for i in 1..=m {
1029        for j in 1..=n {
1030            dp[i][j] = if a[i - 1] == b[j - 1] {
1031                dp[i - 1][j - 1]
1032            } else {
1033                1 + dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1])
1034            };
1035        }
1036    }
1037    dp[m][n]
1038}
1039
1040/// Split `"tool_name {json}"` into `(tool_name, Value)`.
1041///
1042/// Returns `Err(AgentRuntimeError::AgentLoop)` when:
1043/// - the tool name is empty
1044/// - the argument portion is non-empty but not valid JSON
1045fn parse_tool_call(action: &str) -> Result<(String, Value), AgentRuntimeError> {
1046    let mut parts = action.splitn(2, ' ');
1047    let name = parts.next().unwrap_or("").to_owned();
1048    if name.is_empty() {
1049        return Err(AgentRuntimeError::AgentLoop(
1050            "tool call has an empty tool name".into(),
1051        ));
1052    }
1053    let args_str = parts.next().unwrap_or("{}");
1054    let args: Value = serde_json::from_str(args_str).map_err(|e| {
1055        AgentRuntimeError::AgentLoop(format!(
1056            "invalid JSON args for tool call '{name}': {e} (raw: {args_str})"
1057        ))
1058    })?;
1059    Ok((name, args))
1060}
1061
1062/// Agent-specific errors, mirrors `wasm-agent::AgentError`.
1063///
1064/// Converts to `AgentRuntimeError::AgentLoop` via the `From` implementation.
1065#[derive(Debug, thiserror::Error)]
1066pub enum AgentError {
1067    /// The referenced tool name does not exist in the registry.
1068    #[error("Tool '{0}' not found")]
1069    ToolNotFound(String),
1070    /// The ReAct loop consumed all iterations without emitting `FINAL_ANSWER`.
1071    #[error("Max iterations exceeded: {0}")]
1072    MaxIterations(usize),
1073    /// The model response could not be parsed into a `ReActStep`.
1074    #[error("Parse error: {0}")]
1075    ParseError(String),
1076}
1077
1078impl From<AgentError> for AgentRuntimeError {
1079    fn from(e: AgentError) -> Self {
1080        AgentRuntimeError::AgentLoop(e.to_string())
1081    }
1082}
1083
1084// ── Observer ──────────────────────────────────────────────────────────────────
1085
1086/// Hook trait for observing agent loop events.
1087///
1088/// All methods have no-op default implementations so you only override
1089/// what you care about.
1090pub trait Observer: Send + Sync {
1091    /// Called when a ReAct step completes.
1092    fn on_step(&self, step_index: usize, step: &ReActStep) {
1093        let _ = (step_index, step);
1094    }
1095    /// Called when a tool is about to be dispatched.
1096    fn on_tool_call(&self, tool_name: &str, args: &serde_json::Value) {
1097        let _ = (tool_name, args);
1098    }
1099    /// Called when an action hook blocks a tool call before dispatch.
1100    ///
1101    /// `tool_name` is the name of the blocked tool; `args` are the arguments
1102    /// that were passed to the hook.  This is called *instead of* `on_tool_call`.
1103    fn on_action_blocked(&self, tool_name: &str, args: &serde_json::Value) {
1104        let _ = (tool_name, args);
1105    }
1106    /// Called when the loop starts.
1107    fn on_loop_start(&self, prompt: &str) {
1108        let _ = prompt;
1109    }
1110    /// Called when the loop finishes (success or error).
1111    fn on_loop_end(&self, step_count: usize) {
1112        let _ = step_count;
1113    }
1114}
1115
1116// ── Action ────────────────────────────────────────────────────────────────────
1117
1118/// A parsed action from a ReAct step.
1119#[derive(Debug, Clone, PartialEq)]
1120pub enum Action {
1121    /// The agent has produced a final answer.
1122    FinalAnswer(String),
1123    /// A tool call with a name and JSON arguments.
1124    ToolCall {
1125        /// The tool name.
1126        name: String,
1127        /// The parsed JSON arguments.
1128        args: serde_json::Value,
1129    },
1130}
1131
1132impl Action {
1133    /// Parse an action string into an `Action`.
1134    ///
1135    /// Returns `Action::FinalAnswer` if the string starts with `FINAL_ANSWER` (case-insensitive).
1136    /// Otherwise parses as a tool call via `parse_tool_call`.
1137    pub fn parse(s: &str) -> Result<Action, AgentRuntimeError> {
1138        if s.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER") {
1139            let answer = s.trim()["FINAL_ANSWER".len()..].trim().to_owned();
1140            return Ok(Action::FinalAnswer(answer));
1141        }
1142        let (name, args) = parse_tool_call(s)?;
1143        Ok(Action::ToolCall { name, args })
1144    }
1145}
1146
1147/// Async hook called before each tool action. Return `true` to proceed, `false` to block.
1148///
1149/// When blocked, the loop inserts a synthetic observation
1150/// `{"ok": false, "error": "action blocked by reviewer", "kind": "blocked"}`
1151/// and continues to the next iteration without invoking the tool.
1152///
1153/// ## Observer interaction
1154///
1155/// When a hook **allows** an action (`true`), the normal observer sequence fires:
1156/// 1. `Observer::on_tool_call` — called before the tool is dispatched
1157/// 2. `Observer::on_step` — called after the observation is recorded
1158///
1159/// When a hook **blocks** an action (`false`), the sequence is:
1160/// 1. `Observer::on_action_blocked` — called instead of `on_tool_call`
1161/// 2. `Observer::on_step` — called after the synthetic blocked observation is recorded
1162///
1163/// Use [`make_action_hook`] to construct a hook from a plain `async fn` without
1164/// writing the `Arc<dyn Fn…>` boilerplate by hand.
1165pub type ActionHook = Arc<dyn Fn(String, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync>;
1166
1167/// Create an [`ActionHook`] from a plain `async fn` or closure.
1168///
1169/// This helper eliminates the need to manually write
1170/// `Arc::new(|name, args| Box::pin(async move { … }))`.
1171///
1172/// # Example
1173/// ```no_run
1174/// use llm_agent_runtime::agent::make_action_hook;
1175///
1176/// let hook = make_action_hook(|tool_name: String, _args| async move {
1177///     // Block any tool called "dangerous"
1178///     tool_name != "dangerous"
1179/// });
1180/// ```
1181pub fn make_action_hook<F, Fut>(f: F) -> ActionHook
1182where
1183    F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
1184    Fut: std::future::Future<Output = bool> + Send + 'static,
1185{
1186    Arc::new(move |name, args| Box::pin(f(name, args)))
1187}
1188
1189// ── Tests ─────────────────────────────────────────────────────────────────────
1190
1191#[cfg(test)]
1192mod tests {
1193    use super::*;
1194
1195    #[tokio::test]
1196    async fn test_final_answer_on_first_step() {
1197        let config = AgentConfig::new(5, "test-model");
1198        let loop_ = ReActLoop::new(config);
1199
1200        let steps = loop_
1201            .run("Say hello", |_ctx| async {
1202                "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
1203            })
1204            .await
1205            .unwrap();
1206
1207        assert_eq!(steps.len(), 1);
1208        assert!(steps[0]
1209            .action
1210            .to_ascii_uppercase()
1211            .starts_with("FINAL_ANSWER"));
1212    }
1213
1214    #[tokio::test]
1215    async fn test_tool_call_then_final_answer() {
1216        let config = AgentConfig::new(5, "test-model");
1217        let mut loop_ = ReActLoop::new(config);
1218
1219        loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
1220            serde_json::json!("hello!")
1221        }));
1222
1223        let mut call_count = 0;
1224        let steps = loop_
1225            .run("Say hello", |_ctx| {
1226                call_count += 1;
1227                let count = call_count;
1228                async move {
1229                    if count == 1 {
1230                        "Thought: I will greet\nAction: greet {}".to_string()
1231                    } else {
1232                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
1233                    }
1234                }
1235            })
1236            .await
1237            .unwrap();
1238
1239        assert_eq!(steps.len(), 2);
1240        assert_eq!(steps[0].action, "greet {}");
1241        assert!(steps[1]
1242            .action
1243            .to_ascii_uppercase()
1244            .starts_with("FINAL_ANSWER"));
1245    }
1246
1247    #[tokio::test]
1248    async fn test_max_iterations_exceeded() {
1249        let config = AgentConfig::new(2, "test-model");
1250        let loop_ = ReActLoop::new(config);
1251
1252        let result = loop_
1253            .run("loop forever", |_ctx| async {
1254                "Thought: thinking\nAction: noop {}".to_string()
1255            })
1256            .await;
1257
1258        assert!(result.is_err());
1259        let err = result.unwrap_err().to_string();
1260        assert!(err.contains("max iterations"));
1261    }
1262
1263    #[tokio::test]
1264    async fn test_parse_react_step_valid() {
1265        let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
1266        let step = parse_react_step(text).unwrap();
1267        assert_eq!(step.thought, "I should check");
1268        assert_eq!(step.action, "lookup {\"key\":\"val\"}");
1269    }
1270
1271    #[tokio::test]
1272    async fn test_parse_react_step_empty_fails() {
1273        let result = parse_react_step("no prefix lines here");
1274        assert!(result.is_err());
1275    }
1276
1277    #[tokio::test]
1278    async fn test_tool_not_found_returns_error_observation() {
1279        let config = AgentConfig::new(3, "test-model");
1280        let loop_ = ReActLoop::new(config);
1281
1282        let mut call_count = 0;
1283        let steps = loop_
1284            .run("test", |_ctx| {
1285                call_count += 1;
1286                let count = call_count;
1287                async move {
1288                    if count == 1 {
1289                        "Thought: try missing tool\nAction: missing_tool {}".to_string()
1290                    } else {
1291                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
1292                    }
1293                }
1294            })
1295            .await
1296            .unwrap();
1297
1298        assert_eq!(steps.len(), 2);
1299        assert!(steps[0].observation.contains("\"ok\":false"));
1300    }
1301
1302    #[tokio::test]
1303    async fn test_new_async_tool_spec() {
1304        let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
1305            Box::pin(async move { serde_json::json!({"echo": args}) })
1306        });
1307
1308        let result = spec.call(serde_json::json!({"input": "test"})).await;
1309        assert!(result.get("echo").is_some());
1310    }
1311
1312    // Item 1 — Robust ReAct Parser tests
1313
1314    #[tokio::test]
1315    async fn test_parse_react_step_case_insensitive() {
1316        let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
1317        let step = parse_react_step(text).unwrap();
1318        assert_eq!(step.thought, "done");
1319        assert_eq!(step.action, "FINAL_ANSWER");
1320    }
1321
1322    #[tokio::test]
1323    async fn test_parse_react_step_space_before_colon() {
1324        let text = "Thought : done\nAction : go";
1325        let step = parse_react_step(text).unwrap();
1326        assert_eq!(step.thought, "done");
1327        assert_eq!(step.action, "go");
1328    }
1329
1330    // Item 3 — Tool required field validation tests
1331
1332    #[tokio::test]
1333    async fn test_tool_required_fields_missing_returns_error() {
1334        let config = AgentConfig::new(3, "test-model");
1335        let mut loop_ = ReActLoop::new(config);
1336
1337        loop_.register_tool(
1338            ToolSpec::new(
1339                "search",
1340                "Searches for something",
1341                |args| serde_json::json!({ "result": args }),
1342            )
1343            .with_required_fields(vec!["q".to_string()]),
1344        );
1345
1346        let mut call_count = 0;
1347        let steps = loop_
1348            .run("test", |_ctx| {
1349                call_count += 1;
1350                let count = call_count;
1351                async move {
1352                    if count == 1 {
1353                        // Call with empty object — missing "q"
1354                        "Thought: searching\nAction: search {}".to_string()
1355                    } else {
1356                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
1357                    }
1358                }
1359            })
1360            .await
1361            .unwrap();
1362
1363        assert_eq!(steps.len(), 2);
1364        assert!(
1365            steps[0].observation.contains("missing required field"),
1366            "observation was: {}",
1367            steps[0].observation
1368        );
1369    }
1370
1371    // Item 9 — Structured error observation tests
1372
1373    #[tokio::test]
1374    async fn test_tool_error_observation_includes_kind() {
1375        let config = AgentConfig::new(3, "test-model");
1376        let loop_ = ReActLoop::new(config);
1377
1378        let mut call_count = 0;
1379        let steps = loop_
1380            .run("test", |_ctx| {
1381                call_count += 1;
1382                let count = call_count;
1383                async move {
1384                    if count == 1 {
1385                        "Thought: try missing\nAction: nonexistent_tool {}".to_string()
1386                    } else {
1387                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
1388                    }
1389                }
1390            })
1391            .await
1392            .unwrap();
1393
1394        assert_eq!(steps.len(), 2);
1395        let obs = &steps[0].observation;
1396        assert!(obs.contains("\"ok\":false"), "observation: {obs}");
1397        assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
1398    }
1399
1400    // ── step_duration_ms ──────────────────────────────────────────────────────
1401
1402    #[tokio::test]
1403    async fn test_step_duration_ms_is_set() {
1404        let config = AgentConfig::new(5, "test-model");
1405        let loop_ = ReActLoop::new(config);
1406
1407        let steps = loop_
1408            .run("time it", |_ctx| async {
1409                "Thought: done\nAction: FINAL_ANSWER ok".to_string()
1410            })
1411            .await
1412            .unwrap();
1413
1414        // step_duration_ms may be 0 on very fast systems but must be a valid u64.
1415        let _ = steps[0].step_duration_ms; // just verify the field exists and is accessible
1416    }
1417
1418    // ── ToolValidator ─────────────────────────────────────────────────────────
1419
1420    struct RequirePositiveN;
1421    impl ToolValidator for RequirePositiveN {
1422        fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError> {
1423            let n = args.get("n").and_then(|v| v.as_i64()).unwrap_or(0);
1424            if n <= 0 {
1425                return Err(AgentRuntimeError::AgentLoop(
1426                    "n must be a positive integer".into(),
1427                ));
1428            }
1429            Ok(())
1430        }
1431    }
1432
1433    #[tokio::test]
1434    async fn test_tool_validator_blocks_invalid_args() {
1435        let mut registry = ToolRegistry::new();
1436        registry.register(
1437            ToolSpec::new("calc", "compute", |args| serde_json::json!({"n": args}))
1438                .with_validators(vec![Box::new(RequirePositiveN)]),
1439        );
1440
1441        // n = -1 should be rejected by the validator.
1442        let result = registry
1443            .call("calc", serde_json::json!({"n": -1}))
1444            .await;
1445        assert!(result.is_err(), "validator should reject n=-1");
1446        assert!(result.unwrap_err().to_string().contains("positive integer"));
1447    }
1448
1449    #[tokio::test]
1450    async fn test_tool_validator_passes_valid_args() {
1451        let mut registry = ToolRegistry::new();
1452        registry.register(
1453            ToolSpec::new("calc", "compute", |_| serde_json::json!(42))
1454                .with_validators(vec![Box::new(RequirePositiveN)]),
1455        );
1456
1457        let result = registry
1458            .call("calc", serde_json::json!({"n": 5}))
1459            .await;
1460        assert!(result.is_ok(), "validator should accept n=5");
1461    }
1462
1463    // ── Empty tool name ───────────────────────────────────────────────────────
1464
1465    #[tokio::test]
1466    async fn test_empty_tool_name_is_rejected() {
1467        // parse_tool_call("") → error because name is empty
1468        let result = parse_tool_call("");
1469        assert!(result.is_err());
1470        assert!(
1471            result.unwrap_err().to_string().contains("empty tool name"),
1472            "expected 'empty tool name' error"
1473        );
1474    }
1475
1476    // ── Bulk register_tools ───────────────────────────────────────────────────
1477
1478    #[tokio::test]
1479    async fn test_register_tools_bulk() {
1480        let mut registry = ToolRegistry::new();
1481        registry.register_tools(vec![
1482            ToolSpec::new("tool_a", "A", |_| serde_json::json!("a")),
1483            ToolSpec::new("tool_b", "B", |_| serde_json::json!("b")),
1484        ]);
1485        assert!(registry.call("tool_a", serde_json::json!({})).await.is_ok());
1486        assert!(registry.call("tool_b", serde_json::json!({})).await.is_ok());
1487    }
1488
1489    // ── run_streaming parity ──────────────────────────────────────────────────
1490
1491    #[tokio::test]
1492    async fn test_run_streaming_parity_with_run() {
1493        use tokio::sync::mpsc;
1494
1495        let config = AgentConfig::new(5, "test-model");
1496        let loop_ = ReActLoop::new(config);
1497
1498        let steps = loop_
1499            .run_streaming("Say hello", |_ctx| async {
1500                let (tx, rx) = mpsc::channel(4);
1501                // Send the response in chunks
1502                tokio::spawn(async move {
1503                    tx.send(Ok("Thought: done\n".to_string())).await.ok();
1504                    tx.send(Ok("Action: FINAL_ANSWER hi".to_string())).await.ok();
1505                });
1506                rx
1507            })
1508            .await
1509            .unwrap();
1510
1511        assert_eq!(steps.len(), 1);
1512        assert!(steps[0]
1513            .action
1514            .to_ascii_uppercase()
1515            .starts_with("FINAL_ANSWER"));
1516    }
1517
1518    #[tokio::test]
1519    async fn test_run_streaming_error_chunk_is_skipped() {
1520        use tokio::sync::mpsc;
1521        use crate::error::AgentRuntimeError;
1522
1523        let config = AgentConfig::new(5, "test-model");
1524        let loop_ = ReActLoop::new(config);
1525
1526        // Even with an error chunk, the loop recovers and returns the valid parts.
1527        let steps = loop_
1528            .run_streaming("test", |_ctx| async {
1529                let (tx, rx) = mpsc::channel(4);
1530                tokio::spawn(async move {
1531                    tx.send(Err(AgentRuntimeError::Provider("stream error".into())))
1532                        .await
1533                        .ok();
1534                    tx.send(Ok("Thought: recovered\nAction: FINAL_ANSWER ok".to_string()))
1535                        .await
1536                        .ok();
1537                });
1538                rx
1539            })
1540            .await
1541            .unwrap();
1542
1543        assert_eq!(steps.len(), 1);
1544    }
1545
1546    // ── Circuit breaker test (only compiled when feature is active) ────────────
1547
1548    #[cfg(feature = "orchestrator")]
1549    #[tokio::test]
1550    async fn test_tool_with_circuit_breaker_passes_when_closed() {
1551        use std::sync::Arc;
1552
1553        let cb = Arc::new(
1554            crate::orchestrator::CircuitBreaker::new(
1555                "echo-tool",
1556                5,
1557                std::time::Duration::from_secs(30),
1558            )
1559            .unwrap(),
1560        );
1561
1562        let spec = ToolSpec::new(
1563            "echo",
1564            "Echoes args",
1565            |args| serde_json::json!({ "echoed": args }),
1566        )
1567        .with_circuit_breaker(cb);
1568
1569        let registry = {
1570            let mut r = ToolRegistry::new();
1571            r.register(spec);
1572            r
1573        };
1574
1575        let result = registry
1576            .call("echo", serde_json::json!({ "msg": "hi" }))
1577            .await;
1578        assert!(result.is_ok(), "expected Ok, got {:?}", result);
1579    }
1580
1581    // ── Improvement 1: AgentConfig builder methods ────────────────────────────
1582
1583    #[test]
1584    fn test_agent_config_builder_methods_set_fields() {
1585        let config = AgentConfig::new(3, "model")
1586            .with_temperature(0.7)
1587            .with_max_tokens(512)
1588            .with_request_timeout(std::time::Duration::from_secs(10));
1589        assert_eq!(config.temperature, Some(0.7));
1590        assert_eq!(config.max_tokens, Some(512));
1591        assert_eq!(config.request_timeout, Some(std::time::Duration::from_secs(10)));
1592    }
1593
1594    // ── Improvement 2: Fallible tool handlers ─────────────────────────────────
1595
1596    #[tokio::test]
1597    async fn test_fallible_tool_returns_error_json_on_err() {
1598        let spec = ToolSpec::new_fallible(
1599            "fail",
1600            "always fails",
1601            |_| Err::<Value, String>("something went wrong".to_string()),
1602        );
1603        let result = spec.call(serde_json::json!({})).await;
1604        assert_eq!(result["ok"], serde_json::json!(false));
1605        assert_eq!(result["error"], serde_json::json!("something went wrong"));
1606    }
1607
1608    #[tokio::test]
1609    async fn test_fallible_tool_returns_value_on_ok() {
1610        let spec = ToolSpec::new_fallible(
1611            "succeed",
1612            "always succeeds",
1613            |_| Ok::<Value, String>(serde_json::json!(42)),
1614        );
1615        let result = spec.call(serde_json::json!({})).await;
1616        assert_eq!(result, serde_json::json!(42));
1617    }
1618
1619    // ── Improvement 4: Did you mean ───────────────────────────────────────────
1620
1621    #[tokio::test]
1622    async fn test_did_you_mean_suggestion_for_typo() {
1623        let mut registry = ToolRegistry::new();
1624        registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
1625        let result = registry.call("searc", serde_json::json!({})).await;
1626        assert!(result.is_err());
1627        let msg = result.unwrap_err().to_string();
1628        assert!(msg.contains("did you mean"), "expected suggestion in: {msg}");
1629    }
1630
1631    #[tokio::test]
1632    async fn test_no_suggestion_for_very_different_name() {
1633        let mut registry = ToolRegistry::new();
1634        registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
1635        let result = registry.call("xxxxxxxxxxxxxxx", serde_json::json!({})).await;
1636        assert!(result.is_err());
1637        let msg = result.unwrap_err().to_string();
1638        assert!(!msg.contains("did you mean"), "unexpected suggestion in: {msg}");
1639    }
1640
1641    // ── Improvement 11: Action enum ───────────────────────────────────────────
1642
1643    #[test]
1644    fn test_action_parse_final_answer() {
1645        let action = Action::parse("FINAL_ANSWER hello world").unwrap();
1646        assert_eq!(action, Action::FinalAnswer("hello world".to_string()));
1647    }
1648
1649    #[test]
1650    fn test_action_parse_tool_call() {
1651        let action = Action::parse("search {\"q\": \"rust\"}").unwrap();
1652        match action {
1653            Action::ToolCall { name, args } => {
1654                assert_eq!(name, "search");
1655                assert_eq!(args["q"], "rust");
1656            }
1657            _ => panic!("expected ToolCall"),
1658        }
1659    }
1660
1661    #[test]
1662    fn test_action_parse_invalid_returns_err() {
1663        let result = Action::parse("");
1664        assert!(result.is_err());
1665    }
1666
1667    // ── Improvement 13: Observer ──────────────────────────────────────────────
1668
1669    #[tokio::test]
1670    async fn test_observer_on_step_called_for_each_step() {
1671        use std::sync::{Arc, Mutex};
1672
1673        struct CountingObserver {
1674            step_count: Mutex<usize>,
1675        }
1676        impl Observer for CountingObserver {
1677            fn on_step(&self, _step_index: usize, _step: &ReActStep) {
1678                let mut c = self.step_count.lock().unwrap_or_else(|e| e.into_inner());
1679                *c += 1;
1680            }
1681        }
1682
1683        let obs = Arc::new(CountingObserver { step_count: Mutex::new(0) });
1684        let config = AgentConfig::new(5, "test-model");
1685        let mut loop_ = ReActLoop::new(config).with_observer(obs.clone() as Arc<dyn Observer>);
1686        loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
1687
1688        let mut call_count = 0;
1689        let _steps = loop_.run("test", |_ctx| {
1690            call_count += 1;
1691            let count = call_count;
1692            async move {
1693                if count == 1 {
1694                    "Thought: call noop\nAction: noop {}".to_string()
1695                } else {
1696                    "Thought: done\nAction: FINAL_ANSWER done".to_string()
1697                }
1698            }
1699        }).await.unwrap();
1700
1701        let count = *obs.step_count.lock().unwrap_or_else(|e| e.into_inner());
1702        assert_eq!(count, 2, "observer should have seen 2 steps");
1703    }
1704
1705    // ── Improvement 14: ToolCache ─────────────────────────────────────────────
1706
1707    #[tokio::test]
1708    async fn test_tool_cache_returns_cached_result_on_second_call() {
1709        use std::collections::HashMap;
1710        use std::sync::Mutex;
1711
1712        struct InMemCache {
1713            map: Mutex<HashMap<String, Value>>,
1714        }
1715        impl ToolCache for InMemCache {
1716            fn get(&self, tool_name: &str, args: &Value) -> Option<Value> {
1717                let key = format!("{tool_name}:{args}");
1718                let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
1719                map.get(&key).cloned()
1720            }
1721            fn set(&self, tool_name: &str, args: &Value, result: Value) {
1722                let key = format!("{tool_name}:{args}");
1723                let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
1724                map.insert(key, result);
1725            }
1726        }
1727
1728        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1729        let call_count_clone = call_count.clone();
1730
1731        let cache = Arc::new(InMemCache { map: Mutex::new(HashMap::new()) });
1732        let registry = ToolRegistry::new()
1733            .with_cache(cache as Arc<dyn ToolCache>);
1734        let mut registry = registry;
1735
1736        registry.register(ToolSpec::new("count", "count calls", move |_| {
1737            call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1738            serde_json::json!({"calls": 1})
1739        }));
1740
1741        let args = serde_json::json!({});
1742        let r1 = registry.call("count", args.clone()).await.unwrap();
1743        let r2 = registry.call("count", args.clone()).await.unwrap();
1744
1745        assert_eq!(r1, r2);
1746        // The handler should only be called once; second call hits cache.
1747        assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
1748    }
1749
1750    // ── Task 12: Chained validator short-circuit ──────────────────────────────
1751
1752    #[tokio::test]
1753    async fn test_validators_short_circuit_on_first_failure() {
1754        use std::sync::atomic::{AtomicUsize, Ordering as AOrdering};
1755        use std::sync::Arc;
1756
1757        let second_called = Arc::new(AtomicUsize::new(0));
1758        let second_called_clone = Arc::clone(&second_called);
1759
1760        struct AlwaysFail;
1761        impl ToolValidator for AlwaysFail {
1762            fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
1763                Err(AgentRuntimeError::AgentLoop("first validator failed".into()))
1764            }
1765        }
1766
1767        struct CountCalls(Arc<AtomicUsize>);
1768        impl ToolValidator for CountCalls {
1769            fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
1770                self.0.fetch_add(1, AOrdering::SeqCst);
1771                Ok(())
1772            }
1773        }
1774
1775        let mut registry = ToolRegistry::new();
1776        registry.register(
1777            ToolSpec::new("guarded", "A guarded tool", |args| args.clone())
1778                .with_validators(vec![
1779                    Box::new(AlwaysFail),
1780                    Box::new(CountCalls(second_called_clone)),
1781                ]),
1782        );
1783
1784        let result = registry.call("guarded", serde_json::json!({})).await;
1785        assert!(result.is_err(), "should fail due to first validator");
1786        assert_eq!(
1787            second_called.load(AOrdering::SeqCst),
1788            0,
1789            "second validator must not be called when first fails"
1790        );
1791    }
1792
1793    // ── Task 14: loop_timeout integration test ────────────────────────────────
1794
1795    #[tokio::test]
1796    async fn test_loop_timeout_fires_between_iterations() {
1797        let mut config = AgentConfig::new(100, "test-model");
1798        // 30 ms deadline; each infer call sleeps 20 ms, so timeout fires after 2 iterations.
1799        config.loop_timeout = Some(std::time::Duration::from_millis(30));
1800        let loop_ = ReActLoop::new(config);
1801
1802        let result = loop_
1803            .run("test", |_ctx| async {
1804                // Sleep just long enough that the cumulative time exceeds the deadline.
1805                tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1806                // Return a valid step that keeps the loop going (unknown tool → error observation → next iter).
1807                "Thought: still working\nAction: noop {}".to_string()
1808            })
1809            .await;
1810
1811        assert!(result.is_err(), "loop should time out");
1812        let msg = result.unwrap_err().to_string();
1813        assert!(msg.contains("loop timeout"), "unexpected error: {msg}");
1814    }
1815
1816    // ── Improvement 15: ActionHook ────────────────────────────────────────────
1817
1818    // ── #2 ReActStep::is_final_answer / is_tool_call ──────────────────────────
1819
1820    #[test]
1821    fn test_react_step_is_final_answer() {
1822        let step = ReActStep {
1823            thought: "".into(),
1824            action: "FINAL_ANSWER done".into(),
1825            observation: "".into(),
1826            step_duration_ms: 0,
1827        };
1828        assert!(step.is_final_answer());
1829        assert!(!step.is_tool_call());
1830    }
1831
1832    #[test]
1833    fn test_react_step_is_tool_call() {
1834        let step = ReActStep {
1835            thought: "".into(),
1836            action: "search {}".into(),
1837            observation: "".into(),
1838            step_duration_ms: 0,
1839        };
1840        assert!(!step.is_final_answer());
1841        assert!(step.is_tool_call());
1842    }
1843
1844    // ── #6 Role Display ───────────────────────────────────────────────────────
1845
1846    #[test]
1847    fn test_role_display() {
1848        assert_eq!(Role::System.to_string(), "system");
1849        assert_eq!(Role::User.to_string(), "user");
1850        assert_eq!(Role::Assistant.to_string(), "assistant");
1851        assert_eq!(Role::Tool.to_string(), "tool");
1852    }
1853
1854    // ── #12 Message accessors ─────────────────────────────────────────────────
1855
1856    #[test]
1857    fn test_message_accessors() {
1858        let msg = Message::new(Role::User, "hello");
1859        assert_eq!(msg.role(), &Role::User);
1860        assert_eq!(msg.content(), "hello");
1861    }
1862
1863    // ── #25 Action parse round-trips ──────────────────────────────────────────
1864
1865    #[test]
1866    fn test_action_parse_final_answer_round_trip() {
1867        let step = ReActStep {
1868            thought: "done".into(),
1869            action: "FINAL_ANSWER Paris".into(),
1870            observation: "".into(),
1871            step_duration_ms: 0,
1872        };
1873        assert!(step.is_final_answer());
1874        let action = Action::parse(&step.action).unwrap();
1875        assert!(matches!(action, Action::FinalAnswer(ref s) if s == "Paris"));
1876    }
1877
1878    #[test]
1879    fn test_action_parse_tool_call_round_trip() {
1880        let step = ReActStep {
1881            thought: "searching".into(),
1882            action: "search {\"q\":\"hello\"}".into(),
1883            observation: "".into(),
1884            step_duration_ms: 0,
1885        };
1886        assert!(step.is_tool_call());
1887        let action = Action::parse(&step.action).unwrap();
1888        assert!(matches!(action, Action::ToolCall { ref name, .. } if name == "search"));
1889    }
1890
1891    // ── #26 Observer step indices ─────────────────────────────────────────────
1892
1893    #[tokio::test]
1894    async fn test_observer_receives_correct_step_indices() {
1895        use std::sync::{Arc, Mutex};
1896
1897        struct IndexCollector(Arc<Mutex<Vec<usize>>>);
1898        impl Observer for IndexCollector {
1899            fn on_step(&self, step_index: usize, _step: &ReActStep) {
1900                self.0.lock().unwrap_or_else(|e| e.into_inner()).push(step_index);
1901            }
1902        }
1903
1904        let indices = Arc::new(Mutex::new(Vec::new()));
1905        let obs = Arc::new(IndexCollector(Arc::clone(&indices)));
1906
1907        let config = AgentConfig::new(5, "test");
1908        let mut loop_ = ReActLoop::new(config).with_observer(obs as Arc<dyn Observer>);
1909        loop_.register_tool(ToolSpec::new("noop", "no-op", |_| serde_json::json!({})));
1910
1911        let mut call_count = 0;
1912        loop_.run("test", |_ctx| {
1913            call_count += 1;
1914            let count = call_count;
1915            async move {
1916                if count == 1 {
1917                    "Thought: step1\nAction: noop {}".to_string()
1918                } else {
1919                    "Thought: done\nAction: FINAL_ANSWER ok".to_string()
1920                }
1921            }
1922        }).await.unwrap();
1923
1924        let collected = indices.lock().unwrap_or_else(|e| e.into_inner()).clone();
1925        assert_eq!(collected, vec![0, 1], "expected step indices 0 and 1");
1926    }
1927
1928    #[tokio::test]
1929    async fn test_action_hook_blocking_inserts_blocked_observation() {
1930        let hook: ActionHook = Arc::new(|_name, _args| {
1931            Box::pin(async move { false }) // always block
1932        });
1933
1934        let config = AgentConfig::new(5, "test-model");
1935        let mut loop_ = ReActLoop::new(config).with_action_hook(hook);
1936        loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
1937
1938        let mut call_count = 0;
1939        let steps = loop_.run("test", |_ctx| {
1940            call_count += 1;
1941            let count = call_count;
1942            async move {
1943                if count == 1 {
1944                    "Thought: try tool\nAction: noop {}".to_string()
1945                } else {
1946                    "Thought: done\nAction: FINAL_ANSWER done".to_string()
1947                }
1948            }
1949        }).await.unwrap();
1950
1951        assert!(steps[0].observation.contains("blocked"), "expected blocked observation, got: {}", steps[0].observation);
1952    }
1953}