1use crate::error::AgentRuntimeError;
18use crate::metrics::RuntimeMetrics;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::collections::HashMap;
22use std::fmt::Write as FmtWrite;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
31
32pub type AsyncToolResultFuture = Pin<Box<dyn Future<Output = Result<Value, String>> + Send>>;
34
35pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub enum Role {
41 System,
43 User,
45 Assistant,
47 Tool,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Message {
54 pub role: Role,
56 pub content: String,
58}
59
60impl Message {
61 pub fn new(role: Role, content: impl Into<String>) -> Self {
67 Self {
68 role,
69 content: content.into(),
70 }
71 }
72
73 pub fn user(content: impl Into<String>) -> Self {
75 Self::new(Role::User, content)
76 }
77
78 pub fn assistant(content: impl Into<String>) -> Self {
80 Self::new(Role::Assistant, content)
81 }
82
83 pub fn system(content: impl Into<String>) -> Self {
85 Self::new(Role::System, content)
86 }
87
88 pub fn role(&self) -> &Role {
90 &self.role
91 }
92
93 pub fn content(&self) -> &str {
95 &self.content
96 }
97
98 pub fn is_user(&self) -> bool {
100 self.role == Role::User
101 }
102
103 pub fn is_assistant(&self) -> bool {
105 self.role == Role::Assistant
106 }
107
108 pub fn is_system(&self) -> bool {
110 self.role == Role::System
111 }
112
113 pub fn is_tool(&self) -> bool {
115 self.role == Role::Tool
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.content.is_empty()
121 }
122
123 pub fn word_count(&self) -> usize {
125 self.content.split_whitespace().count()
126 }
127}
128
129impl std::fmt::Display for Role {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 Role::System => write!(f, "system"),
133 Role::User => write!(f, "user"),
134 Role::Assistant => write!(f, "assistant"),
135 Role::Tool => write!(f, "tool"),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ReActStep {
143 pub thought: String,
145 pub action: String,
147 pub observation: String,
149 #[serde(default)]
154 pub step_duration_ms: u64,
155}
156
157impl ReActStep {
158 pub fn new(
163 thought: impl Into<String>,
164 action: impl Into<String>,
165 observation: impl Into<String>,
166 ) -> Self {
167 Self {
168 thought: thought.into(),
169 action: action.into(),
170 observation: observation.into(),
171 step_duration_ms: 0,
172 }
173 }
174
175 pub fn is_final_answer(&self) -> bool {
177 self.action.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER")
178 }
179
180 pub fn is_tool_call(&self) -> bool {
182 !self.is_final_answer() && !self.action.trim().is_empty()
183 }
184
185 pub fn with_duration(mut self, ms: u64) -> Self {
190 self.step_duration_ms = ms;
191 self
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.thought.is_empty() && self.action.is_empty() && self.observation.is_empty()
197 }
198
199 pub fn observation_is_empty(&self) -> bool {
203 self.observation.is_empty()
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct AgentConfig {
210 pub max_iterations: usize,
212 pub model: String,
214 pub system_prompt: String,
216 pub max_memory_recalls: usize,
219 pub max_memory_tokens: Option<usize>,
228 pub loop_timeout: Option<std::time::Duration>,
232 pub temperature: Option<f32>,
234 pub max_tokens: Option<usize>,
236 pub request_timeout: Option<std::time::Duration>,
238 pub max_context_chars: Option<usize>,
245 pub stop_sequences: Vec<String>,
250}
251
252impl AgentConfig {
253 pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
255 Self {
256 max_iterations,
257 model: model.into(),
258 system_prompt: "You are a helpful AI agent.".into(),
259 max_memory_recalls: 3,
260 max_memory_tokens: None,
261 loop_timeout: None,
262 temperature: None,
263 max_tokens: None,
264 request_timeout: None,
265 max_context_chars: None,
266 stop_sequences: vec![],
267 }
268 }
269
270 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
272 self.system_prompt = prompt.into();
273 self
274 }
275
276 pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
278 self.max_memory_recalls = n;
279 self
280 }
281
282 pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
284 self.max_memory_tokens = Some(n);
285 self
286 }
287
288 pub fn with_loop_timeout(mut self, d: std::time::Duration) -> Self {
293 self.loop_timeout = Some(d);
294 self
295 }
296
297 pub fn with_loop_timeout_secs(self, secs: u64) -> Self {
301 self.with_loop_timeout(std::time::Duration::from_secs(secs))
302 }
303
304 pub fn with_loop_timeout_ms(self, ms: u64) -> Self {
308 self.with_loop_timeout(std::time::Duration::from_millis(ms))
309 }
310
311 pub fn with_max_iterations(mut self, n: usize) -> Self {
313 self.max_iterations = n;
314 self
315 }
316
317 pub fn max_iterations(&self) -> usize {
319 self.max_iterations
320 }
321
322 pub fn with_temperature(mut self, t: f32) -> Self {
324 self.temperature = Some(t);
325 self
326 }
327
328 pub fn with_max_tokens(mut self, n: usize) -> Self {
330 self.max_tokens = Some(n);
331 self
332 }
333
334 pub fn with_request_timeout(mut self, d: std::time::Duration) -> Self {
336 self.request_timeout = Some(d);
337 self
338 }
339
340 pub fn with_request_timeout_secs(self, secs: u64) -> Self {
344 self.with_request_timeout(std::time::Duration::from_secs(secs))
345 }
346
347 pub fn with_request_timeout_ms(self, ms: u64) -> Self {
351 self.with_request_timeout(std::time::Duration::from_millis(ms))
352 }
353
354 pub fn with_max_context_chars(mut self, n: usize) -> Self {
360 self.max_context_chars = Some(n);
361 self
362 }
363
364 pub fn with_model(mut self, model: impl Into<String>) -> Self {
366 self.model = model.into();
367 self
368 }
369
370 pub fn clone_with_model(&self, model: impl Into<String>) -> Self {
375 let mut copy = self.clone();
376 copy.model = model.into();
377 copy
378 }
379
380 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
385 self.stop_sequences = sequences;
386 self
387 }
388
389 pub fn is_valid(&self) -> bool {
394 self.max_iterations >= 1 && !self.model.is_empty()
395 }
396
397 pub fn validate(&self) -> Result<(), crate::error::AgentRuntimeError> {
405 if self.max_iterations == 0 {
406 return Err(crate::error::AgentRuntimeError::AgentLoop(
407 "AgentConfig: max_iterations must be >= 1".into(),
408 ));
409 }
410 if self.model.is_empty() {
411 return Err(crate::error::AgentRuntimeError::AgentLoop(
412 "AgentConfig: model must not be empty".into(),
413 ));
414 }
415 Ok(())
416 }
417
418 pub fn has_loop_timeout(&self) -> bool {
420 self.loop_timeout.is_some()
421 }
422
423 pub fn has_stop_sequences(&self) -> bool {
425 !self.stop_sequences.is_empty()
426 }
427
428 pub fn stop_sequence_count(&self) -> usize {
430 self.stop_sequences.len()
431 }
432
433 pub fn is_single_shot(&self) -> bool {
439 self.max_iterations == 1
440 }
441
442 pub fn has_temperature(&self) -> bool {
444 self.temperature.is_some()
445 }
446
447 pub fn temperature(&self) -> Option<f32> {
449 self.temperature
450 }
451
452 pub fn max_tokens(&self) -> Option<usize> {
454 self.max_tokens
455 }
456
457 pub fn has_request_timeout(&self) -> bool {
459 self.request_timeout.is_some()
460 }
461
462 pub fn request_timeout(&self) -> Option<std::time::Duration> {
464 self.request_timeout
465 }
466
467 pub fn has_max_context_chars(&self) -> bool {
469 self.max_context_chars.is_some()
470 }
471
472 pub fn max_context_chars(&self) -> Option<usize> {
474 self.max_context_chars
475 }
476
477 pub fn remaining_iterations_after(&self, n: usize) -> usize {
482 self.max_iterations.saturating_sub(n)
483 }
484
485 pub fn system_prompt(&self) -> &str {
487 &self.system_prompt
488 }
489
490 pub fn model(&self) -> &str {
492 &self.model
493 }
494}
495
496pub struct ToolSpec {
500 pub name: String,
502 pub description: String,
504 pub(crate) handler: AsyncToolHandler,
506 pub required_fields: Vec<String>,
509 pub validators: Vec<Box<dyn ToolValidator>>,
512 #[cfg(feature = "orchestrator")]
514 pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
515}
516
517impl std::fmt::Debug for ToolSpec {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 let mut s = f.debug_struct("ToolSpec");
520 s.field("name", &self.name)
521 .field("description", &self.description)
522 .field("required_fields", &self.required_fields);
523 #[cfg(feature = "orchestrator")]
524 s.field("has_circuit_breaker", &self.circuit_breaker.is_some());
525 s.finish()
526 }
527}
528
529impl ToolSpec {
530 pub fn new(
533 name: impl Into<String>,
534 description: impl Into<String>,
535 handler: impl Fn(Value) -> Value + Send + Sync + 'static,
536 ) -> Self {
537 Self {
538 name: name.into(),
539 description: description.into(),
540 handler: Box::new(move |args| {
541 let result = handler(args);
542 Box::pin(async move { result })
543 }),
544 required_fields: Vec::new(),
545 validators: Vec::new(),
546 #[cfg(feature = "orchestrator")]
547 circuit_breaker: None,
548 }
549 }
550
551 pub fn new_async(
553 name: impl Into<String>,
554 description: impl Into<String>,
555 handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
556 ) -> Self {
557 Self {
558 name: name.into(),
559 description: description.into(),
560 handler: Box::new(handler),
561 required_fields: Vec::new(),
562 validators: Vec::new(),
563 #[cfg(feature = "orchestrator")]
564 circuit_breaker: None,
565 }
566 }
567
568 pub fn new_fallible(
571 name: impl Into<String>,
572 description: impl Into<String>,
573 handler: impl Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
574 ) -> Self {
575 Self {
576 name: name.into(),
577 description: description.into(),
578 handler: Box::new(move |args| {
579 let result = handler(args);
580 let value = match result {
581 Ok(v) => v,
582 Err(msg) => serde_json::json!({"error": msg, "ok": false}),
583 };
584 Box::pin(async move { value })
585 }),
586 required_fields: Vec::new(),
587 validators: Vec::new(),
588 #[cfg(feature = "orchestrator")]
589 circuit_breaker: None,
590 }
591 }
592
593 pub fn new_async_fallible(
596 name: impl Into<String>,
597 description: impl Into<String>,
598 handler: impl Fn(Value) -> AsyncToolResultFuture + Send + Sync + 'static,
599 ) -> Self {
600 Self {
601 name: name.into(),
602 description: description.into(),
603 handler: Box::new(move |args| {
604 let fut = handler(args);
605 Box::pin(async move {
606 match fut.await {
607 Ok(v) => v,
608 Err(msg) => serde_json::json!({"error": msg, "ok": false}),
609 }
610 })
611 }),
612 required_fields: Vec::new(),
613 validators: Vec::new(),
614 #[cfg(feature = "orchestrator")]
615 circuit_breaker: None,
616 }
617 }
618
619 pub fn with_required_fields(
625 mut self,
626 fields: impl IntoIterator<Item = impl Into<String>>,
627 ) -> Self {
628 self.required_fields = fields.into_iter().map(Into::into).collect();
629 self
630 }
631
632 pub fn with_validators(mut self, validators: Vec<Box<dyn ToolValidator>>) -> Self {
637 self.validators = validators;
638 self
639 }
640
641 #[cfg(feature = "orchestrator")]
643 pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
644 self.circuit_breaker = Some(cb);
645 self
646 }
647
648 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
650 self.description = desc.into();
651 self
652 }
653
654 pub fn with_name(mut self, name: impl Into<String>) -> Self {
656 self.name = name.into();
657 self
658 }
659
660 pub fn required_field_count(&self) -> usize {
662 self.required_fields.len()
663 }
664
665 pub fn has_required_fields(&self) -> bool {
667 !self.required_fields.is_empty()
668 }
669
670 pub fn has_validators(&self) -> bool {
672 !self.validators.is_empty()
673 }
674
675 pub async fn call(&self, args: Value) -> Value {
677 (self.handler)(args).await
678 }
679}
680
681pub trait ToolCache: Send + Sync {
707 fn get(&self, tool_name: &str, args: &serde_json::Value) -> Option<serde_json::Value>;
709 fn set(&self, tool_name: &str, args: &serde_json::Value, result: serde_json::Value);
711}
712
713struct ToolCacheInner {
717 map: HashMap<(String, String), serde_json::Value>,
718 order: std::collections::VecDeque<(String, String)>,
720}
721
722pub struct InMemoryToolCache {
737 inner: std::sync::Mutex<ToolCacheInner>,
738 max_entries: Option<usize>,
739}
740
741impl std::fmt::Debug for InMemoryToolCache {
742 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
743 let len = self.len();
744 f.debug_struct("InMemoryToolCache")
745 .field("entries", &len)
746 .field("max_entries", &self.max_entries)
747 .finish()
748 }
749}
750
751impl InMemoryToolCache {
752 pub fn new() -> Self {
754 Self {
755 inner: std::sync::Mutex::new(ToolCacheInner {
756 map: HashMap::new(),
757 order: std::collections::VecDeque::new(),
758 }),
759 max_entries: None,
760 }
761 }
762
763 pub fn with_max_entries(max: usize) -> Self {
765 Self {
766 inner: std::sync::Mutex::new(ToolCacheInner {
767 map: HashMap::new(),
768 order: std::collections::VecDeque::new(),
769 }),
770 max_entries: Some(max),
771 }
772 }
773
774 pub fn clear(&self) {
776 if let Ok(mut inner) = self.inner.lock() {
777 inner.map.clear();
778 inner.order.clear();
779 }
780 }
781
782 pub fn len(&self) -> usize {
784 self.inner.lock().map(|s| s.map.len()).unwrap_or(0)
785 }
786
787 pub fn is_empty(&self) -> bool {
789 self.len() == 0
790 }
791
792 pub fn contains(&self, tool_name: &str, args: &serde_json::Value) -> bool {
794 let key = (tool_name.to_owned(), args.to_string());
795 self.inner
796 .lock()
797 .map(|s| s.map.contains_key(&key))
798 .unwrap_or(false)
799 }
800
801 pub fn remove(&self, tool_name: &str, args: &serde_json::Value) -> bool {
803 let key = (tool_name.to_owned(), args.to_string());
804 if let Ok(mut inner) = self.inner.lock() {
805 if inner.map.remove(&key).is_some() {
806 inner.order.retain(|k| k != &key);
807 return true;
808 }
809 }
810 false
811 }
812
813 pub fn capacity(&self) -> Option<usize> {
815 self.max_entries
816 }
817}
818
819impl Default for InMemoryToolCache {
820 fn default() -> Self {
821 Self::new()
822 }
823}
824
825impl ToolCache for InMemoryToolCache {
826 fn get(&self, tool_name: &str, args: &serde_json::Value) -> Option<serde_json::Value> {
827 let key = (tool_name.to_owned(), args.to_string());
828 self.inner.lock().ok()?.map.get(&key).cloned()
829 }
830
831 fn set(&self, tool_name: &str, args: &serde_json::Value, result: serde_json::Value) {
832 let key = (tool_name.to_owned(), args.to_string());
833 if let Ok(mut inner) = self.inner.lock() {
834 if !inner.map.contains_key(&key) {
835 inner.order.push_back(key.clone());
836 }
837 inner.map.insert(key, result);
838 if let Some(max) = self.max_entries {
839 while inner.map.len() > max {
840 if let Some(oldest) = inner.order.pop_front() {
841 inner.map.remove(&oldest);
842 }
843 }
844 }
845 }
846 }
847}
848
849pub struct ToolRegistry {
853 tools: HashMap<String, ToolSpec>,
854 cache: Option<Arc<dyn ToolCache>>,
856}
857
858impl std::fmt::Debug for ToolRegistry {
859 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860 f.debug_struct("ToolRegistry")
861 .field("tools", &self.tools.keys().collect::<Vec<_>>())
862 .field("has_cache", &self.cache.is_some())
863 .finish()
864 }
865}
866
867impl Default for ToolRegistry {
868 fn default() -> Self {
869 Self::new()
870 }
871}
872
873impl ToolRegistry {
874 pub fn new() -> Self {
876 Self {
877 tools: HashMap::new(),
878 cache: None,
879 }
880 }
881
882 pub fn with_cache(mut self, cache: Arc<dyn ToolCache>) -> Self {
884 self.cache = Some(cache);
885 self
886 }
887
888 pub fn register(&mut self, spec: ToolSpec) {
890 self.tools.insert(spec.name.clone(), spec);
891 }
892
893 pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
900 for spec in specs {
901 self.register(spec);
902 }
903 }
904
905 pub fn with_tool(mut self, spec: ToolSpec) -> Self {
916 self.register(spec);
917 self
918 }
919
920 #[tracing::instrument(skip_all, fields(tool_name = %name))]
928 pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
929 let spec = self.tools.get(name).ok_or_else(|| {
930 let mut suggestion = String::new();
931 let names = self.tool_names();
932 if !names.is_empty() {
933 if let Some((closest, dist)) = names
934 .iter()
935 .map(|n| (n, levenshtein(name, n)))
936 .min_by_key(|(_, d)| *d)
937 {
938 if dist <= 3 {
939 suggestion = format!(" (did you mean '{closest}'?)");
940 }
941 }
942 }
943 AgentRuntimeError::AgentLoop(format!("tool '{name}' not found{suggestion}"))
944 })?;
945
946 if !spec.required_fields.is_empty() {
948 if let Some(obj) = args.as_object() {
949 for field in &spec.required_fields {
950 if !obj.contains_key(field) {
951 return Err(AgentRuntimeError::AgentLoop(format!(
952 "tool '{}' missing required field '{}'",
953 name, field
954 )));
955 }
956 }
957 } else {
958 return Err(AgentRuntimeError::AgentLoop(format!(
959 "tool '{}' requires JSON object args, got {}",
960 name, args
961 )));
962 }
963 }
964
965 for validator in &spec.validators {
967 validator.validate(&args)?;
968 }
969
970 #[cfg(feature = "orchestrator")]
972 if let Some(ref cb) = spec.circuit_breaker {
973 use crate::orchestrator::CircuitState;
974 if let Ok(CircuitState::Open { .. }) = cb.state() {
975 return Err(AgentRuntimeError::CircuitOpen {
976 service: format!("tool:{}", name),
977 });
978 }
979 }
980
981 if let Some(ref cache) = self.cache {
983 if let Some(cached) = cache.get(name, &args) {
984 return Ok(cached);
985 }
986 }
987
988 let result = spec.call(args.clone()).await;
989
990 #[cfg(feature = "orchestrator")]
994 if let Some(ref cb) = spec.circuit_breaker {
995 let is_failure = result
996 .get("ok")
997 .and_then(|v| v.as_bool())
998 .is_some_and(|ok| !ok);
999 if is_failure {
1000 cb.record_failure();
1001 } else {
1002 cb.record_success();
1003 }
1004 }
1005
1006 if let Some(ref cache) = self.cache {
1008 cache.set(name, &args, result.clone());
1009 }
1010
1011 Ok(result)
1012 }
1013
1014 pub fn get(&self, name: &str) -> Option<&ToolSpec> {
1016 self.tools.get(name)
1017 }
1018
1019 pub fn has_tool(&self, name: &str) -> bool {
1021 self.tools.contains_key(name)
1022 }
1023
1024 pub fn unregister(&mut self, name: &str) -> bool {
1026 self.tools.remove(name).is_some()
1027 }
1028
1029 pub fn tool_names(&self) -> Vec<&str> {
1031 self.tools.keys().map(|s| s.as_str()).collect()
1032 }
1033
1034 pub fn tool_names_owned(&self) -> Vec<String> {
1041 self.tools.keys().cloned().collect()
1042 }
1043
1044 pub fn all_tool_names(&self) -> Vec<String> {
1048 let mut names: Vec<String> = self.tools.keys().cloned().collect();
1049 names.sort();
1050 names
1051 }
1052
1053 pub fn tool_specs(&self) -> Vec<&ToolSpec> {
1055 self.tools.values().collect()
1056 }
1057
1058 pub fn filter_tools<F: Fn(&ToolSpec) -> bool>(&self, pred: F) -> Vec<&ToolSpec> {
1067 self.tools.values().filter(|s| pred(s)).collect()
1068 }
1069
1070 pub fn rename_tool(&mut self, old_name: &str, new_name: impl Into<String>) -> bool {
1077 let Some(mut spec) = self.tools.remove(old_name) else {
1078 return false;
1079 };
1080 let new_name = new_name.into();
1081 spec.name = new_name.clone();
1082 self.tools.insert(new_name, spec);
1083 true
1084 }
1085
1086 pub fn tool_count(&self) -> usize {
1088 self.tools.len()
1089 }
1090
1091 pub fn is_empty(&self) -> bool {
1093 self.tools.is_empty()
1094 }
1095
1096 pub fn clear(&mut self) {
1101 self.tools.clear();
1102 }
1103
1104 pub fn remove(&mut self, name: &str) -> Option<ToolSpec> {
1108 self.tools.remove(name)
1109 }
1110
1111 pub fn contains(&self, name: &str) -> bool {
1113 self.tools.contains_key(name)
1114 }
1115
1116 pub fn descriptions(&self) -> Vec<(&str, &str)> {
1120 let mut pairs: Vec<(&str, &str)> = self
1121 .tools
1122 .values()
1123 .map(|s| (s.name.as_str(), s.description.as_str()))
1124 .collect();
1125 pairs.sort_unstable_by_key(|(name, _)| *name);
1126 pairs
1127 }
1128
1129 pub fn find_by_description_keyword(&self, keyword: &str) -> Vec<&ToolSpec> {
1132 let lower = keyword.to_ascii_lowercase();
1133 self.tools
1134 .values()
1135 .filter(|s| s.description.to_ascii_lowercase().contains(&lower))
1136 .collect()
1137 }
1138
1139 pub fn tool_count_with_required_fields(&self) -> usize {
1141 self.tools.values().filter(|s| s.has_required_fields()).count()
1142 }
1143
1144 pub fn names(&self) -> Vec<&str> {
1146 let mut names: Vec<&str> = self.tools.keys().map(|k| k.as_str()).collect();
1147 names.sort_unstable();
1148 names
1149 }
1150
1151 pub fn tool_names_starting_with(&self, prefix: &str) -> Vec<&str> {
1154 let mut names: Vec<&str> = self
1155 .tools
1156 .keys()
1157 .filter(|k| k.starts_with(prefix))
1158 .map(|k| k.as_str())
1159 .collect();
1160 names.sort_unstable();
1161 names
1162 }
1163
1164 pub fn description_for(&self, name: &str) -> Option<&str> {
1167 self.tools.get(name).map(|s| s.description.as_str())
1168 }
1169
1170 pub fn count_with_description_containing(&self, keyword: &str) -> usize {
1173 let lower = keyword.to_ascii_lowercase();
1174 self.tools
1175 .values()
1176 .filter(|s| s.description.to_ascii_lowercase().contains(&lower))
1177 .count()
1178 }
1179
1180 pub fn unregister_all(&mut self) {
1182 self.tools.clear();
1183 }
1184
1185 pub fn names_containing(&self, substring: &str) -> Vec<&str> {
1187 let sub = substring.to_ascii_lowercase();
1188 let mut names: Vec<&str> = self
1189 .tools
1190 .keys()
1191 .filter(|name| name.to_ascii_lowercase().contains(&sub))
1192 .map(|s| s.as_str())
1193 .collect();
1194 names.sort_unstable();
1195 names
1196 }
1197
1198 pub fn shortest_description(&self) -> Option<&str> {
1202 self.tools
1203 .values()
1204 .min_by_key(|s| s.description.len())
1205 .map(|s| s.description.as_str())
1206 }
1207
1208 pub fn longest_description(&self) -> Option<&str> {
1212 self.tools
1213 .values()
1214 .max_by_key(|s| s.description.len())
1215 .map(|s| s.description.as_str())
1216 }
1217
1218 pub fn all_descriptions(&self) -> Vec<&str> {
1220 let mut descs: Vec<&str> = self.tools.values().map(|s| s.description.as_str()).collect();
1221 descs.sort_unstable();
1222 descs
1223 }
1224
1225 pub fn tool_names_with_keyword(&self, keyword: &str) -> Vec<&str> {
1227 let kw = keyword.to_ascii_lowercase();
1228 self.tools
1229 .values()
1230 .filter(|s| s.description.to_ascii_lowercase().contains(&kw))
1231 .map(|s| s.name.as_str())
1232 .collect()
1233 }
1234
1235 pub fn avg_description_length(&self) -> f64 {
1239 if self.tools.is_empty() {
1240 return 0.0;
1241 }
1242 let total: usize = self.tools.values().map(|s| s.description.len()).sum();
1243 total as f64 / self.tools.len() as f64
1244 }
1245
1246 pub fn tool_names_sorted(&self) -> Vec<&str> {
1248 let mut names: Vec<&str> = self.tools.keys().map(|k| k.as_str()).collect();
1249 names.sort_unstable();
1250 names
1251 }
1252
1253 pub fn description_contains_count(&self, keyword: &str) -> usize {
1255 let kw = keyword.to_ascii_lowercase();
1256 self.tools
1257 .values()
1258 .filter(|s| s.description.to_ascii_lowercase().contains(&kw))
1259 .count()
1260 }
1261}
1262
1263pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
1274 #[derive(PartialEq)]
1276 enum Section { None, Thought, Action }
1277
1278 let mut thought_lines: Vec<&str> = Vec::new();
1279 let mut action_lines: Vec<&str> = Vec::new();
1280 let mut current = Section::None;
1281
1282 for line in text.lines() {
1283 let trimmed = line.trim();
1284 let lower = trimmed.to_ascii_lowercase();
1285 if lower.starts_with("thought") {
1286 if let Some(colon_pos) = trimmed.find(':') {
1287 current = Section::Thought;
1288 thought_lines.clear();
1289 let first = trimmed[colon_pos + 1..].trim();
1290 if !first.is_empty() {
1291 thought_lines.push(first);
1292 }
1293 continue;
1294 }
1295 } else if lower.starts_with("action") {
1296 if let Some(colon_pos) = trimmed.find(':') {
1297 current = Section::Action;
1298 action_lines.clear();
1299 let first = trimmed[colon_pos + 1..].trim();
1300 if !first.is_empty() {
1301 action_lines.push(first);
1302 }
1303 continue;
1304 }
1305 } else if lower.starts_with("observation") {
1306 current = Section::None;
1308 continue;
1309 }
1310 match current {
1312 Section::Thought => thought_lines.push(trimmed),
1313 Section::Action => action_lines.push(trimmed),
1314 Section::None => {}
1315 }
1316 }
1317
1318 let thought = thought_lines.join(" ");
1319 let action = action_lines.join("\n").trim().to_owned();
1320
1321 if thought.is_empty() && action.is_empty() {
1322 return Err(AgentRuntimeError::AgentLoop(
1323 "could not parse ReAct step from response".into(),
1324 ));
1325 }
1326
1327 Ok(ReActStep {
1328 thought,
1329 action,
1330 observation: String::new(),
1331 step_duration_ms: 0,
1332 })
1333}
1334
1335pub struct ReActLoop {
1337 config: AgentConfig,
1338 registry: ToolRegistry,
1339 metrics: Option<Arc<RuntimeMetrics>>,
1341 #[cfg(feature = "persistence")]
1343 checkpoint_backend: Option<(Arc<dyn crate::persistence::PersistenceBackend>, String)>,
1344 observer: Option<Arc<dyn Observer>>,
1346 action_hook: Option<ActionHook>,
1348}
1349
1350impl std::fmt::Debug for ReActLoop {
1351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1352 let mut s = f.debug_struct("ReActLoop");
1353 s.field("config", &self.config)
1354 .field("registry", &self.registry)
1355 .field("has_metrics", &self.metrics.is_some())
1356 .field("has_observer", &self.observer.is_some())
1357 .field("has_action_hook", &self.action_hook.is_some());
1358 #[cfg(feature = "persistence")]
1359 s.field("has_checkpoint_backend", &self.checkpoint_backend.is_some());
1360 s.finish()
1361 }
1362}
1363
1364impl ReActLoop {
1365 pub fn new(config: AgentConfig) -> Self {
1367 Self {
1368 config,
1369 registry: ToolRegistry::new(),
1370 metrics: None,
1371 #[cfg(feature = "persistence")]
1372 checkpoint_backend: None,
1373 observer: None,
1374 action_hook: None,
1375 }
1376 }
1377
1378 pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
1380 self.observer = Some(observer);
1381 self
1382 }
1383
1384 pub fn with_action_hook(mut self, hook: ActionHook) -> Self {
1386 self.action_hook = Some(hook);
1387 self
1388 }
1389
1390 pub fn with_metrics(mut self, metrics: Arc<RuntimeMetrics>) -> Self {
1395 self.metrics = Some(metrics);
1396 self
1397 }
1398
1399 #[cfg(feature = "persistence")]
1405 pub fn with_step_checkpoint(
1406 mut self,
1407 backend: Arc<dyn crate::persistence::PersistenceBackend>,
1408 session_id: impl Into<String>,
1409 ) -> Self {
1410 self.checkpoint_backend = Some((backend, session_id.into()));
1411 self
1412 }
1413
1414 pub fn registry(&self) -> &ToolRegistry {
1416 &self.registry
1417 }
1418
1419 pub fn tool_count(&self) -> usize {
1423 self.registry.tool_count()
1424 }
1425
1426 pub fn unregister_tool(&mut self, name: &str) -> bool {
1428 self.registry.unregister(name)
1429 }
1430
1431 pub fn register_tool(&mut self, spec: ToolSpec) {
1433 self.registry.register(spec);
1434 }
1435
1436 pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
1442 for spec in specs {
1443 self.registry.register(spec);
1444 }
1445 }
1446
1447 fn maybe_trim_context(context: &mut String, max_chars: usize) {
1454 while context.len() > max_chars {
1455 let first = context.find("\nThought:");
1459 let second = first.and_then(|pos| {
1460 context[pos + 1..].find("\nThought:").map(|p| pos + 1 + p)
1461 });
1462 if let Some(drop_until) = second {
1463 context.drain(..drop_until);
1464 } else {
1465 break; }
1467 }
1468 }
1469
1470 fn blocked_observation() -> String {
1472 serde_json::json!({
1473 "ok": false,
1474 "error": "action blocked by reviewer",
1475 "kind": "blocked"
1476 })
1477 .to_string()
1478 }
1479
1480 fn error_observation(_tool_name: &str, e: &AgentRuntimeError) -> String {
1482 let kind = match e {
1483 AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => "not_found",
1484 #[cfg(feature = "orchestrator")]
1485 AgentRuntimeError::CircuitOpen { .. } => "transient",
1486 _ => "permanent",
1487 };
1488 serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind }).to_string()
1489 }
1490
1491 #[tracing::instrument(skip(infer))]
1505 pub async fn run<F, Fut>(
1506 &self,
1507 prompt: &str,
1508 mut infer: F,
1509 ) -> Result<Vec<ReActStep>, AgentRuntimeError>
1510 where
1511 F: FnMut(String) -> Fut,
1512 Fut: Future<Output = String>,
1513 {
1514 let mut steps: Vec<ReActStep> = Vec::new();
1515 let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
1516
1517 let deadline = self
1519 .config
1520 .loop_timeout
1521 .map(|d| std::time::Instant::now() + d);
1522
1523 if let Some(ref obs) = self.observer {
1525 obs.on_loop_start(prompt);
1526 }
1527
1528 for iteration in 0..self.config.max_iterations {
1529 let iter_span = tracing::info_span!(
1530 "react_iteration",
1531 iteration = iteration,
1532 model = %self.config.model,
1533 );
1534 let _iter_guard = iter_span.enter();
1535
1536 if let Some(dl) = deadline {
1538 if std::time::Instant::now() >= dl {
1539 let ms = self
1540 .config
1541 .loop_timeout
1542 .map(|d| d.as_millis())
1543 .unwrap_or(0);
1544 let err = AgentRuntimeError::AgentLoop(format!("loop timeout after {ms} ms"));
1545 if let Some(ref obs) = self.observer {
1546 obs.on_error(&err);
1547 obs.on_loop_end(steps.len());
1548 }
1549 return Err(err);
1550 }
1551 }
1552
1553 let step_start = std::time::Instant::now();
1554 let response = infer(context.clone()).await;
1555 let mut step = parse_react_step(&response)?;
1556
1557 tracing::debug!(
1558 step = iteration,
1559 thought = %step.thought,
1560 action = %step.action,
1561 "ReAct iteration"
1562 );
1563
1564 if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
1565 step.observation = step.action.clone();
1566 step.step_duration_ms = step_start.elapsed().as_millis() as u64;
1567 if let Some(ref m) = self.metrics {
1568 m.record_step_latency(step.step_duration_ms);
1569 }
1570 if let Some(ref obs) = self.observer {
1571 obs.on_step(iteration, &step);
1572 }
1573 steps.push(step);
1574 tracing::info!(step = iteration, "FINAL_ANSWER reached");
1575 if let Some(ref obs) = self.observer {
1576 obs.on_loop_end(steps.len());
1577 }
1578 return Ok(steps);
1579 }
1580
1581 let (tool_name, args) = parse_tool_call(&step.action)?;
1583
1584 tracing::debug!(
1585 step = iteration,
1586 tool_name = %tool_name,
1587 "dispatching tool call"
1588 );
1589
1590 if let Some(ref hook) = self.action_hook {
1592 if !hook(tool_name.clone(), args.clone()).await {
1593 if let Some(ref obs) = self.observer {
1594 obs.on_action_blocked(&tool_name, &args);
1595 }
1596 if let Some(ref m) = self.metrics {
1597 m.record_tool_call(&tool_name);
1598 m.record_tool_failure(&tool_name);
1599 }
1600 step.observation = Self::blocked_observation();
1601 step.step_duration_ms = step_start.elapsed().as_millis() as u64;
1602 if let Some(ref m) = self.metrics {
1603 m.record_step_latency(step.step_duration_ms);
1604 }
1605 let _ = write!(
1606 context,
1607 "\nThought: {}\nAction: {}\nObservation: {}\n",
1608 step.thought, step.action, step.observation
1609 );
1610 if let Some(max) = self.config.max_context_chars {
1611 Self::maybe_trim_context(&mut context, max);
1612 }
1613 if let Some(ref obs) = self.observer {
1614 obs.on_step(iteration, &step);
1615 }
1616 steps.push(step);
1617 continue;
1618 }
1619 }
1620
1621 if let Some(ref obs) = self.observer {
1623 obs.on_tool_call(&tool_name, &args);
1624 }
1625
1626 if let Some(ref m) = self.metrics {
1628 m.record_tool_call(&tool_name);
1629 }
1630
1631 let tool_span = tracing::info_span!("tool_dispatch", tool = %tool_name);
1633 let _tool_guard = tool_span.enter();
1634 let observation = match self.registry.call(&tool_name, args).await {
1635 Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
1636 Err(e) => {
1637 if let Some(ref m) = self.metrics {
1639 m.record_tool_failure(&tool_name);
1640 }
1641 Self::error_observation(&tool_name, &e)
1642 }
1643 };
1644
1645 step.observation = observation.clone();
1646 step.step_duration_ms = step_start.elapsed().as_millis() as u64;
1647 if let Some(ref m) = self.metrics {
1648 m.record_step_latency(step.step_duration_ms);
1649 }
1650 let _ = write!(
1651 context,
1652 "\nThought: {}\nAction: {}\nObservation: {}\n",
1653 step.thought, step.action, observation
1654 );
1655 if let Some(max) = self.config.max_context_chars {
1656 Self::maybe_trim_context(&mut context, max);
1657 }
1658 if let Some(ref obs) = self.observer {
1659 obs.on_step(iteration, &step);
1660 }
1661 steps.push(step);
1662
1663 #[cfg(feature = "persistence")]
1665 if let Some((ref backend, ref session_id)) = self.checkpoint_backend {
1666 let step_idx = steps.len();
1667 let key = format!("loop:{session_id}:step:{step_idx}");
1668 match serde_json::to_vec(&steps) {
1669 Ok(bytes) => {
1670 if let Err(e) = backend.save(&key, &bytes).await {
1671 tracing::warn!(
1672 key = %key,
1673 error = %e,
1674 "loop step checkpoint save failed"
1675 );
1676 }
1677 }
1678 Err(e) => {
1679 tracing::warn!(
1680 step = step_idx,
1681 error = %e,
1682 "loop step checkpoint serialisation failed"
1683 );
1684 }
1685 }
1686 }
1687 }
1688
1689 let err = AgentRuntimeError::AgentLoop(format!(
1690 "max iterations ({}) reached without final answer",
1691 self.config.max_iterations
1692 ));
1693 tracing::warn!(
1694 max_iterations = self.config.max_iterations,
1695 "ReAct loop exhausted max iterations without FINAL_ANSWER"
1696 );
1697 if let Some(ref obs) = self.observer {
1698 obs.on_error(&err);
1699 obs.on_loop_end(steps.len());
1700 }
1701 Err(err)
1702 }
1703
1704 #[tracing::instrument(skip(infer_stream))]
1714 pub async fn run_streaming<F, Fut>(
1715 &self,
1716 prompt: &str,
1717 mut infer_stream: F,
1718 ) -> Result<Vec<ReActStep>, AgentRuntimeError>
1719 where
1720 F: FnMut(String) -> Fut,
1721 Fut: Future<
1722 Output = tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>,
1723 >,
1724 {
1725 self.run(prompt, move |ctx| {
1726 let rx_fut = infer_stream(ctx);
1727 async move {
1728 let mut rx = rx_fut.await;
1729 let mut out = String::new();
1730 while let Some(chunk) = rx.recv().await {
1731 match chunk {
1732 Ok(s) => out.push_str(&s),
1733 Err(e) => {
1734 tracing::warn!(error = %e, "streaming chunk error; skipping");
1735 }
1736 }
1737 }
1738 out
1739 }
1740 })
1741 .await
1742 }
1743}
1744
1745pub trait ToolValidator: Send + Sync {
1811 fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError>;
1816}
1817
1818fn levenshtein(a: &str, b: &str) -> usize {
1822 let a: Vec<char> = a.chars().collect();
1823 let b: Vec<char> = b.chars().collect();
1824 let (m, n) = (a.len(), b.len());
1825 let mut dp = vec![vec![0usize; n + 1]; m + 1];
1826 for i in 0..=m {
1827 dp[i][0] = i;
1828 }
1829 for j in 0..=n {
1830 dp[0][j] = j;
1831 }
1832 for i in 1..=m {
1833 for j in 1..=n {
1834 dp[i][j] = if a[i - 1] == b[j - 1] {
1835 dp[i - 1][j - 1]
1836 } else {
1837 1 + dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1])
1838 };
1839 }
1840 }
1841 dp[m][n]
1842}
1843
1844fn parse_tool_call(action: &str) -> Result<(String, Value), AgentRuntimeError> {
1850 let mut parts = action.splitn(2, ' ');
1851 let name = parts.next().unwrap_or("").to_owned();
1852 if name.is_empty() {
1853 return Err(AgentRuntimeError::AgentLoop(
1854 "tool call has an empty tool name".into(),
1855 ));
1856 }
1857 let args_str = parts.next().unwrap_or("{}");
1858 let args: Value = serde_json::from_str(args_str).map_err(|e| {
1859 AgentRuntimeError::AgentLoop(format!(
1860 "invalid JSON args for tool call '{name}': {e} (raw: {args_str})"
1861 ))
1862 })?;
1863 Ok((name, args))
1864}
1865
1866#[derive(Debug, thiserror::Error)]
1870pub enum AgentError {
1871 #[error("Tool '{0}' not found")]
1873 ToolNotFound(String),
1874 #[error("Max iterations exceeded: {0}")]
1876 MaxIterations(usize),
1877 #[error("Parse error: {0}")]
1879 ParseError(String),
1880}
1881
1882impl From<AgentError> for AgentRuntimeError {
1883 fn from(e: AgentError) -> Self {
1884 AgentRuntimeError::AgentLoop(e.to_string())
1885 }
1886}
1887
1888pub trait Observer: Send + Sync {
1895 fn on_step(&self, step_index: usize, step: &ReActStep) {
1897 let _ = (step_index, step);
1898 }
1899 fn on_tool_call(&self, tool_name: &str, args: &serde_json::Value) {
1901 let _ = (tool_name, args);
1902 }
1903 fn on_action_blocked(&self, tool_name: &str, args: &serde_json::Value) {
1908 let _ = (tool_name, args);
1909 }
1910 fn on_loop_start(&self, prompt: &str) {
1912 let _ = prompt;
1913 }
1914 fn on_loop_end(&self, step_count: usize) {
1916 let _ = step_count;
1917 }
1918 fn on_error(&self, error: &crate::error::AgentRuntimeError) {
1923 let _ = error;
1924 }
1925}
1926
1927#[derive(Debug, Clone, PartialEq)]
1931pub enum Action {
1932 FinalAnswer(String),
1934 ToolCall {
1936 name: String,
1938 args: serde_json::Value,
1940 },
1941}
1942
1943impl Action {
1944 pub fn parse(s: &str) -> Result<Action, AgentRuntimeError> {
1949 if s.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER") {
1950 let answer = s.trim()["FINAL_ANSWER".len()..].trim().to_owned();
1951 return Ok(Action::FinalAnswer(answer));
1952 }
1953 let (name, args) = parse_tool_call(s)?;
1954 Ok(Action::ToolCall { name, args })
1955 }
1956}
1957
1958pub type ActionHook = Arc<dyn Fn(String, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync>;
1977
1978pub fn make_action_hook<F, Fut>(f: F) -> ActionHook
1993where
1994 F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
1995 Fut: std::future::Future<Output = bool> + Send + 'static,
1996{
1997 Arc::new(move |name, args| Box::pin(f(name, args)))
1998}
1999
2000#[cfg(test)]
2003mod tests {
2004 use super::*;
2005
2006 #[tokio::test]
2007 async fn test_final_answer_on_first_step() {
2008 let config = AgentConfig::new(5, "test-model");
2009 let loop_ = ReActLoop::new(config);
2010
2011 let steps = loop_
2012 .run("Say hello", |_ctx| async {
2013 "Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
2014 })
2015 .await
2016 .unwrap();
2017
2018 assert_eq!(steps.len(), 1);
2019 assert!(steps[0]
2020 .action
2021 .to_ascii_uppercase()
2022 .starts_with("FINAL_ANSWER"));
2023 }
2024
2025 #[tokio::test]
2026 async fn test_tool_call_then_final_answer() {
2027 let config = AgentConfig::new(5, "test-model");
2028 let mut loop_ = ReActLoop::new(config);
2029
2030 loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
2031 serde_json::json!("hello!")
2032 }));
2033
2034 let mut call_count = 0;
2035 let steps = loop_
2036 .run("Say hello", |_ctx| {
2037 call_count += 1;
2038 let count = call_count;
2039 async move {
2040 if count == 1 {
2041 "Thought: I will greet\nAction: greet {}".to_string()
2042 } else {
2043 "Thought: done\nAction: FINAL_ANSWER done".to_string()
2044 }
2045 }
2046 })
2047 .await
2048 .unwrap();
2049
2050 assert_eq!(steps.len(), 2);
2051 assert_eq!(steps[0].action, "greet {}");
2052 assert!(steps[1]
2053 .action
2054 .to_ascii_uppercase()
2055 .starts_with("FINAL_ANSWER"));
2056 }
2057
2058 #[tokio::test]
2059 async fn test_max_iterations_exceeded() {
2060 let config = AgentConfig::new(2, "test-model");
2061 let loop_ = ReActLoop::new(config);
2062
2063 let result = loop_
2064 .run("loop forever", |_ctx| async {
2065 "Thought: thinking\nAction: noop {}".to_string()
2066 })
2067 .await;
2068
2069 assert!(result.is_err());
2070 let err = result.unwrap_err().to_string();
2071 assert!(err.contains("max iterations"));
2072 }
2073
2074 #[tokio::test]
2075 async fn test_parse_react_step_valid() {
2076 let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
2077 let step = parse_react_step(text).unwrap();
2078 assert_eq!(step.thought, "I should check");
2079 assert_eq!(step.action, "lookup {\"key\":\"val\"}");
2080 }
2081
2082 #[tokio::test]
2083 async fn test_parse_react_step_empty_fails() {
2084 let result = parse_react_step("no prefix lines here");
2085 assert!(result.is_err());
2086 }
2087
2088 #[tokio::test]
2089 async fn test_tool_not_found_returns_error_observation() {
2090 let config = AgentConfig::new(3, "test-model");
2091 let loop_ = ReActLoop::new(config);
2092
2093 let mut call_count = 0;
2094 let steps = loop_
2095 .run("test", |_ctx| {
2096 call_count += 1;
2097 let count = call_count;
2098 async move {
2099 if count == 1 {
2100 "Thought: try missing tool\nAction: missing_tool {}".to_string()
2101 } else {
2102 "Thought: done\nAction: FINAL_ANSWER done".to_string()
2103 }
2104 }
2105 })
2106 .await
2107 .unwrap();
2108
2109 assert_eq!(steps.len(), 2);
2110 assert!(steps[0].observation.contains("\"ok\":false"));
2111 }
2112
2113 #[tokio::test]
2114 async fn test_new_async_tool_spec() {
2115 let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
2116 Box::pin(async move { serde_json::json!({"echo": args}) })
2117 });
2118
2119 let result = spec.call(serde_json::json!({"input": "test"})).await;
2120 assert!(result.get("echo").is_some());
2121 }
2122
2123 #[tokio::test]
2126 async fn test_parse_react_step_case_insensitive() {
2127 let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
2128 let step = parse_react_step(text).unwrap();
2129 assert_eq!(step.thought, "done");
2130 assert_eq!(step.action, "FINAL_ANSWER");
2131 }
2132
2133 #[tokio::test]
2134 async fn test_parse_react_step_space_before_colon() {
2135 let text = "Thought : done\nAction : go";
2136 let step = parse_react_step(text).unwrap();
2137 assert_eq!(step.thought, "done");
2138 assert_eq!(step.action, "go");
2139 }
2140
2141 #[tokio::test]
2144 async fn test_tool_required_fields_missing_returns_error() {
2145 let config = AgentConfig::new(3, "test-model");
2146 let mut loop_ = ReActLoop::new(config);
2147
2148 loop_.register_tool(
2149 ToolSpec::new(
2150 "search",
2151 "Searches for something",
2152 |args| serde_json::json!({ "result": args }),
2153 )
2154 .with_required_fields(vec!["q".to_string()]),
2155 );
2156
2157 let mut call_count = 0;
2158 let steps = loop_
2159 .run("test", |_ctx| {
2160 call_count += 1;
2161 let count = call_count;
2162 async move {
2163 if count == 1 {
2164 "Thought: searching\nAction: search {}".to_string()
2166 } else {
2167 "Thought: done\nAction: FINAL_ANSWER done".to_string()
2168 }
2169 }
2170 })
2171 .await
2172 .unwrap();
2173
2174 assert_eq!(steps.len(), 2);
2175 assert!(
2176 steps[0].observation.contains("missing required field"),
2177 "observation was: {}",
2178 steps[0].observation
2179 );
2180 }
2181
2182 #[tokio::test]
2185 async fn test_tool_error_observation_includes_kind() {
2186 let config = AgentConfig::new(3, "test-model");
2187 let loop_ = ReActLoop::new(config);
2188
2189 let mut call_count = 0;
2190 let steps = loop_
2191 .run("test", |_ctx| {
2192 call_count += 1;
2193 let count = call_count;
2194 async move {
2195 if count == 1 {
2196 "Thought: try missing\nAction: nonexistent_tool {}".to_string()
2197 } else {
2198 "Thought: done\nAction: FINAL_ANSWER done".to_string()
2199 }
2200 }
2201 })
2202 .await
2203 .unwrap();
2204
2205 assert_eq!(steps.len(), 2);
2206 let obs = &steps[0].observation;
2207 assert!(obs.contains("\"ok\":false"), "observation: {obs}");
2208 assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
2209 }
2210
2211 #[tokio::test]
2214 async fn test_step_duration_ms_is_set() {
2215 let config = AgentConfig::new(5, "test-model");
2216 let loop_ = ReActLoop::new(config);
2217
2218 let steps = loop_
2219 .run("time it", |_ctx| async {
2220 "Thought: done\nAction: FINAL_ANSWER ok".to_string()
2221 })
2222 .await
2223 .unwrap();
2224
2225 let _ = steps[0].step_duration_ms; }
2228
2229 struct RequirePositiveN;
2232 impl ToolValidator for RequirePositiveN {
2233 fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError> {
2234 let n = args.get("n").and_then(|v| v.as_i64()).unwrap_or(0);
2235 if n <= 0 {
2236 return Err(AgentRuntimeError::AgentLoop(
2237 "n must be a positive integer".into(),
2238 ));
2239 }
2240 Ok(())
2241 }
2242 }
2243
2244 #[tokio::test]
2245 async fn test_tool_validator_blocks_invalid_args() {
2246 let mut registry = ToolRegistry::new();
2247 registry.register(
2248 ToolSpec::new("calc", "compute", |args| serde_json::json!({"n": args}))
2249 .with_validators(vec![Box::new(RequirePositiveN)]),
2250 );
2251
2252 let result = registry
2254 .call("calc", serde_json::json!({"n": -1}))
2255 .await;
2256 assert!(result.is_err(), "validator should reject n=-1");
2257 assert!(result.unwrap_err().to_string().contains("positive integer"));
2258 }
2259
2260 #[tokio::test]
2261 async fn test_tool_validator_passes_valid_args() {
2262 let mut registry = ToolRegistry::new();
2263 registry.register(
2264 ToolSpec::new("calc", "compute", |_| serde_json::json!(42))
2265 .with_validators(vec![Box::new(RequirePositiveN)]),
2266 );
2267
2268 let result = registry
2269 .call("calc", serde_json::json!({"n": 5}))
2270 .await;
2271 assert!(result.is_ok(), "validator should accept n=5");
2272 }
2273
2274 #[tokio::test]
2277 async fn test_empty_tool_name_is_rejected() {
2278 let result = parse_tool_call("");
2280 assert!(result.is_err());
2281 assert!(
2282 result.unwrap_err().to_string().contains("empty tool name"),
2283 "expected 'empty tool name' error"
2284 );
2285 }
2286
2287 #[tokio::test]
2290 async fn test_register_tools_bulk() {
2291 let mut registry = ToolRegistry::new();
2292 registry.register_tools(vec![
2293 ToolSpec::new("tool_a", "A", |_| serde_json::json!("a")),
2294 ToolSpec::new("tool_b", "B", |_| serde_json::json!("b")),
2295 ]);
2296 assert!(registry.call("tool_a", serde_json::json!({})).await.is_ok());
2297 assert!(registry.call("tool_b", serde_json::json!({})).await.is_ok());
2298 }
2299
2300 #[tokio::test]
2303 async fn test_run_streaming_parity_with_run() {
2304 use tokio::sync::mpsc;
2305
2306 let config = AgentConfig::new(5, "test-model");
2307 let loop_ = ReActLoop::new(config);
2308
2309 let steps = loop_
2310 .run_streaming("Say hello", |_ctx| async {
2311 let (tx, rx) = mpsc::channel(4);
2312 tokio::spawn(async move {
2314 tx.send(Ok("Thought: done\n".to_string())).await.ok();
2315 tx.send(Ok("Action: FINAL_ANSWER hi".to_string())).await.ok();
2316 });
2317 rx
2318 })
2319 .await
2320 .unwrap();
2321
2322 assert_eq!(steps.len(), 1);
2323 assert!(steps[0]
2324 .action
2325 .to_ascii_uppercase()
2326 .starts_with("FINAL_ANSWER"));
2327 }
2328
2329 #[tokio::test]
2330 async fn test_run_streaming_error_chunk_is_skipped() {
2331 use tokio::sync::mpsc;
2332 use crate::error::AgentRuntimeError;
2333
2334 let config = AgentConfig::new(5, "test-model");
2335 let loop_ = ReActLoop::new(config);
2336
2337 let steps = loop_
2339 .run_streaming("test", |_ctx| async {
2340 let (tx, rx) = mpsc::channel(4);
2341 tokio::spawn(async move {
2342 tx.send(Err(AgentRuntimeError::Provider("stream error".into())))
2343 .await
2344 .ok();
2345 tx.send(Ok("Thought: recovered\nAction: FINAL_ANSWER ok".to_string()))
2346 .await
2347 .ok();
2348 });
2349 rx
2350 })
2351 .await
2352 .unwrap();
2353
2354 assert_eq!(steps.len(), 1);
2355 }
2356
2357 #[cfg(feature = "orchestrator")]
2360 #[tokio::test]
2361 async fn test_tool_with_circuit_breaker_passes_when_closed() {
2362 use std::sync::Arc;
2363
2364 let cb = Arc::new(
2365 crate::orchestrator::CircuitBreaker::new(
2366 "echo-tool",
2367 5,
2368 std::time::Duration::from_secs(30),
2369 )
2370 .unwrap(),
2371 );
2372
2373 let spec = ToolSpec::new(
2374 "echo",
2375 "Echoes args",
2376 |args| serde_json::json!({ "echoed": args }),
2377 )
2378 .with_circuit_breaker(cb);
2379
2380 let registry = {
2381 let mut r = ToolRegistry::new();
2382 r.register(spec);
2383 r
2384 };
2385
2386 let result = registry
2387 .call("echo", serde_json::json!({ "msg": "hi" }))
2388 .await;
2389 assert!(result.is_ok(), "expected Ok, got {:?}", result);
2390 }
2391
2392 #[test]
2395 fn test_agent_config_builder_methods_set_fields() {
2396 let config = AgentConfig::new(3, "model")
2397 .with_temperature(0.7)
2398 .with_max_tokens(512)
2399 .with_request_timeout(std::time::Duration::from_secs(10));
2400 assert_eq!(config.temperature, Some(0.7));
2401 assert_eq!(config.max_tokens, Some(512));
2402 assert_eq!(config.request_timeout, Some(std::time::Duration::from_secs(10)));
2403 }
2404
2405 #[tokio::test]
2408 async fn test_fallible_tool_returns_error_json_on_err() {
2409 let spec = ToolSpec::new_fallible(
2410 "fail",
2411 "always fails",
2412 |_| Err::<Value, String>("something went wrong".to_string()),
2413 );
2414 let result = spec.call(serde_json::json!({})).await;
2415 assert_eq!(result["ok"], serde_json::json!(false));
2416 assert_eq!(result["error"], serde_json::json!("something went wrong"));
2417 }
2418
2419 #[tokio::test]
2420 async fn test_fallible_tool_returns_value_on_ok() {
2421 let spec = ToolSpec::new_fallible(
2422 "succeed",
2423 "always succeeds",
2424 |_| Ok::<Value, String>(serde_json::json!(42)),
2425 );
2426 let result = spec.call(serde_json::json!({})).await;
2427 assert_eq!(result, serde_json::json!(42));
2428 }
2429
2430 #[tokio::test]
2433 async fn test_did_you_mean_suggestion_for_typo() {
2434 let mut registry = ToolRegistry::new();
2435 registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
2436 let result = registry.call("searc", serde_json::json!({})).await;
2437 assert!(result.is_err());
2438 let msg = result.unwrap_err().to_string();
2439 assert!(msg.contains("did you mean"), "expected suggestion in: {msg}");
2440 }
2441
2442 #[tokio::test]
2443 async fn test_no_suggestion_for_very_different_name() {
2444 let mut registry = ToolRegistry::new();
2445 registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
2446 let result = registry.call("xxxxxxxxxxxxxxx", serde_json::json!({})).await;
2447 assert!(result.is_err());
2448 let msg = result.unwrap_err().to_string();
2449 assert!(!msg.contains("did you mean"), "unexpected suggestion in: {msg}");
2450 }
2451
2452 #[test]
2455 fn test_action_parse_final_answer() {
2456 let action = Action::parse("FINAL_ANSWER hello world").unwrap();
2457 assert_eq!(action, Action::FinalAnswer("hello world".to_string()));
2458 }
2459
2460 #[test]
2461 fn test_action_parse_tool_call() {
2462 let action = Action::parse("search {\"q\": \"rust\"}").unwrap();
2463 match action {
2464 Action::ToolCall { name, args } => {
2465 assert_eq!(name, "search");
2466 assert_eq!(args["q"], "rust");
2467 }
2468 _ => panic!("expected ToolCall"),
2469 }
2470 }
2471
2472 #[test]
2473 fn test_action_parse_invalid_returns_err() {
2474 let result = Action::parse("");
2475 assert!(result.is_err());
2476 }
2477
2478 #[tokio::test]
2481 async fn test_observer_on_step_called_for_each_step() {
2482 use std::sync::{Arc, Mutex};
2483
2484 struct CountingObserver {
2485 step_count: Mutex<usize>,
2486 }
2487 impl Observer for CountingObserver {
2488 fn on_step(&self, _step_index: usize, _step: &ReActStep) {
2489 let mut c = self.step_count.lock().unwrap_or_else(|e| e.into_inner());
2490 *c += 1;
2491 }
2492 }
2493
2494 let obs = Arc::new(CountingObserver { step_count: Mutex::new(0) });
2495 let config = AgentConfig::new(5, "test-model");
2496 let mut loop_ = ReActLoop::new(config).with_observer(obs.clone() as Arc<dyn Observer>);
2497 loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
2498
2499 let mut call_count = 0;
2500 let _steps = loop_.run("test", |_ctx| {
2501 call_count += 1;
2502 let count = call_count;
2503 async move {
2504 if count == 1 {
2505 "Thought: call noop\nAction: noop {}".to_string()
2506 } else {
2507 "Thought: done\nAction: FINAL_ANSWER done".to_string()
2508 }
2509 }
2510 }).await.unwrap();
2511
2512 let count = *obs.step_count.lock().unwrap_or_else(|e| e.into_inner());
2513 assert_eq!(count, 2, "observer should have seen 2 steps");
2514 }
2515
2516 #[tokio::test]
2519 async fn test_tool_cache_returns_cached_result_on_second_call() {
2520 use std::collections::HashMap;
2521 use std::sync::Mutex;
2522
2523 struct InMemCache {
2524 map: Mutex<HashMap<String, Value>>,
2525 }
2526 impl ToolCache for InMemCache {
2527 fn get(&self, tool_name: &str, args: &Value) -> Option<Value> {
2528 let key = format!("{tool_name}:{args}");
2529 let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
2530 map.get(&key).cloned()
2531 }
2532 fn set(&self, tool_name: &str, args: &Value, result: Value) {
2533 let key = format!("{tool_name}:{args}");
2534 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
2535 map.insert(key, result);
2536 }
2537 }
2538
2539 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
2540 let call_count_clone = call_count.clone();
2541
2542 let cache = Arc::new(InMemCache { map: Mutex::new(HashMap::new()) });
2543 let registry = ToolRegistry::new()
2544 .with_cache(cache as Arc<dyn ToolCache>);
2545 let mut registry = registry;
2546
2547 registry.register(ToolSpec::new("count", "count calls", move |_| {
2548 call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2549 serde_json::json!({"calls": 1})
2550 }));
2551
2552 let args = serde_json::json!({});
2553 let r1 = registry.call("count", args.clone()).await.unwrap();
2554 let r2 = registry.call("count", args.clone()).await.unwrap();
2555
2556 assert_eq!(r1, r2);
2557 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2559 }
2560
2561 #[tokio::test]
2564 async fn test_validators_short_circuit_on_first_failure() {
2565 use std::sync::atomic::{AtomicUsize, Ordering as AOrdering};
2566 use std::sync::Arc;
2567
2568 let second_called = Arc::new(AtomicUsize::new(0));
2569 let second_called_clone = Arc::clone(&second_called);
2570
2571 struct AlwaysFail;
2572 impl ToolValidator for AlwaysFail {
2573 fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
2574 Err(AgentRuntimeError::AgentLoop("first validator failed".into()))
2575 }
2576 }
2577
2578 struct CountCalls(Arc<AtomicUsize>);
2579 impl ToolValidator for CountCalls {
2580 fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
2581 self.0.fetch_add(1, AOrdering::SeqCst);
2582 Ok(())
2583 }
2584 }
2585
2586 let mut registry = ToolRegistry::new();
2587 registry.register(
2588 ToolSpec::new("guarded", "A guarded tool", |args| args.clone())
2589 .with_validators(vec![
2590 Box::new(AlwaysFail),
2591 Box::new(CountCalls(second_called_clone)),
2592 ]),
2593 );
2594
2595 let result = registry.call("guarded", serde_json::json!({})).await;
2596 assert!(result.is_err(), "should fail due to first validator");
2597 assert_eq!(
2598 second_called.load(AOrdering::SeqCst),
2599 0,
2600 "second validator must not be called when first fails"
2601 );
2602 }
2603
2604 #[tokio::test]
2607 async fn test_loop_timeout_fires_between_iterations() {
2608 let mut config = AgentConfig::new(100, "test-model");
2609 config.loop_timeout = Some(std::time::Duration::from_millis(30));
2611 let loop_ = ReActLoop::new(config);
2612
2613 let result = loop_
2614 .run("test", |_ctx| async {
2615 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
2617 "Thought: still working\nAction: noop {}".to_string()
2619 })
2620 .await;
2621
2622 assert!(result.is_err(), "loop should time out");
2623 let msg = result.unwrap_err().to_string();
2624 assert!(msg.contains("loop timeout"), "unexpected error: {msg}");
2625 }
2626
2627 #[test]
2632 fn test_react_step_is_final_answer() {
2633 let step = ReActStep {
2634 thought: "".into(),
2635 action: "FINAL_ANSWER done".into(),
2636 observation: "".into(),
2637 step_duration_ms: 0,
2638 };
2639 assert!(step.is_final_answer());
2640 assert!(!step.is_tool_call());
2641 }
2642
2643 #[test]
2644 fn test_react_step_is_tool_call() {
2645 let step = ReActStep {
2646 thought: "".into(),
2647 action: "search {}".into(),
2648 observation: "".into(),
2649 step_duration_ms: 0,
2650 };
2651 assert!(!step.is_final_answer());
2652 assert!(step.is_tool_call());
2653 }
2654
2655 #[test]
2658 fn test_role_display() {
2659 assert_eq!(Role::System.to_string(), "system");
2660 assert_eq!(Role::User.to_string(), "user");
2661 assert_eq!(Role::Assistant.to_string(), "assistant");
2662 assert_eq!(Role::Tool.to_string(), "tool");
2663 }
2664
2665 #[test]
2668 fn test_message_accessors() {
2669 let msg = Message::new(Role::User, "hello");
2670 assert_eq!(msg.role(), &Role::User);
2671 assert_eq!(msg.content(), "hello");
2672 }
2673
2674 #[test]
2677 fn test_action_parse_final_answer_round_trip() {
2678 let step = ReActStep {
2679 thought: "done".into(),
2680 action: "FINAL_ANSWER Paris".into(),
2681 observation: "".into(),
2682 step_duration_ms: 0,
2683 };
2684 assert!(step.is_final_answer());
2685 let action = Action::parse(&step.action).unwrap();
2686 assert!(matches!(action, Action::FinalAnswer(ref s) if s == "Paris"));
2687 }
2688
2689 #[test]
2690 fn test_action_parse_tool_call_round_trip() {
2691 let step = ReActStep {
2692 thought: "searching".into(),
2693 action: "search {\"q\":\"hello\"}".into(),
2694 observation: "".into(),
2695 step_duration_ms: 0,
2696 };
2697 assert!(step.is_tool_call());
2698 let action = Action::parse(&step.action).unwrap();
2699 assert!(matches!(action, Action::ToolCall { ref name, .. } if name == "search"));
2700 }
2701
2702 #[tokio::test]
2705 async fn test_observer_receives_correct_step_indices() {
2706 use std::sync::{Arc, Mutex};
2707
2708 struct IndexCollector(Arc<Mutex<Vec<usize>>>);
2709 impl Observer for IndexCollector {
2710 fn on_step(&self, step_index: usize, _step: &ReActStep) {
2711 self.0.lock().unwrap_or_else(|e| e.into_inner()).push(step_index);
2712 }
2713 }
2714
2715 let indices = Arc::new(Mutex::new(Vec::new()));
2716 let obs = Arc::new(IndexCollector(Arc::clone(&indices)));
2717
2718 let config = AgentConfig::new(5, "test");
2719 let mut loop_ = ReActLoop::new(config).with_observer(obs as Arc<dyn Observer>);
2720 loop_.register_tool(ToolSpec::new("noop", "no-op", |_| serde_json::json!({})));
2721
2722 let mut call_count = 0;
2723 loop_.run("test", |_ctx| {
2724 call_count += 1;
2725 let count = call_count;
2726 async move {
2727 if count == 1 {
2728 "Thought: step1\nAction: noop {}".to_string()
2729 } else {
2730 "Thought: done\nAction: FINAL_ANSWER ok".to_string()
2731 }
2732 }
2733 }).await.unwrap();
2734
2735 let collected = indices.lock().unwrap_or_else(|e| e.into_inner()).clone();
2736 assert_eq!(collected, vec![0, 1], "expected step indices 0 and 1");
2737 }
2738
2739 #[tokio::test]
2740 async fn test_action_hook_blocking_inserts_blocked_observation() {
2741 let hook: ActionHook = Arc::new(|_name, _args| {
2742 Box::pin(async move { false }) });
2744
2745 let config = AgentConfig::new(5, "test-model");
2746 let mut loop_ = ReActLoop::new(config).with_action_hook(hook);
2747 loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
2748
2749 let mut call_count = 0;
2750 let steps = loop_.run("test", |_ctx| {
2751 call_count += 1;
2752 let count = call_count;
2753 async move {
2754 if count == 1 {
2755 "Thought: try tool\nAction: noop {}".to_string()
2756 } else {
2757 "Thought: done\nAction: FINAL_ANSWER done".to_string()
2758 }
2759 }
2760 }).await.unwrap();
2761
2762 assert!(steps[0].observation.contains("blocked"), "expected blocked observation, got: {}", steps[0].observation);
2763 }
2764
2765 #[test]
2766 fn test_react_step_new_constructor() {
2767 let s = ReActStep::new("think", "act", "obs");
2768 assert_eq!(s.thought, "think");
2769 assert_eq!(s.action, "act");
2770 assert_eq!(s.observation, "obs");
2771 assert_eq!(s.step_duration_ms, 0);
2772 }
2773
2774 #[test]
2775 fn test_react_step_new_is_tool_call() {
2776 let s = ReActStep::new("think", "search {}", "result");
2777 assert!(s.is_tool_call());
2778 assert!(!s.is_final_answer());
2779 }
2780
2781 #[test]
2782 fn test_react_step_new_is_final_answer() {
2783 let s = ReActStep::new("done", "FINAL_ANSWER 42", "");
2784 assert!(s.is_final_answer());
2785 assert!(!s.is_tool_call());
2786 }
2787
2788 #[test]
2789 fn test_agent_config_is_valid_with_valid_config() {
2790 let cfg = AgentConfig::new(5, "my-model");
2791 assert!(cfg.is_valid());
2792 }
2793
2794 #[test]
2795 fn test_agent_config_is_valid_with_zero_iterations() {
2796 let mut cfg = AgentConfig::new(1, "my-model");
2797 cfg.max_iterations = 0;
2798 assert!(!cfg.is_valid());
2799 }
2800
2801 #[test]
2802 fn test_agent_config_is_valid_with_empty_model() {
2803 let mut cfg = AgentConfig::new(5, "my-model");
2804 cfg.model = String::new();
2805 assert!(!cfg.is_valid());
2806 }
2807
2808 #[test]
2809 fn test_react_loop_tool_count_delegates_to_registry() {
2810 let cfg = AgentConfig::new(5, "model");
2811 let mut loop_ = ReActLoop::new(cfg);
2812 assert_eq!(loop_.tool_count(), 0);
2813 loop_.register_tool(ToolSpec::new("t1", "desc", |_| serde_json::json!("ok")));
2814 loop_.register_tool(ToolSpec::new("t2", "desc", |_| serde_json::json!("ok")));
2815 assert_eq!(loop_.tool_count(), 2);
2816 }
2817
2818 #[test]
2819 fn test_tool_registry_has_tool_returns_true_when_registered() {
2820 let mut reg = ToolRegistry::new();
2821 reg.register(ToolSpec::new("my-tool", "desc", |_| serde_json::json!("ok")));
2822 assert!(reg.has_tool("my-tool"));
2823 assert!(!reg.has_tool("other-tool"));
2824 }
2825
2826 #[test]
2829 fn test_agent_config_validate_ok_for_valid_config() {
2830 let cfg = AgentConfig::new(5, "my-model");
2831 assert!(cfg.validate().is_ok());
2832 }
2833
2834 #[test]
2835 fn test_agent_config_validate_err_for_zero_iterations() {
2836 let cfg = AgentConfig::new(0, "my-model");
2837 let err = cfg.validate().unwrap_err();
2838 assert!(err.to_string().contains("max_iterations"));
2839 }
2840
2841 #[test]
2842 fn test_agent_config_validate_err_for_empty_model() {
2843 let cfg = AgentConfig::new(5, "");
2844 let err = cfg.validate().unwrap_err();
2845 assert!(err.to_string().contains("model"));
2846 }
2847
2848 #[test]
2851 fn test_clone_with_model_produces_new_model_string() {
2852 let cfg = AgentConfig::new(5, "gpt-4");
2853 let new_cfg = cfg.clone_with_model("claude-3");
2854 assert_eq!(new_cfg.model, "claude-3");
2855 assert_eq!(cfg.model, "gpt-4");
2857 }
2858
2859 #[test]
2860 fn test_clone_with_model_preserves_other_fields() {
2861 let cfg = AgentConfig::new(10, "gpt-4").with_stop_sequences(vec!["STOP".to_string()]);
2862 let new_cfg = cfg.clone_with_model("o1");
2863 assert_eq!(new_cfg.max_iterations, 10);
2864 assert_eq!(new_cfg.stop_sequences, cfg.stop_sequences);
2865 }
2866
2867 #[tokio::test]
2868 async fn test_tool_spec_with_name_changes_name() {
2869 let spec = ToolSpec::new("original", "desc", |_| serde_json::json!("ok"))
2870 .with_name("renamed");
2871 assert_eq!(spec.name, "renamed");
2872 }
2873
2874 #[tokio::test]
2875 async fn test_tool_spec_with_name_and_description_chainable() {
2876 let spec = ToolSpec::new("old", "old desc", |_| serde_json::json!("ok"))
2877 .with_name("new")
2878 .with_description("new desc");
2879 assert_eq!(spec.name, "new");
2880 assert_eq!(spec.description, "new desc");
2881 }
2882
2883 #[test]
2886 fn test_message_user_sets_role_and_content() {
2887 let m = Message::user("hello");
2888 assert_eq!(m.content(), "hello");
2889 assert!(m.is_user());
2890 assert!(!m.is_assistant());
2891 }
2892
2893 #[test]
2894 fn test_message_assistant_sets_role() {
2895 let m = Message::assistant("reply");
2896 assert!(m.is_assistant());
2897 assert!(!m.is_user());
2898 assert!(!m.is_system());
2899 }
2900
2901 #[test]
2902 fn test_message_system_sets_role() {
2903 let m = Message::system("system prompt");
2904 assert!(m.is_system());
2905 assert_eq!(m.content(), "system prompt");
2906 }
2907
2908 #[test]
2909 fn test_parse_react_step_valid_input() {
2910 let text = "Thought: I need to search\nAction: search[query]";
2911 let step = parse_react_step(text).unwrap();
2912 assert!(step.thought.contains("search"));
2913 assert!(step.action.contains("search"));
2914 }
2915
2916 #[test]
2917 fn test_parse_react_step_missing_fields_returns_err() {
2918 let text = "no structured content here";
2919 assert!(parse_react_step(text).is_err());
2920 }
2921
2922 #[test]
2925 fn test_react_step_is_final_answer_true() {
2926 let step = ReActStep::new("t", "FINAL_ANSWER Paris", "");
2927 assert!(step.is_final_answer());
2928 assert!(!step.is_tool_call());
2929 }
2930
2931 #[test]
2932 fn test_react_step_is_tool_call_true() {
2933 let step = ReActStep::new("t", "search {}", "result");
2934 assert!(step.is_tool_call());
2935 assert!(!step.is_final_answer());
2936 }
2937
2938 #[test]
2939 fn test_tool_registry_unregister_returns_true_when_present() {
2940 let mut reg = ToolRegistry::new();
2941 reg.register(ToolSpec::new("tool-x", "desc", |_| serde_json::json!("ok")));
2942 assert!(reg.unregister("tool-x"));
2943 assert!(!reg.has_tool("tool-x"));
2944 }
2945
2946 #[test]
2947 fn test_tool_registry_unregister_returns_false_when_absent() {
2948 let mut reg = ToolRegistry::new();
2949 assert!(!reg.unregister("ghost"));
2950 }
2951
2952 #[test]
2953 fn test_tool_registry_contains_matches_has_tool() {
2954 let mut reg = ToolRegistry::new();
2955 reg.register(ToolSpec::new("alpha", "desc", |_| serde_json::json!("ok")));
2956 assert!(reg.contains("alpha"));
2957 assert!(!reg.contains("beta"));
2958 }
2959
2960 #[test]
2961 fn test_agent_config_with_system_prompt() {
2962 let cfg = AgentConfig::new(5, "model")
2963 .with_system_prompt("You are helpful.");
2964 assert_eq!(cfg.system_prompt, "You are helpful.");
2965 }
2966
2967 #[test]
2968 fn test_agent_config_with_temperature_and_max_tokens() {
2969 let cfg = AgentConfig::new(3, "model")
2970 .with_temperature(0.7)
2971 .with_max_tokens(512);
2972 assert!((cfg.temperature.unwrap() - 0.7).abs() < 1e-6);
2973 assert_eq!(cfg.max_tokens, Some(512));
2974 }
2975
2976 #[test]
2977 fn test_agent_config_clone_with_model() {
2978 let orig = AgentConfig::new(5, "gpt-4");
2979 let cloned = orig.clone_with_model("claude-3");
2980 assert_eq!(cloned.model, "claude-3");
2981 assert_eq!(cloned.max_iterations, 5);
2982 }
2983
2984 #[test]
2987 fn test_agent_config_with_loop_timeout_secs() {
2988 let cfg = AgentConfig::new(5, "model").with_loop_timeout_secs(30);
2989 assert_eq!(cfg.loop_timeout, Some(std::time::Duration::from_secs(30)));
2990 }
2991
2992 #[test]
2993 fn test_agent_config_with_max_context_chars() {
2994 let cfg = AgentConfig::new(5, "model").with_max_context_chars(4096);
2995 assert_eq!(cfg.max_context_chars, Some(4096));
2996 }
2997
2998 #[test]
2999 fn test_agent_config_with_stop_sequences() {
3000 let cfg = AgentConfig::new(5, "model")
3001 .with_stop_sequences(vec!["STOP".to_string(), "END".to_string()]);
3002 assert_eq!(cfg.stop_sequences, vec!["STOP", "END"]);
3003 }
3004
3005 #[test]
3006 fn test_message_is_tool_false_for_non_tool_roles() {
3007 assert!(!Message::user("hi").is_tool());
3008 assert!(!Message::assistant("reply").is_tool());
3009 assert!(!Message::system("prompt").is_tool());
3010 }
3011
3012 #[test]
3015 fn test_agent_config_with_max_iterations() {
3016 let cfg = AgentConfig::new(5, "m").with_max_iterations(20);
3017 assert_eq!(cfg.max_iterations, 20);
3018 }
3019
3020 #[test]
3021 fn test_tool_registry_tool_names_owned_returns_strings() {
3022 let mut reg = ToolRegistry::new();
3023 reg.register(ToolSpec::new("alpha", "d", |_| serde_json::json!("ok")));
3024 reg.register(ToolSpec::new("beta", "d", |_| serde_json::json!("ok")));
3025 let mut names = reg.tool_names_owned();
3026 names.sort();
3027 assert_eq!(names, vec!["alpha".to_string(), "beta".to_string()]);
3028 }
3029
3030 #[test]
3031 fn test_tool_registry_tool_names_owned_empty_when_no_tools() {
3032 let reg = ToolRegistry::new();
3033 assert!(reg.tool_names_owned().is_empty());
3034 }
3035
3036 #[test]
3039 fn test_tool_registry_tool_specs_returns_all_specs() {
3040 let mut reg = ToolRegistry::new();
3041 reg.register(ToolSpec::new("t1", "desc1", |_| serde_json::json!("ok")));
3042 reg.register(ToolSpec::new("t2", "desc2", |_| serde_json::json!("ok")));
3043 let specs = reg.tool_specs();
3044 assert_eq!(specs.len(), 2);
3045 }
3046
3047 #[test]
3048 fn test_tool_registry_tool_specs_empty_when_no_tools() {
3049 let reg = ToolRegistry::new();
3050 assert!(reg.tool_specs().is_empty());
3051 }
3052
3053 #[test]
3056 fn test_rename_tool_updates_name_and_key() {
3057 let mut reg = ToolRegistry::new();
3058 reg.register(ToolSpec::new("old", "desc", |_| serde_json::json!("ok")));
3059 assert!(reg.rename_tool("old", "new"));
3060 assert!(reg.has_tool("new"));
3061 assert!(!reg.has_tool("old"));
3062 let spec = reg.get("new").unwrap();
3063 assert_eq!(spec.name, "new");
3064 }
3065
3066 #[test]
3067 fn test_rename_tool_returns_false_for_unknown_name() {
3068 let mut reg = ToolRegistry::new();
3069 assert!(!reg.rename_tool("ghost", "other"));
3070 }
3071
3072 #[test]
3075 fn test_filter_tools_returns_matching_specs() {
3076 let mut reg = ToolRegistry::new();
3077 reg.register(ToolSpec::new("short_desc", "hi", |_| serde_json::json!({})));
3078 reg.register(ToolSpec::new("long_desc", "a longer description here", |_| serde_json::json!({})));
3079 let long_ones = reg.filter_tools(|s| s.description.len() > 10);
3080 assert_eq!(long_ones.len(), 1);
3081 assert_eq!(long_ones[0].name, "long_desc");
3082 }
3083
3084 #[test]
3085 fn test_filter_tools_returns_empty_when_none_match() {
3086 let mut reg = ToolRegistry::new();
3087 reg.register(ToolSpec::new("t1", "desc", |_| serde_json::json!({})));
3088 let none: Vec<_> = reg.filter_tools(|_| false);
3089 assert!(none.is_empty());
3090 }
3091
3092 #[test]
3093 fn test_filter_tools_returns_all_when_predicate_always_true() {
3094 let mut reg = ToolRegistry::new();
3095 reg.register(ToolSpec::new("a", "d1", |_| serde_json::json!({})));
3096 reg.register(ToolSpec::new("b", "d2", |_| serde_json::json!({})));
3097 let all = reg.filter_tools(|_| true);
3098 assert_eq!(all.len(), 2);
3099 }
3100
3101 #[test]
3104 fn test_agent_config_max_iterations_getter_returns_configured_value() {
3105 let cfg = AgentConfig::new(5, "model-x");
3106 assert_eq!(cfg.max_iterations(), 5);
3107 }
3108
3109 #[test]
3110 fn test_agent_config_with_max_iterations_updates_getter() {
3111 let cfg = AgentConfig::new(3, "m").with_max_iterations(10);
3112 assert_eq!(cfg.max_iterations(), 10);
3113 }
3114
3115 #[test]
3118 fn test_tool_registry_is_empty_true_when_new() {
3119 let reg = ToolRegistry::new();
3120 assert!(reg.is_empty());
3121 }
3122
3123 #[test]
3124 fn test_tool_registry_is_empty_false_after_register() {
3125 let mut reg = ToolRegistry::new();
3126 reg.register(ToolSpec::new("t", "d", |_| serde_json::json!({})));
3127 assert!(!reg.is_empty());
3128 }
3129
3130 #[test]
3131 fn test_tool_registry_clear_empties_registry() {
3132 let mut reg = ToolRegistry::new();
3133 reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
3134 reg.register(ToolSpec::new("t2", "d", |_| serde_json::json!({})));
3135 reg.clear();
3136 assert!(reg.is_empty());
3137 assert_eq!(reg.tool_count(), 0);
3138 }
3139
3140 #[test]
3141 fn test_tool_registry_remove_returns_spec_and_decrements_count() {
3142 let mut reg = ToolRegistry::new();
3143 reg.register(ToolSpec::new("myTool", "desc", |_| serde_json::json!({})));
3144 assert_eq!(reg.tool_count(), 1);
3145 let removed = reg.remove("myTool");
3146 assert!(removed.is_some());
3147 assert_eq!(reg.tool_count(), 0);
3148 }
3149
3150 #[test]
3151 fn test_tool_registry_remove_returns_none_for_absent_tool() {
3152 let mut reg = ToolRegistry::new();
3153 assert!(reg.remove("ghost").is_none());
3154 }
3155
3156 #[test]
3159 fn test_all_tool_names_returns_sorted_names() {
3160 let mut reg = ToolRegistry::new();
3161 reg.register(ToolSpec::new("zebra", "d", |_| serde_json::json!({})));
3162 reg.register(ToolSpec::new("apple", "d", |_| serde_json::json!({})));
3163 reg.register(ToolSpec::new("mango", "d", |_| serde_json::json!({})));
3164 let names = reg.all_tool_names();
3165 assert_eq!(names, vec!["apple", "mango", "zebra"]);
3166 }
3167
3168 #[test]
3169 fn test_all_tool_names_empty_for_empty_registry() {
3170 let reg = ToolRegistry::new();
3171 assert!(reg.all_tool_names().is_empty());
3172 }
3173
3174 #[test]
3177 fn test_remaining_iterations_after_full_budget() {
3178 let cfg = AgentConfig::new(10, "m");
3179 assert_eq!(cfg.remaining_iterations_after(0), 10);
3180 }
3181
3182 #[test]
3183 fn test_remaining_iterations_after_partial_use() {
3184 let cfg = AgentConfig::new(10, "m");
3185 assert_eq!(cfg.remaining_iterations_after(3), 7);
3186 }
3187
3188 #[test]
3189 fn test_remaining_iterations_after_saturates_at_zero() {
3190 let cfg = AgentConfig::new(5, "m");
3191 assert_eq!(cfg.remaining_iterations_after(10), 0);
3192 }
3193
3194 #[test]
3195 fn test_tool_spec_required_field_count_zero_by_default() {
3196 let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}));
3197 assert_eq!(spec.required_field_count(), 0);
3198 }
3199
3200 #[test]
3201 fn test_tool_spec_required_field_count_after_adding() {
3202 let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}))
3203 .with_required_fields(["query", "limit"]);
3204 assert_eq!(spec.required_field_count(), 2);
3205 }
3206
3207 #[test]
3208 fn test_tool_spec_has_required_fields_false_by_default() {
3209 let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}));
3210 assert!(!spec.has_required_fields());
3211 }
3212
3213 #[test]
3214 fn test_tool_spec_has_required_fields_true_after_adding() {
3215 let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}))
3216 .with_required_fields(["key"]);
3217 assert!(spec.has_required_fields());
3218 }
3219
3220 #[test]
3221 fn test_tool_spec_has_validators_false_by_default() {
3222 let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}));
3223 assert!(!spec.has_validators());
3224 }
3225
3226 #[test]
3229 fn test_tool_registry_contains_true_for_registered_tool() {
3230 let mut reg = ToolRegistry::new();
3231 reg.register(ToolSpec::new("search", "d", |_| serde_json::json!({})));
3232 assert!(reg.contains("search"));
3233 }
3234
3235 #[test]
3236 fn test_tool_registry_contains_false_for_unknown_tool() {
3237 let reg = ToolRegistry::new();
3238 assert!(!reg.contains("missing"));
3239 }
3240
3241 #[test]
3242 fn test_tool_registry_descriptions_sorted_by_name() {
3243 let mut reg = ToolRegistry::new();
3244 reg.register(ToolSpec::new("zebra", "z-desc", |_| serde_json::json!({})));
3245 reg.register(ToolSpec::new("apple", "a-desc", |_| serde_json::json!({})));
3246 let descs = reg.descriptions();
3247 assert_eq!(descs[0], ("apple", "a-desc"));
3248 assert_eq!(descs[1], ("zebra", "z-desc"));
3249 }
3250
3251 #[test]
3252 fn test_tool_registry_descriptions_empty_when_no_tools() {
3253 let reg = ToolRegistry::new();
3254 assert!(reg.descriptions().is_empty());
3255 }
3256
3257 #[test]
3258 fn test_tool_registry_tool_count_increments_on_register() {
3259 let mut reg = ToolRegistry::new();
3260 assert_eq!(reg.tool_count(), 0);
3261 reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
3262 assert_eq!(reg.tool_count(), 1);
3263 reg.register(ToolSpec::new("t2", "d", |_| serde_json::json!({})));
3264 assert_eq!(reg.tool_count(), 2);
3265 }
3266
3267 #[test]
3270 fn test_observation_is_empty_true_for_empty_string() {
3271 let step = ReActStep::new("think", "search", "");
3272 assert!(step.observation_is_empty());
3273 }
3274
3275 #[test]
3276 fn test_observation_is_empty_false_for_non_empty() {
3277 let step = ReActStep::new("think", "search", "found results");
3278 assert!(!step.observation_is_empty());
3279 }
3280
3281 #[test]
3284 fn test_agent_config_temperature_getter_none_by_default() {
3285 let cfg = AgentConfig::new(5, "gpt-4");
3286 assert!(cfg.temperature().is_none());
3287 }
3288
3289 #[test]
3290 fn test_agent_config_temperature_getter_some_when_set() {
3291 let cfg = AgentConfig::new(5, "gpt-4").with_temperature(0.7);
3292 assert!((cfg.temperature().unwrap() - 0.7).abs() < 1e-5);
3293 }
3294
3295 #[test]
3296 fn test_agent_config_max_tokens_getter_none_by_default() {
3297 let cfg = AgentConfig::new(5, "gpt-4");
3298 assert!(cfg.max_tokens().is_none());
3299 }
3300
3301 #[test]
3302 fn test_agent_config_max_tokens_getter_some_when_set() {
3303 let cfg = AgentConfig::new(5, "gpt-4").with_max_tokens(512);
3304 assert_eq!(cfg.max_tokens(), Some(512));
3305 }
3306
3307 #[test]
3308 fn test_agent_config_request_timeout_getter_none_by_default() {
3309 let cfg = AgentConfig::new(5, "gpt-4");
3310 assert!(cfg.request_timeout().is_none());
3311 }
3312
3313 #[test]
3314 fn test_agent_config_request_timeout_getter_some_when_set() {
3315 let cfg = AgentConfig::new(5, "gpt-4")
3316 .with_request_timeout(std::time::Duration::from_secs(10));
3317 assert_eq!(cfg.request_timeout(), Some(std::time::Duration::from_secs(10)));
3318 }
3319
3320 #[test]
3323 fn test_agent_config_has_max_context_chars_false_by_default() {
3324 let cfg = AgentConfig::new(5, "gpt-4");
3325 assert!(!cfg.has_max_context_chars());
3326 }
3327
3328 #[test]
3329 fn test_agent_config_has_max_context_chars_true_after_setting() {
3330 let cfg = AgentConfig::new(5, "gpt-4").with_max_context_chars(8192);
3331 assert!(cfg.has_max_context_chars());
3332 }
3333
3334 #[test]
3335 fn test_agent_config_max_context_chars_none_by_default() {
3336 let cfg = AgentConfig::new(5, "gpt-4");
3337 assert_eq!(cfg.max_context_chars(), None);
3338 }
3339
3340 #[test]
3341 fn test_agent_config_max_context_chars_some_after_setting() {
3342 let cfg = AgentConfig::new(5, "gpt-4").with_max_context_chars(4096);
3343 assert_eq!(cfg.max_context_chars(), Some(4096));
3344 }
3345
3346 #[test]
3347 fn test_agent_config_system_prompt_returns_configured_prompt() {
3348 let cfg = AgentConfig::new(5, "gpt-4").with_system_prompt("Be concise.");
3349 assert_eq!(cfg.system_prompt(), "Be concise.");
3350 }
3351
3352 #[test]
3353 fn test_agent_config_model_returns_configured_model() {
3354 let cfg = AgentConfig::new(5, "claude-3");
3355 assert_eq!(cfg.model(), "claude-3");
3356 }
3357
3358 #[test]
3361 fn test_message_is_system_true_for_system_role() {
3362 let m = Message::system("context");
3363 assert!(m.is_system());
3364 }
3365
3366 #[test]
3367 fn test_message_is_system_false_for_user_role() {
3368 let m = Message::user("hello");
3369 assert!(!m.is_system());
3370 }
3371
3372 #[test]
3373 fn test_message_word_count_counts_whitespace_words() {
3374 let m = Message::user("hello world foo");
3375 assert_eq!(m.word_count(), 3);
3376 }
3377
3378 #[test]
3379 fn test_message_word_count_zero_for_empty_content() {
3380 let m = Message::user("");
3381 assert_eq!(m.word_count(), 0);
3382 }
3383
3384 #[test]
3385 fn test_agent_config_has_loop_timeout_false_by_default() {
3386 let cfg = AgentConfig::new(5, "m");
3387 assert!(!cfg.has_loop_timeout());
3388 }
3389
3390 #[test]
3391 fn test_agent_config_has_loop_timeout_true_after_setting() {
3392 let cfg = AgentConfig::new(5, "m")
3393 .with_loop_timeout(std::time::Duration::from_secs(30));
3394 assert!(cfg.has_loop_timeout());
3395 }
3396
3397 #[test]
3398 fn test_agent_config_has_stop_sequences_false_by_default() {
3399 let cfg = AgentConfig::new(5, "m");
3400 assert!(!cfg.has_stop_sequences());
3401 }
3402
3403 #[test]
3404 fn test_agent_config_has_stop_sequences_true_after_adding() {
3405 let cfg = AgentConfig::new(5, "m").with_stop_sequences(vec!["STOP".to_string()]);
3406 assert!(cfg.has_stop_sequences());
3407 }
3408
3409 #[test]
3410 fn test_agent_config_is_single_shot_true_when_max_iterations_one() {
3411 let cfg = AgentConfig::new(1, "m");
3412 assert!(cfg.is_single_shot());
3413 }
3414
3415 #[test]
3416 fn test_agent_config_is_single_shot_false_when_max_iterations_gt_one() {
3417 let cfg = AgentConfig::new(5, "m");
3418 assert!(!cfg.is_single_shot());
3419 }
3420
3421 #[test]
3422 fn test_agent_config_has_temperature_false_by_default() {
3423 let cfg = AgentConfig::new(5, "m");
3424 assert!(!cfg.has_temperature());
3425 }
3426
3427 #[test]
3428 fn test_agent_config_has_temperature_true_after_setting() {
3429 let cfg = AgentConfig::new(5, "m").with_temperature(0.7);
3430 assert!(cfg.has_temperature());
3431 }
3432
3433 #[test]
3436 fn test_tool_spec_new_fallible_returns_ok_value() {
3437 let rt = tokio::runtime::Runtime::new().unwrap();
3438 let tool = ToolSpec::new_fallible(
3439 "add",
3440 "adds numbers",
3441 |_args| Ok(serde_json::json!({"result": 42})),
3442 );
3443 let result = rt.block_on(tool.call(serde_json::json!({})));
3444 assert_eq!(result["result"], 42);
3445 }
3446
3447 #[test]
3448 fn test_tool_spec_new_fallible_wraps_error_as_json() {
3449 let rt = tokio::runtime::Runtime::new().unwrap();
3450 let tool = ToolSpec::new_fallible(
3451 "fail",
3452 "always fails",
3453 |_| Err("bad input".to_string()),
3454 );
3455 let result = rt.block_on(tool.call(serde_json::json!({})));
3456 assert_eq!(result["error"], "bad input");
3457 assert_eq!(result["ok"], false);
3458 }
3459
3460 #[test]
3461 fn test_tool_spec_new_async_fallible_wraps_error() {
3462 let rt = tokio::runtime::Runtime::new().unwrap();
3463 let tool = ToolSpec::new_async_fallible(
3464 "async_fail",
3465 "async error",
3466 |_| Box::pin(async { Err("async bad".to_string()) }),
3467 );
3468 let result = rt.block_on(tool.call(serde_json::json!({})));
3469 assert_eq!(result["error"], "async bad");
3470 }
3471
3472 #[test]
3473 fn test_tool_spec_with_required_fields_sets_fields() {
3474 let tool = ToolSpec::new("t", "d", |_| serde_json::json!({}))
3475 .with_required_fields(["name", "value"]);
3476 assert_eq!(tool.required_field_count(), 2);
3477 }
3478
3479 #[test]
3480 fn test_tool_spec_with_description_overrides_description() {
3481 let tool = ToolSpec::new("t", "original", |_| serde_json::json!({}))
3482 .with_description("updated description");
3483 assert_eq!(tool.description, "updated description");
3484 }
3485
3486 #[test]
3489 fn test_agent_config_stop_sequence_count_zero_by_default() {
3490 let cfg = AgentConfig::new(5, "gpt-4");
3491 assert_eq!(cfg.stop_sequence_count(), 0);
3492 }
3493
3494 #[test]
3495 fn test_agent_config_stop_sequence_count_reflects_configured_count() {
3496 let cfg = AgentConfig::new(5, "gpt-4")
3497 .with_stop_sequences(vec!["STOP".to_string(), "END".to_string()]);
3498 assert_eq!(cfg.stop_sequence_count(), 2);
3499 }
3500
3501 #[test]
3502 fn test_tool_registry_find_by_description_keyword_empty_when_no_match() {
3503 let mut reg = ToolRegistry::new();
3504 reg.register(ToolSpec::new("calc", "Performs arithmetic", |_| serde_json::json!({})));
3505 let results = reg.find_by_description_keyword("weather");
3506 assert!(results.is_empty());
3507 }
3508
3509 #[test]
3510 fn test_tool_registry_find_by_description_keyword_case_insensitive() {
3511 let mut reg = ToolRegistry::new();
3512 reg.register(ToolSpec::new("calc", "Performs ARITHMETIC operations", |_| serde_json::json!({})));
3513 reg.register(ToolSpec::new("search", "Searches the web", |_| serde_json::json!({})));
3514 let results = reg.find_by_description_keyword("arithmetic");
3515 assert_eq!(results.len(), 1);
3516 assert_eq!(results[0].name, "calc");
3517 }
3518
3519 #[test]
3520 fn test_tool_registry_find_by_description_keyword_multiple_matches() {
3521 let mut reg = ToolRegistry::new();
3522 reg.register(ToolSpec::new("t1", "query the database", |_| serde_json::json!({})));
3523 reg.register(ToolSpec::new("t2", "query the cache", |_| serde_json::json!({})));
3524 reg.register(ToolSpec::new("t3", "send a message", |_| serde_json::json!({})));
3525 let results = reg.find_by_description_keyword("query");
3526 assert_eq!(results.len(), 2);
3527 }
3528
3529 #[test]
3533 fn test_message_is_user_true_for_user_role_r31() {
3534 let msg = Message::user("hello");
3535 assert!(msg.is_user());
3536 assert!(!msg.is_assistant());
3537 }
3538
3539 #[test]
3540 fn test_message_is_assistant_true_for_assistant_role_r31() {
3541 let msg = Message::assistant("hi there");
3542 assert!(msg.is_assistant());
3543 assert!(!msg.is_user());
3544 }
3545
3546 #[test]
3547 fn test_agent_config_stop_sequence_count_zero_for_new_config() {
3548 let cfg = AgentConfig::new(5, "model");
3549 assert_eq!(cfg.stop_sequence_count(), 0);
3550 }
3551
3552 #[test]
3553 fn test_agent_config_stop_sequence_count_after_setting() {
3554 let cfg = AgentConfig::new(5, "model")
3555 .with_stop_sequences(vec!["<stop>".to_string(), "END".to_string()]);
3556 assert_eq!(cfg.stop_sequence_count(), 2);
3557 }
3558
3559 #[test]
3560 fn test_agent_config_has_request_timeout_false_by_default() {
3561 let cfg = AgentConfig::new(5, "model");
3562 assert!(!cfg.has_request_timeout());
3563 }
3564
3565 #[test]
3566 fn test_agent_config_has_request_timeout_true_after_setting() {
3567 let cfg = AgentConfig::new(5, "model")
3568 .with_request_timeout(std::time::Duration::from_secs(30));
3569 assert!(cfg.has_request_timeout());
3570 }
3571
3572 #[test]
3575 fn test_react_loop_unregister_tool_removes_registered_tool() {
3576 let mut agent = ReActLoop::new(AgentConfig::new(5, "m"));
3577 agent.register_tool(ToolSpec::new("t1", "desc", |_| serde_json::json!({})));
3578 assert!(agent.unregister_tool("t1"));
3579 assert_eq!(agent.tool_count(), 0);
3580 }
3581
3582 #[test]
3583 fn test_react_loop_unregister_tool_returns_false_for_unknown() {
3584 let mut agent = ReActLoop::new(AgentConfig::new(5, "m"));
3585 assert!(!agent.unregister_tool("nonexistent"));
3586 }
3587
3588 #[test]
3591 fn test_tool_count_with_required_fields_zero_when_empty() {
3592 let reg = ToolRegistry::new();
3593 assert_eq!(reg.tool_count_with_required_fields(), 0);
3594 }
3595
3596 #[test]
3597 fn test_tool_count_with_required_fields_excludes_tools_without_fields() {
3598 let mut reg = ToolRegistry::new();
3599 reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
3600 assert_eq!(reg.tool_count_with_required_fields(), 0);
3601 }
3602
3603 #[test]
3604 fn test_tool_count_with_required_fields_counts_only_tools_with_fields() {
3605 let mut reg = ToolRegistry::new();
3606 reg.register(
3607 ToolSpec::new("t1", "d", |_| serde_json::json!({}))
3608 .with_required_fields(["query"]),
3609 );
3610 reg.register(ToolSpec::new("t2", "d", |_| serde_json::json!({}))); reg.register(
3612 ToolSpec::new("t3", "d", |_| serde_json::json!({}))
3613 .with_required_fields(["url", "method"]),
3614 );
3615 assert_eq!(reg.tool_count_with_required_fields(), 2);
3616 }
3617
3618 #[test]
3621 fn test_tool_registry_names_empty_when_no_tools() {
3622 let reg = ToolRegistry::new();
3623 assert!(reg.names().is_empty());
3624 }
3625
3626 #[test]
3627 fn test_tool_registry_names_sorted_alphabetically() {
3628 let mut reg = ToolRegistry::new();
3629 reg.register(ToolSpec::new("zebra", "d", |_| serde_json::json!({})));
3630 reg.register(ToolSpec::new("alpha", "d", |_| serde_json::json!({})));
3631 reg.register(ToolSpec::new("mango", "d", |_| serde_json::json!({})));
3632 assert_eq!(reg.names(), vec!["alpha", "mango", "zebra"]);
3633 }
3634
3635 #[test]
3638 fn test_tool_names_starting_with_empty_when_no_match() {
3639 let mut reg = ToolRegistry::new();
3640 reg.register(ToolSpec::new("search", "d", |_| serde_json::json!({})));
3641 assert!(reg.tool_names_starting_with("calc").is_empty());
3642 }
3643
3644 #[test]
3645 fn test_tool_names_starting_with_returns_sorted_matches() {
3646 let mut reg = ToolRegistry::new();
3647 reg.register(ToolSpec::new("db_write", "d", |_| serde_json::json!({})));
3648 reg.register(ToolSpec::new("db_read", "d", |_| serde_json::json!({})));
3649 reg.register(ToolSpec::new("cache_get", "d", |_| serde_json::json!({})));
3650 let results = reg.tool_names_starting_with("db_");
3651 assert_eq!(results, vec!["db_read", "db_write"]);
3652 }
3653
3654 #[test]
3657 fn test_tool_registry_description_for_none_when_missing() {
3658 let reg = ToolRegistry::new();
3659 assert!(reg.description_for("unknown").is_none());
3660 }
3661
3662 #[test]
3663 fn test_tool_registry_description_for_returns_description() {
3664 let mut reg = ToolRegistry::new();
3665 reg.register(ToolSpec::new("search", "Find web results", |_| serde_json::json!({})));
3666 assert_eq!(reg.description_for("search"), Some("Find web results"));
3667 }
3668
3669 #[test]
3672 fn test_count_with_description_containing_zero_when_no_match() {
3673 let mut reg = ToolRegistry::new();
3674 reg.register(ToolSpec::new("t1", "database query", |_| serde_json::json!({})));
3675 assert_eq!(reg.count_with_description_containing("weather"), 0);
3676 }
3677
3678 #[test]
3679 fn test_count_with_description_containing_case_insensitive() {
3680 let mut reg = ToolRegistry::new();
3681 reg.register(ToolSpec::new("t1", "Search the WEB", |_| serde_json::json!({})));
3682 reg.register(ToolSpec::new("t2", "web scraper tool", |_| serde_json::json!({})));
3683 reg.register(ToolSpec::new("t3", "database lookup", |_| serde_json::json!({})));
3684 assert_eq!(reg.count_with_description_containing("web"), 2);
3685 }
3686
3687 #[test]
3688 fn test_unregister_all_clears_all_tools() {
3689 let mut reg = ToolRegistry::new();
3690 reg.register(ToolSpec::new("t1", "tool one", |_| serde_json::json!({})));
3691 reg.register(ToolSpec::new("t2", "tool two", |_| serde_json::json!({})));
3692 assert_eq!(reg.tool_count(), 2);
3693 reg.unregister_all();
3694 assert_eq!(reg.tool_count(), 0);
3695 }
3696
3697 #[test]
3698 fn test_tool_names_with_keyword_returns_matching_tool_names() {
3699 let mut reg = ToolRegistry::new();
3700 reg.register(ToolSpec::new("search", "search the web for info", |_| serde_json::json!({})));
3701 reg.register(ToolSpec::new("db", "query database records", |_| serde_json::json!({})));
3702 reg.register(ToolSpec::new("web-fetch", "fetch a WEB page", |_| serde_json::json!({})));
3703 let mut names = reg.tool_names_with_keyword("web");
3704 names.sort_unstable();
3705 assert_eq!(names, vec!["search", "web-fetch"]);
3706 }
3707
3708 #[test]
3709 fn test_tool_names_with_keyword_no_match_returns_empty() {
3710 let mut reg = ToolRegistry::new();
3711 reg.register(ToolSpec::new("t", "some tool", |_| serde_json::json!({})));
3712 assert!(reg.tool_names_with_keyword("missing").is_empty());
3713 }
3714
3715 #[test]
3716 fn test_all_descriptions_returns_sorted_descriptions() {
3717 let mut reg = ToolRegistry::new();
3718 reg.register(ToolSpec::new("t1", "z description", |_| serde_json::json!({})));
3719 reg.register(ToolSpec::new("t2", "a description", |_| serde_json::json!({})));
3720 assert_eq!(reg.all_descriptions(), vec!["a description", "z description"]);
3721 }
3722
3723 #[test]
3724 fn test_all_descriptions_empty_registry_returns_empty() {
3725 let reg = ToolRegistry::new();
3726 assert!(reg.all_descriptions().is_empty());
3727 }
3728
3729 #[test]
3730 fn test_longest_description_returns_longest() {
3731 let mut reg = ToolRegistry::new();
3732 reg.register(ToolSpec::new("t1", "short", |_| serde_json::json!({})));
3733 reg.register(ToolSpec::new("t2", "a much longer description here", |_| serde_json::json!({})));
3734 assert_eq!(reg.longest_description(), Some("a much longer description here"));
3735 }
3736
3737 #[test]
3738 fn test_longest_description_empty_registry_returns_none() {
3739 let reg = ToolRegistry::new();
3740 assert!(reg.longest_description().is_none());
3741 }
3742
3743 #[test]
3744 fn test_names_containing_returns_sorted_matching_names() {
3745 let mut reg = ToolRegistry::new();
3746 reg.register(ToolSpec::new("search-web", "search tool", |_| serde_json::json!({})));
3747 reg.register(ToolSpec::new("web-fetch", "fetch tool", |_| serde_json::json!({})));
3748 reg.register(ToolSpec::new("db-query", "database tool", |_| serde_json::json!({})));
3749 let names = reg.names_containing("web");
3750 assert_eq!(names, vec!["search-web", "web-fetch"]);
3751 }
3752
3753 #[test]
3754 fn test_names_containing_no_match_returns_empty() {
3755 let mut reg = ToolRegistry::new();
3756 reg.register(ToolSpec::new("t", "tool", |_| serde_json::json!({})));
3757 assert!(reg.names_containing("missing").is_empty());
3758 }
3759
3760 #[test]
3763 fn test_avg_description_length_returns_mean_byte_length() {
3764 let mut reg = ToolRegistry::new();
3765 reg.register(ToolSpec::new("a", "ab", |_| serde_json::json!({}))); reg.register(ToolSpec::new("b", "abcd", |_| serde_json::json!({}))); let avg = reg.avg_description_length();
3768 assert!((avg - 3.0).abs() < 1e-9);
3769 }
3770
3771 #[test]
3772 fn test_avg_description_length_returns_zero_when_empty() {
3773 let reg = ToolRegistry::new();
3774 assert_eq!(reg.avg_description_length(), 0.0);
3775 }
3776
3777 #[test]
3780 fn test_shortest_description_returns_shortest_string() {
3781 let mut reg = ToolRegistry::new();
3782 reg.register(ToolSpec::new("a", "hello world", |_| serde_json::json!({})));
3783 reg.register(ToolSpec::new("b", "hi", |_| serde_json::json!({})));
3784 reg.register(ToolSpec::new("c", "greetings", |_| serde_json::json!({})));
3785 assert_eq!(reg.shortest_description(), Some("hi"));
3786 }
3787
3788 #[test]
3789 fn test_shortest_description_returns_none_when_empty() {
3790 let reg = ToolRegistry::new();
3791 assert!(reg.shortest_description().is_none());
3792 }
3793
3794 #[test]
3797 fn test_tool_names_sorted_returns_names_in_alphabetical_order() {
3798 let mut reg = ToolRegistry::new();
3799 reg.register(ToolSpec::new("zap", "z tool", |_| serde_json::json!({})));
3800 reg.register(ToolSpec::new("alpha", "a tool", |_| serde_json::json!({})));
3801 reg.register(ToolSpec::new("middle", "m tool", |_| serde_json::json!({})));
3802 assert_eq!(reg.tool_names_sorted(), vec!["alpha", "middle", "zap"]);
3803 }
3804
3805 #[test]
3806 fn test_tool_names_sorted_empty_returns_empty() {
3807 let reg = ToolRegistry::new();
3808 assert!(reg.tool_names_sorted().is_empty());
3809 }
3810
3811 #[test]
3814 fn test_description_contains_count_counts_matching_descriptions() {
3815 let mut reg = ToolRegistry::new();
3816 reg.register(ToolSpec::new("a", "search the web", |_| serde_json::json!({})));
3817 reg.register(ToolSpec::new("b", "write to disk", |_| serde_json::json!({})));
3818 reg.register(ToolSpec::new("c", "search and filter", |_| serde_json::json!({})));
3819 assert_eq!(reg.description_contains_count("search"), 2);
3820 assert_eq!(reg.description_contains_count("SEARCH"), 2);
3821 assert_eq!(reg.description_contains_count("missing"), 0);
3822 }
3823
3824 #[test]
3825 fn test_description_contains_count_zero_when_empty() {
3826 let reg = ToolRegistry::new();
3827 assert_eq!(reg.description_contains_count("anything"), 0);
3828 }
3829}