use crate::error::AgentRuntimeError;
use crate::metrics::RuntimeMetrics;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Write as FmtWrite;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type AsyncToolFuture = Pin<Box<dyn Future<Output = Value> + Send>>;
pub type AsyncToolResultFuture = Pin<Box<dyn Future<Output = Result<Value, String>> + Send>>;
pub type AsyncToolHandler = Box<dyn Fn(Value) -> AsyncToolFuture + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
impl Message {
pub fn new(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(Role::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(Role::Assistant, content)
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(Role::System, content)
}
pub fn role(&self) -> &Role {
&self.role
}
pub fn content(&self) -> &str {
&self.content
}
pub fn is_user(&self) -> bool {
self.role == Role::User
}
pub fn is_assistant(&self) -> bool {
self.role == Role::Assistant
}
pub fn is_system(&self) -> bool {
self.role == Role::System
}
pub fn is_tool(&self) -> bool {
self.role == Role::Tool
}
pub fn is_empty(&self) -> bool {
self.content.is_empty()
}
pub fn word_count(&self) -> usize {
self.content.split_whitespace().count()
}
pub fn byte_len(&self) -> usize {
self.content.len()
}
pub fn content_starts_with(&self, prefix: &str) -> bool {
self.content.starts_with(prefix)
}
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::System => write!(f, "system"),
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::Tool => write!(f, "tool"),
}
}
}
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.role, self.content)
}
}
impl From<(Role, String)> for Message {
fn from((role, content): (Role, String)) -> Self {
Self::new(role, content)
}
}
impl From<(Role, &str)> for Message {
fn from((role, content): (Role, &str)) -> Self {
Self::new(role, content)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReActStep {
pub thought: String,
pub action: String,
pub observation: String,
#[serde(default)]
pub step_duration_ms: u64,
}
impl ReActStep {
pub fn new(
thought: impl Into<String>,
action: impl Into<String>,
observation: impl Into<String>,
) -> Self {
Self {
thought: thought.into(),
action: action.into(),
observation: observation.into(),
step_duration_ms: 0,
}
}
pub fn is_final_answer(&self) -> bool {
self.action.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER")
}
pub fn is_tool_call(&self) -> bool {
!self.is_final_answer() && !self.action.trim().is_empty()
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.step_duration_ms = ms;
self
}
pub fn is_empty(&self) -> bool {
self.thought.is_empty() && self.action.is_empty() && self.observation.is_empty()
}
pub fn observation_is_empty(&self) -> bool {
self.observation.is_empty()
}
pub fn thought_word_count(&self) -> usize {
self.thought.split_whitespace().count()
}
pub fn observation_word_count(&self) -> usize {
self.observation.split_whitespace().count()
}
pub fn thought_is_empty(&self) -> bool {
self.thought.trim().is_empty()
}
pub fn summary(&self) -> String {
fn preview(s: &str) -> String {
if s.len() <= 40 {
s.to_owned()
} else {
format!("{}…", &s[..40])
}
}
let kind = if self.is_final_answer() { "FINAL" } else { "TOOL" };
format!(
"[{kind}] thought={t} action={a} obs={o}",
t = preview(self.thought.trim()),
a = preview(self.action.trim()),
o = preview(self.observation.trim()),
)
}
pub fn combined_byte_length(&self) -> usize {
self.thought.len() + self.action.len() + self.observation.len()
}
pub fn action_is_empty(&self) -> bool {
self.action.trim().is_empty()
}
pub fn total_word_count(&self) -> usize {
self.thought.split_whitespace().count()
+ self.action.split_whitespace().count()
+ self.observation.split_whitespace().count()
}
pub fn is_complete(&self) -> bool {
!self.thought.is_empty() && !self.action.is_empty() && !self.observation.is_empty()
}
pub fn observation_starts_with(&self, prefix: &str) -> bool {
self.observation.starts_with(prefix)
}
pub fn action_word_count(&self) -> usize {
self.action.split_whitespace().count()
}
pub fn thought_byte_len(&self) -> usize {
self.thought.len()
}
pub fn action_byte_len(&self) -> usize {
self.action.len()
}
pub fn has_empty_fields(&self) -> bool {
self.thought.is_empty() || self.action.is_empty() || self.observation.is_empty()
}
pub fn observation_byte_len(&self) -> usize {
self.observation.len()
}
pub fn all_fields_have_words(&self) -> bool {
self.thought.split_whitespace().next().is_some()
&& self.action.split_whitespace().next().is_some()
&& self.observation.split_whitespace().next().is_some()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_iterations: usize,
pub model: String,
pub system_prompt: String,
pub max_memory_recalls: usize,
pub max_memory_tokens: Option<usize>,
pub loop_timeout: Option<std::time::Duration>,
pub temperature: Option<f32>,
pub max_tokens: Option<usize>,
pub request_timeout: Option<std::time::Duration>,
pub max_context_chars: Option<usize>,
pub stop_sequences: Vec<String>,
}
impl AgentConfig {
pub fn new(max_iterations: usize, model: impl Into<String>) -> Self {
Self {
max_iterations,
model: model.into(),
system_prompt: "You are a helpful AI agent.".into(),
max_memory_recalls: 3,
max_memory_tokens: None,
loop_timeout: None,
temperature: None,
max_tokens: None,
request_timeout: None,
max_context_chars: None,
stop_sequences: vec![],
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_max_memory_recalls(mut self, n: usize) -> Self {
self.max_memory_recalls = n;
self
}
pub fn with_max_memory_tokens(mut self, n: usize) -> Self {
self.max_memory_tokens = Some(n);
self
}
pub fn with_loop_timeout(mut self, d: std::time::Duration) -> Self {
self.loop_timeout = Some(d);
self
}
pub fn with_loop_timeout_secs(self, secs: u64) -> Self {
self.with_loop_timeout(std::time::Duration::from_secs(secs))
}
pub fn with_loop_timeout_ms(self, ms: u64) -> Self {
self.with_loop_timeout(std::time::Duration::from_millis(ms))
}
pub fn with_max_iterations(mut self, n: usize) -> Self {
self.max_iterations = n;
self
}
pub fn max_iterations(&self) -> usize {
self.max_iterations
}
pub fn with_temperature(mut self, t: f32) -> Self {
self.temperature = Some(t);
self
}
pub fn with_max_tokens(mut self, n: usize) -> Self {
self.max_tokens = Some(n);
self
}
pub fn with_request_timeout(mut self, d: std::time::Duration) -> Self {
self.request_timeout = Some(d);
self
}
pub fn with_request_timeout_secs(self, secs: u64) -> Self {
self.with_request_timeout(std::time::Duration::from_secs(secs))
}
pub fn with_request_timeout_ms(self, ms: u64) -> Self {
self.with_request_timeout(std::time::Duration::from_millis(ms))
}
pub fn with_max_context_chars(mut self, n: usize) -> Self {
self.max_context_chars = Some(n);
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn clone_with_model(&self, model: impl Into<String>) -> Self {
let mut copy = self.clone();
copy.model = model.into();
copy
}
pub fn clone_with_system_prompt(&self, prompt: impl Into<String>) -> Self {
let mut copy = self.clone();
copy.system_prompt = prompt.into();
copy
}
pub fn clone_with_max_iterations(&self, n: usize) -> Self {
let mut copy = self.clone();
copy.max_iterations = n;
copy
}
pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.stop_sequences = sequences;
self
}
pub fn is_valid(&self) -> bool {
self.max_iterations >= 1 && !self.model.is_empty()
}
pub fn validate(&self) -> Result<(), crate::error::AgentRuntimeError> {
if self.max_iterations == 0 {
return Err(crate::error::AgentRuntimeError::AgentLoop(
"AgentConfig: max_iterations must be >= 1".into(),
));
}
if self.model.is_empty() {
return Err(crate::error::AgentRuntimeError::AgentLoop(
"AgentConfig: model must not be empty".into(),
));
}
Ok(())
}
pub fn has_loop_timeout(&self) -> bool {
self.loop_timeout.is_some()
}
pub fn has_stop_sequences(&self) -> bool {
!self.stop_sequences.is_empty()
}
pub fn stop_sequence_count(&self) -> usize {
self.stop_sequences.len()
}
pub fn is_single_shot(&self) -> bool {
self.max_iterations == 1
}
pub fn has_temperature(&self) -> bool {
self.temperature.is_some()
}
pub fn temperature(&self) -> Option<f32> {
self.temperature
}
pub fn max_tokens(&self) -> Option<usize> {
self.max_tokens
}
pub fn has_request_timeout(&self) -> bool {
self.request_timeout.is_some()
}
pub fn request_timeout(&self) -> Option<std::time::Duration> {
self.request_timeout
}
pub fn has_max_context_chars(&self) -> bool {
self.max_context_chars.is_some()
}
pub fn max_context_chars(&self) -> Option<usize> {
self.max_context_chars
}
pub fn remaining_iterations_after(&self, n: usize) -> usize {
self.max_iterations.saturating_sub(n)
}
pub fn system_prompt(&self) -> &str {
&self.system_prompt
}
pub fn system_prompt_is_empty(&self) -> bool {
self.system_prompt.trim().is_empty()
}
pub fn model(&self) -> &str {
&self.model
}
pub fn loop_timeout_ms(&self) -> u64 {
self.loop_timeout
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub fn total_timeout_ms(&self) -> u64 {
let loop_ms = self.loop_timeout_ms();
let req_ms = self
.request_timeout
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
loop_ms.saturating_add(self.max_iterations as u64 * req_ms)
}
pub fn model_is(&self, name: &str) -> bool {
self.model == name
}
pub fn system_prompt_word_count(&self) -> usize {
self.system_prompt.split_whitespace().count()
}
pub fn iteration_budget_remaining(&self, steps_done: usize) -> usize {
self.max_iterations.saturating_sub(steps_done)
}
pub fn is_minimal(&self) -> bool {
self.system_prompt.trim().is_empty() && self.max_iterations == 1
}
pub fn model_starts_with(&self, prefix: &str) -> bool {
self.model.starts_with(prefix)
}
pub fn exceeds_iteration_limit(&self, steps_done: usize) -> bool {
steps_done >= self.max_iterations
}
pub fn token_budget_configured(&self) -> bool {
self.max_tokens.is_some() || self.max_context_chars.is_some()
}
pub fn max_tokens_or_default(&self, default: usize) -> usize {
self.max_tokens.unwrap_or(default)
}
pub fn effective_temperature(&self) -> f32 {
self.temperature.unwrap_or(1.0)
}
pub fn system_prompt_starts_with(&self, prefix: &str) -> bool {
self.system_prompt.starts_with(prefix)
}
pub fn max_iterations_above(&self, n: usize) -> bool {
self.max_iterations > n
}
pub fn stop_sequences_contain(&self, s: &str) -> bool {
self.stop_sequences.iter().any(|seq| seq == s)
}
pub fn system_prompt_byte_len(&self) -> usize {
self.system_prompt.len()
}
pub fn has_valid_temperature(&self) -> bool {
self.temperature.map_or(false, |t| (0.0..=2.0).contains(&t))
}
}
pub struct ToolSpec {
pub name: String,
pub description: String,
pub(crate) handler: AsyncToolHandler,
pub required_fields: Vec<String>,
pub validators: Vec<Box<dyn ToolValidator>>,
#[cfg(feature = "orchestrator")]
pub circuit_breaker: Option<Arc<crate::orchestrator::CircuitBreaker>>,
}
impl std::fmt::Debug for ToolSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("ToolSpec");
s.field("name", &self.name)
.field("description", &self.description)
.field("required_fields", &self.required_fields);
#[cfg(feature = "orchestrator")]
s.field("has_circuit_breaker", &self.circuit_breaker.is_some());
s.finish()
}
}
impl ToolSpec {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
handler: impl Fn(Value) -> Value + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
description: description.into(),
handler: Box::new(move |args| {
let result = handler(args);
Box::pin(async move { result })
}),
required_fields: Vec::new(),
validators: Vec::new(),
#[cfg(feature = "orchestrator")]
circuit_breaker: None,
}
}
pub fn new_async(
name: impl Into<String>,
description: impl Into<String>,
handler: impl Fn(Value) -> AsyncToolFuture + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
description: description.into(),
handler: Box::new(handler),
required_fields: Vec::new(),
validators: Vec::new(),
#[cfg(feature = "orchestrator")]
circuit_breaker: None,
}
}
pub fn new_fallible(
name: impl Into<String>,
description: impl Into<String>,
handler: impl Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
description: description.into(),
handler: Box::new(move |args| {
let result = handler(args);
let value = match result {
Ok(v) => v,
Err(msg) => serde_json::json!({"error": msg, "ok": false}),
};
Box::pin(async move { value })
}),
required_fields: Vec::new(),
validators: Vec::new(),
#[cfg(feature = "orchestrator")]
circuit_breaker: None,
}
}
pub fn new_async_fallible(
name: impl Into<String>,
description: impl Into<String>,
handler: impl Fn(Value) -> AsyncToolResultFuture + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
description: description.into(),
handler: Box::new(move |args| {
let fut = handler(args);
Box::pin(async move {
match fut.await {
Ok(v) => v,
Err(msg) => serde_json::json!({"error": msg, "ok": false}),
}
})
}),
required_fields: Vec::new(),
validators: Vec::new(),
#[cfg(feature = "orchestrator")]
circuit_breaker: None,
}
}
pub fn with_required_fields(
mut self,
fields: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.required_fields = fields.into_iter().map(Into::into).collect();
self
}
pub fn with_validators(mut self, validators: Vec<Box<dyn ToolValidator>>) -> Self {
self.validators = validators;
self
}
#[cfg(feature = "orchestrator")]
pub fn with_circuit_breaker(mut self, cb: Arc<crate::orchestrator::CircuitBreaker>) -> Self {
self.circuit_breaker = Some(cb);
self
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn required_field_count(&self) -> usize {
self.required_fields.len()
}
pub fn has_required_fields(&self) -> bool {
!self.required_fields.is_empty()
}
pub fn has_validators(&self) -> bool {
!self.validators.is_empty()
}
pub async fn call(&self, args: Value) -> Value {
(self.handler)(args).await
}
}
pub trait ToolCache: Send + Sync {
fn get(&self, tool_name: &str, args: &serde_json::Value) -> Option<serde_json::Value>;
fn set(&self, tool_name: &str, args: &serde_json::Value, result: serde_json::Value);
}
struct ToolCacheInner {
map: HashMap<(String, String), serde_json::Value>,
order: std::collections::VecDeque<(String, String)>,
}
pub struct InMemoryToolCache {
inner: std::sync::Mutex<ToolCacheInner>,
max_entries: Option<usize>,
}
impl std::fmt::Debug for InMemoryToolCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let len = self.len();
f.debug_struct("InMemoryToolCache")
.field("entries", &len)
.field("max_entries", &self.max_entries)
.finish()
}
}
impl InMemoryToolCache {
pub fn new() -> Self {
Self {
inner: std::sync::Mutex::new(ToolCacheInner {
map: HashMap::new(),
order: std::collections::VecDeque::new(),
}),
max_entries: None,
}
}
pub fn with_max_entries(max: usize) -> Self {
Self {
inner: std::sync::Mutex::new(ToolCacheInner {
map: HashMap::new(),
order: std::collections::VecDeque::new(),
}),
max_entries: Some(max),
}
}
pub fn clear(&self) {
if let Ok(mut inner) = self.inner.lock() {
inner.map.clear();
inner.order.clear();
}
}
pub fn len(&self) -> usize {
self.inner.lock().map(|s| s.map.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn contains(&self, tool_name: &str, args: &serde_json::Value) -> bool {
let key = (tool_name.to_owned(), args.to_string());
self.inner
.lock()
.map(|s| s.map.contains_key(&key))
.unwrap_or(false)
}
pub fn remove(&self, tool_name: &str, args: &serde_json::Value) -> bool {
let key = (tool_name.to_owned(), args.to_string());
if let Ok(mut inner) = self.inner.lock() {
if inner.map.remove(&key).is_some() {
inner.order.retain(|k| k != &key);
return true;
}
}
false
}
pub fn capacity(&self) -> Option<usize> {
self.max_entries
}
}
impl Default for InMemoryToolCache {
fn default() -> Self {
Self::new()
}
}
impl ToolCache for InMemoryToolCache {
fn get(&self, tool_name: &str, args: &serde_json::Value) -> Option<serde_json::Value> {
let key = (tool_name.to_owned(), args.to_string());
self.inner.lock().ok()?.map.get(&key).cloned()
}
fn set(&self, tool_name: &str, args: &serde_json::Value, result: serde_json::Value) {
let key = (tool_name.to_owned(), args.to_string());
if let Ok(mut inner) = self.inner.lock() {
if !inner.map.contains_key(&key) {
inner.order.push_back(key.clone());
}
inner.map.insert(key, result);
if let Some(max) = self.max_entries {
while inner.map.len() > max {
if let Some(oldest) = inner.order.pop_front() {
inner.map.remove(&oldest);
}
}
}
}
}
}
pub struct ToolRegistry {
tools: HashMap<String, ToolSpec>,
cache: Option<Arc<dyn ToolCache>>,
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tools", &self.tools.keys().collect::<Vec<_>>())
.field("has_cache", &self.cache.is_some())
.finish()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
cache: None,
}
}
pub fn with_cache(mut self, cache: Arc<dyn ToolCache>) -> Self {
self.cache = Some(cache);
self
}
pub fn register(&mut self, spec: ToolSpec) {
self.tools.insert(spec.name.clone(), spec);
}
pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
for spec in specs {
self.register(spec);
}
}
pub fn with_tool(mut self, spec: ToolSpec) -> Self {
self.register(spec);
self
}
#[tracing::instrument(skip_all, fields(tool_name = %name))]
pub async fn call(&self, name: &str, args: Value) -> Result<Value, AgentRuntimeError> {
let spec = self.tools.get(name).ok_or_else(|| {
let mut suggestion = String::new();
let names = self.tool_names();
if !names.is_empty() {
if let Some((closest, dist)) = names
.iter()
.map(|n| (n, levenshtein(name, n)))
.min_by_key(|(_, d)| *d)
{
if dist <= 3 {
suggestion = format!(" (did you mean '{closest}'?)");
}
}
}
AgentRuntimeError::AgentLoop(format!("tool '{name}' not found{suggestion}"))
})?;
if !spec.required_fields.is_empty() {
if let Some(obj) = args.as_object() {
for field in &spec.required_fields {
if !obj.contains_key(field) {
return Err(AgentRuntimeError::AgentLoop(format!(
"tool '{}' missing required field '{}'",
name, field
)));
}
}
} else {
return Err(AgentRuntimeError::AgentLoop(format!(
"tool '{}' requires JSON object args, got {}",
name, args
)));
}
}
for validator in &spec.validators {
validator.validate(&args)?;
}
#[cfg(feature = "orchestrator")]
if let Some(ref cb) = spec.circuit_breaker {
use crate::orchestrator::CircuitState;
if let Ok(CircuitState::Open { .. }) = cb.state() {
return Err(AgentRuntimeError::CircuitOpen {
service: format!("tool:{}", name),
});
}
}
if let Some(ref cache) = self.cache {
if let Some(cached) = cache.get(name, &args) {
return Ok(cached);
}
}
let result = spec.call(args.clone()).await;
#[cfg(feature = "orchestrator")]
if let Some(ref cb) = spec.circuit_breaker {
let is_failure = result
.get("ok")
.and_then(|v| v.as_bool())
.is_some_and(|ok| !ok);
if is_failure {
cb.record_failure();
} else {
cb.record_success();
}
}
if let Some(ref cache) = self.cache {
cache.set(name, &args, result.clone());
}
Ok(result)
}
pub fn get(&self, name: &str) -> Option<&ToolSpec> {
self.tools.get(name)
}
pub fn has_tool(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn unregister(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}
pub fn tool_names(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
pub fn tool_names_owned(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub fn all_tool_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.tools.keys().cloned().collect();
names.sort();
names
}
pub fn tool_specs(&self) -> Vec<&ToolSpec> {
self.tools.values().collect()
}
pub fn filter_tools<F: Fn(&ToolSpec) -> bool>(&self, pred: F) -> Vec<&ToolSpec> {
self.tools.values().filter(|s| pred(s)).collect()
}
pub fn rename_tool(&mut self, old_name: &str, new_name: impl Into<String>) -> bool {
let Some(mut spec) = self.tools.remove(old_name) else {
return false;
};
let new_name = new_name.into();
spec.name = new_name.clone();
self.tools.insert(new_name, spec);
true
}
pub fn tool_count(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn clear(&mut self) {
self.tools.clear();
}
pub fn remove(&mut self, name: &str) -> Option<ToolSpec> {
self.tools.remove(name)
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn descriptions(&self) -> Vec<(&str, &str)> {
let mut pairs: Vec<(&str, &str)> = self
.tools
.values()
.map(|s| (s.name.as_str(), s.description.as_str()))
.collect();
pairs.sort_unstable_by_key(|(name, _)| *name);
pairs
}
pub fn find_by_description_keyword(&self, keyword: &str) -> Vec<&ToolSpec> {
let lower = keyword.to_ascii_lowercase();
self.tools
.values()
.filter(|s| s.description.to_ascii_lowercase().contains(&lower))
.collect()
}
pub fn tool_count_with_required_fields(&self) -> usize {
self.tools.values().filter(|s| s.has_required_fields()).count()
}
pub fn tool_count_with_validators(&self) -> usize {
self.tools.values().filter(|s| s.has_validators()).count()
}
pub fn names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.tools.keys().map(|k| k.as_str()).collect();
names.sort_unstable();
names
}
pub fn tool_names_starting_with(&self, prefix: &str) -> Vec<&str> {
let mut names: Vec<&str> = self
.tools
.keys()
.filter(|k| k.starts_with(prefix))
.map(|k| k.as_str())
.collect();
names.sort_unstable();
names
}
pub fn description_for(&self, name: &str) -> Option<&str> {
self.tools.get(name).map(|s| s.description.as_str())
}
pub fn count_with_description_containing(&self, keyword: &str) -> usize {
let lower = keyword.to_ascii_lowercase();
self.tools
.values()
.filter(|s| s.description.to_ascii_lowercase().contains(&lower))
.count()
}
pub fn unregister_all(&mut self) {
self.tools.clear();
}
pub fn names_containing(&self, substring: &str) -> Vec<&str> {
let sub = substring.to_ascii_lowercase();
let mut names: Vec<&str> = self
.tools
.keys()
.filter(|name| name.to_ascii_lowercase().contains(&sub))
.map(|s| s.as_str())
.collect();
names.sort_unstable();
names
}
pub fn shortest_description(&self) -> Option<&str> {
self.tools
.values()
.min_by_key(|s| s.description.len())
.map(|s| s.description.as_str())
}
pub fn longest_description(&self) -> Option<&str> {
self.tools
.values()
.max_by_key(|s| s.description.len())
.map(|s| s.description.as_str())
}
pub fn all_descriptions(&self) -> Vec<&str> {
let mut descs: Vec<&str> = self.tools.values().map(|s| s.description.as_str()).collect();
descs.sort_unstable();
descs
}
pub fn tool_names_with_keyword(&self, keyword: &str) -> Vec<&str> {
let kw = keyword.to_ascii_lowercase();
self.tools
.values()
.filter(|s| s.description.to_ascii_lowercase().contains(&kw))
.map(|s| s.name.as_str())
.collect()
}
pub fn avg_description_length(&self) -> f64 {
if self.tools.is_empty() {
return 0.0;
}
let total: usize = self.tools.values().map(|s| s.description.len()).sum();
total as f64 / self.tools.len() as f64
}
pub fn tool_names_sorted(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.tools.keys().map(|k| k.as_str()).collect();
names.sort_unstable();
names
}
pub fn description_contains_count(&self, keyword: &str) -> usize {
let kw = keyword.to_ascii_lowercase();
self.tools
.values()
.filter(|s| s.description.to_ascii_lowercase().contains(&kw))
.count()
}
pub fn total_description_bytes(&self) -> usize {
self.tools.values().map(|s| s.description.len()).sum()
}
pub fn shortest_description_length(&self) -> usize {
self.tools
.values()
.map(|s| s.description.len())
.min()
.unwrap_or(0)
}
pub fn longest_description_length(&self) -> usize {
self.tools
.values()
.map(|s| s.description.len())
.max()
.unwrap_or(0)
}
pub fn tool_count_above_desc_bytes(&self, min_bytes: usize) -> usize {
self.tools
.values()
.filter(|s| s.description.len() > min_bytes)
.count()
}
pub fn tools_with_required_field(&self, field: &str) -> Vec<&ToolSpec> {
self.tools
.values()
.filter(|s| s.required_fields.iter().any(|f| f == field))
.collect()
}
pub fn tools_without_required_fields(&self) -> Vec<&ToolSpec> {
self.tools
.values()
.filter(|s| s.required_fields.is_empty())
.collect()
}
pub fn avg_required_fields_count(&self) -> f64 {
if self.tools.is_empty() {
return 0.0;
}
let total: usize = self.tools.values().map(|s| s.required_fields.len()).sum();
total as f64 / self.tools.len() as f64
}
pub fn tool_descriptions_total_words(&self) -> usize {
self.tools
.values()
.map(|spec| spec.description.split_ascii_whitespace().count())
.sum()
}
pub fn has_tools_with_empty_descriptions(&self) -> bool {
self.tools.values().any(|s| s.description.trim().is_empty())
}
pub fn total_required_fields(&self) -> usize {
self.tools.values().map(|s| s.required_fields.len()).sum()
}
pub fn has_tool_with_description_containing(&self, keyword: &str) -> bool {
self.tools.values().any(|s| s.description.contains(keyword))
}
pub fn tools_with_description_longer_than(&self, min_bytes: usize) -> Vec<&str> {
let mut names: Vec<&str> = self
.tools
.values()
.filter(|s| s.description.len() > min_bytes)
.map(|s| s.name.as_str())
.collect();
names.sort_unstable();
names
}
pub fn max_description_bytes(&self) -> usize {
self.tools.values().map(|s| s.description.len()).max().unwrap_or(0)
}
pub fn min_description_bytes(&self) -> usize {
self.tools.values().map(|s| s.description.len()).min().unwrap_or(0)
}
pub fn description_starts_with_any(&self, prefixes: &[&str]) -> bool {
self.tools
.values()
.any(|s| prefixes.iter().any(|p| s.description.starts_with(p)))
}
pub fn tool_with_most_required_fields(&self) -> Option<&ToolSpec> {
self.tools.values().max_by(|a, b| {
a.required_fields
.len()
.cmp(&b.required_fields.len())
.then_with(|| b.name.cmp(&a.name))
})
}
pub fn tool_by_name(&self, name: &str) -> Option<&ToolSpec> {
self.tools.get(name)
}
pub fn tools_without_validators(&self) -> Vec<&str> {
let mut names: Vec<&str> = self
.tools
.values()
.filter(|s| s.validators.is_empty())
.map(|s| s.name.as_str())
.collect();
names.sort_unstable();
names
}
pub fn tool_names_with_required_fields(&self) -> Vec<&str> {
let mut names: Vec<&str> = self
.tools
.values()
.filter(|s| !s.required_fields.is_empty())
.map(|s| s.name.as_str())
.collect();
names.sort_unstable();
names
}
pub fn has_all_tools(&self, names: &[&str]) -> bool {
names.iter().all(|n| self.tools.contains_key(*n))
}
pub fn tools_with_required_fields_count(&self) -> usize {
self.tools
.values()
.filter(|t| !t.required_fields.is_empty())
.count()
}
pub fn tool_names_with_prefix<'a>(&'a self, prefix: &str) -> Vec<&'a str> {
let mut names: Vec<&str> = self
.tools
.keys()
.filter(|n| n.starts_with(prefix))
.map(|n| n.as_str())
.collect();
names.sort_unstable();
names
}
}
pub fn parse_react_step(text: &str) -> Result<ReActStep, AgentRuntimeError> {
#[derive(PartialEq)]
enum Section { None, Thought, Action }
let mut thought_lines: Vec<&str> = Vec::new();
let mut action_lines: Vec<&str> = Vec::new();
let mut current = Section::None;
for line in text.lines() {
let trimmed = line.trim();
let lower = trimmed.to_ascii_lowercase();
if lower.starts_with("thought") {
if let Some(colon_pos) = trimmed.find(':') {
current = Section::Thought;
thought_lines.clear();
let first = trimmed[colon_pos + 1..].trim();
if !first.is_empty() {
thought_lines.push(first);
}
continue;
}
} else if lower.starts_with("action") {
if let Some(colon_pos) = trimmed.find(':') {
current = Section::Action;
action_lines.clear();
let first = trimmed[colon_pos + 1..].trim();
if !first.is_empty() {
action_lines.push(first);
}
continue;
}
} else if lower.starts_with("observation") {
current = Section::None;
continue;
}
match current {
Section::Thought => thought_lines.push(trimmed),
Section::Action => action_lines.push(trimmed),
Section::None => {}
}
}
let thought = thought_lines.join(" ");
let action = action_lines.join("\n").trim().to_owned();
if thought.is_empty() && action.is_empty() {
return Err(AgentRuntimeError::AgentLoop(
"could not parse ReAct step from response".into(),
));
}
Ok(ReActStep {
thought,
action,
observation: String::new(),
step_duration_ms: 0,
})
}
pub struct ReActLoop {
config: AgentConfig,
registry: ToolRegistry,
metrics: Option<Arc<RuntimeMetrics>>,
#[cfg(feature = "persistence")]
checkpoint_backend: Option<(Arc<dyn crate::persistence::PersistenceBackend>, String)>,
observer: Option<Arc<dyn Observer>>,
action_hook: Option<ActionHook>,
}
impl std::fmt::Debug for ReActLoop {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("ReActLoop");
s.field("config", &self.config)
.field("registry", &self.registry)
.field("has_metrics", &self.metrics.is_some())
.field("has_observer", &self.observer.is_some())
.field("has_action_hook", &self.action_hook.is_some());
#[cfg(feature = "persistence")]
s.field("has_checkpoint_backend", &self.checkpoint_backend.is_some());
s.finish()
}
}
impl ReActLoop {
pub fn new(config: AgentConfig) -> Self {
Self {
config,
registry: ToolRegistry::new(),
metrics: None,
#[cfg(feature = "persistence")]
checkpoint_backend: None,
observer: None,
action_hook: None,
}
}
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn with_action_hook(mut self, hook: ActionHook) -> Self {
self.action_hook = Some(hook);
self
}
pub fn with_metrics(mut self, metrics: Arc<RuntimeMetrics>) -> Self {
self.metrics = Some(metrics);
self
}
#[cfg(feature = "persistence")]
pub fn with_step_checkpoint(
mut self,
backend: Arc<dyn crate::persistence::PersistenceBackend>,
session_id: impl Into<String>,
) -> Self {
self.checkpoint_backend = Some((backend, session_id.into()));
self
}
pub fn registry(&self) -> &ToolRegistry {
&self.registry
}
pub fn tool_count(&self) -> usize {
self.registry.tool_count()
}
pub fn unregister_tool(&mut self, name: &str) -> bool {
self.registry.unregister(name)
}
pub fn register_tool(&mut self, spec: ToolSpec) {
self.registry.register(spec);
}
pub fn register_tools(&mut self, specs: impl IntoIterator<Item = ToolSpec>) {
for spec in specs {
self.registry.register(spec);
}
}
fn maybe_trim_context(context: &mut String, max_chars: usize) {
while context.len() > max_chars {
let first = context.find("\nThought:");
let second = first.and_then(|pos| {
context[pos + 1..].find("\nThought:").map(|p| pos + 1 + p)
});
if let Some(drop_until) = second {
context.drain(..drop_until);
} else {
break; }
}
}
fn blocked_observation() -> String {
serde_json::json!({
"ok": false,
"error": "action blocked by reviewer",
"kind": "blocked"
})
.to_string()
}
fn error_observation(_tool_name: &str, e: &AgentRuntimeError) -> String {
let kind = match e {
AgentRuntimeError::AgentLoop(msg) if msg.contains("not found") => "not_found",
#[cfg(feature = "orchestrator")]
AgentRuntimeError::CircuitOpen { .. } => "transient",
_ => "permanent",
};
serde_json::json!({ "ok": false, "error": e.to_string(), "kind": kind }).to_string()
}
#[tracing::instrument(skip(infer))]
pub async fn run<F, Fut>(
&self,
prompt: &str,
mut infer: F,
) -> Result<Vec<ReActStep>, AgentRuntimeError>
where
F: FnMut(String) -> Fut,
Fut: Future<Output = String>,
{
let mut steps: Vec<ReActStep> = Vec::new();
let mut context = format!("{}\n\nUser: {}\n", self.config.system_prompt, prompt);
let deadline = self
.config
.loop_timeout
.map(|d| std::time::Instant::now() + d);
if let Some(ref obs) = self.observer {
obs.on_loop_start(prompt);
}
for iteration in 0..self.config.max_iterations {
let iter_span = tracing::info_span!(
"react_iteration",
iteration = iteration,
model = %self.config.model,
);
let _iter_guard = iter_span.enter();
if let Some(dl) = deadline {
if std::time::Instant::now() >= dl {
let ms = self
.config
.loop_timeout
.map(|d| d.as_millis())
.unwrap_or(0);
let err = AgentRuntimeError::AgentLoop(format!("loop timeout after {ms} ms"));
if let Some(ref obs) = self.observer {
obs.on_error(&err);
obs.on_loop_end(steps.len());
}
return Err(err);
}
}
let step_start = std::time::Instant::now();
let response = infer(context.clone()).await;
let mut step = parse_react_step(&response)?;
tracing::debug!(
step = iteration,
thought = %step.thought,
action = %step.action,
"ReAct iteration"
);
if step.action.to_ascii_uppercase().starts_with("FINAL_ANSWER") {
step.observation = step.action.clone();
step.step_duration_ms = step_start.elapsed().as_millis() as u64;
if let Some(ref m) = self.metrics {
m.record_step_latency(step.step_duration_ms);
}
if let Some(ref obs) = self.observer {
obs.on_step(iteration, &step);
}
steps.push(step);
tracing::info!(step = iteration, "FINAL_ANSWER reached");
if let Some(ref obs) = self.observer {
obs.on_loop_end(steps.len());
}
return Ok(steps);
}
let (tool_name, args) = parse_tool_call(&step.action)?;
tracing::debug!(
step = iteration,
tool_name = %tool_name,
"dispatching tool call"
);
if let Some(ref hook) = self.action_hook {
if !hook(tool_name.clone(), args.clone()).await {
if let Some(ref obs) = self.observer {
obs.on_action_blocked(&tool_name, &args);
}
if let Some(ref m) = self.metrics {
m.record_tool_call(&tool_name);
m.record_tool_failure(&tool_name);
}
step.observation = Self::blocked_observation();
step.step_duration_ms = step_start.elapsed().as_millis() as u64;
if let Some(ref m) = self.metrics {
m.record_step_latency(step.step_duration_ms);
}
let _ = write!(
context,
"\nThought: {}\nAction: {}\nObservation: {}\n",
step.thought, step.action, step.observation
);
if let Some(max) = self.config.max_context_chars {
Self::maybe_trim_context(&mut context, max);
}
if let Some(ref obs) = self.observer {
obs.on_step(iteration, &step);
}
steps.push(step);
continue;
}
}
if let Some(ref obs) = self.observer {
obs.on_tool_call(&tool_name, &args);
}
if let Some(ref m) = self.metrics {
m.record_tool_call(&tool_name);
}
let tool_span = tracing::info_span!("tool_dispatch", tool = %tool_name);
let _tool_guard = tool_span.enter();
let observation = match self.registry.call(&tool_name, args).await {
Ok(result) => serde_json::json!({ "ok": true, "data": result }).to_string(),
Err(e) => {
if let Some(ref m) = self.metrics {
m.record_tool_failure(&tool_name);
}
Self::error_observation(&tool_name, &e)
}
};
step.observation = observation.clone();
step.step_duration_ms = step_start.elapsed().as_millis() as u64;
if let Some(ref m) = self.metrics {
m.record_step_latency(step.step_duration_ms);
}
let _ = write!(
context,
"\nThought: {}\nAction: {}\nObservation: {}\n",
step.thought, step.action, observation
);
if let Some(max) = self.config.max_context_chars {
Self::maybe_trim_context(&mut context, max);
}
if let Some(ref obs) = self.observer {
obs.on_step(iteration, &step);
}
steps.push(step);
#[cfg(feature = "persistence")]
if let Some((ref backend, ref session_id)) = self.checkpoint_backend {
let step_idx = steps.len();
let key = format!("loop:{session_id}:step:{step_idx}");
match serde_json::to_vec(&steps) {
Ok(bytes) => {
if let Err(e) = backend.save(&key, &bytes).await {
tracing::warn!(
key = %key,
error = %e,
"loop step checkpoint save failed"
);
}
}
Err(e) => {
tracing::warn!(
step = step_idx,
error = %e,
"loop step checkpoint serialisation failed"
);
}
}
}
}
let err = AgentRuntimeError::AgentLoop(format!(
"max iterations ({}) reached without final answer",
self.config.max_iterations
));
tracing::warn!(
max_iterations = self.config.max_iterations,
"ReAct loop exhausted max iterations without FINAL_ANSWER"
);
if let Some(ref obs) = self.observer {
obs.on_error(&err);
obs.on_loop_end(steps.len());
}
Err(err)
}
#[tracing::instrument(skip(infer_stream))]
pub async fn run_streaming<F, Fut>(
&self,
prompt: &str,
mut infer_stream: F,
) -> Result<Vec<ReActStep>, AgentRuntimeError>
where
F: FnMut(String) -> Fut,
Fut: Future<
Output = tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>,
>,
{
self.run(prompt, move |ctx| {
let rx_fut = infer_stream(ctx);
async move {
let mut rx = rx_fut.await;
let mut out = String::new();
while let Some(chunk) = rx.recv().await {
match chunk {
Ok(s) => out.push_str(&s),
Err(e) => {
tracing::warn!(error = %e, "streaming chunk error; skipping");
}
}
}
out
}
})
.await
}
}
pub trait ToolValidator: Send + Sync {
fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError>;
}
fn levenshtein(a: &str, b: &str) -> usize {
let a: Vec<char> = a.chars().collect();
let b: Vec<char> = b.chars().collect();
let (m, n) = (a.len(), b.len());
let mut dp = vec![vec![0usize; n + 1]; m + 1];
for i in 0..=m {
dp[i][0] = i;
}
for j in 0..=n {
dp[0][j] = j;
}
for i in 1..=m {
for j in 1..=n {
dp[i][j] = if a[i - 1] == b[j - 1] {
dp[i - 1][j - 1]
} else {
1 + dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1])
};
}
}
dp[m][n]
}
fn parse_tool_call(action: &str) -> Result<(String, Value), AgentRuntimeError> {
let mut parts = action.splitn(2, ' ');
let name = parts.next().unwrap_or("").to_owned();
if name.is_empty() {
return Err(AgentRuntimeError::AgentLoop(
"tool call has an empty tool name".into(),
));
}
let args_str = parts.next().unwrap_or("{}");
let args: Value = serde_json::from_str(args_str).map_err(|e| {
AgentRuntimeError::AgentLoop(format!(
"invalid JSON args for tool call '{name}': {e} (raw: {args_str})"
))
})?;
Ok((name, args))
}
#[derive(Debug, thiserror::Error)]
pub enum AgentError {
#[error("Tool '{0}' not found")]
ToolNotFound(String),
#[error("Max iterations exceeded: {0}")]
MaxIterations(usize),
#[error("Parse error: {0}")]
ParseError(String),
}
impl From<AgentError> for AgentRuntimeError {
fn from(e: AgentError) -> Self {
AgentRuntimeError::AgentLoop(e.to_string())
}
}
pub trait Observer: Send + Sync {
fn on_step(&self, step_index: usize, step: &ReActStep) {
let _ = (step_index, step);
}
fn on_tool_call(&self, tool_name: &str, args: &serde_json::Value) {
let _ = (tool_name, args);
}
fn on_action_blocked(&self, tool_name: &str, args: &serde_json::Value) {
let _ = (tool_name, args);
}
fn on_loop_start(&self, prompt: &str) {
let _ = prompt;
}
fn on_loop_end(&self, step_count: usize) {
let _ = step_count;
}
fn on_error(&self, error: &crate::error::AgentRuntimeError) {
let _ = error;
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Action {
FinalAnswer(String),
ToolCall {
name: String,
args: serde_json::Value,
},
}
impl Action {
pub fn parse(s: &str) -> Result<Action, AgentRuntimeError> {
if s.trim().to_ascii_uppercase().starts_with("FINAL_ANSWER") {
let answer = s.trim()["FINAL_ANSWER".len()..].trim().to_owned();
return Ok(Action::FinalAnswer(answer));
}
let (name, args) = parse_tool_call(s)?;
Ok(Action::ToolCall { name, args })
}
}
pub type ActionHook = Arc<dyn Fn(String, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync>;
pub fn make_action_hook<F, Fut>(f: F) -> ActionHook
where
F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = bool> + Send + 'static,
{
Arc::new(move |name, args| Box::pin(f(name, args)))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_final_answer_on_first_step() {
let config = AgentConfig::new(5, "test-model");
let loop_ = ReActLoop::new(config);
let steps = loop_
.run("Say hello", |_ctx| async {
"Thought: I will answer directly\nAction: FINAL_ANSWER hello".to_string()
})
.await
.unwrap();
assert_eq!(steps.len(), 1);
assert!(steps[0]
.action
.to_ascii_uppercase()
.starts_with("FINAL_ANSWER"));
}
#[tokio::test]
async fn test_tool_call_then_final_answer() {
let config = AgentConfig::new(5, "test-model");
let mut loop_ = ReActLoop::new(config);
loop_.register_tool(ToolSpec::new("greet", "Greets someone", |_args| {
serde_json::json!("hello!")
}));
let mut call_count = 0;
let steps = loop_
.run("Say hello", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: I will greet\nAction: greet {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER done".to_string()
}
}
})
.await
.unwrap();
assert_eq!(steps.len(), 2);
assert_eq!(steps[0].action, "greet {}");
assert!(steps[1]
.action
.to_ascii_uppercase()
.starts_with("FINAL_ANSWER"));
}
#[tokio::test]
async fn test_max_iterations_exceeded() {
let config = AgentConfig::new(2, "test-model");
let loop_ = ReActLoop::new(config);
let result = loop_
.run("loop forever", |_ctx| async {
"Thought: thinking\nAction: noop {}".to_string()
})
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("max iterations"));
}
#[tokio::test]
async fn test_parse_react_step_valid() {
let text = "Thought: I should check\nAction: lookup {\"key\":\"val\"}";
let step = parse_react_step(text).unwrap();
assert_eq!(step.thought, "I should check");
assert_eq!(step.action, "lookup {\"key\":\"val\"}");
}
#[tokio::test]
async fn test_parse_react_step_empty_fails() {
let result = parse_react_step("no prefix lines here");
assert!(result.is_err());
}
#[tokio::test]
async fn test_tool_not_found_returns_error_observation() {
let config = AgentConfig::new(3, "test-model");
let loop_ = ReActLoop::new(config);
let mut call_count = 0;
let steps = loop_
.run("test", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: try missing tool\nAction: missing_tool {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER done".to_string()
}
}
})
.await
.unwrap();
assert_eq!(steps.len(), 2);
assert!(steps[0].observation.contains("\"ok\":false"));
}
#[tokio::test]
async fn test_new_async_tool_spec() {
let spec = ToolSpec::new_async("async_tool", "An async tool", |args| {
Box::pin(async move { serde_json::json!({"echo": args}) })
});
let result = spec.call(serde_json::json!({"input": "test"})).await;
assert!(result.get("echo").is_some());
}
#[tokio::test]
async fn test_parse_react_step_case_insensitive() {
let text = "THOUGHT: done\nACTION: FINAL_ANSWER";
let step = parse_react_step(text).unwrap();
assert_eq!(step.thought, "done");
assert_eq!(step.action, "FINAL_ANSWER");
}
#[tokio::test]
async fn test_parse_react_step_space_before_colon() {
let text = "Thought : done\nAction : go";
let step = parse_react_step(text).unwrap();
assert_eq!(step.thought, "done");
assert_eq!(step.action, "go");
}
#[tokio::test]
async fn test_tool_required_fields_missing_returns_error() {
let config = AgentConfig::new(3, "test-model");
let mut loop_ = ReActLoop::new(config);
loop_.register_tool(
ToolSpec::new(
"search",
"Searches for something",
|args| serde_json::json!({ "result": args }),
)
.with_required_fields(vec!["q".to_string()]),
);
let mut call_count = 0;
let steps = loop_
.run("test", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: searching\nAction: search {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER done".to_string()
}
}
})
.await
.unwrap();
assert_eq!(steps.len(), 2);
assert!(
steps[0].observation.contains("missing required field"),
"observation was: {}",
steps[0].observation
);
}
#[tokio::test]
async fn test_tool_error_observation_includes_kind() {
let config = AgentConfig::new(3, "test-model");
let loop_ = ReActLoop::new(config);
let mut call_count = 0;
let steps = loop_
.run("test", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: try missing\nAction: nonexistent_tool {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER done".to_string()
}
}
})
.await
.unwrap();
assert_eq!(steps.len(), 2);
let obs = &steps[0].observation;
assert!(obs.contains("\"ok\":false"), "observation: {obs}");
assert!(obs.contains("\"kind\":\"not_found\""), "observation: {obs}");
}
#[tokio::test]
async fn test_step_duration_ms_is_set() {
let config = AgentConfig::new(5, "test-model");
let loop_ = ReActLoop::new(config);
let steps = loop_
.run("time it", |_ctx| async {
"Thought: done\nAction: FINAL_ANSWER ok".to_string()
})
.await
.unwrap();
let _ = steps[0].step_duration_ms; }
struct RequirePositiveN;
impl ToolValidator for RequirePositiveN {
fn validate(&self, args: &Value) -> Result<(), AgentRuntimeError> {
let n = args.get("n").and_then(|v| v.as_i64()).unwrap_or(0);
if n <= 0 {
return Err(AgentRuntimeError::AgentLoop(
"n must be a positive integer".into(),
));
}
Ok(())
}
}
#[tokio::test]
async fn test_tool_validator_blocks_invalid_args() {
let mut registry = ToolRegistry::new();
registry.register(
ToolSpec::new("calc", "compute", |args| serde_json::json!({"n": args}))
.with_validators(vec![Box::new(RequirePositiveN)]),
);
let result = registry
.call("calc", serde_json::json!({"n": -1}))
.await;
assert!(result.is_err(), "validator should reject n=-1");
assert!(result.unwrap_err().to_string().contains("positive integer"));
}
#[tokio::test]
async fn test_tool_validator_passes_valid_args() {
let mut registry = ToolRegistry::new();
registry.register(
ToolSpec::new("calc", "compute", |_| serde_json::json!(42))
.with_validators(vec![Box::new(RequirePositiveN)]),
);
let result = registry
.call("calc", serde_json::json!({"n": 5}))
.await;
assert!(result.is_ok(), "validator should accept n=5");
}
#[tokio::test]
async fn test_empty_tool_name_is_rejected() {
let result = parse_tool_call("");
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("empty tool name"),
"expected 'empty tool name' error"
);
}
#[tokio::test]
async fn test_register_tools_bulk() {
let mut registry = ToolRegistry::new();
registry.register_tools(vec![
ToolSpec::new("tool_a", "A", |_| serde_json::json!("a")),
ToolSpec::new("tool_b", "B", |_| serde_json::json!("b")),
]);
assert!(registry.call("tool_a", serde_json::json!({})).await.is_ok());
assert!(registry.call("tool_b", serde_json::json!({})).await.is_ok());
}
#[tokio::test]
async fn test_run_streaming_parity_with_run() {
use tokio::sync::mpsc;
let config = AgentConfig::new(5, "test-model");
let loop_ = ReActLoop::new(config);
let steps = loop_
.run_streaming("Say hello", |_ctx| async {
let (tx, rx) = mpsc::channel(4);
tokio::spawn(async move {
tx.send(Ok("Thought: done\n".to_string())).await.ok();
tx.send(Ok("Action: FINAL_ANSWER hi".to_string())).await.ok();
});
rx
})
.await
.unwrap();
assert_eq!(steps.len(), 1);
assert!(steps[0]
.action
.to_ascii_uppercase()
.starts_with("FINAL_ANSWER"));
}
#[tokio::test]
async fn test_run_streaming_error_chunk_is_skipped() {
use tokio::sync::mpsc;
use crate::error::AgentRuntimeError;
let config = AgentConfig::new(5, "test-model");
let loop_ = ReActLoop::new(config);
let steps = loop_
.run_streaming("test", |_ctx| async {
let (tx, rx) = mpsc::channel(4);
tokio::spawn(async move {
tx.send(Err(AgentRuntimeError::Provider("stream error".into())))
.await
.ok();
tx.send(Ok("Thought: recovered\nAction: FINAL_ANSWER ok".to_string()))
.await
.ok();
});
rx
})
.await
.unwrap();
assert_eq!(steps.len(), 1);
}
#[cfg(feature = "orchestrator")]
#[tokio::test]
async fn test_tool_with_circuit_breaker_passes_when_closed() {
use std::sync::Arc;
let cb = Arc::new(
crate::orchestrator::CircuitBreaker::new(
"echo-tool",
5,
std::time::Duration::from_secs(30),
)
.unwrap(),
);
let spec = ToolSpec::new(
"echo",
"Echoes args",
|args| serde_json::json!({ "echoed": args }),
)
.with_circuit_breaker(cb);
let registry = {
let mut r = ToolRegistry::new();
r.register(spec);
r
};
let result = registry
.call("echo", serde_json::json!({ "msg": "hi" }))
.await;
assert!(result.is_ok(), "expected Ok, got {:?}", result);
}
#[test]
fn test_agent_config_builder_methods_set_fields() {
let config = AgentConfig::new(3, "model")
.with_temperature(0.7)
.with_max_tokens(512)
.with_request_timeout(std::time::Duration::from_secs(10));
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.max_tokens, Some(512));
assert_eq!(config.request_timeout, Some(std::time::Duration::from_secs(10)));
}
#[tokio::test]
async fn test_fallible_tool_returns_error_json_on_err() {
let spec = ToolSpec::new_fallible(
"fail",
"always fails",
|_| Err::<Value, String>("something went wrong".to_string()),
);
let result = spec.call(serde_json::json!({})).await;
assert_eq!(result["ok"], serde_json::json!(false));
assert_eq!(result["error"], serde_json::json!("something went wrong"));
}
#[tokio::test]
async fn test_fallible_tool_returns_value_on_ok() {
let spec = ToolSpec::new_fallible(
"succeed",
"always succeeds",
|_| Ok::<Value, String>(serde_json::json!(42)),
);
let result = spec.call(serde_json::json!({})).await;
assert_eq!(result, serde_json::json!(42));
}
#[tokio::test]
async fn test_did_you_mean_suggestion_for_typo() {
let mut registry = ToolRegistry::new();
registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
let result = registry.call("searc", serde_json::json!({})).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("did you mean"), "expected suggestion in: {msg}");
}
#[tokio::test]
async fn test_no_suggestion_for_very_different_name() {
let mut registry = ToolRegistry::new();
registry.register(ToolSpec::new("search", "search", |_| serde_json::json!("ok")));
let result = registry.call("xxxxxxxxxxxxxxx", serde_json::json!({})).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(!msg.contains("did you mean"), "unexpected suggestion in: {msg}");
}
#[test]
fn test_action_parse_final_answer() {
let action = Action::parse("FINAL_ANSWER hello world").unwrap();
assert_eq!(action, Action::FinalAnswer("hello world".to_string()));
}
#[test]
fn test_action_parse_tool_call() {
let action = Action::parse("search {\"q\": \"rust\"}").unwrap();
match action {
Action::ToolCall { name, args } => {
assert_eq!(name, "search");
assert_eq!(args["q"], "rust");
}
_ => panic!("expected ToolCall"),
}
}
#[test]
fn test_action_parse_invalid_returns_err() {
let result = Action::parse("");
assert!(result.is_err());
}
#[tokio::test]
async fn test_observer_on_step_called_for_each_step() {
use std::sync::{Arc, Mutex};
struct CountingObserver {
step_count: Mutex<usize>,
}
impl Observer for CountingObserver {
fn on_step(&self, _step_index: usize, _step: &ReActStep) {
let mut c = self.step_count.lock().unwrap_or_else(|e| e.into_inner());
*c += 1;
}
}
let obs = Arc::new(CountingObserver { step_count: Mutex::new(0) });
let config = AgentConfig::new(5, "test-model");
let mut loop_ = ReActLoop::new(config).with_observer(obs.clone() as Arc<dyn Observer>);
loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
let mut call_count = 0;
let _steps = loop_.run("test", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: call noop\nAction: noop {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER done".to_string()
}
}
}).await.unwrap();
let count = *obs.step_count.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(count, 2, "observer should have seen 2 steps");
}
#[tokio::test]
async fn test_tool_cache_returns_cached_result_on_second_call() {
use std::collections::HashMap;
use std::sync::Mutex;
struct InMemCache {
map: Mutex<HashMap<String, Value>>,
}
impl ToolCache for InMemCache {
fn get(&self, tool_name: &str, args: &Value) -> Option<Value> {
let key = format!("{tool_name}:{args}");
let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
map.get(&key).cloned()
}
fn set(&self, tool_name: &str, args: &Value, result: Value) {
let key = format!("{tool_name}:{args}");
let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
map.insert(key, result);
}
}
let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let cache = Arc::new(InMemCache { map: Mutex::new(HashMap::new()) });
let registry = ToolRegistry::new()
.with_cache(cache as Arc<dyn ToolCache>);
let mut registry = registry;
registry.register(ToolSpec::new("count", "count calls", move |_| {
call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
serde_json::json!({"calls": 1})
}));
let args = serde_json::json!({});
let r1 = registry.call("count", args.clone()).await.unwrap();
let r2 = registry.call("count", args.clone()).await.unwrap();
assert_eq!(r1, r2);
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_validators_short_circuit_on_first_failure() {
use std::sync::atomic::{AtomicUsize, Ordering as AOrdering};
use std::sync::Arc;
let second_called = Arc::new(AtomicUsize::new(0));
let second_called_clone = Arc::clone(&second_called);
struct AlwaysFail;
impl ToolValidator for AlwaysFail {
fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
Err(AgentRuntimeError::AgentLoop("first validator failed".into()))
}
}
struct CountCalls(Arc<AtomicUsize>);
impl ToolValidator for CountCalls {
fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
self.0.fetch_add(1, AOrdering::SeqCst);
Ok(())
}
}
let mut registry = ToolRegistry::new();
registry.register(
ToolSpec::new("guarded", "A guarded tool", |args| args.clone())
.with_validators(vec![
Box::new(AlwaysFail),
Box::new(CountCalls(second_called_clone)),
]),
);
let result = registry.call("guarded", serde_json::json!({})).await;
assert!(result.is_err(), "should fail due to first validator");
assert_eq!(
second_called.load(AOrdering::SeqCst),
0,
"second validator must not be called when first fails"
);
}
#[tokio::test]
async fn test_loop_timeout_fires_between_iterations() {
let mut config = AgentConfig::new(100, "test-model");
config.loop_timeout = Some(std::time::Duration::from_millis(30));
let loop_ = ReActLoop::new(config);
let result = loop_
.run("test", |_ctx| async {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
"Thought: still working\nAction: noop {}".to_string()
})
.await;
assert!(result.is_err(), "loop should time out");
let msg = result.unwrap_err().to_string();
assert!(msg.contains("loop timeout"), "unexpected error: {msg}");
}
#[test]
fn test_react_step_is_final_answer() {
let step = ReActStep {
thought: "".into(),
action: "FINAL_ANSWER done".into(),
observation: "".into(),
step_duration_ms: 0,
};
assert!(step.is_final_answer());
assert!(!step.is_tool_call());
}
#[test]
fn test_react_step_is_tool_call() {
let step = ReActStep {
thought: "".into(),
action: "search {}".into(),
observation: "".into(),
step_duration_ms: 0,
};
assert!(!step.is_final_answer());
assert!(step.is_tool_call());
}
#[test]
fn test_role_display() {
assert_eq!(Role::System.to_string(), "system");
assert_eq!(Role::User.to_string(), "user");
assert_eq!(Role::Assistant.to_string(), "assistant");
assert_eq!(Role::Tool.to_string(), "tool");
}
#[test]
fn test_message_accessors() {
let msg = Message::new(Role::User, "hello");
assert_eq!(msg.role(), &Role::User);
assert_eq!(msg.content(), "hello");
}
#[test]
fn test_action_parse_final_answer_round_trip() {
let step = ReActStep {
thought: "done".into(),
action: "FINAL_ANSWER Paris".into(),
observation: "".into(),
step_duration_ms: 0,
};
assert!(step.is_final_answer());
let action = Action::parse(&step.action).unwrap();
assert!(matches!(action, Action::FinalAnswer(ref s) if s == "Paris"));
}
#[test]
fn test_action_parse_tool_call_round_trip() {
let step = ReActStep {
thought: "searching".into(),
action: "search {\"q\":\"hello\"}".into(),
observation: "".into(),
step_duration_ms: 0,
};
assert!(step.is_tool_call());
let action = Action::parse(&step.action).unwrap();
assert!(matches!(action, Action::ToolCall { ref name, .. } if name == "search"));
}
#[tokio::test]
async fn test_observer_receives_correct_step_indices() {
use std::sync::{Arc, Mutex};
struct IndexCollector(Arc<Mutex<Vec<usize>>>);
impl Observer for IndexCollector {
fn on_step(&self, step_index: usize, _step: &ReActStep) {
self.0.lock().unwrap_or_else(|e| e.into_inner()).push(step_index);
}
}
let indices = Arc::new(Mutex::new(Vec::new()));
let obs = Arc::new(IndexCollector(Arc::clone(&indices)));
let config = AgentConfig::new(5, "test");
let mut loop_ = ReActLoop::new(config).with_observer(obs as Arc<dyn Observer>);
loop_.register_tool(ToolSpec::new("noop", "no-op", |_| serde_json::json!({})));
let mut call_count = 0;
loop_.run("test", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: step1\nAction: noop {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER ok".to_string()
}
}
}).await.unwrap();
let collected = indices.lock().unwrap_or_else(|e| e.into_inner()).clone();
assert_eq!(collected, vec![0, 1], "expected step indices 0 and 1");
}
#[tokio::test]
async fn test_action_hook_blocking_inserts_blocked_observation() {
let hook: ActionHook = Arc::new(|_name, _args| {
Box::pin(async move { false }) });
let config = AgentConfig::new(5, "test-model");
let mut loop_ = ReActLoop::new(config).with_action_hook(hook);
loop_.register_tool(ToolSpec::new("noop", "noop", |_| serde_json::json!("ok")));
let mut call_count = 0;
let steps = loop_.run("test", |_ctx| {
call_count += 1;
let count = call_count;
async move {
if count == 1 {
"Thought: try tool\nAction: noop {}".to_string()
} else {
"Thought: done\nAction: FINAL_ANSWER done".to_string()
}
}
}).await.unwrap();
assert!(steps[0].observation.contains("blocked"), "expected blocked observation, got: {}", steps[0].observation);
}
#[test]
fn test_react_step_new_constructor() {
let s = ReActStep::new("think", "act", "obs");
assert_eq!(s.thought, "think");
assert_eq!(s.action, "act");
assert_eq!(s.observation, "obs");
assert_eq!(s.step_duration_ms, 0);
}
#[test]
fn test_react_step_new_is_tool_call() {
let s = ReActStep::new("think", "search {}", "result");
assert!(s.is_tool_call());
assert!(!s.is_final_answer());
}
#[test]
fn test_react_step_new_is_final_answer() {
let s = ReActStep::new("done", "FINAL_ANSWER 42", "");
assert!(s.is_final_answer());
assert!(!s.is_tool_call());
}
#[test]
fn test_agent_config_is_valid_with_valid_config() {
let cfg = AgentConfig::new(5, "my-model");
assert!(cfg.is_valid());
}
#[test]
fn test_agent_config_is_valid_with_zero_iterations() {
let mut cfg = AgentConfig::new(1, "my-model");
cfg.max_iterations = 0;
assert!(!cfg.is_valid());
}
#[test]
fn test_agent_config_is_valid_with_empty_model() {
let mut cfg = AgentConfig::new(5, "my-model");
cfg.model = String::new();
assert!(!cfg.is_valid());
}
#[test]
fn test_react_loop_tool_count_delegates_to_registry() {
let cfg = AgentConfig::new(5, "model");
let mut loop_ = ReActLoop::new(cfg);
assert_eq!(loop_.tool_count(), 0);
loop_.register_tool(ToolSpec::new("t1", "desc", |_| serde_json::json!("ok")));
loop_.register_tool(ToolSpec::new("t2", "desc", |_| serde_json::json!("ok")));
assert_eq!(loop_.tool_count(), 2);
}
#[test]
fn test_tool_registry_has_tool_returns_true_when_registered() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("my-tool", "desc", |_| serde_json::json!("ok")));
assert!(reg.has_tool("my-tool"));
assert!(!reg.has_tool("other-tool"));
}
#[test]
fn test_agent_config_validate_ok_for_valid_config() {
let cfg = AgentConfig::new(5, "my-model");
assert!(cfg.validate().is_ok());
}
#[test]
fn test_agent_config_validate_err_for_zero_iterations() {
let cfg = AgentConfig::new(0, "my-model");
let err = cfg.validate().unwrap_err();
assert!(err.to_string().contains("max_iterations"));
}
#[test]
fn test_agent_config_validate_err_for_empty_model() {
let cfg = AgentConfig::new(5, "");
let err = cfg.validate().unwrap_err();
assert!(err.to_string().contains("model"));
}
#[test]
fn test_clone_with_model_produces_new_model_string() {
let cfg = AgentConfig::new(5, "gpt-4");
let new_cfg = cfg.clone_with_model("claude-3");
assert_eq!(new_cfg.model, "claude-3");
assert_eq!(cfg.model, "gpt-4");
}
#[test]
fn test_clone_with_model_preserves_other_fields() {
let cfg = AgentConfig::new(10, "gpt-4").with_stop_sequences(vec!["STOP".to_string()]);
let new_cfg = cfg.clone_with_model("o1");
assert_eq!(new_cfg.max_iterations, 10);
assert_eq!(new_cfg.stop_sequences, cfg.stop_sequences);
}
#[tokio::test]
async fn test_tool_spec_with_name_changes_name() {
let spec = ToolSpec::new("original", "desc", |_| serde_json::json!("ok"))
.with_name("renamed");
assert_eq!(spec.name, "renamed");
}
#[tokio::test]
async fn test_tool_spec_with_name_and_description_chainable() {
let spec = ToolSpec::new("old", "old desc", |_| serde_json::json!("ok"))
.with_name("new")
.with_description("new desc");
assert_eq!(spec.name, "new");
assert_eq!(spec.description, "new desc");
}
#[test]
fn test_message_user_sets_role_and_content() {
let m = Message::user("hello");
assert_eq!(m.content(), "hello");
assert!(m.is_user());
assert!(!m.is_assistant());
}
#[test]
fn test_message_assistant_sets_role() {
let m = Message::assistant("reply");
assert!(m.is_assistant());
assert!(!m.is_user());
assert!(!m.is_system());
}
#[test]
fn test_message_system_sets_role() {
let m = Message::system("system prompt");
assert!(m.is_system());
assert_eq!(m.content(), "system prompt");
}
#[test]
fn test_parse_react_step_valid_input() {
let text = "Thought: I need to search\nAction: search[query]";
let step = parse_react_step(text).unwrap();
assert!(step.thought.contains("search"));
assert!(step.action.contains("search"));
}
#[test]
fn test_parse_react_step_missing_fields_returns_err() {
let text = "no structured content here";
assert!(parse_react_step(text).is_err());
}
#[test]
fn test_react_step_is_final_answer_true() {
let step = ReActStep::new("t", "FINAL_ANSWER Paris", "");
assert!(step.is_final_answer());
assert!(!step.is_tool_call());
}
#[test]
fn test_react_step_is_tool_call_true() {
let step = ReActStep::new("t", "search {}", "result");
assert!(step.is_tool_call());
assert!(!step.is_final_answer());
}
#[test]
fn test_tool_registry_unregister_returns_true_when_present() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("tool-x", "desc", |_| serde_json::json!("ok")));
assert!(reg.unregister("tool-x"));
assert!(!reg.has_tool("tool-x"));
}
#[test]
fn test_tool_registry_unregister_returns_false_when_absent() {
let mut reg = ToolRegistry::new();
assert!(!reg.unregister("ghost"));
}
#[test]
fn test_tool_registry_contains_matches_has_tool() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("alpha", "desc", |_| serde_json::json!("ok")));
assert!(reg.contains("alpha"));
assert!(!reg.contains("beta"));
}
#[test]
fn test_agent_config_with_system_prompt() {
let cfg = AgentConfig::new(5, "model")
.with_system_prompt("You are helpful.");
assert_eq!(cfg.system_prompt, "You are helpful.");
}
#[test]
fn test_agent_config_with_temperature_and_max_tokens() {
let cfg = AgentConfig::new(3, "model")
.with_temperature(0.7)
.with_max_tokens(512);
assert!((cfg.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(cfg.max_tokens, Some(512));
}
#[test]
fn test_agent_config_clone_with_model() {
let orig = AgentConfig::new(5, "gpt-4");
let cloned = orig.clone_with_model("claude-3");
assert_eq!(cloned.model, "claude-3");
assert_eq!(cloned.max_iterations, 5);
}
#[test]
fn test_agent_config_with_loop_timeout_secs() {
let cfg = AgentConfig::new(5, "model").with_loop_timeout_secs(30);
assert_eq!(cfg.loop_timeout, Some(std::time::Duration::from_secs(30)));
}
#[test]
fn test_agent_config_with_max_context_chars() {
let cfg = AgentConfig::new(5, "model").with_max_context_chars(4096);
assert_eq!(cfg.max_context_chars, Some(4096));
}
#[test]
fn test_agent_config_with_stop_sequences() {
let cfg = AgentConfig::new(5, "model")
.with_stop_sequences(vec!["STOP".to_string(), "END".to_string()]);
assert_eq!(cfg.stop_sequences, vec!["STOP", "END"]);
}
#[test]
fn test_message_is_tool_false_for_non_tool_roles() {
assert!(!Message::user("hi").is_tool());
assert!(!Message::assistant("reply").is_tool());
assert!(!Message::system("prompt").is_tool());
}
#[test]
fn test_agent_config_with_max_iterations() {
let cfg = AgentConfig::new(5, "m").with_max_iterations(20);
assert_eq!(cfg.max_iterations, 20);
}
#[test]
fn test_tool_registry_tool_names_owned_returns_strings() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("alpha", "d", |_| serde_json::json!("ok")));
reg.register(ToolSpec::new("beta", "d", |_| serde_json::json!("ok")));
let mut names = reg.tool_names_owned();
names.sort();
assert_eq!(names, vec!["alpha".to_string(), "beta".to_string()]);
}
#[test]
fn test_tool_registry_tool_names_owned_empty_when_no_tools() {
let reg = ToolRegistry::new();
assert!(reg.tool_names_owned().is_empty());
}
#[test]
fn test_tool_registry_tool_specs_returns_all_specs() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "desc1", |_| serde_json::json!("ok")));
reg.register(ToolSpec::new("t2", "desc2", |_| serde_json::json!("ok")));
let specs = reg.tool_specs();
assert_eq!(specs.len(), 2);
}
#[test]
fn test_tool_registry_tool_specs_empty_when_no_tools() {
let reg = ToolRegistry::new();
assert!(reg.tool_specs().is_empty());
}
#[test]
fn test_rename_tool_updates_name_and_key() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("old", "desc", |_| serde_json::json!("ok")));
assert!(reg.rename_tool("old", "new"));
assert!(reg.has_tool("new"));
assert!(!reg.has_tool("old"));
let spec = reg.get("new").unwrap();
assert_eq!(spec.name, "new");
}
#[test]
fn test_rename_tool_returns_false_for_unknown_name() {
let mut reg = ToolRegistry::new();
assert!(!reg.rename_tool("ghost", "other"));
}
#[test]
fn test_filter_tools_returns_matching_specs() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("short_desc", "hi", |_| serde_json::json!({})));
reg.register(ToolSpec::new("long_desc", "a longer description here", |_| serde_json::json!({})));
let long_ones = reg.filter_tools(|s| s.description.len() > 10);
assert_eq!(long_ones.len(), 1);
assert_eq!(long_ones[0].name, "long_desc");
}
#[test]
fn test_filter_tools_returns_empty_when_none_match() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "desc", |_| serde_json::json!({})));
let none: Vec<_> = reg.filter_tools(|_| false);
assert!(none.is_empty());
}
#[test]
fn test_filter_tools_returns_all_when_predicate_always_true() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "d1", |_| serde_json::json!({})));
reg.register(ToolSpec::new("b", "d2", |_| serde_json::json!({})));
let all = reg.filter_tools(|_| true);
assert_eq!(all.len(), 2);
}
#[test]
fn test_agent_config_max_iterations_getter_returns_configured_value() {
let cfg = AgentConfig::new(5, "model-x");
assert_eq!(cfg.max_iterations(), 5);
}
#[test]
fn test_agent_config_with_max_iterations_updates_getter() {
let cfg = AgentConfig::new(3, "m").with_max_iterations(10);
assert_eq!(cfg.max_iterations(), 10);
}
#[test]
fn test_tool_registry_is_empty_true_when_new() {
let reg = ToolRegistry::new();
assert!(reg.is_empty());
}
#[test]
fn test_tool_registry_is_empty_false_after_register() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t", "d", |_| serde_json::json!({})));
assert!(!reg.is_empty());
}
#[test]
fn test_tool_registry_clear_empties_registry() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "d", |_| serde_json::json!({})));
reg.clear();
assert!(reg.is_empty());
assert_eq!(reg.tool_count(), 0);
}
#[test]
fn test_tool_registry_remove_returns_spec_and_decrements_count() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("myTool", "desc", |_| serde_json::json!({})));
assert_eq!(reg.tool_count(), 1);
let removed = reg.remove("myTool");
assert!(removed.is_some());
assert_eq!(reg.tool_count(), 0);
}
#[test]
fn test_tool_registry_remove_returns_none_for_absent_tool() {
let mut reg = ToolRegistry::new();
assert!(reg.remove("ghost").is_none());
}
#[test]
fn test_all_tool_names_returns_sorted_names() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("zebra", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("apple", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("mango", "d", |_| serde_json::json!({})));
let names = reg.all_tool_names();
assert_eq!(names, vec!["apple", "mango", "zebra"]);
}
#[test]
fn test_all_tool_names_empty_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(reg.all_tool_names().is_empty());
}
#[test]
fn test_remaining_iterations_after_full_budget() {
let cfg = AgentConfig::new(10, "m");
assert_eq!(cfg.remaining_iterations_after(0), 10);
}
#[test]
fn test_remaining_iterations_after_partial_use() {
let cfg = AgentConfig::new(10, "m");
assert_eq!(cfg.remaining_iterations_after(3), 7);
}
#[test]
fn test_remaining_iterations_after_saturates_at_zero() {
let cfg = AgentConfig::new(5, "m");
assert_eq!(cfg.remaining_iterations_after(10), 0);
}
#[test]
fn test_tool_spec_required_field_count_zero_by_default() {
let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}));
assert_eq!(spec.required_field_count(), 0);
}
#[test]
fn test_tool_spec_required_field_count_after_adding() {
let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}))
.with_required_fields(["query", "limit"]);
assert_eq!(spec.required_field_count(), 2);
}
#[test]
fn test_tool_spec_has_required_fields_false_by_default() {
let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}));
assert!(!spec.has_required_fields());
}
#[test]
fn test_tool_spec_has_required_fields_true_after_adding() {
let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}))
.with_required_fields(["key"]);
assert!(spec.has_required_fields());
}
#[test]
fn test_tool_spec_has_validators_false_by_default() {
let spec = ToolSpec::new("t", "d", |_| serde_json::json!({}));
assert!(!spec.has_validators());
}
#[test]
fn test_tool_registry_contains_true_for_registered_tool() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "d", |_| serde_json::json!({})));
assert!(reg.contains("search"));
}
#[test]
fn test_tool_registry_contains_false_for_unknown_tool() {
let reg = ToolRegistry::new();
assert!(!reg.contains("missing"));
}
#[test]
fn test_tool_registry_descriptions_sorted_by_name() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("zebra", "z-desc", |_| serde_json::json!({})));
reg.register(ToolSpec::new("apple", "a-desc", |_| serde_json::json!({})));
let descs = reg.descriptions();
assert_eq!(descs[0], ("apple", "a-desc"));
assert_eq!(descs[1], ("zebra", "z-desc"));
}
#[test]
fn test_tool_registry_descriptions_empty_when_no_tools() {
let reg = ToolRegistry::new();
assert!(reg.descriptions().is_empty());
}
#[test]
fn test_tool_registry_tool_count_increments_on_register() {
let mut reg = ToolRegistry::new();
assert_eq!(reg.tool_count(), 0);
reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
assert_eq!(reg.tool_count(), 1);
reg.register(ToolSpec::new("t2", "d", |_| serde_json::json!({})));
assert_eq!(reg.tool_count(), 2);
}
#[test]
fn test_observation_is_empty_true_for_empty_string() {
let step = ReActStep::new("think", "search", "");
assert!(step.observation_is_empty());
}
#[test]
fn test_observation_is_empty_false_for_non_empty() {
let step = ReActStep::new("think", "search", "found results");
assert!(!step.observation_is_empty());
}
#[test]
fn test_agent_config_temperature_getter_none_by_default() {
let cfg = AgentConfig::new(5, "gpt-4");
assert!(cfg.temperature().is_none());
}
#[test]
fn test_agent_config_temperature_getter_some_when_set() {
let cfg = AgentConfig::new(5, "gpt-4").with_temperature(0.7);
assert!((cfg.temperature().unwrap() - 0.7).abs() < 1e-5);
}
#[test]
fn test_agent_config_max_tokens_getter_none_by_default() {
let cfg = AgentConfig::new(5, "gpt-4");
assert!(cfg.max_tokens().is_none());
}
#[test]
fn test_agent_config_max_tokens_getter_some_when_set() {
let cfg = AgentConfig::new(5, "gpt-4").with_max_tokens(512);
assert_eq!(cfg.max_tokens(), Some(512));
}
#[test]
fn test_agent_config_request_timeout_getter_none_by_default() {
let cfg = AgentConfig::new(5, "gpt-4");
assert!(cfg.request_timeout().is_none());
}
#[test]
fn test_agent_config_request_timeout_getter_some_when_set() {
let cfg = AgentConfig::new(5, "gpt-4")
.with_request_timeout(std::time::Duration::from_secs(10));
assert_eq!(cfg.request_timeout(), Some(std::time::Duration::from_secs(10)));
}
#[test]
fn test_agent_config_has_max_context_chars_false_by_default() {
let cfg = AgentConfig::new(5, "gpt-4");
assert!(!cfg.has_max_context_chars());
}
#[test]
fn test_agent_config_has_max_context_chars_true_after_setting() {
let cfg = AgentConfig::new(5, "gpt-4").with_max_context_chars(8192);
assert!(cfg.has_max_context_chars());
}
#[test]
fn test_agent_config_max_context_chars_none_by_default() {
let cfg = AgentConfig::new(5, "gpt-4");
assert_eq!(cfg.max_context_chars(), None);
}
#[test]
fn test_agent_config_max_context_chars_some_after_setting() {
let cfg = AgentConfig::new(5, "gpt-4").with_max_context_chars(4096);
assert_eq!(cfg.max_context_chars(), Some(4096));
}
#[test]
fn test_agent_config_system_prompt_returns_configured_prompt() {
let cfg = AgentConfig::new(5, "gpt-4").with_system_prompt("Be concise.");
assert_eq!(cfg.system_prompt(), "Be concise.");
}
#[test]
fn test_agent_config_model_returns_configured_model() {
let cfg = AgentConfig::new(5, "claude-3");
assert_eq!(cfg.model(), "claude-3");
}
#[test]
fn test_message_is_system_true_for_system_role() {
let m = Message::system("context");
assert!(m.is_system());
}
#[test]
fn test_message_is_system_false_for_user_role() {
let m = Message::user("hello");
assert!(!m.is_system());
}
#[test]
fn test_message_word_count_counts_whitespace_words() {
let m = Message::user("hello world foo");
assert_eq!(m.word_count(), 3);
}
#[test]
fn test_message_word_count_zero_for_empty_content() {
let m = Message::user("");
assert_eq!(m.word_count(), 0);
}
#[test]
fn test_agent_config_has_loop_timeout_false_by_default() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.has_loop_timeout());
}
#[test]
fn test_agent_config_has_loop_timeout_true_after_setting() {
let cfg = AgentConfig::new(5, "m")
.with_loop_timeout(std::time::Duration::from_secs(30));
assert!(cfg.has_loop_timeout());
}
#[test]
fn test_agent_config_has_stop_sequences_false_by_default() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.has_stop_sequences());
}
#[test]
fn test_agent_config_has_stop_sequences_true_after_adding() {
let cfg = AgentConfig::new(5, "m").with_stop_sequences(vec!["STOP".to_string()]);
assert!(cfg.has_stop_sequences());
}
#[test]
fn test_agent_config_is_single_shot_true_when_max_iterations_one() {
let cfg = AgentConfig::new(1, "m");
assert!(cfg.is_single_shot());
}
#[test]
fn test_agent_config_is_single_shot_false_when_max_iterations_gt_one() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.is_single_shot());
}
#[test]
fn test_agent_config_has_temperature_false_by_default() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.has_temperature());
}
#[test]
fn test_agent_config_has_temperature_true_after_setting() {
let cfg = AgentConfig::new(5, "m").with_temperature(0.7);
assert!(cfg.has_temperature());
}
#[test]
fn test_tool_spec_new_fallible_returns_ok_value() {
let rt = tokio::runtime::Runtime::new().unwrap();
let tool = ToolSpec::new_fallible(
"add",
"adds numbers",
|_args| Ok(serde_json::json!({"result": 42})),
);
let result = rt.block_on(tool.call(serde_json::json!({})));
assert_eq!(result["result"], 42);
}
#[test]
fn test_tool_spec_new_fallible_wraps_error_as_json() {
let rt = tokio::runtime::Runtime::new().unwrap();
let tool = ToolSpec::new_fallible(
"fail",
"always fails",
|_| Err("bad input".to_string()),
);
let result = rt.block_on(tool.call(serde_json::json!({})));
assert_eq!(result["error"], "bad input");
assert_eq!(result["ok"], false);
}
#[test]
fn test_tool_spec_new_async_fallible_wraps_error() {
let rt = tokio::runtime::Runtime::new().unwrap();
let tool = ToolSpec::new_async_fallible(
"async_fail",
"async error",
|_| Box::pin(async { Err("async bad".to_string()) }),
);
let result = rt.block_on(tool.call(serde_json::json!({})));
assert_eq!(result["error"], "async bad");
}
#[test]
fn test_tool_spec_with_required_fields_sets_fields() {
let tool = ToolSpec::new("t", "d", |_| serde_json::json!({}))
.with_required_fields(["name", "value"]);
assert_eq!(tool.required_field_count(), 2);
}
#[test]
fn test_tool_spec_with_description_overrides_description() {
let tool = ToolSpec::new("t", "original", |_| serde_json::json!({}))
.with_description("updated description");
assert_eq!(tool.description, "updated description");
}
#[test]
fn test_agent_config_stop_sequence_count_zero_by_default() {
let cfg = AgentConfig::new(5, "gpt-4");
assert_eq!(cfg.stop_sequence_count(), 0);
}
#[test]
fn test_agent_config_stop_sequence_count_reflects_configured_count() {
let cfg = AgentConfig::new(5, "gpt-4")
.with_stop_sequences(vec!["STOP".to_string(), "END".to_string()]);
assert_eq!(cfg.stop_sequence_count(), 2);
}
#[test]
fn test_tool_registry_find_by_description_keyword_empty_when_no_match() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("calc", "Performs arithmetic", |_| serde_json::json!({})));
let results = reg.find_by_description_keyword("weather");
assert!(results.is_empty());
}
#[test]
fn test_tool_registry_find_by_description_keyword_case_insensitive() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("calc", "Performs ARITHMETIC operations", |_| serde_json::json!({})));
reg.register(ToolSpec::new("search", "Searches the web", |_| serde_json::json!({})));
let results = reg.find_by_description_keyword("arithmetic");
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "calc");
}
#[test]
fn test_tool_registry_find_by_description_keyword_multiple_matches() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "query the database", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "query the cache", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t3", "send a message", |_| serde_json::json!({})));
let results = reg.find_by_description_keyword("query");
assert_eq!(results.len(), 2);
}
#[test]
fn test_message_is_user_true_for_user_role_r31() {
let msg = Message::user("hello");
assert!(msg.is_user());
assert!(!msg.is_assistant());
}
#[test]
fn test_message_is_assistant_true_for_assistant_role_r31() {
let msg = Message::assistant("hi there");
assert!(msg.is_assistant());
assert!(!msg.is_user());
}
#[test]
fn test_agent_config_stop_sequence_count_zero_for_new_config() {
let cfg = AgentConfig::new(5, "model");
assert_eq!(cfg.stop_sequence_count(), 0);
}
#[test]
fn test_agent_config_stop_sequence_count_after_setting() {
let cfg = AgentConfig::new(5, "model")
.with_stop_sequences(vec!["<stop>".to_string(), "END".to_string()]);
assert_eq!(cfg.stop_sequence_count(), 2);
}
#[test]
fn test_agent_config_has_request_timeout_false_by_default() {
let cfg = AgentConfig::new(5, "model");
assert!(!cfg.has_request_timeout());
}
#[test]
fn test_agent_config_has_request_timeout_true_after_setting() {
let cfg = AgentConfig::new(5, "model")
.with_request_timeout(std::time::Duration::from_secs(30));
assert!(cfg.has_request_timeout());
}
#[test]
fn test_react_loop_unregister_tool_removes_registered_tool() {
let mut agent = ReActLoop::new(AgentConfig::new(5, "m"));
agent.register_tool(ToolSpec::new("t1", "desc", |_| serde_json::json!({})));
assert!(agent.unregister_tool("t1"));
assert_eq!(agent.tool_count(), 0);
}
#[test]
fn test_react_loop_unregister_tool_returns_false_for_unknown() {
let mut agent = ReActLoop::new(AgentConfig::new(5, "m"));
assert!(!agent.unregister_tool("nonexistent"));
}
#[test]
fn test_tool_count_with_required_fields_zero_when_empty() {
let reg = ToolRegistry::new();
assert_eq!(reg.tool_count_with_required_fields(), 0);
}
#[test]
fn test_tool_count_with_required_fields_excludes_tools_without_fields() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
assert_eq!(reg.tool_count_with_required_fields(), 0);
}
#[test]
fn test_tool_count_with_required_fields_counts_only_tools_with_fields() {
let mut reg = ToolRegistry::new();
reg.register(
ToolSpec::new("t1", "d", |_| serde_json::json!({}))
.with_required_fields(["query"]),
);
reg.register(ToolSpec::new("t2", "d", |_| serde_json::json!({}))); reg.register(
ToolSpec::new("t3", "d", |_| serde_json::json!({}))
.with_required_fields(["url", "method"]),
);
assert_eq!(reg.tool_count_with_required_fields(), 2);
}
#[test]
fn test_tool_registry_names_empty_when_no_tools() {
let reg = ToolRegistry::new();
assert!(reg.names().is_empty());
}
#[test]
fn test_tool_registry_names_sorted_alphabetically() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("zebra", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("alpha", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("mango", "d", |_| serde_json::json!({})));
assert_eq!(reg.names(), vec!["alpha", "mango", "zebra"]);
}
#[test]
fn test_tool_names_starting_with_empty_when_no_match() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "d", |_| serde_json::json!({})));
assert!(reg.tool_names_starting_with("calc").is_empty());
}
#[test]
fn test_tool_names_starting_with_returns_sorted_matches() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("db_write", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("db_read", "d", |_| serde_json::json!({})));
reg.register(ToolSpec::new("cache_get", "d", |_| serde_json::json!({})));
let results = reg.tool_names_starting_with("db_");
assert_eq!(results, vec!["db_read", "db_write"]);
}
#[test]
fn test_tool_registry_description_for_none_when_missing() {
let reg = ToolRegistry::new();
assert!(reg.description_for("unknown").is_none());
}
#[test]
fn test_tool_registry_description_for_returns_description() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "Find web results", |_| serde_json::json!({})));
assert_eq!(reg.description_for("search"), Some("Find web results"));
}
#[test]
fn test_count_with_description_containing_zero_when_no_match() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "database query", |_| serde_json::json!({})));
assert_eq!(reg.count_with_description_containing("weather"), 0);
}
#[test]
fn test_count_with_description_containing_case_insensitive() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "Search the WEB", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "web scraper tool", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t3", "database lookup", |_| serde_json::json!({})));
assert_eq!(reg.count_with_description_containing("web"), 2);
}
#[test]
fn test_unregister_all_clears_all_tools() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "tool one", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "tool two", |_| serde_json::json!({})));
assert_eq!(reg.tool_count(), 2);
reg.unregister_all();
assert_eq!(reg.tool_count(), 0);
}
#[test]
fn test_tool_names_with_keyword_returns_matching_tool_names() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "search the web for info", |_| serde_json::json!({})));
reg.register(ToolSpec::new("db", "query database records", |_| serde_json::json!({})));
reg.register(ToolSpec::new("web-fetch", "fetch a WEB page", |_| serde_json::json!({})));
let mut names = reg.tool_names_with_keyword("web");
names.sort_unstable();
assert_eq!(names, vec!["search", "web-fetch"]);
}
#[test]
fn test_tool_names_with_keyword_no_match_returns_empty() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t", "some tool", |_| serde_json::json!({})));
assert!(reg.tool_names_with_keyword("missing").is_empty());
}
#[test]
fn test_all_descriptions_returns_sorted_descriptions() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "z description", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "a description", |_| serde_json::json!({})));
assert_eq!(reg.all_descriptions(), vec!["a description", "z description"]);
}
#[test]
fn test_all_descriptions_empty_registry_returns_empty() {
let reg = ToolRegistry::new();
assert!(reg.all_descriptions().is_empty());
}
#[test]
fn test_longest_description_returns_longest() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "short", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "a much longer description here", |_| serde_json::json!({})));
assert_eq!(reg.longest_description(), Some("a much longer description here"));
}
#[test]
fn test_longest_description_empty_registry_returns_none() {
let reg = ToolRegistry::new();
assert!(reg.longest_description().is_none());
}
#[test]
fn test_names_containing_returns_sorted_matching_names() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search-web", "search tool", |_| serde_json::json!({})));
reg.register(ToolSpec::new("web-fetch", "fetch tool", |_| serde_json::json!({})));
reg.register(ToolSpec::new("db-query", "database tool", |_| serde_json::json!({})));
let names = reg.names_containing("web");
assert_eq!(names, vec!["search-web", "web-fetch"]);
}
#[test]
fn test_names_containing_no_match_returns_empty() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t", "tool", |_| serde_json::json!({})));
assert!(reg.names_containing("missing").is_empty());
}
#[test]
fn test_avg_description_length_returns_mean_byte_length() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "ab", |_| serde_json::json!({}))); reg.register(ToolSpec::new("b", "abcd", |_| serde_json::json!({}))); let avg = reg.avg_description_length();
assert!((avg - 3.0).abs() < 1e-9);
}
#[test]
fn test_avg_description_length_returns_zero_when_empty() {
let reg = ToolRegistry::new();
assert_eq!(reg.avg_description_length(), 0.0);
}
#[test]
fn test_shortest_description_returns_shortest_string() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "hello world", |_| serde_json::json!({})));
reg.register(ToolSpec::new("b", "hi", |_| serde_json::json!({})));
reg.register(ToolSpec::new("c", "greetings", |_| serde_json::json!({})));
assert_eq!(reg.shortest_description(), Some("hi"));
}
#[test]
fn test_shortest_description_returns_none_when_empty() {
let reg = ToolRegistry::new();
assert!(reg.shortest_description().is_none());
}
#[test]
fn test_tool_names_sorted_returns_names_in_alphabetical_order() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("zap", "z tool", |_| serde_json::json!({})));
reg.register(ToolSpec::new("alpha", "a tool", |_| serde_json::json!({})));
reg.register(ToolSpec::new("middle", "m tool", |_| serde_json::json!({})));
assert_eq!(reg.tool_names_sorted(), vec!["alpha", "middle", "zap"]);
}
#[test]
fn test_tool_names_sorted_empty_returns_empty() {
let reg = ToolRegistry::new();
assert!(reg.tool_names_sorted().is_empty());
}
#[test]
fn test_description_contains_count_counts_matching_descriptions() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "search the web", |_| serde_json::json!({})));
reg.register(ToolSpec::new("b", "write to disk", |_| serde_json::json!({})));
reg.register(ToolSpec::new("c", "search and filter", |_| serde_json::json!({})));
assert_eq!(reg.description_contains_count("search"), 2);
assert_eq!(reg.description_contains_count("SEARCH"), 2);
assert_eq!(reg.description_contains_count("missing"), 0);
}
#[test]
fn test_description_contains_count_zero_when_empty() {
let reg = ToolRegistry::new();
assert_eq!(reg.description_contains_count("anything"), 0);
}
#[test]
fn test_react_step_summary_tool_kind() {
let step = ReActStep::new("I need to search", r#"{"tool":"search","q":"rust"}"#, "results");
let s = step.summary();
assert!(s.starts_with("[TOOL]"));
assert!(s.contains("I need to search"));
assert!(s.contains("results"));
}
#[test]
fn test_react_step_summary_final_kind() {
let step = ReActStep::new("Done", "FINAL_ANSWER hello", "");
let s = step.summary();
assert!(s.starts_with("[FINAL]"));
assert!(s.contains("FINAL_ANSWER hello"));
}
#[test]
fn test_react_step_summary_truncates_long_fields() {
let long = "a".repeat(100);
let step = ReActStep::new(long.clone(), long.clone(), long.clone());
let s = step.summary();
assert!(s.contains('…'));
}
#[test]
fn test_react_step_summary_empty_fields() {
let step = ReActStep::new("", "", "");
let s = step.summary();
assert!(s.contains("[TOOL]"));
}
#[test]
fn test_tool_registry_total_description_bytes_sums_correctly() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "hello", |_| serde_json::json!({}))); reg.register(ToolSpec::new("b", "world!", |_| serde_json::json!({}))); assert_eq!(reg.total_description_bytes(), 11);
}
#[test]
fn test_tool_registry_total_description_bytes_empty_returns_zero() {
let reg = ToolRegistry::new();
assert_eq!(reg.total_description_bytes(), 0);
}
#[test]
fn test_react_step_thought_word_count_counts_words() {
let step = ReActStep::new("hello world foo", "act", "obs");
assert_eq!(step.thought_word_count(), 3);
}
#[test]
fn test_react_step_thought_word_count_empty_thought_returns_zero() {
let step = ReActStep::new("", "act", "obs");
assert_eq!(step.thought_word_count(), 0);
}
#[test]
fn test_agent_config_clone_with_system_prompt_changes_only_prompt() {
let original = AgentConfig::new(5, "gpt-4");
let cloned = original.clone_with_system_prompt("Custom prompt.");
assert_eq!(cloned.system_prompt, "Custom prompt.");
assert_eq!(cloned.model, "gpt-4");
assert_eq!(cloned.max_iterations, 5);
}
#[test]
fn test_agent_config_clone_with_system_prompt_leaves_original_unchanged() {
let original = AgentConfig::new(3, "claude").with_system_prompt("Original.");
let _cloned = original.clone_with_system_prompt("New.");
assert_eq!(original.system_prompt, "Original.");
}
#[test]
fn test_agent_config_clone_with_max_iterations_changes_only_iterations() {
let original = AgentConfig::new(5, "claude-3");
let cloned = original.clone_with_max_iterations(20);
assert_eq!(cloned.max_iterations, 20);
assert_eq!(cloned.model, "claude-3");
}
#[test]
fn test_agent_config_clone_with_max_iterations_leaves_original_unchanged() {
let original = AgentConfig::new(5, "claude-3");
let _cloned = original.clone_with_max_iterations(10);
assert_eq!(original.max_iterations, 5);
}
#[test]
fn test_message_display_user_role() {
let m = Message::user("hello world");
assert_eq!(m.to_string(), "user: hello world");
}
#[test]
fn test_message_display_assistant_role() {
let m = Message::assistant("I can help");
assert_eq!(m.to_string(), "assistant: I can help");
}
#[test]
fn test_message_display_system_role() {
let m = Message::system("Be helpful");
assert_eq!(m.to_string(), "system: Be helpful");
}
#[test]
fn test_message_from_role_string_tuple() {
let m = Message::from((Role::User, "hello".to_owned()));
assert_eq!(m.role, Role::User);
assert_eq!(m.content, "hello");
}
#[test]
fn test_message_from_role_str_ref_tuple() {
let m = Message::from((Role::Assistant, "ok"));
assert_eq!(m.role, Role::Assistant);
assert_eq!(m.content, "ok");
}
#[test]
fn test_message_into_from_system_tuple() {
let m: Message = (Role::System, "sys prompt").into();
assert!(m.is_system());
assert_eq!(m.content(), "sys prompt");
}
#[test]
fn test_tool_registry_shortest_description_length_returns_min_bytes() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "hello world", |_| serde_json::json!({}))); reg.register(ToolSpec::new("b", "hi", |_| serde_json::json!({}))); reg.register(ToolSpec::new("c", "greetings!", |_| serde_json::json!({}))); assert_eq!(reg.shortest_description_length(), 2);
}
#[test]
fn test_tool_registry_shortest_description_length_empty_returns_zero() {
let reg = ToolRegistry::new();
assert_eq!(reg.shortest_description_length(), 0);
}
struct AlwaysOk;
impl ToolValidator for AlwaysOk {
fn validate(&self, _args: &Value) -> Result<(), AgentRuntimeError> {
Ok(())
}
}
#[test]
fn test_tool_count_with_validators_counts_tools_that_have_validators() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "desc", |_| serde_json::json!({}))
.with_validators(vec![Box::new(AlwaysOk)]));
reg.register(ToolSpec::new("b", "desc", |_| serde_json::json!({}))); reg.register(ToolSpec::new("c", "desc", |_| serde_json::json!({}))
.with_validators(vec![Box::new(AlwaysOk)]));
assert_eq!(reg.tool_count_with_validators(), 2);
}
#[test]
fn test_tool_count_with_validators_zero_when_none_have_validators() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "desc", |_| serde_json::json!({})));
assert_eq!(reg.tool_count_with_validators(), 0);
}
#[test]
fn test_tool_count_with_validators_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.tool_count_with_validators(), 0);
}
#[test]
fn test_longest_description_length_returns_max_bytes() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "hi", |_| serde_json::json!({}))); reg.register(ToolSpec::new("b", "hello world", |_| serde_json::json!({}))); reg.register(ToolSpec::new("c", "yo", |_| serde_json::json!({}))); assert_eq!(reg.longest_description_length(), 11);
}
#[test]
fn test_longest_description_length_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.longest_description_length(), 0);
}
#[test]
fn test_tools_with_required_field_returns_matching_tools() {
let mut reg = ToolRegistry::new();
reg.register(
ToolSpec::new("a", "desc", |_| serde_json::json!({}))
.with_required_fields(vec!["query".to_string()]),
);
reg.register(ToolSpec::new("b", "desc", |_| serde_json::json!({}))); reg.register(
ToolSpec::new("c", "desc", |_| serde_json::json!({}))
.with_required_fields(vec!["query".to_string(), "limit".to_string()]),
);
let result = reg.tools_with_required_field("query");
assert_eq!(result.len(), 2);
assert!(result.iter().any(|t| t.name == "a"));
assert!(result.iter().any(|t| t.name == "c"));
}
#[test]
fn test_tools_with_required_field_empty_when_no_match() {
let mut reg = ToolRegistry::new();
reg.register(
ToolSpec::new("a", "desc", |_| serde_json::json!({}))
.with_required_fields(vec!["x".to_string()]),
);
assert!(reg.tools_with_required_field("missing").is_empty());
}
#[test]
fn test_tools_with_required_field_empty_registry_returns_empty() {
let reg = ToolRegistry::new();
assert!(reg.tools_with_required_field("any").is_empty());
}
#[test]
fn test_observation_word_count_counts_words() {
let step = ReActStep {
thought: "t".into(),
action: "a".into(),
observation: "hello world foo".into(),
step_duration_ms: 0,
};
assert_eq!(step.observation_word_count(), 3);
}
#[test]
fn test_observation_word_count_zero_for_empty() {
let step = ReActStep {
thought: "t".into(),
action: "a".into(),
observation: "".into(),
step_duration_ms: 0,
};
assert_eq!(step.observation_word_count(), 0);
}
#[test]
fn test_tool_with_most_required_fields_returns_correct_tool() {
let mut reg = ToolRegistry::new();
reg.register(
ToolSpec::new("few", "d", |_| serde_json::json!({}))
.with_required_fields(vec!["a".to_string()]),
);
reg.register(
ToolSpec::new("many", "d", |_| serde_json::json!({}))
.with_required_fields(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
);
let winner = reg.tool_with_most_required_fields().unwrap();
assert_eq!(winner.name, "many");
}
#[test]
fn test_tool_with_most_required_fields_returns_none_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(reg.tool_with_most_required_fields().is_none());
}
#[test]
fn test_tool_count_above_desc_bytes_counts_tools_with_long_descriptions() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("short", "hi", |_| serde_json::json!({})));
reg.register(ToolSpec::new("long", "a much longer description here", |_| serde_json::json!({})));
assert_eq!(reg.tool_count_above_desc_bytes(2), 1);
}
#[test]
fn test_tool_count_above_desc_bytes_zero_when_none_exceed() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t", "ab", |_| serde_json::json!({})));
assert_eq!(reg.tool_count_above_desc_bytes(100), 0);
}
#[test]
fn test_tool_count_above_desc_bytes_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.tool_count_above_desc_bytes(0), 0);
}
#[test]
fn test_tool_names_with_required_fields_returns_sorted_names() {
let mut reg = ToolRegistry::new();
reg.register(
ToolSpec::new("b", "desc", |_| serde_json::json!({}))
.with_required_fields(vec!["x".to_string()]),
);
reg.register(
ToolSpec::new("a", "desc", |_| serde_json::json!({}))
.with_required_fields(vec!["y".to_string()]),
);
reg.register(ToolSpec::new("c", "desc", |_| serde_json::json!({}))); assert_eq!(reg.tool_names_with_required_fields(), vec!["a", "b"]);
}
#[test]
fn test_tool_names_with_required_fields_empty_when_none_have_fields() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "desc", |_| serde_json::json!({})));
assert!(reg.tool_names_with_required_fields().is_empty());
}
#[test]
fn test_tool_names_with_required_fields_empty_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(reg.tool_names_with_required_fields().is_empty());
}
#[test]
fn test_tools_without_required_fields_returns_tools_with_no_required_fields() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("no-req", "desc", |_| serde_json::json!({})));
reg.register(
ToolSpec::new("with-req", "desc", |_| serde_json::json!({}))
.with_required_fields(vec!["x".to_string()]),
);
let result = reg.tools_without_required_fields();
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "no-req");
}
#[test]
fn test_tools_without_required_fields_empty_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(reg.tools_without_required_fields().is_empty());
}
#[test]
fn test_avg_required_fields_count_computes_mean() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "d", |_| serde_json::json!({})));
reg.register(
ToolSpec::new("t2", "d", |_| serde_json::json!({}))
.with_required_fields(vec!["a".to_string(), "b".to_string()]),
);
assert!((reg.avg_required_fields_count() - 1.0).abs() < 1e-9);
}
#[test]
fn test_avg_required_fields_count_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.avg_required_fields_count(), 0.0);
}
#[test]
fn test_thought_is_empty_true_for_empty_thought() {
let step = ReActStep::new("", "action", "obs");
assert!(step.thought_is_empty());
}
#[test]
fn test_thought_is_empty_true_for_whitespace_only() {
let step = ReActStep::new(" ", "action", "obs");
assert!(step.thought_is_empty());
}
#[test]
fn test_thought_is_empty_false_for_nonempty_thought() {
let step = ReActStep::new("I need to search", "action", "obs");
assert!(!step.thought_is_empty());
}
#[test]
fn test_model_is_true_for_matching_name() {
let config = AgentConfig::new(10, "claude-sonnet-4-6");
assert!(config.model_is("claude-sonnet-4-6"));
}
#[test]
fn test_model_is_false_for_different_name() {
let config = AgentConfig::new(10, "claude-opus-4-6");
assert!(!config.model_is("claude-sonnet-4-6"));
}
#[test]
fn test_loop_timeout_ms_returns_zero_when_not_configured() {
let config = AgentConfig::new(10, "m");
assert_eq!(config.loop_timeout_ms(), 0);
}
#[test]
fn test_loop_timeout_ms_returns_millis_when_configured() {
let config = AgentConfig::new(10, "m")
.with_loop_timeout(std::time::Duration::from_millis(5000));
assert_eq!(config.loop_timeout_ms(), 5000);
}
#[test]
fn test_total_timeout_ms_zero_when_neither_configured() {
let config = AgentConfig::new(10, "m");
assert_eq!(config.total_timeout_ms(), 0);
}
#[test]
fn test_total_timeout_ms_includes_loop_and_request_budgets() {
let config = AgentConfig::new(4, "m")
.with_loop_timeout(std::time::Duration::from_millis(1000))
.with_request_timeout(std::time::Duration::from_millis(500));
assert_eq!(config.total_timeout_ms(), 3000);
}
#[test]
fn test_tool_descriptions_total_words_sums_words() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "one two three", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "four five", |_| serde_json::json!({})));
assert_eq!(reg.tool_descriptions_total_words(), 5);
}
#[test]
fn test_tool_descriptions_total_words_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.tool_descriptions_total_words(), 0);
}
#[test]
fn test_content_starts_with_true_for_matching_prefix() {
let msg = Message::user("Hello, world!");
assert!(msg.content_starts_with("Hello"));
}
#[test]
fn test_content_starts_with_false_for_non_matching_prefix() {
let msg = Message::user("Hello, world!");
assert!(!msg.content_starts_with("World"));
}
#[test]
fn test_content_starts_with_empty_prefix_always_true() {
let msg = Message::assistant("anything");
assert!(msg.content_starts_with(""));
}
#[test]
fn test_system_prompt_is_empty_true_for_blank_prompt() {
let cfg = AgentConfig::new(5, "m").with_system_prompt("");
assert!(cfg.system_prompt_is_empty());
}
#[test]
fn test_system_prompt_is_empty_false_when_set() {
let cfg = AgentConfig::new(5, "m").with_system_prompt("You are helpful.");
assert!(!cfg.system_prompt_is_empty());
}
#[test]
fn test_has_tools_with_empty_descriptions_true_when_blank_present() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", " ", |_| serde_json::json!({})));
assert!(reg.has_tools_with_empty_descriptions());
}
#[test]
fn test_has_tools_with_empty_descriptions_false_when_all_filled() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "desc", |_| serde_json::json!({})));
assert!(!reg.has_tools_with_empty_descriptions());
}
#[test]
fn test_total_required_fields_sums_across_tools() {
let mut reg = ToolRegistry::new();
reg.register(
ToolSpec::new("t1", "d", |_| serde_json::json!({}))
.with_required_fields(vec!["a".to_string(), "b".to_string()]),
);
reg.register(
ToolSpec::new("t2", "d", |_| serde_json::json!({}))
.with_required_fields(vec!["c".to_string()]),
);
assert_eq!(reg.total_required_fields(), 3);
}
#[test]
fn test_total_required_fields_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.total_required_fields(), 0);
}
#[test]
fn test_tools_with_description_longer_than_returns_matching_tools() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("short", "hi", |_| serde_json::json!({})));
reg.register(ToolSpec::new("long", "a much longer description", |_| serde_json::json!({})));
let names = reg.tools_with_description_longer_than(5);
assert_eq!(names, vec!["long"]);
}
#[test]
fn test_max_description_bytes_returns_longest() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "hi", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "hello world", |_| serde_json::json!({})));
assert_eq!(reg.max_description_bytes(), 11);
}
#[test]
fn test_min_description_bytes_returns_shortest() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "hi", |_| serde_json::json!({})));
reg.register(ToolSpec::new("t2", "hello world", |_| serde_json::json!({})));
assert_eq!(reg.min_description_bytes(), 2);
}
#[test]
fn test_max_description_bytes_zero_for_empty_registry() {
let reg = ToolRegistry::new();
assert_eq!(reg.max_description_bytes(), 0);
}
#[test]
fn test_has_tool_with_description_containing_true_when_keyword_found() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "search the web", |_| serde_json::json!({})));
assert!(reg.has_tool_with_description_containing("web"));
}
#[test]
fn test_has_tool_with_description_containing_false_when_keyword_absent() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "search the web", |_| serde_json::json!({})));
assert!(!reg.has_tool_with_description_containing("database"));
}
#[test]
fn test_has_tool_with_description_containing_false_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(!reg.has_tool_with_description_containing("anything"));
}
#[test]
fn test_system_prompt_word_count_counts_words() {
let cfg = AgentConfig::new(10, "m")
.with_system_prompt("You are a helpful AI agent.");
assert_eq!(cfg.system_prompt_word_count(), 6);
}
#[test]
fn test_system_prompt_word_count_zero_for_empty_prompt() {
let cfg = AgentConfig::new(10, "m").with_system_prompt("");
assert_eq!(cfg.system_prompt_word_count(), 0);
}
#[test]
fn test_description_starts_with_any_true_when_prefix_matches() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "Search the web", |_| serde_json::json!({})));
assert!(reg.description_starts_with_any(&["Search", "Write"]));
}
#[test]
fn test_description_starts_with_any_false_when_no_prefix_matches() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("t1", "Read a file", |_| serde_json::json!({})));
assert!(!reg.description_starts_with_any(&["Search", "Write"]));
}
#[test]
fn test_description_starts_with_any_false_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(!reg.description_starts_with_any(&["Search"]));
}
#[test]
fn test_combined_byte_length_sums_all_fields() {
let step = ReActStep::new("hello", "search", "result");
assert_eq!(step.combined_byte_length(), 5 + 6 + 6);
}
#[test]
fn test_combined_byte_length_zero_for_empty_step() {
let step = ReActStep::new("", "", "");
assert_eq!(step.combined_byte_length(), 0);
}
#[test]
fn test_iteration_budget_remaining_full_when_no_steps_done() {
let cfg = AgentConfig::new(10, "m");
assert_eq!(cfg.iteration_budget_remaining(0), 10);
}
#[test]
fn test_iteration_budget_remaining_decreases_with_steps() {
let cfg = AgentConfig::new(10, "m");
assert_eq!(cfg.iteration_budget_remaining(7), 3);
}
#[test]
fn test_iteration_budget_remaining_saturates_at_zero() {
let cfg = AgentConfig::new(5, "m");
assert_eq!(cfg.iteration_budget_remaining(10), 0);
}
#[test]
fn test_has_all_tools_true_when_all_registered() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "Search", |_| serde_json::json!({})));
reg.register(ToolSpec::new("write", "Write", |_| serde_json::json!({})));
assert!(reg.has_all_tools(&["search", "write"]));
}
#[test]
fn test_has_all_tools_false_when_one_missing() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "Search", |_| serde_json::json!({})));
assert!(!reg.has_all_tools(&["search", "write"]));
}
#[test]
fn test_has_all_tools_true_for_empty_slice() {
let reg = ToolRegistry::new();
assert!(reg.has_all_tools(&[]));
}
#[test]
fn test_tool_by_name_returns_tool_when_present() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search", "Search the web", |_| serde_json::json!({})));
assert!(reg.tool_by_name("search").is_some());
assert_eq!(reg.tool_by_name("search").unwrap().name, "search");
}
#[test]
fn test_tool_by_name_returns_none_when_absent() {
let reg = ToolRegistry::new();
assert!(reg.tool_by_name("missing").is_none());
}
#[test]
fn test_tools_without_validators_returns_unvalidated_tools() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("a", "Tool A", |_| serde_json::json!({})));
reg.register(ToolSpec::new("b", "Tool B", |_| serde_json::json!({})));
let names = reg.tools_without_validators();
assert!(names.contains(&"a"));
assert!(names.contains(&"b"));
}
#[test]
fn test_tools_without_validators_empty_for_empty_registry() {
let reg = ToolRegistry::new();
assert!(reg.tools_without_validators().is_empty());
}
#[test]
fn test_action_is_empty_true_for_empty_action() {
let step = ReActStep::new("thought", "", "obs");
assert!(step.action_is_empty());
}
#[test]
fn test_action_is_empty_false_for_nonempty_action() {
let step = ReActStep::new("thought", "search", "obs");
assert!(!step.action_is_empty());
}
#[test]
fn test_action_is_empty_true_for_whitespace_only() {
let step = ReActStep::new("thought", " ", "obs");
assert!(step.action_is_empty());
}
#[test]
fn test_is_minimal_true_for_single_iteration_no_prompt() {
let cfg = AgentConfig::new(1, "m").with_system_prompt("");
assert!(cfg.is_minimal());
}
#[test]
fn test_is_minimal_false_when_max_iterations_above_one() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.is_minimal());
}
#[test]
fn test_is_minimal_false_when_system_prompt_set() {
let cfg = AgentConfig::new(1, "m").with_system_prompt("prompt");
assert!(!cfg.is_minimal());
}
#[test]
fn test_model_starts_with_true_when_prefix_matches() {
let cfg = AgentConfig::new(3, "claude-3-opus");
assert!(cfg.model_starts_with("claude"));
}
#[test]
fn test_model_starts_with_false_when_prefix_differs() {
let cfg = AgentConfig::new(3, "gpt-4o");
assert!(!cfg.model_starts_with("claude"));
}
#[test]
fn test_tools_with_required_fields_count_correct() {
let mut registry = ToolRegistry::new();
registry.register(ToolSpec::new(
"search",
"desc",
|_| serde_json::json!("ok"),
).with_required_fields(vec!["query".to_string()]));
registry.register(ToolSpec::new(
"noop",
"desc",
|_| serde_json::json!("ok"),
));
assert_eq!(registry.tools_with_required_fields_count(), 1);
}
#[test]
fn test_tools_with_required_fields_count_zero_for_empty_registry() {
let registry = ToolRegistry::new();
assert_eq!(registry.tools_with_required_fields_count(), 0);
}
#[test]
fn test_tool_names_with_prefix_returns_matching_names() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("search_web", "desc", |_| serde_json::json!({})));
reg.register(ToolSpec::new("search_code", "desc", |_| serde_json::json!({})));
reg.register(ToolSpec::new("write_file", "desc", |_| serde_json::json!({})));
let names = reg.tool_names_with_prefix("search_");
assert_eq!(names, vec!["search_code", "search_web"]);
}
#[test]
fn test_tool_names_with_prefix_empty_when_no_match() {
let mut reg = ToolRegistry::new();
reg.register(ToolSpec::new("write_file", "desc", |_| serde_json::json!({})));
assert!(reg.tool_names_with_prefix("search_").is_empty());
}
#[test]
fn test_exceeds_iteration_limit_true_when_at_limit() {
let cfg = AgentConfig::new(5, "m");
assert!(cfg.exceeds_iteration_limit(5));
assert!(cfg.exceeds_iteration_limit(10));
}
#[test]
fn test_exceeds_iteration_limit_false_when_below_limit() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.exceeds_iteration_limit(4));
assert!(!cfg.exceeds_iteration_limit(0));
}
#[test]
fn test_total_word_count_sums_all_fields() {
let step = ReActStep::new("one two", "three", "four five six");
assert_eq!(step.total_word_count(), 6);
}
#[test]
fn test_total_word_count_zero_for_empty_step() {
let step = ReActStep::new("", "", "");
assert_eq!(step.total_word_count(), 0);
}
#[test]
fn test_token_budget_configured_true_when_max_tokens_set() {
let cfg = AgentConfig::new(3, "m").with_max_tokens(100);
assert!(cfg.token_budget_configured());
}
#[test]
fn test_token_budget_configured_true_when_max_context_chars_set() {
let cfg = AgentConfig::new(3, "m").with_max_context_chars(200);
assert!(cfg.token_budget_configured());
}
#[test]
fn test_token_budget_configured_false_when_neither_set() {
let cfg = AgentConfig::new(3, "m");
assert!(!cfg.token_budget_configured());
}
#[test]
fn test_is_complete_true_when_all_fields_nonempty() {
let step = ReActStep::new("thought", "action", "observation");
assert!(step.is_complete());
}
#[test]
fn test_is_complete_false_when_observation_empty() {
let step = ReActStep::new("thought", "action", "");
assert!(!step.is_complete());
}
#[test]
fn test_is_complete_false_when_action_empty() {
let step = ReActStep::new("thought", "", "obs");
assert!(!step.is_complete());
}
#[test]
fn test_max_tokens_or_default_returns_value_when_set() {
let cfg = AgentConfig::new(3, "m").with_max_tokens(512);
assert_eq!(cfg.max_tokens_or_default(100), 512);
}
#[test]
fn test_max_tokens_or_default_returns_default_when_unset() {
let cfg = AgentConfig::new(3, "m");
assert_eq!(cfg.max_tokens_or_default(256), 256);
}
#[test]
fn test_observation_starts_with_true_for_matching_prefix() {
let step = ReActStep::new("t", "a", "Result: ok");
assert!(step.observation_starts_with("Result:"));
}
#[test]
fn test_observation_starts_with_false_for_non_matching_prefix() {
let step = ReActStep::new("t", "a", "Error: failed");
assert!(!step.observation_starts_with("Result:"));
}
#[test]
fn test_effective_temperature_returns_configured_value() {
let cfg = AgentConfig::new(3, "m").with_temperature(0.5);
assert!((cfg.effective_temperature() - 0.5_f32).abs() < 1e-6);
}
#[test]
fn test_effective_temperature_returns_default_when_unset() {
let cfg = AgentConfig::new(3, "m");
assert!((cfg.effective_temperature() - 1.0_f32).abs() < 1e-6);
}
#[test]
fn test_action_word_count_returns_words_in_action() {
let step = ReActStep::new("think", "do this now", "ok");
assert_eq!(step.action_word_count(), 3);
}
#[test]
fn test_action_word_count_zero_for_empty_action() {
let step = ReActStep::new("think", "", "ok");
assert_eq!(step.action_word_count(), 0);
}
#[test]
fn test_thought_byte_len_matches_string_len() {
let step = ReActStep::new("hello", "act", "obs");
assert_eq!(step.thought_byte_len(), "hello".len());
}
#[test]
fn test_action_byte_len_matches_string_len() {
let step = ReActStep::new("think", "do it", "obs");
assert_eq!(step.action_byte_len(), "do it".len());
}
#[test]
fn test_has_empty_fields_true_when_observation_empty() {
let step = ReActStep::new("think", "act", "");
assert!(step.has_empty_fields());
}
#[test]
fn test_has_empty_fields_false_when_all_populated() {
let step = ReActStep::new("think", "act", "obs");
assert!(!step.has_empty_fields());
}
#[test]
fn test_system_prompt_starts_with_true_for_matching_prefix() {
let cfg = AgentConfig::new(3, "m").with_system_prompt("You are a helpful assistant.");
assert!(cfg.system_prompt_starts_with("You are"));
}
#[test]
fn test_system_prompt_starts_with_false_for_non_matching_prefix() {
let cfg = AgentConfig::new(3, "m").with_system_prompt("Hello world");
assert!(!cfg.system_prompt_starts_with("Goodbye"));
}
#[test]
fn test_max_iterations_above_true_when_greater() {
let cfg = AgentConfig::new(5, "m");
assert!(cfg.max_iterations_above(4));
}
#[test]
fn test_max_iterations_above_false_when_equal() {
let cfg = AgentConfig::new(5, "m");
assert!(!cfg.max_iterations_above(5));
}
#[test]
fn test_stop_sequences_contain_true_for_present_sequence() {
let cfg = AgentConfig::new(3, "m")
.with_stop_sequences(vec!["STOP".to_string(), "END".to_string()]);
assert!(cfg.stop_sequences_contain("STOP"));
}
#[test]
fn test_stop_sequences_contain_false_for_absent_sequence() {
let cfg = AgentConfig::new(3, "m")
.with_stop_sequences(vec!["STOP".to_string()]);
assert!(!cfg.stop_sequences_contain("END"));
}
#[test]
fn test_stop_sequences_contain_false_for_empty_config() {
let cfg = AgentConfig::new(3, "m");
assert!(!cfg.stop_sequences_contain("STOP"));
}
#[test]
fn test_observation_byte_len_matches_string_len() {
let step = ReActStep::new("t", "a", "result");
assert_eq!(step.observation_byte_len(), "result".len());
}
#[test]
fn test_observation_byte_len_zero_for_empty() {
let step = ReActStep::new("t", "a", "");
assert_eq!(step.observation_byte_len(), 0);
}
#[test]
fn test_all_fields_have_words_true_when_all_populated() {
let step = ReActStep::new("think", "act", "obs");
assert!(step.all_fields_have_words());
}
#[test]
fn test_all_fields_have_words_false_when_action_empty() {
let step = ReActStep::new("think", "", "obs");
assert!(!step.all_fields_have_words());
}
#[test]
fn test_system_prompt_byte_len_returns_length() {
let cfg = AgentConfig::new(3, "m").with_system_prompt("Hello!");
assert_eq!(cfg.system_prompt_byte_len(), "Hello!".len());
}
#[test]
fn test_system_prompt_byte_len_default_is_nonzero() {
let cfg = AgentConfig::new(3, "m");
assert_eq!(cfg.system_prompt_byte_len(), "You are a helpful AI agent.".len());
}
#[test]
fn test_has_valid_temperature_true_for_in_range() {
let cfg = AgentConfig::new(3, "m").with_temperature(0.7);
assert!(cfg.has_valid_temperature());
}
#[test]
fn test_has_valid_temperature_false_when_unset() {
let cfg = AgentConfig::new(3, "m");
assert!(!cfg.has_valid_temperature());
}
#[test]
fn test_has_valid_temperature_true_at_boundaries() {
assert!(AgentConfig::new(3, "m").with_temperature(0.0).has_valid_temperature());
assert!(AgentConfig::new(3, "m").with_temperature(2.0).has_valid_temperature());
}
}