use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::time::Duration;
use clawgarden_proto::{Envelope, EventType, MessagePayload};
use crate::bus_client::BusClient;
use crate::loop_checkpoint::{
clear_checkpoint, load_checkpoint, save_checkpoint, LoopCheckpoint, CHECKPOINT_VERSION,
};
use crate::loop_guard::LoopGuard;
use crate::loop_policy::LoopPolicy;
use crate::response_parser::ToolCallAction;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum LoopState {
Init,
Reasoning,
ToolDispatch,
Observe,
Integrate,
FinalRespond,
FinalSilent,
AbortTimeout,
AbortStall,
AbortErrorBudget,
AbortStepBudget,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ChatMessage {
User {
content: String,
},
AssistantToolCall {
id: String,
name: String,
arguments: String,
},
ToolResult {
tool_call_id: String,
content: String,
},
}
impl ChatMessage {
pub fn user(content: &str) -> Self {
Self::User {
content: content.to_string(),
}
}
pub fn assistant_tool_call(id: &str, name: &str, arguments: &str) -> Self {
Self::AssistantToolCall {
id: id.to_string(),
name: name.to_string(),
arguments: arguments.to_string(),
}
}
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
Self::ToolResult {
tool_call_id: tool_call_id.to_string(),
content: content.to_string(),
}
}
pub fn to_api_message(&self) -> serde_json::Value {
match self {
ChatMessage::User { content } => serde_json::json!({
"role": "user",
"content": content,
}),
ChatMessage::AssistantToolCall {
id,
name,
arguments,
} => serde_json::json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
}),
ChatMessage::ToolResult {
tool_call_id,
content,
} => serde_json::json!({
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}),
}
}
}
pub fn messages_to_api(messages: &[ChatMessage]) -> Vec<serde_json::Value> {
messages.iter().map(|m| m.to_api_message()).collect()
}
#[derive(Debug, Clone)]
pub struct AgentLoopResult {
pub action: ToolCallAction,
pub steps: usize,
pub tool_calls: Vec<String>,
pub termination: LoopState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ExecObservation {
exit_code: Option<i32>,
timed_out: bool,
duration_ms: u64,
stdout_len: usize,
stderr_len: usize,
stdout_hash: u64,
stderr_hash: u64,
spawn_error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ReadObservation {
ok: bool,
bytes: usize,
content_hash: u64,
error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SkillObservation {
ok: bool,
name: String,
error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ToolObservation {
Exec(ExecObservation),
ReadFile(ReadObservation),
CreateSkill(SkillObservation),
}
impl ToolObservation {
fn is_error(&self) -> bool {
match self {
ToolObservation::Exec(o) => {
o.timed_out || o.spawn_error.is_some() || o.exit_code.unwrap_or(1) != 0
}
ToolObservation::ReadFile(o) => !o.ok,
ToolObservation::CreateSkill(o) => !o.ok,
}
}
fn progress_delta(&self) -> i32 {
if self.is_error() { -1 } else { 1 }
}
fn fingerprint(&self) -> String {
hash_json(self)
}
fn to_llm_content(&self) -> String {
serde_json::to_string_pretty(self)
.unwrap_or_else(|_| "{\"error\":\"failed_to_serialize_observation\"}".to_string())
}
}
fn hash_json<T: Serialize>(value: &T) -> String {
let json = serde_json::to_string(value).unwrap_or_default();
let mut hasher = std::collections::hash_map::DefaultHasher::new();
json.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn hash_str(s: &str) -> u64 {
let mut h = std::collections::hash_map::DefaultHasher::new();
s.hash(&mut h);
h.finish()
}
async fn send_final_message(
bus: &mut BusClient,
env: &Envelope,
agent_name: &str,
reply_event_type: EventType,
reply_target: &str,
guard: &mut LoopGuard,
content: &str,
) -> Result<()> {
guard.record(&env.correlation_id, content);
let self_msg = format!("[{}]: {}", agent_name, content);
crate::main_helpers::record_history(&env.conversation_id, &self_msg);
let resp = crate::main_helpers::make_envelope(
env,
agent_name,
reply_event_type,
reply_target,
MessagePayload {
content: content.to_string(),
context: vec![],
},
);
bus.send(&resp).await?;
Ok(())
}
fn summarize_tool_results(messages: &[ChatMessage]) -> String {
let tool_results: Vec<&str> = messages
.iter()
.filter_map(|m| match m {
ChatMessage::ToolResult { content, .. } => Some(content.as_str()),
_ => None,
})
.collect();
if tool_results.is_empty() {
return "Task completed but could not summarize results.".to_string();
}
let combined = tool_results.join("\n");
if combined.len() > 2000 {
format!("{}\n... (truncated)", &combined[..2000])
} else {
combined
}
}
fn action_fingerprint(action: &ToolCallAction) -> String {
match action {
ToolCallAction::Exec {
command,
workdir,
timeout_secs,
} => format!(
"exec:{}:{}:{}",
hash_json(command),
hash_json(workdir),
hash_json(timeout_secs)
),
ToolCallAction::ReadFile { path } => format!("read_file:{}", hash_json(path)),
ToolCallAction::SkillCreate {
skill_name,
description,
body,
} => format!(
"create_skill:{}:{}:{}",
hash_json(skill_name),
hash_json(description),
hash_json(body)
),
ToolCallAction::Message { content } => format!("message:{}", hash_json(content)),
ToolCallAction::Silence => "silence".to_string(),
}
}
fn detect_stall(
policy: &LoopPolicy,
same_action_repeat_count: usize,
progress_window: &VecDeque<i32>,
cycle_count: usize,
) -> bool {
if same_action_repeat_count >= policy.same_action_repeat_limit {
return true;
}
if cycle_count >= policy.cycle_repeat_limit {
return true;
}
if progress_window.len() >= policy.stall_window {
let sum: i32 = progress_window.iter().sum();
if sum <= policy.stall_min_progress_sum {
return true;
}
}
false
}
async fn make_result(
termination: LoopState,
step: usize,
tool_call_log: Vec<String>,
agent_name: &str,
env: &Envelope,
checkpoint_enabled: bool,
) -> Result<AgentLoopResult> {
if checkpoint_enabled {
clear_checkpoint(agent_name, &env.conversation_id, &env.correlation_id).await?;
}
Ok(AgentLoopResult {
action: ToolCallAction::Silence,
steps: step,
tool_calls: tool_call_log,
termination,
})
}
fn check_abort(
obs: &ToolObservation,
action_fp: &str,
same_action_repeat_count: usize,
consecutive_error_count: &mut usize,
progress_window: &mut VecDeque<i32>,
cycle_counter: &mut HashMap<String, usize>,
policy: &LoopPolicy,
) -> Option<LoopState> {
let obs_fp = obs.fingerprint();
let cycle_key = format!("{}|{}", action_fp, obs_fp);
let cycle_count = {
let e = cycle_counter.entry(cycle_key).or_insert(0);
*e += 1;
*e
};
let progress = obs.progress_delta();
if progress_window.len() >= policy.stall_window {
progress_window.pop_front();
}
progress_window.push_back(progress);
if obs.is_error() {
*consecutive_error_count += 1;
} else {
*consecutive_error_count = 0;
}
if *consecutive_error_count >= policy.consecutive_error_limit {
return Some(LoopState::AbortErrorBudget);
}
if detect_stall(policy, same_action_repeat_count, progress_window, cycle_count) {
return Some(LoopState::AbortStall);
}
None
}
#[allow(clippy::too_many_arguments)]
pub async fn agent_loop(
agent_name: &str,
role: &str,
memory: &str,
skills: &str,
initial_message: &str,
history: &[String],
force: bool,
bus: &mut BusClient,
env: &Envelope,
reply_event_type: EventType,
reply_target: &str,
guard: &mut LoopGuard,
llm_timeout_ms: u64,
policy: &LoopPolicy,
) -> Result<AgentLoopResult> {
policy.validate();
let mut state = LoopState::Reasoning;
let mut step = 0usize;
let mut tool_call_log = Vec::<String>::new();
let mut messages = vec![ChatMessage::user(initial_message)];
let mut last_action_fingerprint: Option<String> = None;
let mut same_action_repeat_count = 0usize;
let mut consecutive_error_count = 0usize;
let mut progress_window: VecDeque<i32> = VecDeque::with_capacity(policy.stall_window);
let mut cycle_counter: HashMap<String, usize> = HashMap::new();
if policy.checkpoint_enabled {
if let Some(cp) = load_checkpoint(
agent_name,
&env.conversation_id,
&env.correlation_id,
&env.trace_id,
)
.await?
{
state = cp.state;
step = cp.step;
messages = cp.messages;
tool_call_log = cp.tool_call_log;
last_action_fingerprint = cp.last_action_fingerprint;
same_action_repeat_count = cp.same_action_repeat_count;
consecutive_error_count = cp.consecutive_error_count;
progress_window = cp.progress_window.into_iter().collect();
log::info!(
"Resumed agent loop from checkpoint: step={} state={:?}",
step,
state
);
}
}
loop {
step += 1;
if step > policy.hard_max_steps {
state = LoopState::AbortStepBudget;
let summary = format!(
"Safety limit reached (steps={}). Terminating.\n\n{}",
policy.hard_max_steps,
summarize_tool_results(&messages)
);
send_final_message(
bus, env, agent_name,
reply_event_type.clone(), reply_target,
guard, &summary,
)
.await?;
return make_result(state, step - 1, tool_call_log, agent_name, env, policy.checkpoint_enabled).await;
}
if policy.checkpoint_enabled {
save_checkpoint(&LoopCheckpoint {
version: CHECKPOINT_VERSION,
agent_name: agent_name.to_string(),
conversation_id: env.conversation_id.clone(),
correlation_id: env.correlation_id.clone(),
trace_id: env.trace_id.clone(),
step,
state: state.clone(),
messages: messages.clone(),
tool_call_log: tool_call_log.clone(),
last_action_fingerprint: last_action_fingerprint.clone(),
same_action_repeat_count,
consecutive_error_count,
progress_window: progress_window.iter().copied().collect(),
updated_at_unix_ms: chrono::Utc::now().timestamp_millis(),
})
.await?;
}
let llm_result = tokio::time::timeout(
Duration::from_millis(llm_timeout_ms),
crate::pi_rpc::call_llm(
agent_name, role, memory, skills, &messages, history, force,
),
)
.await;
let response = match llm_result {
Ok(Ok(Some(r))) => r,
Ok(Ok(None)) => {
return make_result(
LoopState::FinalSilent, step, tool_call_log,
agent_name, env, policy.checkpoint_enabled,
).await;
}
Ok(Err(e)) => {
state = LoopState::AbortErrorBudget;
let content = format!(
"LLM call error, aborting loop: {}\n\n{}",
e,
summarize_tool_results(&messages)
);
send_final_message(
bus, env, agent_name,
reply_event_type.clone(), reply_target,
guard, &content,
).await?;
return make_result(state, step, tool_call_log, agent_name, env, policy.checkpoint_enabled).await;
}
Err(_) => {
state = LoopState::AbortTimeout;
let content = format!(
"Loop timed out.\n\n{}",
summarize_tool_results(&messages)
);
send_final_message(
bus, env, agent_name,
reply_event_type.clone(), reply_target,
guard, &content,
).await?;
return make_result(state, step, tool_call_log, agent_name, env, policy.checkpoint_enabled).await;
}
};
let action_fp = action_fingerprint(&response.action);
if last_action_fingerprint.as_deref() == Some(action_fp.as_str()) {
same_action_repeat_count += 1;
} else {
last_action_fingerprint = Some(action_fp.clone());
same_action_repeat_count = 1;
}
if let Some(ref tc_id) = response.tool_call_id {
messages.push(ChatMessage::assistant_tool_call(
tc_id,
match &response.action {
ToolCallAction::Exec { .. } => "exec",
ToolCallAction::ReadFile { .. } => "read_file",
ToolCallAction::SkillCreate { .. } => "create_skill",
_ => "unknown",
},
response.tool_call_arguments.as_deref().unwrap_or("{}"),
));
}
let tool_call_id = response
.tool_call_id
.unwrap_or_else(|| format!("step_{}", step));
let (obs, log_label): (ToolObservation, String) = match response.action {
ToolCallAction::Message { content } => {
state = LoopState::FinalRespond;
send_final_message(
bus, env, agent_name,
reply_event_type.clone(), reply_target,
guard, &content,
).await?;
return make_result(state, step, tool_call_log, agent_name, env, policy.checkpoint_enabled).await;
}
ToolCallAction::Silence => {
return make_result(
LoopState::FinalSilent, step, tool_call_log,
agent_name, env, policy.checkpoint_enabled,
).await;
}
ToolCallAction::Exec { command, workdir, timeout_secs } => {
let timeout = timeout_secs.unwrap_or(30);
let exec = crate::main_helpers::execute_command(
&command, workdir.as_deref(), timeout,
).await;
let obs = ToolObservation::Exec(ExecObservation {
exit_code: exec.exit_code,
timed_out: exec.timed_out,
duration_ms: exec.duration_ms,
stdout_len: exec.stdout.len(),
stderr_len: exec.stderr.len(),
stdout_hash: hash_str(&exec.stdout),
stderr_hash: hash_str(&exec.stderr),
spawn_error: exec.spawn_error,
});
(obs, format!("exec:{}", command))
}
ToolCallAction::ReadFile { path } => {
let read = tokio::fs::read_to_string(&path).await;
let obs = match read {
Ok(content) => ToolObservation::ReadFile(ReadObservation {
ok: true,
bytes: content.len(),
content_hash: hash_str(&content),
error: None,
}),
Err(e) => ToolObservation::ReadFile(ReadObservation {
ok: false,
bytes: 0,
content_hash: 0,
error: Some(e.to_string()),
}),
};
(obs, format!("read_file:{}", path))
}
ToolCallAction::SkillCreate { skill_name, description, body } => {
let result = crate::main_helpers::create_skill(
bus, agent_name, env, &skill_name, &description, &body,
).await;
let obs = match result {
Ok(()) => ToolObservation::CreateSkill(SkillObservation {
ok: true,
name: skill_name.clone(),
error: None,
}),
Err(e) => ToolObservation::CreateSkill(SkillObservation {
ok: false,
name: skill_name.clone(),
error: Some(e.to_string()),
}),
};
(obs, format!("create_skill:{}", skill_name))
}
};
state = LoopState::Integrate;
tool_call_log.push(log_label);
messages.push(ChatMessage::tool_result(&tool_call_id, &obs.to_llm_content()));
if let Some(abort_state) = check_abort(
&obs,
&action_fp,
same_action_repeat_count,
&mut consecutive_error_count,
&mut progress_window,
&mut cycle_counter,
policy,
) {
state = abort_state;
let content = format!(
"Loop safety abort ({:?}).\n\n{}",
state,
summarize_tool_results(&messages)
);
send_final_message(
bus, env, agent_name,
reply_event_type.clone(), reply_target,
guard, &content,
).await?;
return make_result(state, step, tool_call_log, agent_name, env, policy.checkpoint_enabled).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_user() {
let msg = ChatMessage::user("Hello");
let api = msg.to_api_message();
assert_eq!(api["role"], "user");
assert_eq!(api["content"], "Hello");
}
#[test]
fn test_chat_message_tool_call() {
let msg = ChatMessage::assistant_tool_call("call_1", "exec", "{\"command\": \"ls\"}");
let api = msg.to_api_message();
assert_eq!(api["role"], "assistant");
assert_eq!(api["tool_calls"][0]["id"], "call_1");
assert_eq!(api["tool_calls"][0]["function"]["name"], "exec");
}
#[test]
fn test_chat_message_tool_result() {
let msg = ChatMessage::tool_result("call_1", "file1.rs\nfile2.rs");
let api = msg.to_api_message();
assert_eq!(api["role"], "tool");
assert_eq!(api["tool_call_id"], "call_1");
assert_eq!(api["content"], "file1.rs\nfile2.rs");
}
#[test]
fn test_messages_to_api() {
let messages = vec![
ChatMessage::user("Search for main"),
ChatMessage::assistant_tool_call("c1", "exec", "{\"command\": \"rg main\"}"),
ChatMessage::tool_result("c1", "main.rs:1:fn main()"),
];
let api = messages_to_api(&messages);
assert_eq!(api.len(), 3);
assert_eq!(api[0]["role"], "user");
assert_eq!(api[1]["role"], "assistant");
assert_eq!(api[2]["role"], "tool");
}
#[test]
fn test_detect_stall_by_repeat() {
let policy = LoopPolicy {
hard_max_steps: 64,
same_action_repeat_limit: 3,
consecutive_error_limit: 3,
stall_window: 4,
stall_min_progress_sum: 0,
cycle_repeat_limit: 3,
checkpoint_enabled: false,
};
let window: VecDeque<i32> = VecDeque::from(vec![1, 1, 1, 1]);
assert!(detect_stall(&policy, 3, &window, 1));
}
#[test]
fn test_detect_stall_by_progress() {
let policy = LoopPolicy {
hard_max_steps: 64,
same_action_repeat_limit: 3,
consecutive_error_limit: 3,
stall_window: 4,
stall_min_progress_sum: 0,
cycle_repeat_limit: 3,
checkpoint_enabled: false,
};
let window: VecDeque<i32> = VecDeque::from(vec![-1, -1, 1, -1]);
assert!(detect_stall(&policy, 1, &window, 1));
}
}