1use crate::error::AgentRuntimeError;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashMap;
21use std::future::Future;
22use std::pin::Pin;
23
24#[cfg(feature = "orchestrator")]
25use std::sync::Arc;
26
27pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
31
32pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum Role {
38 System,
40 User,
42 Assistant,
44 Tool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Message {
51 pub role: Role,
53 pub content: String,
55}
56
57impl Message {
58 pub fn new(role: Role, content: impl Into<String>) -> Self {
64 Self {
65 role,
66 content: content.into(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ReActStep {
74 pub thought: String,
76 pub action: String,
78 pub observation: String,
80}
81
82#[derive(Debug, Clone)]
84pub struct AgentConfig {
85 pub max_iterations: usize,
87 pub model: String,
89 pub system_prompt: String,
91 pub max_memory_recalls: usize,
94 pub max_memory_tokens: Option<usize>,
97}
98
99impl AgentConfig {
100 pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
102 Self {
103 max_iterations,
104 model: model.into(),
105 system_prompt: "You are a helpful AI agent.".into(),
106 max_memory_recalls: 3,
107 max_memory_tokens: None,
108 }
109 }
110
111 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
113 self.system_prompt = prompt.into();
114 self
115 }
116
117 pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
119 self.max_memory_recalls = n;
120 self
121 }
122
123 pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
125 self.max_memory_tokens = Some(n);
126 self
127 }
128}
129
130pub struct ToolSpec {
134 pub name: String,
136 pub description: String,
138 pub handler: AsyncToolHandler,
140 pub required_fields: Vec<String>,
143 #[cfg(feature = "orchestrator")]
145 pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
146}
147
148impl std::fmt::Debug for ToolSpec {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 let mut s = f.debug_struct("ToolSpec");
151 s.field("name", &self.name)
152 .field("description", &self.description)
153 .field("required_fields", &self.required_fields);
154 #[cfg(feature = "orchestrator")]
155 s.field(
156 "has_circuit_breaker",
157 &self.circuit_breaker.is_some(),
158 );
159 s.finish()
160 }
161}
162
163impl ToolSpec {
164 pub fn new(
167 name: impl Into<String>,
168 description: impl Into<String>,
169 handler: impl Fn(Value) -> Value + Send + Sync + 'static,
170 ) -> Self {
171 Self {
172 name: name.into(),
173 description: description.into(),
174 handler: Box::new(move |args| {
175 let result = handler(args);
176 Box::pin(async move { result })
177 }),
178 required_fields: Vec::new(),
179 #[cfg(feature = "orchestrator")]
180 circuit_breaker: None,
181 }
182 }
183
184 pub fn new_async(
186 name: impl Into<String>,
187 description: impl Into<String>,
188 handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
189 ) -> Self {
190 Self {
191 name: name.into(),
192 description: description.into(),
193 handler: Box::new(handler),
194 required_fields: Vec::new(),
195 #[cfg(feature = "orchestrator")]
196 circuit_breaker: None,
197 }
198 }
199
200 pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
202 self.required_fields = fields;
203 self
204 }
205
206 #[cfg(feature = "orchestrator")]
208 pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
209 self.circuit_breaker = Some(cb);
210 self
211 }
212
213 pub async fn call(&self, args: Value) -> Value {
215 (self.handler)(args).await
216 }
217}
218
219#[derive(Debug, Default)]
223pub struct ToolRegistry {
224 tools: HashMap<String, ToolSpec>,
225}
226
227impl ToolRegistry {
228 pub fn new() -> Self {
230 Self {
231 tools: HashMap::new(),
232 }
233 }
234
235 pub fn register(&mut self, spec: ToolSpec) {
237 self.tools.insert(spec.name.clone(), spec);
238 }
239
240 #[tracing::instrument(skip_all, fields(tool_name = %name))]
248 pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
249 let spec = self
250 .tools
251 .get(name)
252 .ok_or_else(|| AgentRuntimeError::AgentLoop(format!("tool '{name}' not found")))?;
253
254 if !spec.required_fields.is_empty() {
256 if let Some(obj) = args.as_object() {
257 for field in &spec.required_fields {
258 if !obj.contains_key(field) {
259 return Err(AgentRuntimeError::AgentLoop(format!(
260 "tool '{}' missing required field '{}'",
261 name, field
262 )));
263 }
264 }
265 } else {
266 return Err(AgentRuntimeError::AgentLoop(format!(
267 "tool '{}' requires JSON object args, got {}",
268 name, args
269 )));
270 }
271 }
272
273 #[cfg(feature = "orchestrator")]
275 if let Some(ref cb) = spec.circuit_breaker {
276 use crate::orchestrator::CircuitState;
277 if let Ok(CircuitState::Open { .. }) = cb.state() {
278 return Err(AgentRuntimeError::CircuitOpen {
279 service: format!("tool:{}", name),
280 });
281 }
282 }
283
284 let result = spec.call(args).await;
285 Ok(result)
286 }
287
288 pub fn tool_names(&self) -> Vec<&str> {
290 self.tools.keys().map(|s| s.as_str()).collect()
291 }
292}
293
294pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
301 let mut thought = String::new();
302 let mut action = String::new();
303
304 for line in text.lines() {
305 let trimmed = line.trim();
306 let lower = trimmed.to_ascii_lowercase();
307 if lower.starts_with("thought") {
308 if let Some(colon_pos) = trimmed.find(':') {
309 thought = trimmed[colon_pos + 1..].trim().to_owned();
310 }
311 } else if lower.starts_with("action") {
312 if let Some(colon_pos) = trimmed.find(':') {
313 action = trimmed[colon_pos + 1..].trim().to_owned();
314 }
315 }
316 }
317
318 if thought.is_empty() && action.is_empty() {
319 return Err(AgentRuntimeError::AgentLoop(
320 "could not parse ReAct step from response".into(),
321 ));
322 }
323
324 Ok(ReActStep {
325 thought,
326 action,
327 observation: String::new(),
328 })
329}
330
331#[derive(Debug)]
333pub struct ReActLoop {
334 config: AgentConfig,
335 registry: ToolRegistry,
336}
337
338impl ReActLoop {
339 pub fn new(config: AgentConfig) -> Self {
341 Self {
342 config,
343 registry: ToolRegistry::new(),
344 }
345 }
346
347 pub fn register_tool(&mut self, spec: ToolSpec) {
349 self.registry.register(spec);
350 }
351
352 #[tracing::instrument(skip(infer))]
363 pub async fn run<F, Fut>(
364 &self,
365 prompt: &str,
366 mut infer: F,
367 ) -> Result<Vec<ReActStep>, AgentRuntimeError>
368 where
369 F: FnMut(String) -> Fut,
370 Fut: Future<Output = String>,
371 {
372 let mut steps: Vec<ReActStep> = Vec::new();
373 let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
374
375 for iteration in 0..self.config.max_iterations {
376 let response = infer(context.clone()).await;
377 let mut step = parse_react_step(&response)?;
378
379 tracing::debug!(
380 step = iteration,
381 thought = %step.thought,
382 action = %step.action,
383 "ReAct iteration"
384 );
385
386 if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
387 step.observation = step.action.clone();
388 steps.push(step);
389 tracing::info!(step = iteration, "FINAL_ANSWER reached");
390 return Ok(steps);
391 }
392
393 let (tool_name, args) = parse_tool_call(&step.action);
394
395 tracing::debug!(
396 step = iteration,
397 tool_name = %tool_name,
398 "dispatching tool call"
399 );
400
401 let observation = match self.registry.call(&tool_name, args).await {
403 Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
404 Err(e) => {
405 let kind = match &e {
406 AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => {
407 "not_found"
408 }
409 #[cfg(feature = "orchestrator")]
410 AgentRuntimeError::CircuitOpen { .. } => "transient",
411 _ => "permanent",
412 };
413 serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind })
414 .to_string()
415 }
416 };
417
418 step.observation = observation.clone();
419 context.push_str(&format!(
420 "\nThought: {}\nAction: {}\nObservation: {}\n",
421 step.thought, step.action, observation
422 ));
423 steps.push(step);
424 }
425
426 let err = AgentRuntimeError::AgentLoop(format!(
427 "max iterations ({}) reached without final answer",
428 self.config.max_iterations
429 ));
430 tracing::warn!(
431 max_iterations = self.config.max_iterations,
432 "ReAct loop exhausted max iterations without FINAL_ANSWER"
433 );
434 Err(err)
435 }
436}
437
438fn parse_tool_call(action: &str) -> (String, Value) {
440 let mut parts = action.splitn(2, ' ');
441 let name = parts.next().unwrap_or("").to_owned();
442 let args_str = parts.next().unwrap_or("{}");
443 let args: Value = serde_json::from_str(args_str).unwrap_or(Value::String(args_str.to_owned()));
444 (name, args)
445}
446
447#[derive(Debug, thiserror::Error)]
451pub enum AgentError {
452 #[error("Tool '{0}' not found")]
454 ToolNotFound(String),
455 #[error("Max iterations exceeded: {0}")]
457 MaxIterations(usize),
458 #[error("Parse error: {0}")]
460 ParseError(String),
461}
462
463impl From<AgentError> for AgentRuntimeError {
464 fn from(e: AgentError) -> Self {
465 AgentRuntimeError::AgentLoop(e.to_string())
466 }
467}
468
469#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[tokio::test]
476 async fn test_final_answer_on_first_step() {
477 let config = AgentConfig::new(5, "test-model");
478 let loop_ = ReActLoop::new(config);
479
480 let steps = loop_
481 .run("Say hello", |_ctx| async {
482 "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
483 })
484 .await
485 .unwrap();
486
487 assert_eq!(steps.len(), 1);
488 assert!(steps[0].action.to_ascii_uppercase().starts_with("FINAL_ANSWER"));
489 }
490
491 #[tokio::test]
492 async fn test_tool_call_then_final_answer() {
493 let config = AgentConfig::new(5, "test-model");
494 let mut loop_ = ReActLoop::new(config);
495
496 loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
497 serde_json::json!("hello!")
498 }));
499
500 let mut call_count = 0;
501 let steps = loop_
502 .run("Say hello", |_ctx| {
503 call_count += 1;
504 let count = call_count;
505 async move {
506 if count == 1 {
507 "Thought: I will greet\nAction: greet {}".to_string()
508 } else {
509 "Thought: done\nAction: FINAL_ANSWER done".to_string()
510 }
511 }
512 })
513 .await
514 .unwrap();
515
516 assert_eq!(steps.len(), 2);
517 assert_eq!(steps[0].action, "greet {}");
518 assert!(steps[1].action.to_ascii_uppercase().starts_with("FINAL_ANSWER"));
519 }
520
521 #[tokio::test]
522 async fn test_max_iterations_exceeded() {
523 let config = AgentConfig::new(2, "test-model");
524 let loop_ = ReActLoop::new(config);
525
526 let result = loop_
527 .run("loop forever", |_ctx| async {
528 "Thought: thinking\nAction: noop {}".to_string()
529 })
530 .await;
531
532 assert!(result.is_err());
533 let err = result.unwrap_err().to_string();
534 assert!(err.contains("max iterations"));
535 }
536
537 #[tokio::test]
538 async fn test_parse_react_step_valid() {
539 let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
540 let step = parse_react_step(text).unwrap();
541 assert_eq!(step.thought, "I should check");
542 assert_eq!(step.action, "lookup {\"key\":\"val\"}");
543 }
544
545 #[tokio::test]
546 async fn test_parse_react_step_empty_fails() {
547 let result = parse_react_step("no prefix lines here");
548 assert!(result.is_err());
549 }
550
551 #[tokio::test]
552 async fn test_tool_not_found_returns_error_observation() {
553 let config = AgentConfig::new(3, "test-model");
554 let loop_ = ReActLoop::new(config);
555
556 let mut call_count = 0;
557 let steps = loop_
558 .run("test", |_ctx| {
559 call_count += 1;
560 let count = call_count;
561 async move {
562 if count == 1 {
563 "Thought: try missing tool\nAction: missing_tool {}".to_string()
564 } else {
565 "Thought: done\nAction: FINAL_ANSWER done".to_string()
566 }
567 }
568 })
569 .await
570 .unwrap();
571
572 assert_eq!(steps.len(), 2);
573 assert!(steps[0].observation.contains("\"ok\":false"));
574 }
575
576 #[tokio::test]
577 async fn test_new_async_tool_spec() {
578 let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
579 Box::pin(async move { serde_json::json!({"echo": args}) })
580 });
581
582 let result = spec.call(serde_json::json!({"input": "test"})).await;
583 assert!(result.get("echo").is_some());
584 }
585
586 #[tokio::test]
589 async fn test_parse_react_step_case_insensitive() {
590 let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
591 let step = parse_react_step(text).unwrap();
592 assert_eq!(step.thought, "done");
593 assert_eq!(step.action, "FINAL_ANSWER");
594 }
595
596 #[tokio::test]
597 async fn test_parse_react_step_space_before_colon() {
598 let text = "Thought : done\nAction : go";
599 let step = parse_react_step(text).unwrap();
600 assert_eq!(step.thought, "done");
601 assert_eq!(step.action, "go");
602 }
603
604 #[tokio::test]
607 async fn test_tool_required_fields_missing_returns_error() {
608 let config = AgentConfig::new(3, "test-model");
609 let mut loop_ = ReActLoop::new(config);
610
611 loop_.register_tool(
612 ToolSpec::new("search", "Searches for something", |args| {
613 serde_json::json!({ "result": args })
614 })
615 .with_required_fields(vec!["q".to_string()]),
616 );
617
618 let mut call_count = 0;
619 let steps = loop_
620 .run("test", |_ctx| {
621 call_count += 1;
622 let count = call_count;
623 async move {
624 if count == 1 {
625 "Thought: searching\nAction: search {}".to_string()
627 } else {
628 "Thought: done\nAction: FINAL_ANSWER done".to_string()
629 }
630 }
631 })
632 .await
633 .unwrap();
634
635 assert_eq!(steps.len(), 2);
636 assert!(
637 steps[0].observation.contains("missing required field"),
638 "observation was: {}",
639 steps[0].observation
640 );
641 }
642
643 #[tokio::test]
646 async fn test_tool_error_observation_includes_kind() {
647 let config = AgentConfig::new(3, "test-model");
648 let loop_ = ReActLoop::new(config);
649
650 let mut call_count = 0;
651 let steps = loop_
652 .run("test", |_ctx| {
653 call_count += 1;
654 let count = call_count;
655 async move {
656 if count == 1 {
657 "Thought: try missing\nAction: nonexistent_tool {}".to_string()
658 } else {
659 "Thought: done\nAction: FINAL_ANSWER done".to_string()
660 }
661 }
662 })
663 .await
664 .unwrap();
665
666 assert_eq!(steps.len(), 2);
667 let obs = &steps[0].observation;
668 assert!(obs.contains("\"ok\":false"), "observation: {obs}");
669 assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
670 }
671
672 #[cfg(feature = "orchestrator")]
675 #[tokio::test]
676 async fn test_tool_with_circuit_breaker_passes_when_closed() {
677 use std::sync::Arc;
678
679 let cb = Arc::new(crate::orchestrator::CircuitBreaker::new(
680 "echo-tool",
681 5,
682 std::time::Duration::from_secs(30),
683 ).unwrap());
684
685 let spec = ToolSpec::new("echo", "Echoes args", |args| {
686 serde_json::json!({ "echoed": args })
687 })
688 .with_circuit_breaker(cb);
689
690 let registry = {
691 let mut r = ToolRegistry::new();
692 r.register(spec);
693 r
694 };
695
696 let result = registry
697 .call("echo", serde_json::json!({ "msg": "hi" }))
698 .await;
699 assert!(result.is_ok(), "expected Ok, got {:?}", result);
700 }
701}