Skip to main content

llm_agent_runtime/
agent.rs

1//! # Module: Agent
2//!
3//! ## Responsibility
4//! Provides a ReAct (Thought-Action-Observation) agent loop with pluggable tools.
5//! Mirrors the public API of `wasm-agent`.
6//!
7//! ## Guarantees
8//! - Deterministic: the loop terminates after at most `max_iterations` cycles
9//! - Non-panicking: all operations return `Result`
10//! - Tool handlers support both sync and async `Fn` closures
11//!
12//! ## NOT Responsible For
13//! - Actual LLM inference (callers supply a mock/stub inference fn)
14//! - WASM compilation or browser execution
15//! - Streaming partial responses
16
17use crate::error::AgentRuntimeError;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashMap;
21use std::future::Future;
22use std::pin::Pin;
23
24#[cfg(feature = "orchestrator")]
25use std::sync::Arc;
26
27// ── Types ─────────────────────────────────────────────────────────────────────
28
29/// A pinned, boxed future returning a `Value`. Used for async tool handlers.
30pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
31
32/// An async tool handler closure.
33pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
34
35/// Role of a message in a conversation.
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum Role {
38    /// System-level instruction injected before the user turn.
39    System,
40    /// Message from the human user.
41    User,
42    /// Message produced by the language model.
43    Assistant,
44    /// Message produced by a tool invocation.
45    Tool,
46}
47
48/// A single message in the conversation history.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Message {
51    /// The role of the speaker who produced this message.
52    pub role: Role,
53    /// The textual content of the message.
54    pub content: String,
55}
56
57impl Message {
58    /// Create a new `Message` with the given role and content.
59    ///
60    /// # Panics
61    ///
62    /// This function does not panic.
63    pub fn new(role: Role, content: impl Into<String>) -> Self {
64        Self {
65            role,
66            content: content.into(),
67        }
68    }
69}
70
71/// A single ReAct step: Thought → Action → Observation.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ReActStep {
74    /// Agent's reasoning about the current state.
75    pub thought: String,
76    /// The action taken (tool name + JSON arguments, or "FINAL_ANSWER").
77    pub action: String,
78    /// The result of the action.
79    pub observation: String,
80}
81
82/// Configuration for the ReAct agent loop.
83#[derive(Debug, Clone)]
84pub struct AgentConfig {
85    /// Maximum number of Thought-Action-Observation cycles.
86    pub max_iterations: usize,
87    /// Model identifier passed to the inference function.
88    pub model: String,
89    /// System prompt injected at the start of the conversation.
90    pub system_prompt: String,
91    /// Maximum number of episodic memories to inject into the prompt.
92    /// Keeping this small prevents silent token-budget overruns.  Default: 3.
93    pub max_memory_recalls: usize,
94    /// Maximum approximate token budget for injected memories.
95    /// Uses ~4 chars/token heuristic. None means no limit.
96    pub max_memory_tokens: Option<usize>,
97}
98
99impl AgentConfig {
100    /// Create a new config with sensible defaults.
101    pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
102        Self {
103            max_iterations,
104            model: model.into(),
105            system_prompt: "You are a helpful AI agent.".into(),
106            max_memory_recalls: 3,
107            max_memory_tokens: None,
108        }
109    }
110
111    /// Override the system prompt.
112    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
113        self.system_prompt = prompt.into();
114        self
115    }
116
117    /// Set the maximum number of episodic memories injected per run.
118    pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
119        self.max_memory_recalls = n;
120        self
121    }
122
123    /// Set a maximum token budget for injected memories (~4 chars/token heuristic).
124    pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
125        self.max_memory_tokens = Some(n);
126        self
127    }
128}
129
130// ── ToolSpec ──────────────────────────────────────────────────────────────────
131
132/// Describes and implements a single callable tool.
133pub struct ToolSpec {
134    /// Short identifier used in action strings (e.g. "search").
135    pub name: String,
136    /// Human-readable description passed to the model as part of the system prompt.
137    pub description: String,
138    /// Async handler: receives JSON arguments, returns a future resolving to a JSON result.
139    pub handler: AsyncToolHandler,
140    /// Field names that must be present in the JSON args object.
141    /// Empty means no validation is performed.
142    pub required_fields: Vec<String>,
143    /// Optional per-tool circuit breaker.
144    #[cfg(feature = "orchestrator")]
145    pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
146}
147
148impl std::fmt::Debug for ToolSpec {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        let mut s = f.debug_struct("ToolSpec");
151        s.field("name", &self.name)
152            .field("description", &self.description)
153            .field("required_fields", &self.required_fields);
154        #[cfg(feature = "orchestrator")]
155        s.field(
156            "has_circuit_breaker",
157            &self.circuit_breaker.is_some(),
158        );
159        s.finish()
160    }
161}
162
163impl ToolSpec {
164    /// Construct a new `ToolSpec` from a synchronous handler closure.
165    /// The closure is wrapped in an `async move` block automatically.
166    pub fn new(
167        name: impl Into<String>,
168        description: impl Into<String>,
169        handler: impl Fn(Value) -> Value + Send + Sync + 'static,
170    ) -> Self {
171        Self {
172            name: name.into(),
173            description: description.into(),
174            handler: Box::new(move |args| {
175                let result = handler(args);
176                Box::pin(async move { result })
177            }),
178            required_fields: Vec::new(),
179            #[cfg(feature = "orchestrator")]
180            circuit_breaker: None,
181        }
182    }
183
184    /// Construct a new `ToolSpec` from an async handler closure.
185    pub fn new_async(
186        name: impl Into<String>,
187        description: impl Into<String>,
188        handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
189    ) -> Self {
190        Self {
191            name: name.into(),
192            description: description.into(),
193            handler: Box::new(handler),
194            required_fields: Vec::new(),
195            #[cfg(feature = "orchestrator")]
196            circuit_breaker: None,
197        }
198    }
199
200    /// Set the required fields that must be present in the JSON args object.
201    pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
202        self.required_fields = fields;
203        self
204    }
205
206    /// Attach a circuit breaker to this tool spec.
207    #[cfg(feature = "orchestrator")]
208    pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
209        self.circuit_breaker = Some(cb);
210        self
211    }
212
213    /// Invoke the tool with the given JSON arguments.
214    pub async fn call(&self, args: Value) -> Value {
215        (self.handler)(args).await
216    }
217}
218
219// ── ToolRegistry ──────────────────────────────────────────────────────────────
220
221/// Registry of available tools for the agent loop.
222#[derive(Debug, Default)]
223pub struct ToolRegistry {
224    tools: HashMap<String, ToolSpec>,
225}
226
227impl ToolRegistry {
228    /// Create a new empty registry.
229    pub fn new() -> Self {
230        Self {
231            tools: HashMap::new(),
232        }
233    }
234
235    /// Register a tool. Overwrites any existing tool with the same name.
236    pub fn register(&mut self, spec: ToolSpec) {
237        self.tools.insert(spec.name.clone(), spec);
238    }
239
240    /// Call a tool by name.
241    ///
242    /// # Returns
243    /// - `Ok(Value)` — tool result
244    /// - `Err(AgentRuntimeError::AgentLoop)` — if the tool is not found or required fields
245    ///   are missing
246    /// - `Err(AgentRuntimeError::CircuitOpen)` — if the tool's circuit breaker is open
247    #[tracing::instrument(skip_all, fields(tool_name = %name))]
248    pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
249        let spec = self
250            .tools
251            .get(name)
252            .ok_or_else(|| AgentRuntimeError::AgentLoop(format!("tool '{name}' not found")))?;
253
254        // Item 3 — required field validation
255        if !spec.required_fields.is_empty() {
256            if let Some(obj) = args.as_object() {
257                for field in &spec.required_fields {
258                    if !obj.contains_key(field) {
259                        return Err(AgentRuntimeError::AgentLoop(format!(
260                            "tool '{}' missing required field '{}'",
261                            name, field
262                        )));
263                    }
264                }
265            } else {
266                return Err(AgentRuntimeError::AgentLoop(format!(
267                    "tool '{}' requires JSON object args, got {}",
268                    name, args
269                )));
270            }
271        }
272
273        // Item 7 — per-tool circuit breaker check
274        #[cfg(feature = "orchestrator")]
275        if let Some(ref cb) = spec.circuit_breaker {
276            use crate::orchestrator::CircuitState;
277            if let Ok(CircuitState::Open { .. }) = cb.state() {
278                return Err(AgentRuntimeError::CircuitOpen {
279                    service: format!("tool:{}", name),
280                });
281            }
282        }
283
284        let result = spec.call(args).await;
285        Ok(result)
286    }
287
288    /// Return the list of registered tool names.
289    pub fn tool_names(&self) -> Vec<&str> {
290        self.tools.keys().map(|s| s.as_str()).collect()
291    }
292}
293
294// ── ReActLoop ─────────────────────────────────────────────────────────────────
295
296/// Parses a ReAct response string into a `ReActStep`.
297///
298/// Case-insensitive; tolerates spaces around the colon.
299/// e.g. `Thought :`, `thought:`, `THOUGHT :` all match.
300pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
301    let mut thought = String::new();
302    let mut action = String::new();
303
304    for line in text.lines() {
305        let trimmed = line.trim();
306        let lower = trimmed.to_ascii_lowercase();
307        if lower.starts_with("thought") {
308            if let Some(colon_pos) = trimmed.find(':') {
309                thought = trimmed[colon_pos + 1..].trim().to_owned();
310            }
311        } else if lower.starts_with("action") {
312            if let Some(colon_pos) = trimmed.find(':') {
313                action = trimmed[colon_pos + 1..].trim().to_owned();
314            }
315        }
316    }
317
318    if thought.is_empty() && action.is_empty() {
319        return Err(AgentRuntimeError::AgentLoop(
320            "could not parse ReAct step from response".into(),
321        ));
322    }
323
324    Ok(ReActStep {
325        thought,
326        action,
327        observation: String::new(),
328    })
329}
330
331/// The ReAct agent loop.
332#[derive(Debug)]
333pub struct ReActLoop {
334    config: AgentConfig,
335    registry: ToolRegistry,
336}
337
338impl ReActLoop {
339    /// Create a new `ReActLoop` with the given configuration and an empty tool registry.
340    pub fn new(config: AgentConfig) -> Self {
341        Self {
342            config,
343            registry: ToolRegistry::new(),
344        }
345    }
346
347    /// Register a tool that the agent loop can invoke.
348    pub fn register_tool(&mut self, spec: ToolSpec) {
349        self.registry.register(spec);
350    }
351
352    /// Execute the ReAct loop for the given prompt.
353    ///
354    /// # Arguments
355    /// * `prompt` — user input passed as the initial context
356    /// * `infer`  — async inference function: receives context string, returns response string
357    ///
358    /// # Returns
359    /// - `Ok(Vec<ReActStep>)` — steps executed, ending with a `FINAL_ANSWER` step
360    /// - `Err(AgentRuntimeError::AgentLoop)` — if max iterations reached without `FINAL_ANSWER`
361    ///   or if a ReAct response cannot be parsed
362    #[tracing::instrument(skip(infer))]
363    pub async fn run<F, Fut>(
364        &self,
365        prompt: &str,
366        mut infer: F,
367    ) -> Result<Vec<ReActStep>, AgentRuntimeError>
368    where
369        F: FnMut(String) -> Fut,
370        Fut: Future<Output = String>,
371    {
372        let mut steps: Vec<ReActStep> = Vec::new();
373        let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
374
375        for iteration in 0..self.config.max_iterations {
376            let response = infer(context.clone()).await;
377            let mut step = parse_react_step(&response)?;
378
379            tracing::debug!(
380                step = iteration,
381                thought = %step.thought,
382                action = %step.action,
383                "ReAct iteration"
384            );
385
386            if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
387                step.observation = step.action.clone();
388                steps.push(step);
389                tracing::info!(step = iteration, "FINAL_ANSWER reached");
390                return Ok(steps);
391            }
392
393            let (tool_name, args) = parse_tool_call(&step.action);
394
395            tracing::debug!(
396                step = iteration,
397                tool_name = %tool_name,
398                "dispatching tool call"
399            );
400
401            // Item 9 — structured error categorization in observation
402            let observation = match self.registry.call(&tool_name, args).await {
403                Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
404                Err(e) => {
405                    let kind = match &e {
406                        AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => {
407                            "not_found"
408                        }
409                        #[cfg(feature = "orchestrator")]
410                        AgentRuntimeError::CircuitOpen { .. } => "transient",
411                        _ => "permanent",
412                    };
413                    serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind })
414                        .to_string()
415                }
416            };
417
418            step.observation = observation.clone();
419            context.push_str(&format!(
420                "\nThought: {}\nAction: {}\nObservation: {}\n",
421                step.thought, step.action, observation
422            ));
423            steps.push(step);
424        }
425
426        let err = AgentRuntimeError::AgentLoop(format!(
427            "max iterations ({}) reached without final answer",
428            self.config.max_iterations
429        ));
430        tracing::warn!(
431            max_iterations = self.config.max_iterations,
432            "ReAct loop exhausted max iterations without FINAL_ANSWER"
433        );
434        Err(err)
435    }
436}
437
438/// Split `"tool_name {json}"` into `(tool_name, Value)`.
439fn parse_tool_call(action: &str) -> (String, Value) {
440    let mut parts = action.splitn(2, ' ');
441    let name = parts.next().unwrap_or("").to_owned();
442    let args_str = parts.next().unwrap_or("{}");
443    let args: Value = serde_json::from_str(args_str).unwrap_or(Value::String(args_str.to_owned()));
444    (name, args)
445}
446
447/// Agent-specific errors, mirrors `wasm-agent::AgentError`.
448///
449/// Converts to `AgentRuntimeError::AgentLoop` via the `From` implementation.
450#[derive(Debug, thiserror::Error)]
451pub enum AgentError {
452    /// The referenced tool name does not exist in the registry.
453    #[error("Tool '{0}' not found")]
454    ToolNotFound(String),
455    /// The ReAct loop consumed all iterations without emitting `FINAL_ANSWER`.
456    #[error("Max iterations exceeded: {0}")]
457    MaxIterations(usize),
458    /// The model response could not be parsed into a `ReActStep`.
459    #[error("Parse error: {0}")]
460    ParseError(String),
461}
462
463impl From<AgentError> for AgentRuntimeError {
464    fn from(e: AgentError) -> Self {
465        AgentRuntimeError::AgentLoop(e.to_string())
466    }
467}
468
469// ── Tests ─────────────────────────────────────────────────────────────────────
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[tokio::test]
476    async fn test_final_answer_on_first_step() {
477        let config = AgentConfig::new(5, "test-model");
478        let loop_ = ReActLoop::new(config);
479
480        let steps = loop_
481            .run("Say hello", |_ctx| async {
482                "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
483            })
484            .await
485            .unwrap();
486
487        assert_eq!(steps.len(), 1);
488        assert!(steps[0].action.to_ascii_uppercase().starts_with("FINAL_ANSWER"));
489    }
490
491    #[tokio::test]
492    async fn test_tool_call_then_final_answer() {
493        let config = AgentConfig::new(5, "test-model");
494        let mut loop_ = ReActLoop::new(config);
495
496        loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
497            serde_json::json!("hello!")
498        }));
499
500        let mut call_count = 0;
501        let steps = loop_
502            .run("Say hello", |_ctx| {
503                call_count += 1;
504                let count = call_count;
505                async move {
506                    if count == 1 {
507                        "Thought: I will greet\nAction: greet {}".to_string()
508                    } else {
509                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
510                    }
511                }
512            })
513            .await
514            .unwrap();
515
516        assert_eq!(steps.len(), 2);
517        assert_eq!(steps[0].action, "greet {}");
518        assert!(steps[1].action.to_ascii_uppercase().starts_with("FINAL_ANSWER"));
519    }
520
521    #[tokio::test]
522    async fn test_max_iterations_exceeded() {
523        let config = AgentConfig::new(2, "test-model");
524        let loop_ = ReActLoop::new(config);
525
526        let result = loop_
527            .run("loop forever", |_ctx| async {
528                "Thought: thinking\nAction: noop {}".to_string()
529            })
530            .await;
531
532        assert!(result.is_err());
533        let err = result.unwrap_err().to_string();
534        assert!(err.contains("max iterations"));
535    }
536
537    #[tokio::test]
538    async fn test_parse_react_step_valid() {
539        let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
540        let step = parse_react_step(text).unwrap();
541        assert_eq!(step.thought, "I should check");
542        assert_eq!(step.action, "lookup {\"key\":\"val\"}");
543    }
544
545    #[tokio::test]
546    async fn test_parse_react_step_empty_fails() {
547        let result = parse_react_step("no prefix lines here");
548        assert!(result.is_err());
549    }
550
551    #[tokio::test]
552    async fn test_tool_not_found_returns_error_observation() {
553        let config = AgentConfig::new(3, "test-model");
554        let loop_ = ReActLoop::new(config);
555
556        let mut call_count = 0;
557        let steps = loop_
558            .run("test", |_ctx| {
559                call_count += 1;
560                let count = call_count;
561                async move {
562                    if count == 1 {
563                        "Thought: try missing tool\nAction: missing_tool {}".to_string()
564                    } else {
565                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
566                    }
567                }
568            })
569            .await
570            .unwrap();
571
572        assert_eq!(steps.len(), 2);
573        assert!(steps[0].observation.contains("\"ok\":false"));
574    }
575
576    #[tokio::test]
577    async fn test_new_async_tool_spec() {
578        let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
579            Box::pin(async move { serde_json::json!({"echo": args}) })
580        });
581
582        let result = spec.call(serde_json::json!({"input": "test"})).await;
583        assert!(result.get("echo").is_some());
584    }
585
586    // Item 1 — Robust ReAct Parser tests
587
588    #[tokio::test]
589    async fn test_parse_react_step_case_insensitive() {
590        let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
591        let step = parse_react_step(text).unwrap();
592        assert_eq!(step.thought, "done");
593        assert_eq!(step.action, "FINAL_ANSWER");
594    }
595
596    #[tokio::test]
597    async fn test_parse_react_step_space_before_colon() {
598        let text = "Thought : done\nAction : go";
599        let step = parse_react_step(text).unwrap();
600        assert_eq!(step.thought, "done");
601        assert_eq!(step.action, "go");
602    }
603
604    // Item 3 — Tool required field validation tests
605
606    #[tokio::test]
607    async fn test_tool_required_fields_missing_returns_error() {
608        let config = AgentConfig::new(3, "test-model");
609        let mut loop_ = ReActLoop::new(config);
610
611        loop_.register_tool(
612            ToolSpec::new("search", "Searches for something", |args| {
613                serde_json::json!({ "result": args })
614            })
615            .with_required_fields(vec!["q".to_string()]),
616        );
617
618        let mut call_count = 0;
619        let steps = loop_
620            .run("test", |_ctx| {
621                call_count += 1;
622                let count = call_count;
623                async move {
624                    if count == 1 {
625                        // Call with empty object — missing "q"
626                        "Thought: searching\nAction: search {}".to_string()
627                    } else {
628                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
629                    }
630                }
631            })
632            .await
633            .unwrap();
634
635        assert_eq!(steps.len(), 2);
636        assert!(
637            steps[0].observation.contains("missing required field"),
638            "observation was: {}",
639            steps[0].observation
640        );
641    }
642
643    // Item 9 — Structured error observation tests
644
645    #[tokio::test]
646    async fn test_tool_error_observation_includes_kind() {
647        let config = AgentConfig::new(3, "test-model");
648        let loop_ = ReActLoop::new(config);
649
650        let mut call_count = 0;
651        let steps = loop_
652            .run("test", |_ctx| {
653                call_count += 1;
654                let count = call_count;
655                async move {
656                    if count == 1 {
657                        "Thought: try missing\nAction: nonexistent_tool {}".to_string()
658                    } else {
659                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
660                    }
661                }
662            })
663            .await
664            .unwrap();
665
666        assert_eq!(steps.len(), 2);
667        let obs = &steps[0].observation;
668        assert!(obs.contains("\"ok\":false"), "observation: {obs}");
669        assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
670    }
671
672    // Item 7 — Circuit breaker test (only compiled when feature is active)
673
674    #[cfg(feature = "orchestrator")]
675    #[tokio::test]
676    async fn test_tool_with_circuit_breaker_passes_when_closed() {
677        use std::sync::Arc;
678
679        let cb = Arc::new(crate::orchestrator::CircuitBreaker::new(
680            "echo-tool",
681            5,
682            std::time::Duration::from_secs(30),
683        ).unwrap());
684
685        let spec = ToolSpec::new("echo", "Echoes args", |args| {
686            serde_json::json!({ "echoed": args })
687        })
688        .with_circuit_breaker(cb);
689
690        let registry = {
691            let mut r = ToolRegistry::new();
692            r.register(spec);
693            r
694        };
695
696        let result = registry
697            .call("echo", serde_json::json!({ "msg": "hi" }))
698            .await;
699        assert!(result.is_ok(), "expected Ok, got {:?}", result);
700    }
701}