Skip to main content

llm_agent_runtime/
agent.rs

1//! # Module: Agent
2//!
3//! ## Responsibility
4//! Provides a ReAct (Thought-Action-Observation) agent loop with pluggable tools.
5//! Mirrors the public API of `wasm-agent`.
6//!
7//! ## Guarantees
8//! - Deterministic: the loop terminates after at most `max_iterations` cycles
9//! - Non-panicking: all operations return `Result`
10//! - Tool handlers support both sync and async `Fn` closures
11//!
12//! ## NOT Responsible For
13//! - Actual LLM inference (callers supply a mock/stub inference fn)
14//! - WASM compilation or browser execution
15//! - Streaming partial responses
16
17use crate::error::AgentRuntimeError;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashMap;
21use std::future::Future;
22use std::pin::Pin;
23
24#[cfg(feature = "orchestrator")]
25use std::sync::Arc;
26
27// ── Types ─────────────────────────────────────────────────────────────────────
28
29/// A pinned, boxed future returning a `Value`. Used for async tool handlers.
30pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
31
32/// An async tool handler closure.
33pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
34
35/// Role of a message in a conversation.
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum Role {
38    /// System-level instruction injected before the user turn.
39    System,
40    /// Message from the human user.
41    User,
42    /// Message produced by the language model.
43    Assistant,
44    /// Message produced by a tool invocation.
45    Tool,
46}
47
48/// A single message in the conversation history.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Message {
51    /// The role of the speaker who produced this message.
52    pub role: Role,
53    /// The textual content of the message.
54    pub content: String,
55}
56
57impl Message {
58    /// Create a new `Message` with the given role and content.
59    ///
60    /// # Panics
61    ///
62    /// This function does not panic.
63    pub fn new(role: Role, content: impl Into<String>) -> Self {
64        Self {
65            role,
66            content: content.into(),
67        }
68    }
69}
70
71/// A single ReAct step: Thought → Action → Observation.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ReActStep {
74    /// Agent's reasoning about the current state.
75    pub thought: String,
76    /// The action taken (tool name + JSON arguments, or "FINAL_ANSWER").
77    pub action: String,
78    /// The result of the action.
79    pub observation: String,
80}
81
82/// Configuration for the ReAct agent loop.
83#[derive(Debug, Clone)]
84pub struct AgentConfig {
85    /// Maximum number of Thought-Action-Observation cycles.
86    pub max_iterations: usize,
87    /// Model identifier passed to the inference function.
88    pub model: String,
89    /// System prompt injected at the start of the conversation.
90    pub system_prompt: String,
91    /// Maximum number of episodic memories to inject into the prompt.
92    /// Keeping this small prevents silent token-budget overruns.  Default: 3.
93    pub max_memory_recalls: usize,
94    /// Maximum approximate token budget for injected memories.
95    /// Uses ~4 chars/token heuristic. None means no limit.
96    pub max_memory_tokens: Option<usize>,
97}
98
99impl AgentConfig {
100    /// Create a new config with sensible defaults.
101    pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
102        Self {
103            max_iterations,
104            model: model.into(),
105            system_prompt: "You are a helpful AI agent.".into(),
106            max_memory_recalls: 3,
107            max_memory_tokens: None,
108        }
109    }
110
111    /// Override the system prompt.
112    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
113        self.system_prompt = prompt.into();
114        self
115    }
116
117    /// Set the maximum number of episodic memories injected per run.
118    pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
119        self.max_memory_recalls = n;
120        self
121    }
122
123    /// Set a maximum token budget for injected memories (~4 chars/token heuristic).
124    pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
125        self.max_memory_tokens = Some(n);
126        self
127    }
128}
129
130// ── ToolSpec ──────────────────────────────────────────────────────────────────
131
132/// Describes and implements a single callable tool.
133pub struct ToolSpec {
134    /// Short identifier used in action strings (e.g. "search").
135    pub name: String,
136    /// Human-readable description passed to the model as part of the system prompt.
137    pub description: String,
138    /// Async handler: receives JSON arguments, returns a future resolving to a JSON result.
139    pub handler: AsyncToolHandler,
140    /// Field names that must be present in the JSON args object.
141    /// Empty means no validation is performed.
142    pub required_fields: Vec<String>,
143    /// Optional per-tool circuit breaker.
144    #[cfg(feature = "orchestrator")]
145    pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
146}
147
148impl std::fmt::Debug for ToolSpec {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        let mut s = f.debug_struct("ToolSpec");
151        s.field("name", &self.name)
152            .field("description", &self.description)
153            .field("required_fields", &self.required_fields);
154        #[cfg(feature = "orchestrator")]
155        s.field("has_circuit_breaker", &self.circuit_breaker.is_some());
156        s.finish()
157    }
158}
159
160impl ToolSpec {
161    /// Construct a new `ToolSpec` from a synchronous handler closure.
162    /// The closure is wrapped in an `async move` block automatically.
163    pub fn new(
164        name: impl Into<String>,
165        description: impl Into<String>,
166        handler: impl Fn(Value) -> Value + Send + Sync + 'static,
167    ) -> Self {
168        Self {
169            name: name.into(),
170            description: description.into(),
171            handler: Box::new(move |args| {
172                let result = handler(args);
173                Box::pin(async move { result })
174            }),
175            required_fields: Vec::new(),
176            #[cfg(feature = "orchestrator")]
177            circuit_breaker: None,
178        }
179    }
180
181    /// Construct a new `ToolSpec` from an async handler closure.
182    pub fn new_async(
183        name: impl Into<String>,
184        description: impl Into<String>,
185        handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
186    ) -> Self {
187        Self {
188            name: name.into(),
189            description: description.into(),
190            handler: Box::new(handler),
191            required_fields: Vec::new(),
192            #[cfg(feature = "orchestrator")]
193            circuit_breaker: None,
194        }
195    }
196
197    /// Set the required fields that must be present in the JSON args object.
198    pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
199        self.required_fields = fields;
200        self
201    }
202
203    /// Attach a circuit breaker to this tool spec.
204    #[cfg(feature = "orchestrator")]
205    pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
206        self.circuit_breaker = Some(cb);
207        self
208    }
209
210    /// Invoke the tool with the given JSON arguments.
211    pub async fn call(&self, args: Value) -> Value {
212        (self.handler)(args).await
213    }
214}
215
216// ── ToolRegistry ──────────────────────────────────────────────────────────────
217
218/// Registry of available tools for the agent loop.
219#[derive(Debug, Default)]
220pub struct ToolRegistry {
221    tools: HashMap<String, ToolSpec>,
222}
223
224impl ToolRegistry {
225    /// Create a new empty registry.
226    pub fn new() -> Self {
227        Self {
228            tools: HashMap::new(),
229        }
230    }
231
232    /// Register a tool. Overwrites any existing tool with the same name.
233    pub fn register(&mut self, spec: ToolSpec) {
234        self.tools.insert(spec.name.clone(), spec);
235    }
236
237    /// Call a tool by name.
238    ///
239    /// # Returns
240    /// - `Ok(Value)` — tool result
241    /// - `Err(AgentRuntimeError::AgentLoop)` — if the tool is not found or required fields
242    ///   are missing
243    /// - `Err(AgentRuntimeError::CircuitOpen)` — if the tool's circuit breaker is open
244    #[tracing::instrument(skip_all, fields(tool_name = %name))]
245    pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
246        let spec = self
247            .tools
248            .get(name)
249            .ok_or_else(|| AgentRuntimeError::AgentLoop(format!("tool '{name}' not found")))?;
250
251        // Item 3 — required field validation
252        if !spec.required_fields.is_empty() {
253            if let Some(obj) = args.as_object() {
254                for field in &spec.required_fields {
255                    if !obj.contains_key(field) {
256                        return Err(AgentRuntimeError::AgentLoop(format!(
257                            "tool '{}' missing required field '{}'",
258                            name, field
259                        )));
260                    }
261                }
262            } else {
263                return Err(AgentRuntimeError::AgentLoop(format!(
264                    "tool '{}' requires JSON object args, got {}",
265                    name, args
266                )));
267            }
268        }
269
270        // Item 7 — per-tool circuit breaker check
271        #[cfg(feature = "orchestrator")]
272        if let Some(ref cb) = spec.circuit_breaker {
273            use crate::orchestrator::CircuitState;
274            if let Ok(CircuitState::Open { .. }) = cb.state() {
275                return Err(AgentRuntimeError::CircuitOpen {
276                    service: format!("tool:{}", name),
277                });
278            }
279        }
280
281        let result = spec.call(args).await;
282        Ok(result)
283    }
284
285    /// Return the list of registered tool names.
286    pub fn tool_names(&self) -> Vec<&str> {
287        self.tools.keys().map(|s| s.as_str()).collect()
288    }
289}
290
291// ── ReActLoop ─────────────────────────────────────────────────────────────────
292
293/// Parses a ReAct response string into a `ReActStep`.
294///
295/// Case-insensitive; tolerates spaces around the colon.
296/// e.g. `Thought :`, `thought:`, `THOUGHT :` all match.
297pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
298    let mut thought = String::new();
299    let mut action = String::new();
300
301    for line in text.lines() {
302        let trimmed = line.trim();
303        let lower = trimmed.to_ascii_lowercase();
304        if lower.starts_with("thought") {
305            if let Some(colon_pos) = trimmed.find(':') {
306                thought = trimmed[colon_pos + 1..].trim().to_owned();
307            }
308        } else if lower.starts_with("action") {
309            if let Some(colon_pos) = trimmed.find(':') {
310                action = trimmed[colon_pos + 1..].trim().to_owned();
311            }
312        }
313    }
314
315    if thought.is_empty() && action.is_empty() {
316        return Err(AgentRuntimeError::AgentLoop(
317            "could not parse ReAct step from response".into(),
318        ));
319    }
320
321    Ok(ReActStep {
322        thought,
323        action,
324        observation: String::new(),
325    })
326}
327
328/// The ReAct agent loop.
329#[derive(Debug)]
330pub struct ReActLoop {
331    config: AgentConfig,
332    registry: ToolRegistry,
333}
334
335impl ReActLoop {
336    /// Create a new `ReActLoop` with the given configuration and an empty tool registry.
337    pub fn new(config: AgentConfig) -> Self {
338        Self {
339            config,
340            registry: ToolRegistry::new(),
341        }
342    }
343
344    /// Register a tool that the agent loop can invoke.
345    pub fn register_tool(&mut self, spec: ToolSpec) {
346        self.registry.register(spec);
347    }
348
349    /// Execute the ReAct loop for the given prompt.
350    ///
351    /// # Arguments
352    /// * `prompt` — user input passed as the initial context
353    /// * `infer`  — async inference function: receives context string, returns response string
354    ///
355    /// # Returns
356    /// - `Ok(Vec<ReActStep>)` — steps executed, ending with a `FINAL_ANSWER` step
357    /// - `Err(AgentRuntimeError::AgentLoop)` — if max iterations reached without `FINAL_ANSWER`
358    ///   or if a ReAct response cannot be parsed
359    #[tracing::instrument(skip(infer))]
360    pub async fn run<F, Fut>(
361        &self,
362        prompt: &str,
363        mut infer: F,
364    ) -> Result<Vec<ReActStep>, AgentRuntimeError>
365    where
366        F: FnMut(String) -> Fut,
367        Fut: Future<Output = String>,
368    {
369        let mut steps: Vec<ReActStep> = Vec::new();
370        let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
371
372        for iteration in 0..self.config.max_iterations {
373            let response = infer(context.clone()).await;
374            let mut step = parse_react_step(&response)?;
375
376            tracing::debug!(
377                step = iteration,
378                thought = %step.thought,
379                action = %step.action,
380                "ReAct iteration"
381            );
382
383            if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
384                step.observation = step.action.clone();
385                steps.push(step);
386                tracing::info!(step = iteration, "FINAL_ANSWER reached");
387                return Ok(steps);
388            }
389
390            let (tool_name, args) = parse_tool_call(&step.action);
391
392            tracing::debug!(
393                step = iteration,
394                tool_name = %tool_name,
395                "dispatching tool call"
396            );
397
398            // Item 9 — structured error categorization in observation
399            let observation = match self.registry.call(&tool_name, args).await {
400                Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
401                Err(e) => {
402                    let kind = match &e {
403                        AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => {
404                            "not_found"
405                        }
406                        #[cfg(feature = "orchestrator")]
407                        AgentRuntimeError::CircuitOpen { .. } => "transient",
408                        _ => "permanent",
409                    };
410                    serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind })
411                        .to_string()
412                }
413            };
414
415            step.observation = observation.clone();
416            context.push_str(&format!(
417                "\nThought: {}\nAction: {}\nObservation: {}\n",
418                step.thought, step.action, observation
419            ));
420            steps.push(step);
421        }
422
423        let err = AgentRuntimeError::AgentLoop(format!(
424            "max iterations ({}) reached without final answer",
425            self.config.max_iterations
426        ));
427        tracing::warn!(
428            max_iterations = self.config.max_iterations,
429            "ReAct loop exhausted max iterations without FINAL_ANSWER"
430        );
431        Err(err)
432    }
433}
434
435/// Split `"tool_name {json}"` into `(tool_name, Value)`.
436fn parse_tool_call(action: &str) -> (String, Value) {
437    let mut parts = action.splitn(2, ' ');
438    let name = parts.next().unwrap_or("").to_owned();
439    let args_str = parts.next().unwrap_or("{}");
440    let args: Value = serde_json::from_str(args_str).unwrap_or(Value::String(args_str.to_owned()));
441    (name, args)
442}
443
444/// Agent-specific errors, mirrors `wasm-agent::AgentError`.
445///
446/// Converts to `AgentRuntimeError::AgentLoop` via the `From` implementation.
447#[derive(Debug, thiserror::Error)]
448pub enum AgentError {
449    /// The referenced tool name does not exist in the registry.
450    #[error("Tool '{0}' not found")]
451    ToolNotFound(String),
452    /// The ReAct loop consumed all iterations without emitting `FINAL_ANSWER`.
453    #[error("Max iterations exceeded: {0}")]
454    MaxIterations(usize),
455    /// The model response could not be parsed into a `ReActStep`.
456    #[error("Parse error: {0}")]
457    ParseError(String),
458}
459
460impl From<AgentError> for AgentRuntimeError {
461    fn from(e: AgentError) -> Self {
462        AgentRuntimeError::AgentLoop(e.to_string())
463    }
464}
465
466// ── Tests ─────────────────────────────────────────────────────────────────────
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[tokio::test]
473    async fn test_final_answer_on_first_step() {
474        let config = AgentConfig::new(5, "test-model");
475        let loop_ = ReActLoop::new(config);
476
477        let steps = loop_
478            .run("Say hello", |_ctx| async {
479                "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
480            })
481            .await
482            .unwrap();
483
484        assert_eq!(steps.len(), 1);
485        assert!(steps[0]
486            .action
487            .to_ascii_uppercase()
488            .starts_with("FINAL_ANSWER"));
489    }
490
491    #[tokio::test]
492    async fn test_tool_call_then_final_answer() {
493        let config = AgentConfig::new(5, "test-model");
494        let mut loop_ = ReActLoop::new(config);
495
496        loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
497            serde_json::json!("hello!")
498        }));
499
500        let mut call_count = 0;
501        let steps = loop_
502            .run("Say hello", |_ctx| {
503                call_count += 1;
504                let count = call_count;
505                async move {
506                    if count == 1 {
507                        "Thought: I will greet\nAction: greet {}".to_string()
508                    } else {
509                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
510                    }
511                }
512            })
513            .await
514            .unwrap();
515
516        assert_eq!(steps.len(), 2);
517        assert_eq!(steps[0].action, "greet {}");
518        assert!(steps[1]
519            .action
520            .to_ascii_uppercase()
521            .starts_with("FINAL_ANSWER"));
522    }
523
524    #[tokio::test]
525    async fn test_max_iterations_exceeded() {
526        let config = AgentConfig::new(2, "test-model");
527        let loop_ = ReActLoop::new(config);
528
529        let result = loop_
530            .run("loop forever", |_ctx| async {
531                "Thought: thinking\nAction: noop {}".to_string()
532            })
533            .await;
534
535        assert!(result.is_err());
536        let err = result.unwrap_err().to_string();
537        assert!(err.contains("max iterations"));
538    }
539
540    #[tokio::test]
541    async fn test_parse_react_step_valid() {
542        let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
543        let step = parse_react_step(text).unwrap();
544        assert_eq!(step.thought, "I should check");
545        assert_eq!(step.action, "lookup {\"key\":\"val\"}");
546    }
547
548    #[tokio::test]
549    async fn test_parse_react_step_empty_fails() {
550        let result = parse_react_step("no prefix lines here");
551        assert!(result.is_err());
552    }
553
554    #[tokio::test]
555    async fn test_tool_not_found_returns_error_observation() {
556        let config = AgentConfig::new(3, "test-model");
557        let loop_ = ReActLoop::new(config);
558
559        let mut call_count = 0;
560        let steps = loop_
561            .run("test", |_ctx| {
562                call_count += 1;
563                let count = call_count;
564                async move {
565                    if count == 1 {
566                        "Thought: try missing tool\nAction: missing_tool {}".to_string()
567                    } else {
568                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
569                    }
570                }
571            })
572            .await
573            .unwrap();
574
575        assert_eq!(steps.len(), 2);
576        assert!(steps[0].observation.contains("\"ok\":false"));
577    }
578
579    #[tokio::test]
580    async fn test_new_async_tool_spec() {
581        let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
582            Box::pin(async move { serde_json::json!({"echo": args}) })
583        });
584
585        let result = spec.call(serde_json::json!({"input": "test"})).await;
586        assert!(result.get("echo").is_some());
587    }
588
589    // Item 1 — Robust ReAct Parser tests
590
591    #[tokio::test]
592    async fn test_parse_react_step_case_insensitive() {
593        let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
594        let step = parse_react_step(text).unwrap();
595        assert_eq!(step.thought, "done");
596        assert_eq!(step.action, "FINAL_ANSWER");
597    }
598
599    #[tokio::test]
600    async fn test_parse_react_step_space_before_colon() {
601        let text = "Thought : done\nAction : go";
602        let step = parse_react_step(text).unwrap();
603        assert_eq!(step.thought, "done");
604        assert_eq!(step.action, "go");
605    }
606
607    // Item 3 — Tool required field validation tests
608
609    #[tokio::test]
610    async fn test_tool_required_fields_missing_returns_error() {
611        let config = AgentConfig::new(3, "test-model");
612        let mut loop_ = ReActLoop::new(config);
613
614        loop_.register_tool(
615            ToolSpec::new(
616                "search",
617                "Searches for something",
618                |args| serde_json::json!({ "result": args }),
619            )
620            .with_required_fields(vec!["q".to_string()]),
621        );
622
623        let mut call_count = 0;
624        let steps = loop_
625            .run("test", |_ctx| {
626                call_count += 1;
627                let count = call_count;
628                async move {
629                    if count == 1 {
630                        // Call with empty object — missing "q"
631                        "Thought: searching\nAction: search {}".to_string()
632                    } else {
633                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
634                    }
635                }
636            })
637            .await
638            .unwrap();
639
640        assert_eq!(steps.len(), 2);
641        assert!(
642            steps[0].observation.contains("missing required field"),
643            "observation was: {}",
644            steps[0].observation
645        );
646    }
647
648    // Item 9 — Structured error observation tests
649
650    #[tokio::test]
651    async fn test_tool_error_observation_includes_kind() {
652        let config = AgentConfig::new(3, "test-model");
653        let loop_ = ReActLoop::new(config);
654
655        let mut call_count = 0;
656        let steps = loop_
657            .run("test", |_ctx| {
658                call_count += 1;
659                let count = call_count;
660                async move {
661                    if count == 1 {
662                        "Thought: try missing\nAction: nonexistent_tool {}".to_string()
663                    } else {
664                        "Thought: done\nAction: FINAL_ANSWER done".to_string()
665                    }
666                }
667            })
668            .await
669            .unwrap();
670
671        assert_eq!(steps.len(), 2);
672        let obs = &steps[0].observation;
673        assert!(obs.contains("\"ok\":false"), "observation: {obs}");
674        assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
675    }
676
677    // Item 7 — Circuit breaker test (only compiled when feature is active)
678
679    #[cfg(feature = "orchestrator")]
680    #[tokio::test]
681    async fn test_tool_with_circuit_breaker_passes_when_closed() {
682        use std::sync::Arc;
683
684        let cb = Arc::new(
685            crate::orchestrator::CircuitBreaker::new(
686                "echo-tool",
687                5,
688                std::time::Duration::from_secs(30),
689            )
690            .unwrap(),
691        );
692
693        let spec = ToolSpec::new(
694            "echo",
695            "Echoes args",
696            |args| serde_json::json!({ "echoed": args }),
697        )
698        .with_circuit_breaker(cb);
699
700        let registry = {
701            let mut r = ToolRegistry::new();
702            r.register(spec);
703            r
704        };
705
706        let result = registry
707            .call("echo", serde_json::json!({ "msg": "hi" }))
708            .await;
709        assert!(result.is_ok(), "expected Ok, got {:?}", result);
710    }
711}