1use crate::error::AgentRuntimeError;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashMap;
21use std::future::Future;
22use std::pin::Pin;
23
24#[cfg(feature = "orchestrator")]
25use std::sync::Arc;
26
27pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
31
32pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum Role {
38 System,
40 User,
42 Assistant,
44 Tool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Message {
51 pub role: Role,
53 pub content: String,
55}
56
57impl Message {
58 pub fn new(role: Role, content: impl Into<String>) -> Self {
64 Self {
65 role,
66 content: content.into(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ReActStep {
74 pub thought: String,
76 pub action: String,
78 pub observation: String,
80}
81
82#[derive(Debug, Clone)]
84pub struct AgentConfig {
85 pub max_iterations: usize,
87 pub model: String,
89 pub system_prompt: String,
91 pub max_memory_recalls: usize,
94 pub max_memory_tokens: Option<usize>,
97}
98
99impl AgentConfig {
100 pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
102 Self {
103 max_iterations,
104 model: model.into(),
105 system_prompt: "You are a helpful AI agent.".into(),
106 max_memory_recalls: 3,
107 max_memory_tokens: None,
108 }
109 }
110
111 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
113 self.system_prompt = prompt.into();
114 self
115 }
116
117 pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
119 self.max_memory_recalls = n;
120 self
121 }
122
123 pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
125 self.max_memory_tokens = Some(n);
126 self
127 }
128}
129
130pub struct ToolSpec {
134 pub name: String,
136 pub description: String,
138 pub handler: AsyncToolHandler,
140 pub required_fields: Vec<String>,
143 #[cfg(feature = "orchestrator")]
145 pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
146}
147
148impl std::fmt::Debug for ToolSpec {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 let mut s = f.debug_struct("ToolSpec");
151 s.field("name", &self.name)
152 .field("description", &self.description)
153 .field("required_fields", &self.required_fields);
154 #[cfg(feature = "orchestrator")]
155 s.field("has_circuit_breaker", &self.circuit_breaker.is_some());
156 s.finish()
157 }
158}
159
160impl ToolSpec {
161 pub fn new(
164 name: impl Into<String>,
165 description: impl Into<String>,
166 handler: impl Fn(Value) -> Value + Send + Sync + 'static,
167 ) -> Self {
168 Self {
169 name: name.into(),
170 description: description.into(),
171 handler: Box::new(move |args| {
172 let result = handler(args);
173 Box::pin(async move { result })
174 }),
175 required_fields: Vec::new(),
176 #[cfg(feature = "orchestrator")]
177 circuit_breaker: None,
178 }
179 }
180
181 pub fn new_async(
183 name: impl Into<String>,
184 description: impl Into<String>,
185 handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
186 ) -> Self {
187 Self {
188 name: name.into(),
189 description: description.into(),
190 handler: Box::new(handler),
191 required_fields: Vec::new(),
192 #[cfg(feature = "orchestrator")]
193 circuit_breaker: None,
194 }
195 }
196
197 pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
199 self.required_fields = fields;
200 self
201 }
202
203 #[cfg(feature = "orchestrator")]
205 pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
206 self.circuit_breaker = Some(cb);
207 self
208 }
209
210 pub async fn call(&self, args: Value) -> Value {
212 (self.handler)(args).await
213 }
214}
215
216#[derive(Debug, Default)]
220pub struct ToolRegistry {
221 tools: HashMap<String, ToolSpec>,
222}
223
224impl ToolRegistry {
225 pub fn new() -> Self {
227 Self {
228 tools: HashMap::new(),
229 }
230 }
231
232 pub fn register(&mut self, spec: ToolSpec) {
234 self.tools.insert(spec.name.clone(), spec);
235 }
236
237 #[tracing::instrument(skip_all, fields(tool_name = %name))]
245 pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
246 let spec = self
247 .tools
248 .get(name)
249 .ok_or_else(|| AgentRuntimeError::AgentLoop(format!("tool '{name}' not found")))?;
250
251 if !spec.required_fields.is_empty() {
253 if let Some(obj) = args.as_object() {
254 for field in &spec.required_fields {
255 if !obj.contains_key(field) {
256 return Err(AgentRuntimeError::AgentLoop(format!(
257 "tool '{}' missing required field '{}'",
258 name, field
259 )));
260 }
261 }
262 } else {
263 return Err(AgentRuntimeError::AgentLoop(format!(
264 "tool '{}' requires JSON object args, got {}",
265 name, args
266 )));
267 }
268 }
269
270 #[cfg(feature = "orchestrator")]
272 if let Some(ref cb) = spec.circuit_breaker {
273 use crate::orchestrator::CircuitState;
274 if let Ok(CircuitState::Open { .. }) = cb.state() {
275 return Err(AgentRuntimeError::CircuitOpen {
276 service: format!("tool:{}", name),
277 });
278 }
279 }
280
281 let result = spec.call(args).await;
282 Ok(result)
283 }
284
285 pub fn tool_names(&self) -> Vec<&str> {
287 self.tools.keys().map(|s| s.as_str()).collect()
288 }
289}
290
291pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
298 let mut thought = String::new();
299 let mut action = String::new();
300
301 for line in text.lines() {
302 let trimmed = line.trim();
303 let lower = trimmed.to_ascii_lowercase();
304 if lower.starts_with("thought") {
305 if let Some(colon_pos) = trimmed.find(':') {
306 thought = trimmed[colon_pos + 1..].trim().to_owned();
307 }
308 } else if lower.starts_with("action") {
309 if let Some(colon_pos) = trimmed.find(':') {
310 action = trimmed[colon_pos + 1..].trim().to_owned();
311 }
312 }
313 }
314
315 if thought.is_empty() && action.is_empty() {
316 return Err(AgentRuntimeError::AgentLoop(
317 "could not parse ReAct step from response".into(),
318 ));
319 }
320
321 Ok(ReActStep {
322 thought,
323 action,
324 observation: String::new(),
325 })
326}
327
328#[derive(Debug)]
330pub struct ReActLoop {
331 config: AgentConfig,
332 registry: ToolRegistry,
333}
334
335impl ReActLoop {
336 pub fn new(config: AgentConfig) -> Self {
338 Self {
339 config,
340 registry: ToolRegistry::new(),
341 }
342 }
343
344 pub fn register_tool(&mut self, spec: ToolSpec) {
346 self.registry.register(spec);
347 }
348
349 #[tracing::instrument(skip(infer))]
360 pub async fn run<F, Fut>(
361 &self,
362 prompt: &str,
363 mut infer: F,
364 ) -> Result<Vec<ReActStep>, AgentRuntimeError>
365 where
366 F: FnMut(String) -> Fut,
367 Fut: Future<Output = String>,
368 {
369 let mut steps: Vec<ReActStep> = Vec::new();
370 let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
371
372 for iteration in 0..self.config.max_iterations {
373 let response = infer(context.clone()).await;
374 let mut step = parse_react_step(&response)?;
375
376 tracing::debug!(
377 step = iteration,
378 thought = %step.thought,
379 action = %step.action,
380 "ReAct iteration"
381 );
382
383 if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
384 step.observation = step.action.clone();
385 steps.push(step);
386 tracing::info!(step = iteration, "FINAL_ANSWER reached");
387 return Ok(steps);
388 }
389
390 let (tool_name, args) = parse_tool_call(&step.action);
391
392 tracing::debug!(
393 step = iteration,
394 tool_name = %tool_name,
395 "dispatching tool call"
396 );
397
398 let observation = match self.registry.call(&tool_name, args).await {
400 Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
401 Err(e) => {
402 let kind = match &e {
403 AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => {
404 "not_found"
405 }
406 #[cfg(feature = "orchestrator")]
407 AgentRuntimeError::CircuitOpen { .. } => "transient",
408 _ => "permanent",
409 };
410 serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind })
411 .to_string()
412 }
413 };
414
415 step.observation = observation.clone();
416 context.push_str(&format!(
417 "\nThought: {}\nAction: {}\nObservation: {}\n",
418 step.thought, step.action, observation
419 ));
420 steps.push(step);
421 }
422
423 let err = AgentRuntimeError::AgentLoop(format!(
424 "max iterations ({}) reached without final answer",
425 self.config.max_iterations
426 ));
427 tracing::warn!(
428 max_iterations = self.config.max_iterations,
429 "ReAct loop exhausted max iterations without FINAL_ANSWER"
430 );
431 Err(err)
432 }
433}
434
435fn parse_tool_call(action: &str) -> (String, Value) {
437 let mut parts = action.splitn(2, ' ');
438 let name = parts.next().unwrap_or("").to_owned();
439 let args_str = parts.next().unwrap_or("{}");
440 let args: Value = serde_json::from_str(args_str).unwrap_or(Value::String(args_str.to_owned()));
441 (name, args)
442}
443
444#[derive(Debug, thiserror::Error)]
448pub enum AgentError {
449 #[error("Tool '{0}' not found")]
451 ToolNotFound(String),
452 #[error("Max iterations exceeded: {0}")]
454 MaxIterations(usize),
455 #[error("Parse error: {0}")]
457 ParseError(String),
458}
459
460impl From<AgentError> for AgentRuntimeError {
461 fn from(e: AgentError) -> Self {
462 AgentRuntimeError::AgentLoop(e.to_string())
463 }
464}
465
466#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[tokio::test]
473 async fn test_final_answer_on_first_step() {
474 let config = AgentConfig::new(5, "test-model");
475 let loop_ = ReActLoop::new(config);
476
477 let steps = loop_
478 .run("Say hello", |_ctx| async {
479 "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
480 })
481 .await
482 .unwrap();
483
484 assert_eq!(steps.len(), 1);
485 assert!(steps[0]
486 .action
487 .to_ascii_uppercase()
488 .starts_with("FINAL_ANSWER"));
489 }
490
491 #[tokio::test]
492 async fn test_tool_call_then_final_answer() {
493 let config = AgentConfig::new(5, "test-model");
494 let mut loop_ = ReActLoop::new(config);
495
496 loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
497 serde_json::json!("hello!")
498 }));
499
500 let mut call_count = 0;
501 let steps = loop_
502 .run("Say hello", |_ctx| {
503 call_count += 1;
504 let count = call_count;
505 async move {
506 if count == 1 {
507 "Thought: I will greet\nAction: greet {}".to_string()
508 } else {
509 "Thought: done\nAction: FINAL_ANSWER done".to_string()
510 }
511 }
512 })
513 .await
514 .unwrap();
515
516 assert_eq!(steps.len(), 2);
517 assert_eq!(steps[0].action, "greet {}");
518 assert!(steps[1]
519 .action
520 .to_ascii_uppercase()
521 .starts_with("FINAL_ANSWER"));
522 }
523
524 #[tokio::test]
525 async fn test_max_iterations_exceeded() {
526 let config = AgentConfig::new(2, "test-model");
527 let loop_ = ReActLoop::new(config);
528
529 let result = loop_
530 .run("loop forever", |_ctx| async {
531 "Thought: thinking\nAction: noop {}".to_string()
532 })
533 .await;
534
535 assert!(result.is_err());
536 let err = result.unwrap_err().to_string();
537 assert!(err.contains("max iterations"));
538 }
539
540 #[tokio::test]
541 async fn test_parse_react_step_valid() {
542 let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
543 let step = parse_react_step(text).unwrap();
544 assert_eq!(step.thought, "I should check");
545 assert_eq!(step.action, "lookup {\"key\":\"val\"}");
546 }
547
548 #[tokio::test]
549 async fn test_parse_react_step_empty_fails() {
550 let result = parse_react_step("no prefix lines here");
551 assert!(result.is_err());
552 }
553
554 #[tokio::test]
555 async fn test_tool_not_found_returns_error_observation() {
556 let config = AgentConfig::new(3, "test-model");
557 let loop_ = ReActLoop::new(config);
558
559 let mut call_count = 0;
560 let steps = loop_
561 .run("test", |_ctx| {
562 call_count += 1;
563 let count = call_count;
564 async move {
565 if count == 1 {
566 "Thought: try missing tool\nAction: missing_tool {}".to_string()
567 } else {
568 "Thought: done\nAction: FINAL_ANSWER done".to_string()
569 }
570 }
571 })
572 .await
573 .unwrap();
574
575 assert_eq!(steps.len(), 2);
576 assert!(steps[0].observation.contains("\"ok\":false"));
577 }
578
579 #[tokio::test]
580 async fn test_new_async_tool_spec() {
581 let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
582 Box::pin(async move { serde_json::json!({"echo": args}) })
583 });
584
585 let result = spec.call(serde_json::json!({"input": "test"})).await;
586 assert!(result.get("echo").is_some());
587 }
588
589 #[tokio::test]
592 async fn test_parse_react_step_case_insensitive() {
593 let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
594 let step = parse_react_step(text).unwrap();
595 assert_eq!(step.thought, "done");
596 assert_eq!(step.action, "FINAL_ANSWER");
597 }
598
599 #[tokio::test]
600 async fn test_parse_react_step_space_before_colon() {
601 let text = "Thought : done\nAction : go";
602 let step = parse_react_step(text).unwrap();
603 assert_eq!(step.thought, "done");
604 assert_eq!(step.action, "go");
605 }
606
607 #[tokio::test]
610 async fn test_tool_required_fields_missing_returns_error() {
611 let config = AgentConfig::new(3, "test-model");
612 let mut loop_ = ReActLoop::new(config);
613
614 loop_.register_tool(
615 ToolSpec::new(
616 "search",
617 "Searches for something",
618 |args| serde_json::json!({ "result": args }),
619 )
620 .with_required_fields(vec!["q".to_string()]),
621 );
622
623 let mut call_count = 0;
624 let steps = loop_
625 .run("test", |_ctx| {
626 call_count += 1;
627 let count = call_count;
628 async move {
629 if count == 1 {
630 "Thought: searching\nAction: search {}".to_string()
632 } else {
633 "Thought: done\nAction: FINAL_ANSWER done".to_string()
634 }
635 }
636 })
637 .await
638 .unwrap();
639
640 assert_eq!(steps.len(), 2);
641 assert!(
642 steps[0].observation.contains("missing required field"),
643 "observation was: {}",
644 steps[0].observation
645 );
646 }
647
648 #[tokio::test]
651 async fn test_tool_error_observation_includes_kind() {
652 let config = AgentConfig::new(3, "test-model");
653 let loop_ = ReActLoop::new(config);
654
655 let mut call_count = 0;
656 let steps = loop_
657 .run("test", |_ctx| {
658 call_count += 1;
659 let count = call_count;
660 async move {
661 if count == 1 {
662 "Thought: try missing\nAction: nonexistent_tool {}".to_string()
663 } else {
664 "Thought: done\nAction: FINAL_ANSWER done".to_string()
665 }
666 }
667 })
668 .await
669 .unwrap();
670
671 assert_eq!(steps.len(), 2);
672 let obs = &steps[0].observation;
673 assert!(obs.contains("\"ok\":false"), "observation: {obs}");
674 assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
675 }
676
677 #[cfg(feature = "orchestrator")]
680 #[tokio::test]
681 async fn test_tool_with_circuit_breaker_passes_when_closed() {
682 use std::sync::Arc;
683
684 let cb = Arc::new(
685 crate::orchestrator::CircuitBreaker::new(
686 "echo-tool",
687 5,
688 std::time::Duration::from_secs(30),
689 )
690 .unwrap(),
691 );
692
693 let spec = ToolSpec::new(
694 "echo",
695 "Echoes args",
696 |args| serde_json::json!({ "echoed": args }),
697 )
698 .with_circuit_breaker(cb);
699
700 let registry = {
701 let mut r = ToolRegistry::new();
702 r.register(spec);
703 r
704 };
705
706 let result = registry
707 .call("echo", serde_json::json!({ "msg": "hi" }))
708 .await;
709 assert!(result.is_ok(), "expected Ok, got {:?}", result);
710 }
711}