1use crate::error::AgentRuntimeError;
18use crate::metrics::RuntimeMetrics;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::collections::HashMap;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25
26pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
30
31pub type AsyncToolResultFuture = Pin<Box<dyn Future<Output = Result<Value, String>> + Send>>;
33
34pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
39pub enum Role {
40 System,
42 User,
44 Assistant,
46 Tool,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct Message {
53 pub role: Role,
55 pub content: String,
57}
58
59impl Message {
60 pub fn new(role: Role, content: impl Into<String>) -> Self {
66 Self {
67 role,
68 content: content.into(),
69 }
70 }
71
72 pub fn role(&self) -> &Role {
74 &self.role
75 }
76
77 pub fn content(&self) -> &str {
79 &self.content
80 }
81}
82
83impl std::fmt::Display for Role {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Role::System => write!(f, "system"),
87 Role::User => write!(f, "user"),
88 Role::Assistant => write!(f, "assistant"),
89 Role::Tool => write!(f, "tool"),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ReActStep {
97 pub thought: String,
99 pub action: String,
101 pub observation: String,
103 #[serde(default)]
108 pub step_duration_ms: u64,
109}
110
111impl ReActStep {
112 pub fn is_final_answer(&self) -> bool {
114 self.action.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER")
115 }
116
117 pub fn is_tool_call(&self) -> bool {
119 !self.is_final_answer() && !self.action.trim().is_empty()
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct AgentConfig {
126 pub max_iterations: usize,
128 pub model: String,
130 pub system_prompt: String,
132 pub max_memory_recalls: usize,
135 pub max_memory_tokens: Option<usize>,
138 pub loop_timeout: Option<std::time::Duration>,
142 pub temperature: Option<f32>,
144 pub max_tokens: Option<usize>,
146 pub request_timeout: Option<std::time::Duration>,
148}
149
150impl AgentConfig {
151 pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
153 Self {
154 max_iterations,
155 model: model.into(),
156 system_prompt: "You are a helpful AI agent.".into(),
157 max_memory_recalls: 3,
158 max_memory_tokens: None,
159 loop_timeout: None,
160 temperature: None,
161 max_tokens: None,
162 request_timeout: None,
163 }
164 }
165
166 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
168 self.system_prompt = prompt.into();
169 self
170 }
171
172 pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
174 self.max_memory_recalls = n;
175 self
176 }
177
178 pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
180 self.max_memory_tokens = Some(n);
181 self
182 }
183
184 pub fn with_loop_timeout(mut self, d: std::time::Duration) -> Self {
189 self.loop_timeout = Some(d);
190 self
191 }
192
193 pub fn with_temperature(mut self, t: f32) -> Self {
195 self.temperature = Some(t);
196 self
197 }
198
199 pub fn with_max_tokens(mut self, n: usize) -> Self {
201 self.max_tokens = Some(n);
202 self
203 }
204
205 pub fn with_request_timeout(mut self, d: std::time::Duration) -> Self {
207 self.request_timeout = Some(d);
208 self
209 }
210}
211
212pub struct ToolSpec {
216 pub name: String,
218 pub description: String,
220 pub(crate) handler: AsyncToolHandler,
222 pub required_fields: Vec<String>,
225 pub validators: Vec<Box<dyn ToolValidator>>,
228 #[cfg(feature = "orchestrator")]
230 pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
231}
232
233impl std::fmt::Debug for ToolSpec {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 let mut s = f.debug_struct("ToolSpec");
236 s.field("name", &self.name)
237 .field("description", &self.description)
238 .field("required_fields", &self.required_fields);
239 #[cfg(feature = "orchestrator")]
240 s.field("has_circuit_breaker", &self.circuit_breaker.is_some());
241 s.finish()
242 }
243}
244
245impl ToolSpec {
246 pub fn new(
249 name: impl Into<String>,
250 description: impl Into<String>,
251 handler: impl Fn(Value) -> Value + Send + Sync + 'static,
252 ) -> Self {
253 Self {
254 name: name.into(),
255 description: description.into(),
256 handler: Box::new(move |args| {
257 let result = handler(args);
258 Box::pin(async move { result })
259 }),
260 required_fields: Vec::new(),
261 validators: Vec::new(),
262 #[cfg(feature = "orchestrator")]
263 circuit_breaker: None,
264 }
265 }
266
267 pub fn new_async(
269 name: impl Into<String>,
270 description: impl Into<String>,
271 handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
272 ) -> Self {
273 Self {
274 name: name.into(),
275 description: description.into(),
276 handler: Box::new(handler),
277 required_fields: Vec::new(),
278 validators: Vec::new(),
279 #[cfg(feature = "orchestrator")]
280 circuit_breaker: None,
281 }
282 }
283
284 pub fn new_fallible(
287 name: impl Into<String>,
288 description: impl Into<String>,
289 handler: impl Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
290 ) -> Self {
291 Self {
292 name: name.into(),
293 description: description.into(),
294 handler: Box::new(move |args| {
295 let result = handler(args);
296 let value = match result {
297 Ok(v) => v,
298 Err(msg) => serde_json::json!({"error": msg, "ok": false}),
299 };
300 Box::pin(async move { value })
301 }),
302 required_fields: Vec::new(),
303 validators: Vec::new(),
304 #[cfg(feature = "orchestrator")]
305 circuit_breaker: None,
306 }
307 }
308
309 pub fn new_async_fallible(
312 name: impl Into<String>,
313 description: impl Into<String>,
314 handler: impl Fn(Value) -> AsyncToolResultFuture + Send + Sync + 'static,
315 ) -> Self {
316 Self {
317 name: name.into(),
318 description: description.into(),
319 handler: Box::new(move |args| {
320 let fut = handler(args);
321 Box::pin(async move {
322 match fut.await {
323 Ok(v) => v,
324 Err(msg) => serde_json::json!({"error": msg, "ok": false}),
325 }
326 })
327 }),
328 required_fields: Vec::new(),
329 validators: Vec::new(),
330 #[cfg(feature = "orchestrator")]
331 circuit_breaker: None,
332 }
333 }
334
335 pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
337 self.required_fields = fields;
338 self
339 }
340
341 pub fn with_validators(mut self, validators: Vec<Box<dyn ToolValidator>>) -> Self {
346 self.validators = validators;
347 self
348 }
349
350 #[cfg(feature = "orchestrator")]
352 pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
353 self.circuit_breaker = Some(cb);
354 self
355 }
356
357 pub async fn call(&self, args: Value) -> Value {
359 (self.handler)(args).await
360 }
361}
362
363pub trait ToolCache: Send + Sync {
389 fn get(&self, tool_name: &str, args: &serde_json::Value) -> Option<serde_json::Value>;
391 fn set(&self, tool_name: &str, args: &serde_json::Value, result: serde_json::Value);
393}
394
395pub struct ToolRegistry {
399 tools: HashMap<String, ToolSpec>,
400 cache: Option<Arc<dyn ToolCache>>,
402}
403
404impl std::fmt::Debug for ToolRegistry {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 f.debug_struct("ToolRegistry")
407 .field("tools", &self.tools.keys().collect::<Vec<_>>())
408 .field("has_cache", &self.cache.is_some())
409 .finish()
410 }
411}
412
413impl Default for ToolRegistry {
414 fn default() -> Self {
415 Self::new()
416 }
417}
418
419impl ToolRegistry {
420 pub fn new() -> Self {
422 Self {
423 tools: HashMap::new(),
424 cache: None,
425 }
426 }
427
428 pub fn with_cache(mut self, cache: Arc<dyn ToolCache>) -> Self {
430 self.cache = Some(cache);
431 self
432 }
433
434 pub fn register(&mut self, spec: ToolSpec) {
436 self.tools.insert(spec.name.clone(), spec);
437 }
438
439 pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
446 for spec in specs {
447 self.register(spec);
448 }
449 }
450
451 pub fn with_tool(mut self, spec: ToolSpec) -> Self {
462 self.register(spec);
463 self
464 }
465
466 #[tracing::instrument(skip_all, fields(tool_name = %name))]
474 pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
475 let spec = self.tools.get(name).ok_or_else(|| {
476 let mut suggestion = String::new();
477 let names = self.tool_names();
478 if !names.is_empty() {
479 if let Some((closest, dist)) = names
480 .iter()
481 .map(|n| (n, levenshtein(name, n)))
482 .min_by_key(|(_, d)| *d)
483 {
484 if dist <= 3 {
485 suggestion = format!(" (did you mean '{closest}'?)");
486 }
487 }
488 }
489 AgentRuntimeError::AgentLoop(format!("tool '{name}' not found{suggestion}"))
490 })?;
491
492 if !spec.required_fields.is_empty() {
494 if let Some(obj) = args.as_object() {
495 for field in &spec.required_fields {
496 if !obj.contains_key(field) {
497 return Err(AgentRuntimeError::AgentLoop(format!(
498 "tool '{}' missing required field '{}'",
499 name, field
500 )));
501 }
502 }
503 } else {
504 return Err(AgentRuntimeError::AgentLoop(format!(
505 "tool '{}' requires JSON object args, got {}",
506 name, args
507 )));
508 }
509 }
510
511 for validator in &spec.validators {
513 validator.validate(&args)?;
514 }
515
516 #[cfg(feature = "orchestrator")]
518 if let Some(ref cb) = spec.circuit_breaker {
519 use crate::orchestrator::CircuitState;
520 if let Ok(CircuitState::Open { .. }) = cb.state() {
521 return Err(AgentRuntimeError::CircuitOpen {
522 service: format!("tool:{}", name),
523 });
524 }
525 }
526
527 if let Some(ref cache) = self.cache {
529 if let Some(cached) = cache.get(name, &args) {
530 return Ok(cached);
531 }
532 }
533
534 let result = spec.call(args.clone()).await;
535
536 if let Some(ref cache) = self.cache {
538 cache.set(name, &args, result.clone());
539 }
540
541 Ok(result)
542 }
543
544 pub fn tool_names(&self) -> Vec<&str> {
546 self.tools.keys().map(|s| s.as_str()).collect()
547 }
548}
549
550pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
557 let mut thought = String::new();
558 let mut action = String::new();
559
560 for line in text.lines() {
561 let trimmed = line.trim();
562 let lower = trimmed.to_ascii_lowercase();
563 if lower.starts_with("thought") {
564 if let Some(colon_pos) = trimmed.find(':') {
565 thought = trimmed[colon_pos + 1..].trim().to_owned();
566 }
567 } else if lower.starts_with("action") {
568 if let Some(colon_pos) = trimmed.find(':') {
569 action = trimmed[colon_pos + 1..].trim().to_owned();
570 }
571 }
572 }
573
574 if thought.is_empty() && action.is_empty() {
575 return Err(AgentRuntimeError::AgentLoop(
576 "could not parse ReAct step from response".into(),
577 ));
578 }
579
580 Ok(ReActStep {
581 thought,
582 action,
583 observation: String::new(),
584 step_duration_ms: 0,
585 })
586}
587
588pub struct ReActLoop {
590 config: AgentConfig,
591 registry: ToolRegistry,
592 metrics: Option<Arc<RuntimeMetrics>>,
594 #[cfg(feature = "persistence")]
596 checkpoint_backend: Option<(Arc<dyn crate::persistence::PersistenceBackend>, String)>,
597 observer: Option<Arc<dyn Observer>>,
599 action_hook: Option<ActionHook>,
601}
602
603impl std::fmt::Debug for ReActLoop {
604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605 let mut s = f.debug_struct("ReActLoop");
606 s.field("config", &self.config)
607 .field("registry", &self.registry)
608 .field("has_metrics", &self.metrics.is_some())
609 .field("has_observer", &self.observer.is_some())
610 .field("has_action_hook", &self.action_hook.is_some());
611 #[cfg(feature = "persistence")]
612 s.field("has_checkpoint_backend", &self.checkpoint_backend.is_some());
613 s.finish()
614 }
615}
616
617impl ReActLoop {
618 pub fn new(config: AgentConfig) -> Self {
620 Self {
621 config,
622 registry: ToolRegistry::new(),
623 metrics: None,
624 #[cfg(feature = "persistence")]
625 checkpoint_backend: None,
626 observer: None,
627 action_hook: None,
628 }
629 }
630
631 pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
633 self.observer = Some(observer);
634 self
635 }
636
637 pub fn with_action_hook(mut self, hook: ActionHook) -> Self {
639 self.action_hook = Some(hook);
640 self
641 }
642
643 pub fn with_metrics(mut self, metrics: Arc<RuntimeMetrics>) -> Self {
648 self.metrics = Some(metrics);
649 self
650 }
651
652 #[cfg(feature = "persistence")]
658 pub fn with_step_checkpoint(
659 mut self,
660 backend: Arc<dyn crate::persistence::PersistenceBackend>,
661 session_id: impl Into<String>,
662 ) -> Self {
663 self.checkpoint_backend = Some((backend, session_id.into()));
664 self
665 }
666
667 pub fn register_tool(&mut self, spec: ToolSpec) {
669 self.registry.register(spec);
670 }
671
672 pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
678 for spec in specs {
679 self.registry.register(spec);
680 }
681 }
682
683 fn blocked_observation() -> String {
685 serde_json::json!({
686 "ok": false,
687 "error": "action blocked by reviewer",
688 "kind": "blocked"
689 })
690 .to_string()
691 }
692
693 fn error_observation(tool_name: &str, e: &AgentRuntimeError) -> String {
695 let _ = tool_name;
696 let kind = match e {
697 AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => "not_found",
698 #[cfg(feature = "orchestrator")]
699 AgentRuntimeError::CircuitOpen { .. } => "transient",
700 _ => "permanent",
701 };
702 serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind }).to_string()
703 }
704
705 #[tracing::instrument(skip(infer))]
719 pub async fn run<F, Fut>(
720 &self,
721 prompt: &str,
722 mut infer: F,
723 ) -> Result<Vec<ReActStep>, AgentRuntimeError>
724 where
725 F: FnMut(String) -> Fut,
726 Fut: Future<Output = String>,
727 {
728 let mut steps: Vec<ReActStep> = Vec::new();
729 let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
730
731 let deadline = self
733 .config
734 .loop_timeout
735 .map(|d| std::time::Instant::now() + d);
736
737 if let Some(ref obs) = self.observer {
739 obs.on_loop_start(prompt);
740 }
741
742 for iteration in 0..self.config.max_iterations {
743 if let Some(dl) = deadline {
745 if std::time::Instant::now() >= dl {
746 let ms = self
747 .config
748 .loop_timeout
749 .map(|d| d.as_millis())
750 .unwrap_or(0);
751 if let Some(ref obs) = self.observer {
752 obs.on_loop_end(steps.len());
753 }
754 return Err(AgentRuntimeError::AgentLoop(format!(
755 "loop timeout after {ms} ms"
756 )));
757 }
758 }
759
760 let step_start = std::time::Instant::now();
761 let response = infer(context.clone()).await;
762 let mut step = parse_react_step(&response)?;
763
764 tracing::debug!(
765 step = iteration,
766 thought = %step.thought,
767 action = %step.action,
768 "ReAct iteration"
769 );
770
771 if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
772 step.observation = step.action.clone();
773 step.step_duration_ms = step_start.elapsed().as_millis() as u64;
774 if let Some(ref m) = self.metrics {
775 m.record_step_latency(step.step_duration_ms);
776 }
777 if let Some(ref obs) = self.observer {
778 obs.on_step(iteration, &step);
779 }
780 steps.push(step);
781 tracing::info!(step = iteration, "FINAL_ANSWER reached");
782 if let Some(ref obs) = self.observer {
783 obs.on_loop_end(steps.len());
784 }
785 return Ok(steps);
786 }
787
788 let (tool_name, args) = parse_tool_call(&step.action)?;
790
791 tracing::debug!(
792 step = iteration,
793 tool_name = %tool_name,
794 "dispatching tool call"
795 );
796
797 if let Some(ref hook) = self.action_hook {
799 if !hook(tool_name.clone(), args.clone()).await {
800 if let Some(ref obs) = self.observer {
801 obs.on_action_blocked(&tool_name, &args);
802 }
803 if let Some(ref m) = self.metrics {
804 m.record_tool_call(&tool_name);
805 m.record_tool_failure(&tool_name);
806 }
807 step.observation = Self::blocked_observation();
808 step.step_duration_ms = step_start.elapsed().as_millis() as u64;
809 if let Some(ref m) = self.metrics {
810 m.record_step_latency(step.step_duration_ms);
811 }
812 context.push_str(&format!(
813 "\nThought: {}\nAction: {}\nObservation: {}\n",
814 step.thought, step.action, step.observation
815 ));
816 if let Some(ref obs) = self.observer {
817 obs.on_step(iteration, &step);
818 }
819 steps.push(step);
820 continue;
821 }
822 }
823
824 if let Some(ref obs) = self.observer {
826 obs.on_tool_call(&tool_name, &args);
827 }
828
829 if let Some(ref m) = self.metrics {
831 m.record_tool_call(&tool_name);
832 }
833
834 let observation = match self.registry.call(&tool_name, args).await {
836 Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
837 Err(e) => {
838 if let Some(ref m) = self.metrics {
840 m.record_tool_failure(&tool_name);
841 }
842 Self::error_observation(&tool_name, &e)
843 }
844 };
845
846 step.observation = observation.clone();
847 step.step_duration_ms = step_start.elapsed().as_millis() as u64;
848 if let Some(ref m) = self.metrics {
849 m.record_step_latency(step.step_duration_ms);
850 }
851 context.push_str(&format!(
852 "\nThought: {}\nAction: {}\nObservation: {}\n",
853 step.thought, step.action, observation
854 ));
855 if let Some(ref obs) = self.observer {
856 obs.on_step(iteration, &step);
857 }
858 steps.push(step);
859
860 #[cfg(feature = "persistence")]
862 if let Some((ref backend, ref session_id)) = self.checkpoint_backend {
863 let step_idx = steps.len();
864 let key = format!("loop:{session_id}:step:{step_idx}");
865 match serde_json::to_vec(&steps) {
866 Ok(bytes) => {
867 if let Err(e) = backend.save(&key, &bytes).await {
868 tracing::warn!(
869 key = %key,
870 error = %e,
871 "loop step checkpoint save failed"
872 );
873 }
874 }
875 Err(e) => {
876 tracing::warn!(
877 step = step_idx,
878 error = %e,
879 "loop step checkpoint serialisation failed"
880 );
881 }
882 }
883 }
884 }
885
886 let err = AgentRuntimeError::AgentLoop(format!(
887 "max iterations ({}) reached without final answer",
888 self.config.max_iterations
889 ));
890 tracing::warn!(
891 max_iterations = self.config.max_iterations,
892 "ReAct loop exhausted max iterations without FINAL_ANSWER"
893 );
894 if let Some(ref obs) = self.observer {
895 obs.on_loop_end(steps.len());
896 }
897 Err(err)
898 }
899
900 #[tracing::instrument(skip(infer_stream))]
910 pub async fn run_streaming<F, Fut>(
911 &self,
912 prompt: &str,
913 mut infer_stream: F,
914 ) -> Result<Vec<ReActStep>, AgentRuntimeError>
915 where
916 F: FnMut(String) -> Fut,
917 Fut: Future<
918 Output = tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>,
919 >,
920 {
921 self.run(prompt, move |ctx| {
922 let rx_fut = infer_stream(ctx);
923 async move {
924 let mut rx = rx_fut.await;
925 let mut out = String::new();
926 while let Some(chunk) = rx.recv().await {
927 match chunk {
928 Ok(s) => out.push_str(&s),
929 Err(e) => {
930 tracing::warn!(error = %e, "streaming chunk error; skipping");
931 }
932 }
933 }
934 out
935 }
936 })
937 .await
938 }
939}
940
941pub trait ToolValidator: Send + Sync {
1007 fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError>;
1012}
1013
1014fn levenshtein(a: &str, b: &str) -> usize {
1018 let a: Vec<char> = a.chars().collect();
1019 let b: Vec<char> = b.chars().collect();
1020 let (m, n) = (a.len(), b.len());
1021 let mut dp = vec![vec![0usize; n + 1]; m + 1];
1022 for i in 0..=m {
1023 dp[i][0] = i;
1024 }
1025 for j in 0..=n {
1026 dp[0][j] = j;
1027 }
1028 for i in 1..=m {
1029 for j in 1..=n {
1030 dp[i][j] = if a[i - 1] == b[j - 1] {
1031 dp[i - 1][j - 1]
1032 } else {
1033 1 + dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1])
1034 };
1035 }
1036 }
1037 dp[m][n]
1038}
1039
1040fn parse_tool_call(action: &str) -> Result<(String, Value), AgentRuntimeError> {
1046 let mut parts = action.splitn(2, ' ');
1047 let name = parts.next().unwrap_or("").to_owned();
1048 if name.is_empty() {
1049 return Err(AgentRuntimeError::AgentLoop(
1050 "tool call has an empty tool name".into(),
1051 ));
1052 }
1053 let args_str = parts.next().unwrap_or("{}");
1054 let args: Value = serde_json::from_str(args_str).map_err(|e| {
1055 AgentRuntimeError::AgentLoop(format!(
1056 "invalid JSON args for tool call '{name}': {e} (raw: {args_str})"
1057 ))
1058 })?;
1059 Ok((name, args))
1060}
1061
1062#[derive(Debug, thiserror::Error)]
1066pub enum AgentError {
1067 #[error("Tool '{0}' not found")]
1069 ToolNotFound(String),
1070 #[error("Max iterations exceeded: {0}")]
1072 MaxIterations(usize),
1073 #[error("Parse error: {0}")]
1075 ParseError(String),
1076}
1077
1078impl From<AgentError> for AgentRuntimeError {
1079 fn from(e: AgentError) -> Self {
1080 AgentRuntimeError::AgentLoop(e.to_string())
1081 }
1082}
1083
1084pub trait Observer: Send + Sync {
1091 fn on_step(&self, step_index: usize, step: &ReActStep) {
1093 let _ = (step_index, step);
1094 }
1095 fn on_tool_call(&self, tool_name: &str, args: &serde_json::Value) {
1097 let _ = (tool_name, args);
1098 }
1099 fn on_action_blocked(&self, tool_name: &str, args: &serde_json::Value) {
1104 let _ = (tool_name, args);
1105 }
1106 fn on_loop_start(&self, prompt: &str) {
1108 let _ = prompt;
1109 }
1110 fn on_loop_end(&self, step_count: usize) {
1112 let _ = step_count;
1113 }
1114}
1115
1116#[derive(Debug, Clone, PartialEq)]
1120pub enum Action {
1121 FinalAnswer(String),
1123 ToolCall {
1125 name: String,
1127 args: serde_json::Value,
1129 },
1130}
1131
1132impl Action {
1133 pub fn parse(s: &str) -> Result<Action, AgentRuntimeError> {
1138 if s.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER") {
1139 let answer = s.trim()["FINAL_ANSWER".len()..].trim().to_owned();
1140 return Ok(Action::FinalAnswer(answer));
1141 }
1142 let (name, args) = parse_tool_call(s)?;
1143 Ok(Action::ToolCall { name, args })
1144 }
1145}
1146
1147pub type ActionHook = Arc<dyn Fn(String, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync>;
1166
1167pub fn make_action_hook<F, Fut>(f: F) -> ActionHook
1182where
1183 F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
1184 Fut: std::future::Future<Output = bool> + Send + 'static,
1185{
1186 Arc::new(move |name, args| Box::pin(f(name, args)))
1187}
1188
1189#[cfg(test)]
1192mod tests {
1193 use super::*;
1194
1195 #[tokio::test]
1196 async fn test_final_answer_on_first_step() {
1197 let config = AgentConfig::new(5, "test-model");
1198 let loop_ = ReActLoop::new(config);
1199
1200 let steps = loop_
1201 .run("Say hello", |_ctx| async {
1202 "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
1203 })
1204 .await
1205 .unwrap();
1206
1207 assert_eq!(steps.len(), 1);
1208 assert!(steps[0]
1209 .action
1210 .to_ascii_uppercase()
1211 .starts_with("FINAL_ANSWER"));
1212 }
1213
1214 #[tokio::test]
1215 async fn test_tool_call_then_final_answer() {
1216 let config = AgentConfig::new(5, "test-model");
1217 let mut loop_ = ReActLoop::new(config);
1218
1219 loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
1220 serde_json::json!("hello!")
1221 }));
1222
1223 let mut call_count = 0;
1224 let steps = loop_
1225 .run("Say hello", |_ctx| {
1226 call_count += 1;
1227 let count = call_count;
1228 async move {
1229 if count == 1 {
1230 "Thought: I will greet\nAction: greet {}".to_string()
1231 } else {
1232 "Thought: done\nAction: FINAL_ANSWER done".to_string()
1233 }
1234 }
1235 })
1236 .await
1237 .unwrap();
1238
1239 assert_eq!(steps.len(), 2);
1240 assert_eq!(steps[0].action, "greet {}");
1241 assert!(steps[1]
1242 .action
1243 .to_ascii_uppercase()
1244 .starts_with("FINAL_ANSWER"));
1245 }
1246
1247 #[tokio::test]
1248 async fn test_max_iterations_exceeded() {
1249 let config = AgentConfig::new(2, "test-model");
1250 let loop_ = ReActLoop::new(config);
1251
1252 let result = loop_
1253 .run("loop forever", |_ctx| async {
1254 "Thought: thinking\nAction: noop {}".to_string()
1255 })
1256 .await;
1257
1258 assert!(result.is_err());
1259 let err = result.unwrap_err().to_string();
1260 assert!(err.contains("max iterations"));
1261 }
1262
1263 #[tokio::test]
1264 async fn test_parse_react_step_valid() {
1265 let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
1266 let step = parse_react_step(text).unwrap();
1267 assert_eq!(step.thought, "I should check");
1268 assert_eq!(step.action, "lookup {\"key\":\"val\"}");
1269 }
1270
1271 #[tokio::test]
1272 async fn test_parse_react_step_empty_fails() {
1273 let result = parse_react_step("no prefix lines here");
1274 assert!(result.is_err());
1275 }
1276
1277 #[tokio::test]
1278 async fn test_tool_not_found_returns_error_observation() {
1279 let config = AgentConfig::new(3, "test-model");
1280 let loop_ = ReActLoop::new(config);
1281
1282 let mut call_count = 0;
1283 let steps = loop_
1284 .run("test", |_ctx| {
1285 call_count += 1;
1286 let count = call_count;
1287 async move {
1288 if count == 1 {
1289 "Thought: try missing tool\nAction: missing_tool {}".to_string()
1290 } else {
1291 "Thought: done\nAction: FINAL_ANSWER done".to_string()
1292 }
1293 }
1294 })
1295 .await
1296 .unwrap();
1297
1298 assert_eq!(steps.len(), 2);
1299 assert!(steps[0].observation.contains("\"ok\":false"));
1300 }
1301
1302 #[tokio::test]
1303 async fn test_new_async_tool_spec() {
1304 let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
1305 Box::pin(async move { serde_json::json!({"echo": args}) })
1306 });
1307
1308 let result = spec.call(serde_json::json!({"input": "test"})).await;
1309 assert!(result.get("echo").is_some());
1310 }
1311
1312 #[tokio::test]
1315 async fn test_parse_react_step_case_insensitive() {
1316 let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
1317 let step = parse_react_step(text).unwrap();
1318 assert_eq!(step.thought, "done");
1319 assert_eq!(step.action, "FINAL_ANSWER");
1320 }
1321
1322 #[tokio::test]
1323 async fn test_parse_react_step_space_before_colon() {
1324 let text = "Thought : done\nAction : go";
1325 let step = parse_react_step(text).unwrap();
1326 assert_eq!(step.thought, "done");
1327 assert_eq!(step.action, "go");
1328 }
1329
1330 #[tokio::test]
1333 async fn test_tool_required_fields_missing_returns_error() {
1334 let config = AgentConfig::new(3, "test-model");
1335 let mut loop_ = ReActLoop::new(config);
1336
1337 loop_.register_tool(
1338 ToolSpec::new(
1339 "search",
1340 "Searches for something",
1341 |args| serde_json::json!({ "result": args }),
1342 )
1343 .with_required_fields(vec!["q".to_string()]),
1344 );
1345
1346 let mut call_count = 0;
1347 let steps = loop_
1348 .run("test", |_ctx| {
1349 call_count += 1;
1350 let count = call_count;
1351 async move {
1352 if count == 1 {
1353 "Thought: searching\nAction: search {}".to_string()
1355 } else {
1356 "Thought: done\nAction: FINAL_ANSWER done".to_string()
1357 }
1358 }
1359 })
1360 .await
1361 .unwrap();
1362
1363 assert_eq!(steps.len(), 2);
1364 assert!(
1365 steps[0].observation.contains("missing required field"),
1366 "observation was: {}",
1367 steps[0].observation
1368 );
1369 }
1370
1371 #[tokio::test]
1374 async fn test_tool_error_observation_includes_kind() {
1375 let config = AgentConfig::new(3, "test-model");
1376 let loop_ = ReActLoop::new(config);
1377
1378 let mut call_count = 0;
1379 let steps = loop_
1380 .run("test", |_ctx| {
1381 call_count += 1;
1382 let count = call_count;
1383 async move {
1384 if count == 1 {
1385 "Thought: try missing\nAction: nonexistent_tool {}".to_string()
1386 } else {
1387 "Thought: done\nAction: FINAL_ANSWER done".to_string()
1388 }
1389 }
1390 })
1391 .await
1392 .unwrap();
1393
1394 assert_eq!(steps.len(), 2);
1395 let obs = &steps[0].observation;
1396 assert!(obs.contains("\"ok\":false"), "observation: {obs}");
1397 assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
1398 }
1399
1400 #[tokio::test]
1403 async fn test_step_duration_ms_is_set() {
1404 let config = AgentConfig::new(5, "test-model");
1405 let loop_ = ReActLoop::new(config);
1406
1407 let steps = loop_
1408 .run("time it", |_ctx| async {
1409 "Thought: done\nAction: FINAL_ANSWER ok".to_string()
1410 })
1411 .await
1412 .unwrap();
1413
1414 let _ = steps[0].step_duration_ms; }
1417
1418 struct RequirePositiveN;
1421 impl ToolValidator for RequirePositiveN {
1422 fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError> {
1423 let n = args.get("n").and_then(|v| v.as_i64()).unwrap_or(0);
1424 if n <= 0 {
1425 return Err(AgentRuntimeError::AgentLoop(
1426 "n must be a positive integer".into(),
1427 ));
1428 }
1429 Ok(())
1430 }
1431 }
1432
1433 #[tokio::test]
1434 async fn test_tool_validator_blocks_invalid_args() {
1435 let mut registry = ToolRegistry::new();
1436 registry.register(
1437 ToolSpec::new("calc", "compute", |args| serde_json::json!({"n": args}))
1438 .with_validators(vec![Box::new(RequirePositiveN)]),
1439 );
1440
1441 let result = registry
1443 .call("calc", serde_json::json!({"n": -1}))
1444 .await;
1445 assert!(result.is_err(), "validator should reject n=-1");
1446 assert!(result.unwrap_err().to_string().contains("positive integer"));
1447 }
1448
1449 #[tokio::test]
1450 async fn test_tool_validator_passes_valid_args() {
1451 let mut registry = ToolRegistry::new();
1452 registry.register(
1453 ToolSpec::new("calc", "compute", |_| serde_json::json!(42))
1454 .with_validators(vec![Box::new(RequirePositiveN)]),
1455 );
1456
1457 let result = registry
1458 .call("calc", serde_json::json!({"n": 5}))
1459 .await;
1460 assert!(result.is_ok(), "validator should accept n=5");
1461 }
1462
1463 #[tokio::test]
1466 async fn test_empty_tool_name_is_rejected() {
1467 let result = parse_tool_call("");
1469 assert!(result.is_err());
1470 assert!(
1471 result.unwrap_err().to_string().contains("empty tool name"),
1472 "expected 'empty tool name' error"
1473 );
1474 }
1475
1476 #[tokio::test]
1479 async fn test_register_tools_bulk() {
1480 let mut registry = ToolRegistry::new();
1481 registry.register_tools(vec![
1482 ToolSpec::new("tool_a", "A", |_| serde_json::json!("a")),
1483 ToolSpec::new("tool_b", "B", |_| serde_json::json!("b")),
1484 ]);
1485 assert!(registry.call("tool_a", serde_json::json!({})).await.is_ok());
1486 assert!(registry.call("tool_b", serde_json::json!({})).await.is_ok());
1487 }
1488
1489 #[tokio::test]
1492 async fn test_run_streaming_parity_with_run() {
1493 use tokio::sync::mpsc;
1494
1495 let config = AgentConfig::new(5, "test-model");
1496 let loop_ = ReActLoop::new(config);
1497
1498 let steps = loop_
1499 .run_streaming("Say hello", |_ctx| async {
1500 let (tx, rx) = mpsc::channel(4);
1501 tokio::spawn(async move {
1503 tx.send(Ok("Thought: done\n".to_string())).await.ok();
1504 tx.send(Ok("Action: FINAL_ANSWER hi".to_string())).await.ok();
1505 });
1506 rx
1507 })
1508 .await
1509 .unwrap();
1510
1511 assert_eq!(steps.len(), 1);
1512 assert!(steps[0]
1513 .action
1514 .to_ascii_uppercase()
1515 .starts_with("FINAL_ANSWER"));
1516 }
1517
1518 #[tokio::test]
1519 async fn test_run_streaming_error_chunk_is_skipped() {
1520 use tokio::sync::mpsc;
1521 use crate::error::AgentRuntimeError;
1522
1523 let config = AgentConfig::new(5, "test-model");
1524 let loop_ = ReActLoop::new(config);
1525
1526 let steps = loop_
1528 .run_streaming("test", |_ctx| async {
1529 let (tx, rx) = mpsc::channel(4);
1530 tokio::spawn(async move {
1531 tx.send(Err(AgentRuntimeError::Provider("stream error".into())))
1532 .await
1533 .ok();
1534 tx.send(Ok("Thought: recovered\nAction: FINAL_ANSWER ok".to_string()))
1535 .await
1536 .ok();
1537 });
1538 rx
1539 })
1540 .await
1541 .unwrap();
1542
1543 assert_eq!(steps.len(), 1);
1544 }
1545
1546 #[cfg(feature = "orchestrator")]
1549 #[tokio::test]
1550 async fn test_tool_with_circuit_breaker_passes_when_closed() {
1551 use std::sync::Arc;
1552
1553 let cb = Arc::new(
1554 crate::orchestrator::CircuitBreaker::new(
1555 "echo-tool",
1556 5,
1557 std::time::Duration::from_secs(30),
1558 )
1559 .unwrap(),
1560 );
1561
1562 let spec = ToolSpec::new(
1563 "echo",
1564 "Echoes args",
1565 |args| serde_json::json!({ "echoed": args }),
1566 )
1567 .with_circuit_breaker(cb);
1568
1569 let registry = {
1570 let mut r = ToolRegistry::new();
1571 r.register(spec);
1572 r
1573 };
1574
1575 let result = registry
1576 .call("echo", serde_json::json!({ "msg": "hi" }))
1577 .await;
1578 assert!(result.is_ok(), "expected Ok, got {:?}", result);
1579 }
1580
1581 #[test]
1584 fn test_agent_config_builder_methods_set_fields() {
1585 let config = AgentConfig::new(3, "model")
1586 .with_temperature(0.7)
1587 .with_max_tokens(512)
1588 .with_request_timeout(std::time::Duration::from_secs(10));
1589 assert_eq!(config.temperature, Some(0.7));
1590 assert_eq!(config.max_tokens, Some(512));
1591 assert_eq!(config.request_timeout, Some(std::time::Duration::from_secs(10)));
1592 }
1593
1594 #[tokio::test]
1597 async fn test_fallible_tool_returns_error_json_on_err() {
1598 let spec = ToolSpec::new_fallible(
1599 "fail",
1600 "always fails",
1601 |_| Err::<Value, String>("something went wrong".to_string()),
1602 );
1603 let result = spec.call(serde_json::json!({})).await;
1604 assert_eq!(result["ok"], serde_json::json!(false));
1605 assert_eq!(result["error"], serde_json::json!("something went wrong"));
1606 }
1607
1608 #[tokio::test]
1609 async fn test_fallible_tool_returns_value_on_ok() {
1610 let spec = ToolSpec::new_fallible(
1611 "succeed",
1612 "always succeeds",
1613 |_| Ok::<Value, String>(serde_json::json!(42)),
1614 );
1615 let result = spec.call(serde_json::json!({})).await;
1616 assert_eq!(result, serde_json::json!(42));
1617 }
1618
1619 #[tokio::test]
1622 async fn test_did_you_mean_suggestion_for_typo() {
1623 let mut registry = ToolRegistry::new();
1624 registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
1625 let result = registry.call("searc", serde_json::json!({})).await;
1626 assert!(result.is_err());
1627 let msg = result.unwrap_err().to_string();
1628 assert!(msg.contains("did you mean"), "expected suggestion in: {msg}");
1629 }
1630
1631 #[tokio::test]
1632 async fn test_no_suggestion_for_very_different_name() {
1633 let mut registry = ToolRegistry::new();
1634 registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
1635 let result = registry.call("xxxxxxxxxxxxxxx", serde_json::json!({})).await;
1636 assert!(result.is_err());
1637 let msg = result.unwrap_err().to_string();
1638 assert!(!msg.contains("did you mean"), "unexpected suggestion in: {msg}");
1639 }
1640
1641 #[test]
1644 fn test_action_parse_final_answer() {
1645 let action = Action::parse("FINAL_ANSWER hello world").unwrap();
1646 assert_eq!(action, Action::FinalAnswer("hello world".to_string()));
1647 }
1648
1649 #[test]
1650 fn test_action_parse_tool_call() {
1651 let action = Action::parse("search {\"q\": \"rust\"}").unwrap();
1652 match action {
1653 Action::ToolCall { name, args } => {
1654 assert_eq!(name, "search");
1655 assert_eq!(args["q"], "rust");
1656 }
1657 _ => panic!("expected ToolCall"),
1658 }
1659 }
1660
1661 #[test]
1662 fn test_action_parse_invalid_returns_err() {
1663 let result = Action::parse("");
1664 assert!(result.is_err());
1665 }
1666
1667 #[tokio::test]
1670 async fn test_observer_on_step_called_for_each_step() {
1671 use std::sync::{Arc, Mutex};
1672
1673 struct CountingObserver {
1674 step_count: Mutex<usize>,
1675 }
1676 impl Observer for CountingObserver {
1677 fn on_step(&self, _step_index: usize, _step: &ReActStep) {
1678 let mut c = self.step_count.lock().unwrap_or_else(|e| e.into_inner());
1679 *c += 1;
1680 }
1681 }
1682
1683 let obs = Arc::new(CountingObserver { step_count: Mutex::new(0) });
1684 let config = AgentConfig::new(5, "test-model");
1685 let mut loop_ = ReActLoop::new(config).with_observer(obs.clone() as Arc<dyn Observer>);
1686 loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
1687
1688 let mut call_count = 0;
1689 let _steps = loop_.run("test", |_ctx| {
1690 call_count += 1;
1691 let count = call_count;
1692 async move {
1693 if count == 1 {
1694 "Thought: call noop\nAction: noop {}".to_string()
1695 } else {
1696 "Thought: done\nAction: FINAL_ANSWER done".to_string()
1697 }
1698 }
1699 }).await.unwrap();
1700
1701 let count = *obs.step_count.lock().unwrap_or_else(|e| e.into_inner());
1702 assert_eq!(count, 2, "observer should have seen 2 steps");
1703 }
1704
1705 #[tokio::test]
1708 async fn test_tool_cache_returns_cached_result_on_second_call() {
1709 use std::collections::HashMap;
1710 use std::sync::Mutex;
1711
1712 struct InMemCache {
1713 map: Mutex<HashMap<String, Value>>,
1714 }
1715 impl ToolCache for InMemCache {
1716 fn get(&self, tool_name: &str, args: &Value) -> Option<Value> {
1717 let key = format!("{tool_name}:{args}");
1718 let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
1719 map.get(&key).cloned()
1720 }
1721 fn set(&self, tool_name: &str, args: &Value, result: Value) {
1722 let key = format!("{tool_name}:{args}");
1723 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
1724 map.insert(key, result);
1725 }
1726 }
1727
1728 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1729 let call_count_clone = call_count.clone();
1730
1731 let cache = Arc::new(InMemCache { map: Mutex::new(HashMap::new()) });
1732 let registry = ToolRegistry::new()
1733 .with_cache(cache as Arc<dyn ToolCache>);
1734 let mut registry = registry;
1735
1736 registry.register(ToolSpec::new("count", "count calls", move |_| {
1737 call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1738 serde_json::json!({"calls": 1})
1739 }));
1740
1741 let args = serde_json::json!({});
1742 let r1 = registry.call("count", args.clone()).await.unwrap();
1743 let r2 = registry.call("count", args.clone()).await.unwrap();
1744
1745 assert_eq!(r1, r2);
1746 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
1748 }
1749
1750 #[tokio::test]
1753 async fn test_validators_short_circuit_on_first_failure() {
1754 use std::sync::atomic::{AtomicUsize, Ordering as AOrdering};
1755 use std::sync::Arc;
1756
1757 let second_called = Arc::new(AtomicUsize::new(0));
1758 let second_called_clone = Arc::clone(&second_called);
1759
1760 struct AlwaysFail;
1761 impl ToolValidator for AlwaysFail {
1762 fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
1763 Err(AgentRuntimeError::AgentLoop("first validator failed".into()))
1764 }
1765 }
1766
1767 struct CountCalls(Arc<AtomicUsize>);
1768 impl ToolValidator for CountCalls {
1769 fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
1770 self.0.fetch_add(1, AOrdering::SeqCst);
1771 Ok(())
1772 }
1773 }
1774
1775 let mut registry = ToolRegistry::new();
1776 registry.register(
1777 ToolSpec::new("guarded", "A guarded tool", |args| args.clone())
1778 .with_validators(vec![
1779 Box::new(AlwaysFail),
1780 Box::new(CountCalls(second_called_clone)),
1781 ]),
1782 );
1783
1784 let result = registry.call("guarded", serde_json::json!({})).await;
1785 assert!(result.is_err(), "should fail due to first validator");
1786 assert_eq!(
1787 second_called.load(AOrdering::SeqCst),
1788 0,
1789 "second validator must not be called when first fails"
1790 );
1791 }
1792
1793 #[tokio::test]
1796 async fn test_loop_timeout_fires_between_iterations() {
1797 let mut config = AgentConfig::new(100, "test-model");
1798 config.loop_timeout = Some(std::time::Duration::from_millis(30));
1800 let loop_ = ReActLoop::new(config);
1801
1802 let result = loop_
1803 .run("test", |_ctx| async {
1804 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1806 "Thought: still working\nAction: noop {}".to_string()
1808 })
1809 .await;
1810
1811 assert!(result.is_err(), "loop should time out");
1812 let msg = result.unwrap_err().to_string();
1813 assert!(msg.contains("loop timeout"), "unexpected error: {msg}");
1814 }
1815
1816 #[test]
1821 fn test_react_step_is_final_answer() {
1822 let step = ReActStep {
1823 thought: "".into(),
1824 action: "FINAL_ANSWER done".into(),
1825 observation: "".into(),
1826 step_duration_ms: 0,
1827 };
1828 assert!(step.is_final_answer());
1829 assert!(!step.is_tool_call());
1830 }
1831
1832 #[test]
1833 fn test_react_step_is_tool_call() {
1834 let step = ReActStep {
1835 thought: "".into(),
1836 action: "search {}".into(),
1837 observation: "".into(),
1838 step_duration_ms: 0,
1839 };
1840 assert!(!step.is_final_answer());
1841 assert!(step.is_tool_call());
1842 }
1843
1844 #[test]
1847 fn test_role_display() {
1848 assert_eq!(Role::System.to_string(), "system");
1849 assert_eq!(Role::User.to_string(), "user");
1850 assert_eq!(Role::Assistant.to_string(), "assistant");
1851 assert_eq!(Role::Tool.to_string(), "tool");
1852 }
1853
1854 #[test]
1857 fn test_message_accessors() {
1858 let msg = Message::new(Role::User, "hello");
1859 assert_eq!(msg.role(), &Role::User);
1860 assert_eq!(msg.content(), "hello");
1861 }
1862
1863 #[test]
1866 fn test_action_parse_final_answer_round_trip() {
1867 let step = ReActStep {
1868 thought: "done".into(),
1869 action: "FINAL_ANSWER Paris".into(),
1870 observation: "".into(),
1871 step_duration_ms: 0,
1872 };
1873 assert!(step.is_final_answer());
1874 let action = Action::parse(&step.action).unwrap();
1875 assert!(matches!(action, Action::FinalAnswer(ref s) if s == "Paris"));
1876 }
1877
1878 #[test]
1879 fn test_action_parse_tool_call_round_trip() {
1880 let step = ReActStep {
1881 thought: "searching".into(),
1882 action: "search {\"q\":\"hello\"}".into(),
1883 observation: "".into(),
1884 step_duration_ms: 0,
1885 };
1886 assert!(step.is_tool_call());
1887 let action = Action::parse(&step.action).unwrap();
1888 assert!(matches!(action, Action::ToolCall { ref name, .. } if name == "search"));
1889 }
1890
1891 #[tokio::test]
1894 async fn test_observer_receives_correct_step_indices() {
1895 use std::sync::{Arc, Mutex};
1896
1897 struct IndexCollector(Arc<Mutex<Vec<usize>>>);
1898 impl Observer for IndexCollector {
1899 fn on_step(&self, step_index: usize, _step: &ReActStep) {
1900 self.0.lock().unwrap_or_else(|e| e.into_inner()).push(step_index);
1901 }
1902 }
1903
1904 let indices = Arc::new(Mutex::new(Vec::new()));
1905 let obs = Arc::new(IndexCollector(Arc::clone(&indices)));
1906
1907 let config = AgentConfig::new(5, "test");
1908 let mut loop_ = ReActLoop::new(config).with_observer(obs as Arc<dyn Observer>);
1909 loop_.register_tool(ToolSpec::new("noop", "no-op", |_| serde_json::json!({})));
1910
1911 let mut call_count = 0;
1912 loop_.run("test", |_ctx| {
1913 call_count += 1;
1914 let count = call_count;
1915 async move {
1916 if count == 1 {
1917 "Thought: step1\nAction: noop {}".to_string()
1918 } else {
1919 "Thought: done\nAction: FINAL_ANSWER ok".to_string()
1920 }
1921 }
1922 }).await.unwrap();
1923
1924 let collected = indices.lock().unwrap_or_else(|e| e.into_inner()).clone();
1925 assert_eq!(collected, vec![0, 1], "expected step indices 0 and 1");
1926 }
1927
1928 #[tokio::test]
1929 async fn test_action_hook_blocking_inserts_blocked_observation() {
1930 let hook: ActionHook = Arc::new(|_name, _args| {
1931 Box::pin(async move { false }) });
1933
1934 let config = AgentConfig::new(5, "test-model");
1935 let mut loop_ = ReActLoop::new(config).with_action_hook(hook);
1936 loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
1937
1938 let mut call_count = 0;
1939 let steps = loop_.run("test", |_ctx| {
1940 call_count += 1;
1941 let count = call_count;
1942 async move {
1943 if count == 1 {
1944 "Thought: try tool\nAction: noop {}".to_string()
1945 } else {
1946 "Thought: done\nAction: FINAL_ANSWER done".to_string()
1947 }
1948 }
1949 }).await.unwrap();
1950
1951 assert!(steps[0].observation.contains("blocked"), "expected blocked observation, got: {}", steps[0].observation);
1952 }
1953}