use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{watch, Mutex, RwLock};
use tracing::{debug, error, info, info_span, warn, Instrument};
use crate::agent::context_monitor::ContextMonitor;
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::cache::ResponseCache;
use crate::config::Config;
use crate::error::{Result, ZeptoError};
use crate::health::UsageMetrics;
use crate::providers::{ChatOptions, LLMProvider, LLMToolCall};
use crate::safety::SafetyLayer;
use crate::session::{Message, Role, SessionManager, ToolCall};
use crate::tools::approval::ApprovalGate;
use crate::tools::{Tool, ToolCategory, ToolContext, ToolRegistry};
use crate::utils::metrics::MetricsCollector;
use super::budget::TokenBudget;
use super::context::ContextBuilder;
const MEMORY_FLUSH_PROMPT: &str =
"Review the conversation above. Save any important facts, decisions, \
user preferences, or learnings to long-term memory using the longterm_memory tool. \
Also review existing memories for duplicates — merge or delete stale entries. \
Be selective: only save what would be useful in future conversations.";
const MEMORY_FLUSH_TIMEOUT_SECS: u64 = 10;
async fn needs_sequential_execution(
tools: &Arc<RwLock<ToolRegistry>>,
tool_calls: &[LLMToolCall],
) -> bool {
let guard = tools.read().await;
tool_calls.iter().any(|tc| {
guard
.get(&tc.name)
.map(|t| {
matches!(
t.category(),
ToolCategory::FilesystemWrite | ToolCategory::Shell
)
})
.unwrap_or(true) })
}
#[derive(Debug, Clone)]
pub struct ToolFeedback {
pub tool_name: String,
pub phase: ToolFeedbackPhase,
}
#[derive(Debug, Clone)]
pub enum ToolFeedbackPhase {
Starting,
Done {
elapsed_ms: u64,
},
Failed {
elapsed_ms: u64,
error: String,
},
}
pub struct AgentLoop {
config: Config,
session_manager: Arc<SessionManager>,
bus: Arc<MessageBus>,
provider: Arc<RwLock<Option<Arc<dyn LLMProvider>>>>,
provider_registry: Arc<RwLock<HashMap<String, Arc<dyn LLMProvider>>>>,
tools: Arc<RwLock<ToolRegistry>>,
running: AtomicBool,
context_builder: ContextBuilder,
usage_metrics: Arc<RwLock<Option<Arc<UsageMetrics>>>>,
metrics_collector: Arc<MetricsCollector>,
shutdown_tx: watch::Sender<bool>,
session_locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
pending_messages: Arc<Mutex<HashMap<String, Vec<InboundMessage>>>>,
streaming: AtomicBool,
dry_run: AtomicBool,
token_budget: Arc<TokenBudget>,
approval_gate: Arc<ApprovalGate>,
agent_mode: crate::security::AgentMode,
safety_layer: Option<Arc<SafetyLayer>>,
context_monitor: Option<ContextMonitor>,
tool_feedback_tx: Arc<RwLock<Option<tokio::sync::mpsc::UnboundedSender<ToolFeedback>>>>,
cache: Option<Arc<std::sync::Mutex<ResponseCache>>>,
pairing: Option<Arc<std::sync::Mutex<crate::security::PairingManager>>>,
}
impl AgentLoop {
fn build_cache(config: &Config) -> Option<Arc<std::sync::Mutex<ResponseCache>>> {
if config.cache.enabled {
Some(Arc::new(std::sync::Mutex::new(ResponseCache::new(
config.cache.ttl_secs,
config.cache.max_entries,
))))
} else {
None
}
}
fn build_pairing(
config: &Config,
) -> Option<Arc<std::sync::Mutex<crate::security::PairingManager>>> {
if config.pairing.enabled {
Some(Arc::new(std::sync::Mutex::new(
crate::security::PairingManager::new(
config.pairing.max_attempts,
config.pairing.lockout_secs,
),
)))
} else {
None
}
}
pub fn new(config: Config, session_manager: SessionManager, bus: Arc<MessageBus>) -> Self {
let (shutdown_tx, _) = watch::channel(false);
let token_budget = Arc::new(TokenBudget::new(config.agents.defaults.token_budget));
let approval_gate = Arc::new(ApprovalGate::new(config.approval.clone()));
let agent_mode = config.agent_mode.resolve();
let safety_layer = if config.safety.enabled {
Some(Arc::new(SafetyLayer::new(config.safety.clone())))
} else {
None
};
let context_monitor = if config.compaction.enabled {
Some(ContextMonitor::new(
config.compaction.context_limit,
config.compaction.threshold,
))
} else {
None
};
let cache = Self::build_cache(&config);
let pairing = Self::build_pairing(&config);
Self {
config,
session_manager: Arc::new(session_manager),
bus,
provider: Arc::new(RwLock::new(None)),
provider_registry: Arc::new(RwLock::new(HashMap::new())),
tools: Arc::new(RwLock::new(ToolRegistry::new())),
running: AtomicBool::new(false),
context_builder: ContextBuilder::new(),
usage_metrics: Arc::new(RwLock::new(None)),
metrics_collector: Arc::new(MetricsCollector::new()),
shutdown_tx,
session_locks: Arc::new(Mutex::new(HashMap::new())),
pending_messages: Arc::new(Mutex::new(HashMap::new())),
streaming: AtomicBool::new(false),
dry_run: AtomicBool::new(false),
token_budget,
approval_gate,
agent_mode,
safety_layer,
context_monitor,
tool_feedback_tx: Arc::new(RwLock::new(None)),
cache,
pairing,
}
}
pub fn with_context_builder(
config: Config,
session_manager: SessionManager,
bus: Arc<MessageBus>,
context_builder: ContextBuilder,
) -> Self {
let (shutdown_tx, _) = watch::channel(false);
let token_budget = Arc::new(TokenBudget::new(config.agents.defaults.token_budget));
let approval_gate = Arc::new(ApprovalGate::new(config.approval.clone()));
let agent_mode = config.agent_mode.resolve();
let safety_layer = if config.safety.enabled {
Some(Arc::new(SafetyLayer::new(config.safety.clone())))
} else {
None
};
let context_monitor = if config.compaction.enabled {
Some(ContextMonitor::new(
config.compaction.context_limit,
config.compaction.threshold,
))
} else {
None
};
let cache = Self::build_cache(&config);
let pairing = Self::build_pairing(&config);
Self {
config,
session_manager: Arc::new(session_manager),
bus,
provider: Arc::new(RwLock::new(None)),
provider_registry: Arc::new(RwLock::new(HashMap::new())),
tools: Arc::new(RwLock::new(ToolRegistry::new())),
running: AtomicBool::new(false),
context_builder,
usage_metrics: Arc::new(RwLock::new(None)),
metrics_collector: Arc::new(MetricsCollector::new()),
shutdown_tx,
session_locks: Arc::new(Mutex::new(HashMap::new())),
pending_messages: Arc::new(Mutex::new(HashMap::new())),
streaming: AtomicBool::new(false),
dry_run: AtomicBool::new(false),
token_budget,
approval_gate,
agent_mode,
safety_layer,
context_monitor,
tool_feedback_tx: Arc::new(RwLock::new(None)),
cache,
pairing,
}
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub async fn set_provider(&self, provider: Box<dyn LLMProvider>) {
let mut p = self.provider.write().await;
*p = Some(Arc::from(provider));
}
pub async fn set_provider_in_registry(&self, name: &str, provider: Box<dyn LLMProvider>) {
let mut reg = self.provider_registry.write().await;
reg.insert(name.to_string(), Arc::from(provider));
}
pub async fn get_provider_by_name(&self, name: &str) -> Option<Arc<dyn LLMProvider>> {
let reg = self.provider_registry.read().await;
reg.get(name).cloned()
}
pub async fn registered_provider_names(&self) -> Vec<String> {
let reg = self.provider_registry.read().await;
reg.keys().cloned().collect()
}
pub fn resolve_model_for_message(&self, msg: &InboundMessage) -> String {
msg.metadata
.get("model_override")
.filter(|m| !m.is_empty())
.cloned()
.unwrap_or_else(|| self.config.agents.defaults.model.clone())
}
pub async fn resolve_provider_for_message(
&self,
msg: &InboundMessage,
) -> Option<Arc<dyn LLMProvider>> {
if let Some(provider_name) = msg
.metadata
.get("provider_override")
.filter(|p| !p.is_empty())
{
if let Some(provider) = self.get_provider_by_name(provider_name).await {
return Some(provider);
}
warn!(
provider = %provider_name,
"Provider override '{}' not found in registry, falling back to default",
provider_name
);
}
let p = self.provider.read().await;
p.clone()
}
pub async fn set_usage_metrics(&self, metrics: Arc<UsageMetrics>) {
let mut usage_metrics = self.usage_metrics.write().await;
*usage_metrics = Some(metrics);
}
pub fn metrics_collector(&self) -> Arc<MetricsCollector> {
Arc::clone(&self.metrics_collector)
}
pub async fn register_tool(&self, tool: Box<dyn Tool>) {
let mut tools = self.tools.write().await;
tools.register(tool);
}
pub async fn tool_count(&self) -> usize {
let tools = self.tools.read().await;
tools.len()
}
pub async fn has_tool(&self, name: &str) -> bool {
let tools = self.tools.read().await;
tools.has(name)
}
pub async fn process_message(&self, msg: &InboundMessage) -> Result<String> {
let session_lock = self.session_lock_for(&msg.session_key).await;
let _session_guard = session_lock.lock().await;
let provider = self
.resolve_provider_for_message(msg)
.await
.ok_or_else(|| ZeptoError::Provider("No provider configured".into()))?;
let usage_metrics = {
let metrics = self.usage_metrics.read().await;
metrics.clone()
};
let metrics_collector = Arc::clone(&self.metrics_collector);
let mut session = self.session_manager.get_or_create(&msg.session_key).await?;
if let Some(ref monitor) = self.context_monitor {
if monitor.needs_compaction(&session.messages) {
self.memory_flush(&session.messages).await;
let context_limit = self.config.compaction.context_limit;
let (recovered, tier) = crate::agent::compaction::try_recover_context(
session.messages,
context_limit,
8, 5120, );
if tier > 0 {
debug!(
tier = tier,
"Context recovered via tier {} compaction", tier
);
}
session.messages = recovered;
}
}
let messages = self
.context_builder
.build_messages(&session.messages, &msg.content);
let tool_definitions = {
let tools = self.tools.read().await;
tools.definitions_with_options(self.config.agents.defaults.compact_tools)
};
let options = ChatOptions::new()
.with_max_tokens(self.config.agents.defaults.max_tokens)
.with_temperature(self.config.agents.defaults.temperature);
let model_string = self.resolve_model_for_message(msg);
let model = Some(model_string.as_str());
if self.token_budget.is_exceeded() {
return Err(ZeptoError::Provider(format!(
"Token budget exceeded: {}",
self.token_budget.summary()
)));
}
let cache_key = self.cache.as_ref().map(|_| {
let system_prompt = messages
.first()
.filter(|m| m.role == Role::System)
.map(|m| m.content.as_str())
.unwrap_or("");
ResponseCache::cache_key(
self.config.agents.defaults.model.as_str(),
system_prompt,
&msg.content,
)
});
let cached_hit = if let (Some(ref cache_mutex), Some(ref key)) = (&self.cache, &cache_key) {
cache_mutex.lock().ok().and_then(|mut c| c.get(key))
} else {
None
};
if let Some(cached_response) = cached_hit {
debug!("Cache hit for initial prompt");
session.add_message(Message::user(&msg.content));
session.add_message(Message::assistant(&cached_response));
self.session_manager.save(&session).await?;
return Ok(cached_response);
}
let mut response = provider
.chat(messages, tool_definitions, model, options.clone())
.await?;
if let (Some(metrics), Some(usage)) = (usage_metrics.as_ref(), response.usage.as_ref()) {
metrics.record_tokens(usage.prompt_tokens as u64, usage.completion_tokens as u64);
}
if let Some(usage) = response.usage.as_ref() {
metrics_collector
.record_tokens(usage.prompt_tokens as u64, usage.completion_tokens as u64);
self.token_budget
.record(usage.prompt_tokens as u64, usage.completion_tokens as u64);
}
if !response.has_tool_calls() {
if let (Some(ref cache_mutex), Some(key)) = (&self.cache, cache_key) {
let token_count = response
.usage
.as_ref()
.map(|u| u.completion_tokens)
.unwrap_or(0);
if let Ok(mut cache) = cache_mutex.lock() {
cache.put(key, response.content.clone(), token_count);
debug!("Cached initial LLM response");
}
}
}
session.add_message(Message::user(&msg.content));
let max_iterations = self.config.agents.defaults.max_tool_iterations;
let mut iteration = 0;
while response.has_tool_calls() && iteration < max_iterations {
iteration += 1;
debug!("Tool iteration {} of {}", iteration, max_iterations);
if let Some(metrics) = usage_metrics.as_ref() {
metrics.record_tool_calls(response.tool_calls.len() as u64);
}
let mut assistant_msg = Message::assistant(&response.content);
assistant_msg.tool_calls = Some(
response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect(),
);
session.add_message(assistant_msg);
let workspace = self.config.workspace_path();
let workspace_str = workspace.to_string_lossy();
let tool_ctx = ToolContext::new()
.with_channel(&msg.channel, &msg.chat_id)
.with_workspace(&workspace_str);
let approval_gate = Arc::clone(&self.approval_gate);
let safety_layer = self.safety_layer.clone();
let hook_engine = Arc::new(
crate::hooks::HookEngine::new(self.config.hooks.clone())
.with_bus(Arc::clone(&self.bus)),
);
let current_tokens = ContextMonitor::estimate_tokens(&session.messages);
let context_limit = self.config.compaction.context_limit;
let result_budget = crate::utils::sanitize::compute_tool_result_budget(
context_limit,
current_tokens,
response.tool_calls.len(),
);
let tool_feedback_tx = self.tool_feedback_tx.clone();
let is_dry_run = self.dry_run.load(Ordering::SeqCst);
let current_agent_mode = self.agent_mode;
let run_sequential =
needs_sequential_execution(&self.tools, &response.tool_calls).await;
let tool_futures: Vec<_> = response
.tool_calls
.iter()
.map(|tool_call| {
let tools = Arc::clone(&self.tools);
let ctx = tool_ctx.clone();
let name = tool_call.name.clone();
let id = tool_call.id.clone();
let raw_args = tool_call.arguments.clone();
let usage_metrics = usage_metrics.clone();
let metrics_collector = Arc::clone(&metrics_collector);
let gate = Arc::clone(&approval_gate);
let hooks = Arc::clone(&hook_engine);
let safety = safety_layer.clone();
let budget = result_budget;
let tool_feedback_tx = tool_feedback_tx.clone();
let dry_run = is_dry_run;
let agent_mode = current_agent_mode;
let bus_for_tools = Arc::clone(&self.bus);
async move {
let args: serde_json::Value = match serde_json::from_str(&raw_args) {
Ok(v) => v,
Err(e) => {
tracing::warn!(tool = %name, error = %e, "Invalid JSON in tool arguments");
serde_json::json!({"_parse_error": format!("Invalid arguments JSON: {}", e)})
}
};
let channel_name = ctx.channel.as_deref().unwrap_or("cli");
let chat_id = ctx.chat_id.as_deref().unwrap_or(channel_name);
if let crate::hooks::HookResult::Block(msg) =
hooks.before_tool(&name, &args, channel_name, chat_id)
{
return (id, format!("Tool '{}' blocked by hook: {}", name, msg));
}
{
let mode_policy = crate::security::ModePolicy::new(agent_mode);
let tools_guard = tools.read().await;
if let Some(tool) = tools_guard.get(&name) {
let tool_category = tool.category();
match mode_policy.check(tool_category) {
crate::security::CategoryPermission::Blocked => {
info!(tool = %name, mode = %agent_mode, category = ?tool_category, "Tool blocked by agent mode");
return (id, format!(
"Tool '{}' is blocked in {} mode (category: {})",
name, agent_mode, tool_category
));
}
crate::security::CategoryPermission::RequiresApproval => {
if !gate.requires_approval(&name) {
info!(tool = %name, mode = %agent_mode, category = ?tool_category, "Tool requires approval per agent mode");
return (id, format!(
"Tool '{}' requires approval in {} mode (category: {}). Not executed.",
name, agent_mode, tool_category
));
}
}
crate::security::CategoryPermission::Allowed => {}
}
}
}
if gate.requires_approval(&name) {
let prompt = gate.format_approval_request(&name, &args);
info!(tool = %name, "Tool requires approval, blocking execution");
return (id, format!("Tool '{}' requires user approval and was not executed. {}", name, prompt));
}
if dry_run {
return (id, Self::dry_run_result(&name, &args, &raw_args, budget));
}
if let Some(tx) = tool_feedback_tx.read().await.as_ref() {
let _ = tx.send(ToolFeedback {
tool_name: name.clone(),
phase: ToolFeedbackPhase::Starting,
});
}
let tool_start = std::time::Instant::now();
let (result, success) = {
let tools_guard = tools.read().await;
match tools_guard.execute_with_context(&name, args, &ctx).await {
Ok(output) => {
let elapsed = tool_start.elapsed();
let latency_ms = elapsed.as_millis() as u64;
debug!(tool = %name, latency_ms = latency_ms, "Tool executed successfully");
hooks.after_tool(&name, &output.for_llm, elapsed, channel_name, chat_id);
if let Some(tx) = tool_feedback_tx.read().await.as_ref() {
let _ = tx.send(ToolFeedback {
tool_name: name.clone(),
phase: ToolFeedbackPhase::Done { elapsed_ms: latency_ms },
});
}
if let Some(ref user_msg) = output.for_user {
let outbound = crate::bus::OutboundMessage::new(
ctx.channel.as_deref().unwrap_or(""),
ctx.chat_id.as_deref().unwrap_or(""),
user_msg,
);
let _ = bus_for_tools.publish_outbound(outbound).await;
}
(output.for_llm, !output.is_error)
}
Err(e) => {
let elapsed = tool_start.elapsed();
let latency_ms = elapsed.as_millis() as u64;
error!(tool = %name, latency_ms = latency_ms, error = %e, "Tool execution failed");
hooks.on_error(&name, &e.to_string(), channel_name, chat_id);
if let Some(metrics) = usage_metrics.as_ref() {
metrics.record_error();
}
if let Some(tx) = tool_feedback_tx.read().await.as_ref() {
let _ = tx.send(ToolFeedback {
tool_name: name.clone(),
phase: ToolFeedbackPhase::Failed {
elapsed_ms: latency_ms,
error: e.to_string(),
},
});
}
(format!("Error: {}", e), false)
}
}
};
metrics_collector.record_tool_call(&name, tool_start.elapsed(), success);
let sanitized = crate::utils::sanitize::sanitize_tool_result(
&result,
budget,
);
let sanitized = if let Some(ref safety) = safety {
let safety_result = safety.check_tool_output(&sanitized);
if safety_result.blocked {
format!(
"[Safety blocked]: {}",
safety_result.block_reason.unwrap_or_default()
)
} else {
safety_result.content
}
} else {
sanitized
};
(id, sanitized)
}
})
.collect();
let results = if run_sequential {
let mut out = Vec::with_capacity(tool_futures.len());
for fut in tool_futures {
out.push(fut.await);
}
out
} else {
futures::future::join_all(tool_futures).await
};
for (id, result) in results {
session.add_message(Message::tool_result(&id, &result));
}
let tool_definitions = {
let tools = self.tools.read().await;
tools.definitions_with_options(self.config.agents.defaults.compact_tools)
};
if self.token_budget.is_exceeded() {
info!(budget = %self.token_budget.summary(), "Token budget exceeded during tool loop");
break;
}
let messages: Vec<_> = self
.context_builder
.build_messages(&session.messages, "")
.into_iter()
.filter(|m| !(m.role == Role::User && m.content.is_empty()))
.collect();
response = provider
.chat(messages, tool_definitions, model, options.clone())
.await?;
if let (Some(metrics), Some(usage)) = (usage_metrics.as_ref(), response.usage.as_ref())
{
metrics.record_tokens(usage.prompt_tokens as u64, usage.completion_tokens as u64);
}
if let Some(usage) = response.usage.as_ref() {
metrics_collector
.record_tokens(usage.prompt_tokens as u64, usage.completion_tokens as u64);
self.token_budget
.record(usage.prompt_tokens as u64, usage.completion_tokens as u64);
}
}
if iteration >= max_iterations && response.has_tool_calls() {
info!(
iterations = iteration,
"Tool loop reached maximum iterations, returning partial response"
);
}
session.add_message(Message::assistant(&response.content));
self.session_manager.save(&session).await?;
Ok(response.content)
}
pub async fn process_message_streaming(
&self,
msg: &InboundMessage,
) -> Result<tokio::sync::mpsc::Receiver<crate::providers::StreamEvent>> {
use crate::providers::StreamEvent;
let session_lock = self.session_lock_for(&msg.session_key).await;
let _session_guard = session_lock.lock().await;
let provider = self
.resolve_provider_for_message(msg)
.await
.ok_or_else(|| ZeptoError::Provider("No provider configured".into()))?;
let metrics_collector = Arc::clone(&self.metrics_collector);
let mut session = self.session_manager.get_or_create(&msg.session_key).await?;
if let Some(ref monitor) = self.context_monitor {
if monitor.needs_compaction(&session.messages) {
self.memory_flush(&session.messages).await;
let context_limit = self.config.compaction.context_limit;
let (recovered, tier) = crate::agent::compaction::try_recover_context(
session.messages,
context_limit,
8, 5120, );
if tier > 0 {
debug!(
tier = tier,
"Context recovered via tier {} compaction (streaming)", tier
);
}
session.messages = recovered;
}
}
let messages = self
.context_builder
.build_messages(&session.messages, &msg.content);
let tool_definitions = {
let tools = self.tools.read().await;
tools.definitions_with_options(self.config.agents.defaults.compact_tools)
};
let options = ChatOptions::new()
.with_max_tokens(self.config.agents.defaults.max_tokens)
.with_temperature(self.config.agents.defaults.temperature);
let model_string = self.resolve_model_for_message(msg);
let model = Some(model_string.as_str());
if self.token_budget.is_exceeded() {
return Err(ZeptoError::Provider(format!(
"Token budget exceeded: {}",
self.token_budget.summary()
)));
}
let mut response = provider
.chat(messages, tool_definitions, model, options.clone())
.await?;
if let Some(usage) = response.usage.as_ref() {
self.token_budget
.record(usage.prompt_tokens as u64, usage.completion_tokens as u64);
}
session.add_message(Message::user(&msg.content));
let max_iterations = self.config.agents.defaults.max_tool_iterations;
let mut iteration = 0;
while response.has_tool_calls() && iteration < max_iterations {
iteration += 1;
let mut assistant_msg = Message::assistant(&response.content);
assistant_msg.tool_calls = Some(
response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect(),
);
session.add_message(assistant_msg);
let workspace = self.config.workspace_path();
let workspace_str = workspace.to_string_lossy();
let tool_ctx = ToolContext::new()
.with_channel(&msg.channel, &msg.chat_id)
.with_workspace(&workspace_str);
let approval_gate = Arc::clone(&self.approval_gate);
let safety_layer_stream = self.safety_layer.clone();
let current_tokens_stream = ContextMonitor::estimate_tokens(&session.messages);
let context_limit_stream = self.config.compaction.context_limit;
let result_budget_stream = crate::utils::sanitize::compute_tool_result_budget(
context_limit_stream,
current_tokens_stream,
response.tool_calls.len(),
);
let tool_feedback_tx = self.tool_feedback_tx.clone();
let is_dry_run_stream = self.dry_run.load(Ordering::SeqCst);
let current_agent_mode_stream = self.agent_mode;
let run_sequential =
needs_sequential_execution(&self.tools, &response.tool_calls).await;
let tool_futures: Vec<_> = response
.tool_calls
.iter()
.map(|tool_call| {
let tools = Arc::clone(&self.tools);
let ctx = tool_ctx.clone();
let name = tool_call.name.clone();
let id = tool_call.id.clone();
let raw_args = tool_call.arguments.clone();
let metrics_collector = Arc::clone(&metrics_collector);
let gate = Arc::clone(&approval_gate);
let safety = safety_layer_stream.clone();
let budget = result_budget_stream;
let tool_feedback_tx = tool_feedback_tx.clone();
let dry_run = is_dry_run_stream;
let agent_mode = current_agent_mode_stream;
let bus_for_tools = Arc::clone(&self.bus);
async move {
let args: serde_json::Value = serde_json::from_str(&raw_args)
.unwrap_or_else(|_| serde_json::json!({}));
{
let mode_policy = crate::security::ModePolicy::new(agent_mode);
let tools_guard = tools.read().await;
if let Some(tool) = tools_guard.get(&name) {
let tool_category = tool.category();
match mode_policy.check(tool_category) {
crate::security::CategoryPermission::Blocked => {
info!(tool = %name, mode = %agent_mode, category = ?tool_category, "Tool blocked by agent mode");
return (id, format!(
"Tool '{}' is blocked in {} mode (category: {})",
name, agent_mode, tool_category
));
}
crate::security::CategoryPermission::RequiresApproval => {
if !gate.requires_approval(&name) {
info!(tool = %name, mode = %agent_mode, category = ?tool_category, "Tool requires approval per agent mode");
return (id, format!(
"Tool '{}' requires approval in {} mode (category: {}). Not executed.",
name, agent_mode, tool_category
));
}
}
crate::security::CategoryPermission::Allowed => {}
}
}
}
if gate.requires_approval(&name) {
let prompt = gate.format_approval_request(&name, &args);
info!(tool = %name, "Tool requires approval, blocking execution");
return (
id,
format!(
"Tool '{}' requires user approval and was not executed. {}",
name, prompt
),
);
}
if dry_run {
return (id, Self::dry_run_result(&name, &args, &raw_args, budget));
}
if let Some(tx) = tool_feedback_tx.read().await.as_ref() {
let _ = tx.send(ToolFeedback {
tool_name: name.clone(),
phase: ToolFeedbackPhase::Starting,
});
}
let tool_start = std::time::Instant::now();
let (result, success) = {
let tools_guard = tools.read().await;
match tools_guard.execute_with_context(&name, args, &ctx).await {
Ok(output) => {
if let Some(ref user_msg) = output.for_user {
let outbound = crate::bus::OutboundMessage::new(
ctx.channel.as_deref().unwrap_or(""),
ctx.chat_id.as_deref().unwrap_or(""),
user_msg,
);
let _ = bus_for_tools.publish_outbound(outbound).await;
}
(output.for_llm, !output.is_error)
}
Err(e) => (format!("Error: {}", e), false),
}
};
metrics_collector.record_tool_call(&name, tool_start.elapsed(), success);
if let Some(tx) = tool_feedback_tx.read().await.as_ref() {
let latency_ms = tool_start.elapsed().as_millis() as u64;
if success {
let _ = tx.send(ToolFeedback {
tool_name: name.clone(),
phase: ToolFeedbackPhase::Done {
elapsed_ms: latency_ms,
},
});
} else {
let _ = tx.send(ToolFeedback {
tool_name: name.clone(),
phase: ToolFeedbackPhase::Failed {
elapsed_ms: latency_ms,
error: result.clone(),
},
});
}
}
let sanitized =
crate::utils::sanitize::sanitize_tool_result(&result, budget);
let sanitized = if let Some(ref safety) = safety {
let safety_result = safety.check_tool_output(&sanitized);
if safety_result.blocked {
format!(
"[Safety blocked]: {}",
safety_result.block_reason.unwrap_or_default()
)
} else {
safety_result.content
}
} else {
sanitized
};
(id, sanitized)
}
})
.collect();
let results = if run_sequential {
let mut out = Vec::with_capacity(tool_futures.len());
for fut in tool_futures {
out.push(fut.await);
}
out
} else {
futures::future::join_all(tool_futures).await
};
for (id, result) in results {
session.add_message(Message::tool_result(&id, &result));
}
let tool_definitions = {
let tools = self.tools.read().await;
tools.definitions_with_options(self.config.agents.defaults.compact_tools)
};
if self.token_budget.is_exceeded() {
info!(budget = %self.token_budget.summary(), "Token budget exceeded during streaming tool loop");
break;
}
let messages: Vec<_> = self
.context_builder
.build_messages(&session.messages, "")
.into_iter()
.filter(|m| !(m.role == Role::User && m.content.is_empty()))
.collect();
response = provider
.chat(messages, tool_definitions, model, options.clone())
.await?;
if let Some(usage) = response.usage.as_ref() {
metrics_collector
.record_tokens(usage.prompt_tokens as u64, usage.completion_tokens as u64);
self.token_budget
.record(usage.prompt_tokens as u64, usage.completion_tokens as u64);
}
}
if !response.has_tool_calls() {
let messages: Vec<_> = self
.context_builder
.build_messages(&session.messages, "")
.into_iter()
.filter(|m| !(m.role == Role::User && m.content.is_empty()))
.collect();
let tool_definitions = {
let tools = self.tools.read().await;
tools.definitions_with_options(self.config.agents.defaults.compact_tools)
};
let stream_rx = provider
.chat_stream(messages, tool_definitions, model, options)
.await?;
let (out_tx, out_rx) = tokio::sync::mpsc::channel::<StreamEvent>(32);
let session_manager = Arc::clone(&self.session_manager);
let session_clone = session.clone();
let metrics_collector = Arc::clone(&metrics_collector);
tokio::spawn(async move {
let mut session = session_clone;
let mut stream_rx = stream_rx;
while let Some(event) = stream_rx.recv().await {
match &event {
StreamEvent::Done { content, usage } => {
if let Some(usage) = usage.as_ref() {
metrics_collector.record_tokens(
usage.prompt_tokens as u64,
usage.completion_tokens as u64,
);
}
session.add_message(Message::assistant(content));
let _ = session_manager.save(&session).await;
let _ = out_tx.send(event).await;
return;
}
StreamEvent::ToolCalls(_) => {
let _ = out_tx.send(event).await;
return;
}
_ => {
if out_tx.send(event).await.is_err() {
return;
}
}
}
}
});
Ok(out_rx)
} else {
session.add_message(Message::assistant(&response.content));
self.session_manager.save(&session).await?;
let (tx, rx) = tokio::sync::mpsc::channel(1);
let _ = tx
.send(StreamEvent::Done {
content: response.content,
usage: response.usage,
})
.await;
Ok(rx)
}
}
async fn memory_flush(&self, messages: &[crate::session::Message]) {
use tokio::time::{timeout, Duration};
let provider = {
let guard = self.provider.read().await;
match guard.as_ref() {
Some(p) => Arc::clone(p),
None => {
tracing::warn!("memory_flush: no provider configured, skipping");
return;
}
}
};
let tool_defs = {
let tools = self.tools.read().await;
let defs = tools.definitions_for_tools(&["longterm_memory"]);
if defs.is_empty() {
tracing::debug!("memory_flush: longterm_memory tool not registered, skipping");
return;
}
defs
};
let mut flush_messages: Vec<crate::session::Message> =
vec![Message::system("You are a memory management assistant.")];
flush_messages.extend(messages.iter().cloned());
flush_messages.push(Message::user(MEMORY_FLUSH_PROMPT));
let options = ChatOptions::new()
.with_max_tokens(1024)
.with_temperature(0.0);
let model = Some(self.config.agents.defaults.model.as_str());
info!("memory_flush: running pre-compaction memory flush");
let flush_result = timeout(
Duration::from_secs(MEMORY_FLUSH_TIMEOUT_SECS),
provider.chat(flush_messages, tool_defs.clone(), model, options.clone()),
)
.await;
let response = match flush_result {
Ok(Ok(resp)) => resp,
Ok(Err(e)) => {
tracing::warn!(error = %e, "memory_flush: LLM call failed");
return;
}
Err(_) => {
tracing::warn!(
"memory_flush: timed out after {}s",
MEMORY_FLUSH_TIMEOUT_SECS
);
return;
}
};
if response.has_tool_calls() {
let workspace = self.config.workspace_path();
let workspace_str = workspace.to_string_lossy();
let tool_ctx = ToolContext::new().with_workspace(&workspace_str);
for tc in &response.tool_calls {
let args: serde_json::Value = match serde_json::from_str(&tc.arguments) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
tool = %tc.name,
error = %e,
"memory_flush: invalid tool arguments"
);
continue;
}
};
let result = {
let tools = self.tools.read().await;
tools.execute_with_context(&tc.name, args, &tool_ctx).await
};
match result {
Ok(_) => {
debug!(tool = %tc.name, "memory_flush: tool executed successfully");
}
Err(e) => {
tracing::warn!(
tool = %tc.name,
error = %e,
"memory_flush: tool execution failed"
);
}
}
}
}
info!("memory_flush: completed");
}
async fn session_lock_for(&self, session_key: &str) -> Arc<Mutex<()>> {
let mut locks = self.session_locks.lock().await;
locks
.entry(session_key.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
fn token_snapshot(usage_metrics: Option<&Arc<UsageMetrics>>) -> Option<(u64, u64)> {
usage_metrics.map(|metrics| {
(
metrics
.input_tokens
.load(std::sync::atomic::Ordering::Relaxed),
metrics
.output_tokens
.load(std::sync::atomic::Ordering::Relaxed),
)
})
}
fn token_delta(
usage_metrics: Option<&Arc<UsageMetrics>>,
before: Option<(u64, u64)>,
) -> (u64, u64) {
before
.and_then(|(input_before, output_before)| {
usage_metrics.map(|metrics| {
let input_after = metrics
.input_tokens
.load(std::sync::atomic::Ordering::Relaxed);
let output_after = metrics
.output_tokens
.load(std::sync::atomic::Ordering::Relaxed);
(
input_after.saturating_sub(input_before),
output_after.saturating_sub(output_before),
)
})
})
.unwrap_or((0, 0))
}
async fn drain_pending_messages(&self, msg: &InboundMessage) {
let pending = {
let mut map = self.pending_messages.lock().await;
map.remove(&msg.session_key).unwrap_or_default()
};
if pending.is_empty() {
return;
}
match self.config.agents.defaults.message_queue_mode {
crate::config::MessageQueueMode::Collect => {
let combined: Vec<String> = pending
.iter()
.enumerate()
.map(|(index, item)| format!("{}. {}", index + 1, item.content))
.collect();
let combined_content = format!(
"[Queued messages while I was busy]\n\n{}",
combined.join("\n")
);
let synthetic = InboundMessage::new(
&msg.channel,
&msg.sender_id,
&msg.chat_id,
&combined_content,
);
if let Err(e) = self.bus.publish_inbound(synthetic).await {
error!("Failed to re-queue collected messages: {}", e);
}
}
crate::config::MessageQueueMode::Followup => {
for pending_msg in pending {
if let Err(e) = self.bus.publish_inbound(pending_msg).await {
error!("Failed to re-queue followup message: {}", e);
}
}
}
}
}
async fn process_inbound_message(
&self,
msg: &InboundMessage,
usage_metrics: Option<Arc<UsageMetrics>>,
) {
info!("Processing message");
let start = std::time::Instant::now();
let tokens_before = Self::token_snapshot(usage_metrics.as_ref());
if let Some(metrics) = usage_metrics.as_ref() {
metrics.record_request();
}
let timeout_duration =
std::time::Duration::from_secs(self.config.agents.defaults.agent_timeout_secs);
let process_result =
tokio::time::timeout(timeout_duration, self.process_message(msg)).await;
let agent_completed = match process_result {
Ok(Ok(response)) => {
let latency_ms = start.elapsed().as_millis() as u64;
let (input_tokens, output_tokens) =
Self::token_delta(usage_metrics.as_ref(), tokens_before);
info!(
latency_ms = latency_ms,
response_len = response.len(),
input_tokens = input_tokens,
output_tokens = output_tokens,
"Request completed"
);
let outbound = OutboundMessage::new(&msg.channel, &msg.chat_id, &response);
if let Err(e) = self.bus.publish_outbound(outbound).await {
error!("Failed to publish outbound message: {}", e);
if let Some(metrics) = usage_metrics.as_ref() {
metrics.record_error();
}
}
true
}
Ok(Err(e)) => {
let latency_ms = start.elapsed().as_millis() as u64;
error!(latency_ms = latency_ms, error = %e, "Request failed");
if let Some(metrics) = usage_metrics.as_ref() {
metrics.record_error();
}
let error_msg =
OutboundMessage::new(&msg.channel, &msg.chat_id, &format!("Error: {}", e));
self.bus.publish_outbound(error_msg).await.ok();
false
}
Err(_elapsed) => {
let timeout_secs = self.config.agents.defaults.agent_timeout_secs;
error!(timeout_secs = timeout_secs, "Agent run timed out");
if let Some(metrics) = usage_metrics.as_ref() {
metrics.record_error();
}
let timeout_msg = OutboundMessage::new(
&msg.channel,
&msg.chat_id,
&format!(
"Agent run timed out after {}s. Try a simpler request.",
timeout_secs
),
);
self.bus.publish_outbound(timeout_msg).await.ok();
false
}
};
let slo = crate::utils::slo::SessionSLO::evaluate(&self.metrics_collector, agent_completed);
slo.emit();
debug!(slo_summary = %slo.summary(), "Session SLO summary");
self.drain_pending_messages(msg).await;
}
pub async fn try_queue_or_process(&self, msg: &InboundMessage) -> bool {
let session_lock = self.session_lock_for(&msg.session_key).await;
let is_busy = session_lock.try_lock().is_err();
if is_busy {
let mut pending = self.pending_messages.lock().await;
pending
.entry(msg.session_key.clone())
.or_default()
.push(msg.clone());
debug!(session = %msg.session_key, "Message queued (session busy)");
true
} else {
false
}
}
pub async fn start(&self) -> Result<()> {
if self.running.swap(true, Ordering::SeqCst) {
return Err(ZeptoError::Config("Agent loop already running".into()));
}
info!("Starting agent loop");
let mut shutdown_rx = self.shutdown_tx.subscribe();
let _ = *shutdown_rx.borrow_and_update();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("Received shutdown signal");
break;
}
}
msg = self.bus.consume_inbound() => {
if let Some(msg) = msg {
if let Some(ref pairing) = self.pairing {
let identifier = msg.sender_id.clone();
let token = msg.metadata.get("auth_token").cloned();
let valid = match token {
Some(raw_token) => {
match pairing.lock() {
Ok(mut mgr) => mgr.validate_token(&raw_token, &identifier).is_some(),
Err(_) => false,
}
}
None => false,
};
if !valid {
warn!(
sender = %msg.sender_id,
channel = %msg.channel,
"Rejected unpaired device (pairing enabled)"
);
let rejection = OutboundMessage::new(
&msg.channel,
&msg.chat_id,
"Access denied: device not paired. Use `zeptoclaw pair new` to generate a pairing code.",
);
if let Err(e) = self.bus.publish_outbound(rejection).await {
error!("Failed to publish pairing rejection: {}", e);
}
continue;
}
}
let tenant_id = msg
.metadata
.get("tenant_id")
.filter(|v| !v.is_empty())
.map(String::as_str)
.unwrap_or(&msg.chat_id);
let request_id = uuid::Uuid::new_v4();
let request_span = info_span!(
"request",
request_id = %request_id,
tenant_id = %tenant_id,
chat_id = %msg.chat_id,
session_id = %msg.session_key,
channel = %msg.channel,
sender = %msg.sender_id,
);
let msg_ref = &msg;
async {
if self.try_queue_or_process(msg_ref).await {
return;
}
let usage_metrics = {
let metrics = self.usage_metrics.read().await;
metrics.clone()
};
self.process_inbound_message(msg_ref, usage_metrics).await;
}
.instrument(request_span)
.await;
} else {
info!("Inbound channel closed");
break;
}
}
}
if !self.running.load(Ordering::SeqCst) {
break;
}
}
self.running.store(false, Ordering::SeqCst);
info!("Agent loop stopped");
Ok(())
}
pub fn stop(&self) {
info!("Stopping agent loop");
self.running.store(false, Ordering::SeqCst);
let _ = self.shutdown_tx.send(true);
}
pub fn session_manager(&self) -> &Arc<SessionManager> {
&self.session_manager
}
pub fn bus(&self) -> &Arc<MessageBus> {
&self.bus
}
pub fn config(&self) -> &Config {
&self.config
}
pub async fn provider(&self) -> Option<Arc<dyn LLMProvider>> {
let guard = self.provider.read().await;
guard.clone()
}
pub fn set_streaming(&self, enabled: bool) {
self.streaming.store(enabled, Ordering::SeqCst);
}
pub fn is_streaming(&self) -> bool {
self.streaming.load(Ordering::SeqCst)
}
pub fn set_dry_run(&self, enabled: bool) {
self.dry_run.store(enabled, Ordering::SeqCst);
}
pub fn is_dry_run(&self) -> bool {
self.dry_run.load(Ordering::SeqCst)
}
fn dry_run_result(
name: &str,
args: &serde_json::Value,
raw_args: &str,
budget: usize,
) -> String {
let args_display =
serde_json::to_string_pretty(args).unwrap_or_else(|_| raw_args.to_string());
let sanitized = crate::utils::sanitize::sanitize_tool_result(&args_display, budget);
format!(
"[DRY RUN] Would execute tool '{}' with arguments: {}",
name, sanitized
)
}
pub async fn set_tool_feedback(&self, tx: tokio::sync::mpsc::UnboundedSender<ToolFeedback>) {
*self.tool_feedback_tx.write().await = Some(tx);
}
pub fn token_budget(&self) -> &TokenBudget {
&self.token_budget
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{LLMResponse, ToolDefinition};
use async_trait::async_trait;
#[derive(Debug)]
struct TestProvider {
name: &'static str,
model: &'static str,
}
#[async_trait]
impl LLMProvider for TestProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
self.model
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
}
#[tokio::test]
async fn test_agent_loop_creation() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
assert!(!agent.is_running());
}
#[tokio::test]
async fn test_provider_registry_lookup() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
assert!(agent.get_provider_by_name("openai").await.is_none());
}
#[tokio::test]
async fn test_provider_registry_set_and_get() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
agent
.set_provider_in_registry(
"openai",
Box::new(TestProvider {
name: "openai",
model: "gpt-5.1",
}),
)
.await;
let p = agent.get_provider_by_name("openai").await;
assert!(p.is_some());
assert_eq!(p.unwrap().name(), "openai");
}
#[tokio::test]
async fn test_process_message_uses_model_override_metadata() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let msg = InboundMessage::new("telegram", "user1", "chat1", "hello")
.with_metadata("model_override", "gpt-5.1");
let model = agent.resolve_model_for_message(&msg);
assert_eq!(model, "gpt-5.1");
}
#[tokio::test]
async fn test_resolve_model_falls_back_to_config_default() {
let mut config = Config::default();
config.agents.defaults.model = "claude-sonnet-4-5-20250929".to_string();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let msg = InboundMessage::new("telegram", "user1", "chat1", "hello");
let model = agent.resolve_model_for_message(&msg);
assert_eq!(model, "claude-sonnet-4-5-20250929");
}
#[tokio::test]
async fn test_agent_loop_with_context_builder() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let context_builder = ContextBuilder::new().with_system_prompt("Custom prompt");
let agent = AgentLoop::with_context_builder(config, session_manager, bus, context_builder);
assert!(!agent.is_running());
}
#[tokio::test]
async fn test_agent_loop_tool_registration() {
use crate::tools::EchoTool;
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
assert_eq!(agent.tool_count().await, 0);
assert!(!agent.has_tool("echo").await);
agent.register_tool(Box::new(EchoTool)).await;
assert_eq!(agent.tool_count().await, 1);
assert!(agent.has_tool("echo").await);
}
#[tokio::test]
async fn test_agent_loop_accessors() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let _ = agent.config();
let _ = agent.bus();
let _ = agent.session_manager();
}
#[tokio::test]
async fn test_process_message_no_provider() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let msg = InboundMessage::new("test", "user123", "chat456", "Hello");
let result = agent.process_message(&msg).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ZeptoError::Provider(_)));
assert!(err.to_string().contains("No provider configured"));
}
#[tokio::test]
async fn test_session_lock_for_reuses_same_session_lock() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let first = agent.session_lock_for("telegram:chat1").await;
let second = agent.session_lock_for("telegram:chat1").await;
let other = agent.session_lock_for("telegram:chat2").await;
assert!(Arc::ptr_eq(&first, &second));
assert!(!Arc::ptr_eq(&first, &other));
}
#[tokio::test]
async fn test_try_queue_or_process_returns_false_when_session_idle() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let msg = InboundMessage::new("telegram", "user1", "chat1", "hello");
let queued = agent.try_queue_or_process(&msg).await;
assert!(!queued);
let pending = agent.pending_messages.lock().await;
assert!(pending.get(&msg.session_key).is_none());
}
#[tokio::test]
async fn test_try_queue_or_process_queues_when_session_busy() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let msg = InboundMessage::new("telegram", "user1", "chat1", "followup");
let session_lock = agent.session_lock_for(&msg.session_key).await;
let _guard = session_lock.lock().await;
let queued = agent.try_queue_or_process(&msg).await;
assert!(queued);
let pending = agent.pending_messages.lock().await;
let queued_msgs = pending
.get(&msg.session_key)
.expect("pending messages should contain queued message");
assert_eq!(queued_msgs.len(), 1);
assert_eq!(queued_msgs[0].content, msg.content);
}
#[tokio::test]
async fn test_agent_loop_start_stop() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = Arc::new(AgentLoop::new(config, session_manager, bus.clone()));
assert!(!agent.is_running());
let agent_clone = Arc::clone(&agent);
let handle = tokio::spawn(async move { agent_clone.start().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert!(agent.is_running());
agent.stop();
let dummy_msg = InboundMessage::new("test", "user", "chat", "dummy");
bus.publish_inbound(dummy_msg).await.ok();
let result = tokio::time::timeout(tokio::time::Duration::from_millis(200), handle).await;
assert!(result.is_ok());
assert!(!agent.is_running());
}
#[tokio::test]
async fn test_agent_loop_double_start() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = Arc::new(AgentLoop::new(config, session_manager, bus.clone()));
let agent_clone = Arc::clone(&agent);
let handle = tokio::spawn(async move { agent_clone.start().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let result = agent.start().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already running"));
agent.stop();
let dummy_msg = InboundMessage::new("test", "user", "chat", "dummy");
bus.publish_inbound(dummy_msg).await.ok();
let _ = tokio::time::timeout(tokio::time::Duration::from_millis(200), handle).await;
}
#[tokio::test]
async fn test_agent_loop_graceful_shutdown() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = Arc::new(AgentLoop::new(config, session_manager, bus));
let agent_clone = Arc::clone(&agent);
let handle = tokio::spawn(async move { agent_clone.start().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert!(agent.is_running());
agent.stop();
let result = tokio::time::timeout(tokio::time::Duration::from_millis(100), handle).await;
assert!(
result.is_ok(),
"Agent loop should stop gracefully without needing a message"
);
assert!(!agent.is_running());
}
#[tokio::test]
async fn test_agent_loop_can_restart_after_stop() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = Arc::new(AgentLoop::new(config, session_manager, bus));
let agent_clone = Arc::clone(&agent);
let first = tokio::spawn(async move { agent_clone.start().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
agent.stop();
let first_result =
tokio::time::timeout(tokio::time::Duration::from_millis(200), first).await;
assert!(first_result.is_ok());
assert!(!agent.is_running());
let agent_clone = Arc::clone(&agent);
let second = tokio::spawn(async move { agent_clone.start().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert!(agent.is_running());
agent.stop();
let second_result =
tokio::time::timeout(tokio::time::Duration::from_millis(200), second).await;
assert!(second_result.is_ok());
assert!(!agent.is_running());
}
#[test]
fn test_context_builder_standalone() {
let builder = ContextBuilder::new();
let system = builder.build_system_message();
assert!(system.content.contains("ZeptoClaw"));
}
#[test]
fn test_build_messages_standalone() {
let builder = ContextBuilder::new();
let messages = builder.build_messages(&[], "Hello");
assert_eq!(messages.len(), 2);
assert!(messages[1].content == "Hello");
}
#[tokio::test]
async fn test_agent_loop_streaming_flag_default() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
assert!(!agent.is_streaming());
}
#[tokio::test]
async fn test_agent_loop_set_streaming() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
agent.set_streaming(true);
assert!(agent.is_streaming());
}
#[test]
fn test_tool_feedback_debug() {
let fb = ToolFeedback {
tool_name: "shell".to_string(),
phase: ToolFeedbackPhase::Starting,
};
let debug_str = format!("{:?}", fb);
assert!(debug_str.contains("shell"));
assert!(debug_str.contains("Starting"));
}
#[test]
fn test_tool_feedback_phases() {
let starting = ToolFeedbackPhase::Starting;
let done = ToolFeedbackPhase::Done { elapsed_ms: 1200 };
let failed = ToolFeedbackPhase::Failed {
elapsed_ms: 500,
error: "timeout".to_string(),
};
assert!(format!("{:?}", starting).contains("Starting"));
assert!(format!("{:?}", done).contains("1200"));
assert!(format!("{:?}", failed).contains("timeout"));
}
#[tokio::test]
async fn test_tool_feedback_channel_none_by_default() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let guard = agent.tool_feedback_tx.read().await;
assert!(guard.is_none());
}
#[test]
fn test_memory_flush_prompt_is_valid() {
assert!(MEMORY_FLUSH_PROMPT.contains("long-term memory"));
assert!(MEMORY_FLUSH_PROMPT.contains("longterm_memory"));
assert!(MEMORY_FLUSH_PROMPT.contains("duplicates"));
}
#[test]
fn test_memory_flush_timeout_is_reasonable() {
assert!(MEMORY_FLUSH_TIMEOUT_SECS > 0);
assert!(MEMORY_FLUSH_TIMEOUT_SECS <= 30);
}
#[tokio::test]
async fn test_memory_flush_no_provider() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
let messages = vec![Message::user("hello"), Message::assistant("hi")];
agent.memory_flush(&messages).await;
}
#[test]
fn test_dry_run_default_false() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
assert!(!agent.is_dry_run());
}
#[test]
fn test_set_dry_run() {
let config = Config::default();
let session_manager = SessionManager::new_memory();
let bus = Arc::new(MessageBus::new());
let agent = AgentLoop::new(config, session_manager, bus);
assert!(!agent.is_dry_run());
agent.set_dry_run(true);
assert!(agent.is_dry_run());
agent.set_dry_run(false);
assert!(!agent.is_dry_run());
}
#[derive(Debug)]
struct StubTool {
name: &'static str,
category: ToolCategory,
}
#[async_trait]
impl Tool for StubTool {
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
""
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn category(&self) -> ToolCategory {
self.category
}
async fn execute(
&self,
_args: serde_json::Value,
_ctx: &ToolContext,
) -> std::result::Result<crate::tools::ToolOutput, crate::error::ZeptoError> {
Ok(crate::tools::ToolOutput::llm_only("ok"))
}
}
fn make_tool_call(name: &str) -> LLMToolCall {
LLMToolCall {
id: format!("call_{name}"),
name: name.to_string(),
arguments: "{}".to_string(),
}
}
fn registry_with(tools: Vec<StubTool>) -> Arc<RwLock<ToolRegistry>> {
let mut reg = ToolRegistry::new();
for t in tools {
reg.register(Box::new(t));
}
Arc::new(RwLock::new(reg))
}
#[tokio::test]
async fn test_sequential_triggered_by_filesystem_write() {
let reg = registry_with(vec![
StubTool {
name: "write_file",
category: ToolCategory::FilesystemWrite,
},
StubTool {
name: "read_file",
category: ToolCategory::FilesystemRead,
},
]);
let calls = vec![make_tool_call("write_file"), make_tool_call("read_file")];
assert!(needs_sequential_execution(®, &calls).await);
}
#[tokio::test]
async fn test_sequential_triggered_by_shell() {
let reg = registry_with(vec![
StubTool {
name: "shell",
category: ToolCategory::Shell,
},
StubTool {
name: "read_file",
category: ToolCategory::FilesystemRead,
},
]);
let calls = vec![make_tool_call("shell"), make_tool_call("read_file")];
assert!(needs_sequential_execution(®, &calls).await);
}
#[tokio::test]
async fn test_parallel_when_only_reads() {
let reg = registry_with(vec![
StubTool {
name: "read_file",
category: ToolCategory::FilesystemRead,
},
StubTool {
name: "web_fetch",
category: ToolCategory::NetworkRead,
},
]);
let calls = vec![make_tool_call("read_file"), make_tool_call("web_fetch")];
assert!(!needs_sequential_execution(®, &calls).await);
}
#[tokio::test]
async fn test_sequential_for_unknown_tool_fail_safe() {
let reg = registry_with(vec![StubTool {
name: "read_file",
category: ToolCategory::FilesystemRead,
}]);
let calls = vec![make_tool_call("read_file"), make_tool_call("mystery_tool")];
assert!(needs_sequential_execution(®, &calls).await);
}
#[tokio::test]
async fn test_parallel_for_single_read_tool() {
let reg = registry_with(vec![StubTool {
name: "memory_search",
category: ToolCategory::Memory,
}]);
let calls = vec![make_tool_call("memory_search")];
assert!(!needs_sequential_execution(®, &calls).await);
}
}