use crate::audit::{AuditEventType, AuditLog};
use crate::config::Config;
use crate::error::Result;
use crate::llm::{
ChatMessage, Choice, LLMProviderTrait, MultiModelManager, ProviderFallbackChain, TokenBudget,
};
use crate::mcp::McpClient;
use crate::policy::{Decision, PolicyEngine};
use crate::ravenfabric::RavenFabricClient;
use crate::sandbox::Sandbox;
use crate::tools::{ToolCall, ToolRegistry, ToolResult};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMemory {
max_messages: usize,
messages: Vec<ChatMessage>,
}
impl ConversationMemory {
pub fn new(system_prompt: &str, max_messages: usize) -> Self {
Self {
max_messages,
messages: vec![ChatMessage::new("system", system_prompt.to_string())],
}
}
pub fn add_user_message(&mut self, content: &str) -> &[ChatMessage] {
self.messages
.push(ChatMessage::new("user", content.to_string()));
self.trim_to_max();
&self.messages
}
pub fn add_user_message_with_images(
&mut self,
text: &str,
image_data_uris: Vec<String>,
) -> &[ChatMessage] {
self.messages.push(ChatMessage::with_images(
"user",
text.to_string(),
image_data_uris,
));
self.trim_to_max();
&self.messages
}
pub fn add_assistant_message(&mut self, content: &str) {
self.messages
.push(ChatMessage::new("assistant", content.to_string()));
self.trim_to_max();
}
pub fn history(&self) -> &[ChatMessage] {
&self.messages
}
pub fn from_history(messages: Vec<ChatMessage>, max_messages: usize) -> Self {
Self {
max_messages,
messages,
}
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.messages.len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
fn trim_to_max(&mut self) {
if self.max_messages == 0 {
return;
}
while self.messages.len() > self.max_messages {
if self.messages.len() > 1 {
self.messages.remove(1);
} else {
break;
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointState {
pub session_id: String,
pub iteration: usize,
pub max_iterations: usize,
pub messages: Vec<ChatMessage>,
pub initial_prompt: String,
pub system_prompt: String,
pub provider: String,
pub model: String,
pub enable_tools: bool,
pub last_checkpoint: String,
}
impl CheckpointState {
#[allow(clippy::too_many_arguments)]
pub fn new(
session_id: String,
iteration: usize,
max_iterations: usize,
messages: Vec<ChatMessage>,
initial_prompt: &str,
system_prompt: &str,
provider: &str,
model: &str,
enable_tools: bool,
) -> Self {
Self {
session_id,
iteration,
max_iterations,
messages,
initial_prompt: initial_prompt.to_string(),
system_prompt: system_prompt.to_string(),
provider: provider.to_string(),
model: model.to_string(),
enable_tools,
last_checkpoint: chrono::Utc::now().to_rfc3339(),
}
}
}
pub fn save_checkpoint(
checkpoint_dir: &std::path::Path,
state: &CheckpointState,
) -> std::result::Result<std::path::PathBuf, String> {
let path = checkpoint_dir.join(format!("{}.json", state.session_id));
std::fs::create_dir_all(checkpoint_dir)
.map_err(|e| format!("Failed to create checkpoint directory: {}", e))?;
let content = serde_json::to_string_pretty(state)
.map_err(|e| format!("Failed to serialize checkpoint: {}", e))?;
let tmp_path = path.with_extension("json.tmp");
std::fs::write(&tmp_path, &content)
.map_err(|e| format!("Failed to write checkpoint: {}", e))?;
std::fs::rename(&tmp_path, &path)
.map_err(|e| format!("Failed to finalize checkpoint: {}", e))?;
Ok(path)
}
pub fn load_checkpoint(
checkpoint_dir: &std::path::Path,
session_id: &str,
) -> Option<CheckpointState> {
let path = checkpoint_dir.join(format!("{}.json", session_id));
match std::fs::read_to_string(&path) {
Ok(content) => match serde_json::from_str::<CheckpointState>(&content) {
Ok(state) => {
info!(
session_id = %session_id,
iteration = state.iteration,
max_iterations = state.max_iterations,
"Loaded checkpoint"
);
Some(state)
}
Err(e) => {
warn!(
session_id = %session_id,
error = %e,
"Failed to deserialize checkpoint"
);
None
}
},
Err(e) => {
if e.kind() != std::io::ErrorKind::NotFound {
warn!(
session_id = %session_id,
error = %e,
"Failed to read checkpoint"
);
}
None
}
}
}
pub fn delete_checkpoint(checkpoint_dir: &std::path::Path, session_id: &str) {
let path = checkpoint_dir.join(format!("{}.json", session_id));
if path.exists() {
if let Err(e) = std::fs::remove_file(&path) {
warn!(
session_id = %session_id,
error = %e,
"Failed to delete checkpoint"
);
} else {
debug!(
session_id = %session_id,
"Deleted checkpoint"
);
}
}
}
pub struct AgentLoopConfig {
pub max_iterations: usize,
pub enable_tools: bool,
pub require_approval: bool,
pub prompt_injection_protection: bool,
pub token_lifetime_secs: u64,
pub no_final_required: bool,
pub fallback_chain: Option<Arc<std::sync::Mutex<ProviderFallbackChain>>>,
pub token_budget: Option<Arc<std::sync::Mutex<TokenBudget>>>,
pub ravenfabric: Option<RavenFabricClient>,
pub checkpoint_dir: Option<PathBuf>,
pub session_id: Option<String>,
pub metrics_callback: Option<Box<dyn Fn(u64, u64) + Send + Sync>>,
pub load_manager: Option<Arc<crate::load::LoadManager>>,
}
impl Default for AgentLoopConfig {
fn default() -> Self {
Self {
max_iterations: 10,
enable_tools: false,
require_approval: false,
prompt_injection_protection: true,
token_lifetime_secs: 0,
no_final_required: true,
fallback_chain: None,
token_budget: None,
ravenfabric: None,
checkpoint_dir: None,
session_id: None,
metrics_callback: None,
load_manager: None,
}
}
}
impl std::fmt::Debug for AgentLoopConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentLoopConfig")
.field("max_iterations", &self.max_iterations)
.field("enable_tools", &self.enable_tools)
.field("require_approval", &self.require_approval)
.field(
"prompt_injection_protection",
&self.prompt_injection_protection,
)
.field("token_lifetime_secs", &self.token_lifetime_secs)
.field("no_final_required", &self.no_final_required)
.field("fallback_chain", &self.fallback_chain)
.field("token_budget", &self.token_budget)
.field("ravenfabric", &self.ravenfabric)
.field("checkpoint_dir", &self.checkpoint_dir)
.field("session_id", &self.session_id)
.field(
"metrics_callback",
&self.metrics_callback.as_ref().map(|_| "Box<Fn>"),
)
.field(
"load_manager",
&self.load_manager.as_ref().map(|_| "Arc<LoadManager>"),
)
.finish()
}
}
impl Clone for AgentLoopConfig {
fn clone(&self) -> Self {
Self {
max_iterations: self.max_iterations,
enable_tools: self.enable_tools,
require_approval: self.require_approval,
prompt_injection_protection: self.prompt_injection_protection,
token_lifetime_secs: self.token_lifetime_secs,
no_final_required: self.no_final_required,
fallback_chain: self.fallback_chain.clone(),
token_budget: self.token_budget.clone(),
ravenfabric: self.ravenfabric.clone(),
checkpoint_dir: self.checkpoint_dir.clone(),
session_id: self.session_id.clone(),
metrics_callback: None,
load_manager: self.load_manager.clone(),
}
}
}
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model()))]
pub async fn run_agent_loop(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
) -> Result<String> {
run_agent_loop_with_registry(llm, initial_prompt, system_prompt, config, None).await
}
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model()))]
pub async fn run_agent_loop_with_registry(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
tool_registry: Option<ToolRegistry>,
) -> Result<String> {
let registry = tool_registry.unwrap_or_else(ToolRegistry::with_default_tools);
run_agent_loop_inner(
llm,
initial_prompt,
system_prompt,
config,
registry,
"security integration",
false,
Vec::new(),
)
.await
}
#[allow(dead_code)]
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model(), image_count = image_data_uris.len()))]
pub async fn run_agent_loop_with_images(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
tool_registry: Option<ToolRegistry>,
image_data_uris: Vec<String>,
) -> Result<String> {
let registry = tool_registry.unwrap_or_else(ToolRegistry::with_default_tools);
run_agent_loop_inner(
llm,
initial_prompt,
system_prompt,
config,
registry,
"security integration",
false,
image_data_uris,
)
.await
}
#[allow(clippy::too_many_arguments)]
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model()))]
async fn run_agent_loop_inner(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
registry: ToolRegistry,
loop_label: &str,
mcp_enabled: bool,
image_data_uris: Vec<String>,
) -> Result<String> {
let policy_engine = PolicyEngine::default_secure();
let mut sandbox = Sandbox::default();
sandbox.init().await.map_err(|e| {
crate::error::RavenClawsError::CommandExecution(format!("Sandbox init failed: {}", e))
})?;
let audit_log = AuditLog::new(format!("agent-{}", std::process::id()));
let injection_detector = if config.prompt_injection_protection {
Some(crate::policy::InjectionDetector::new())
} else {
None
};
let session_start = std::time::Instant::now();
info!(
provider = llm.provider_name(),
model = llm.model(),
max_iterations = config.max_iterations,
enable_tools = config.enable_tools,
tool_count = registry.len(),
require_approval = config.require_approval,
prompt_injection_protection = config.prompt_injection_protection,
token_lifetime_secs = config.token_lifetime_secs,
"Agent loop starting with {}",
loop_label
);
let _ = audit_log.append(
AuditEventType::AgentStart,
"agent",
&format!(
"Agent loop started with {} (model: {})",
llm.provider_name(),
llm.model()
),
Some(serde_json::json!({
"provider": llm.provider_name(),
"model": llm.model(),
"max_iterations": config.max_iterations,
"enable_tools": config.enable_tools,
"mcp_enabled": mcp_enabled,
"tool_count": registry.len(),
"require_approval": config.require_approval,
"prompt_injection_protection": config.prompt_injection_protection,
"token_lifetime_secs": config.token_lifetime_secs,
})),
);
let (mut memory, start_iteration) = if let Some(ref checkpoint_dir) = config.checkpoint_dir {
if let Some(ref session_id) = config.session_id {
if let Some(checkpoint) = load_checkpoint(checkpoint_dir, session_id) {
info!(
session_id = %session_id,
iteration = checkpoint.iteration,
max_iterations = checkpoint.max_iterations,
"Resuming agent loop from checkpoint"
);
(
ConversationMemory::from_history(checkpoint.messages, 0),
checkpoint.iteration + 1, )
} else {
info!(
session_id = %session_id,
"No checkpoint found, starting fresh"
);
let mut m = ConversationMemory::new(system_prompt, 0);
if image_data_uris.is_empty() {
m.add_user_message(initial_prompt);
} else {
m.add_user_message_with_images(initial_prompt, image_data_uris.clone());
}
(m, 0)
}
} else {
let mut m = ConversationMemory::new(system_prompt, 0);
if image_data_uris.is_empty() {
m.add_user_message(initial_prompt);
} else {
m.add_user_message_with_images(initial_prompt, image_data_uris.clone());
}
(m, 0)
}
} else {
let mut m = ConversationMemory::new(system_prompt, 0);
if image_data_uris.is_empty() {
m.add_user_message(initial_prompt);
} else {
m.add_user_message_with_images(initial_prompt, image_data_uris.clone());
}
(m, 0)
};
let session_id = config
.session_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
for iteration in start_iteration..config.max_iterations {
if config.token_lifetime_secs > 0 {
let elapsed = session_start.elapsed().as_secs();
if elapsed >= config.token_lifetime_secs {
warn!(
iteration = iteration,
elapsed_secs = elapsed,
token_lifetime_secs = config.token_lifetime_secs,
"Agent loop reached token lifetime limit"
);
let _ = audit_log.append(
AuditEventType::SecurityViolation,
"token_lifetime",
&format!(
"Session expired after {} seconds (limit: {}s)",
elapsed, config.token_lifetime_secs
),
Some(serde_json::json!({
"elapsed_secs": elapsed,
"token_lifetime_secs": config.token_lifetime_secs,
"iteration": iteration,
})),
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Err(crate::error::RavenClawsError::SecurityViolation(format!(
"Session token expired after {} seconds (limit: {}s)",
elapsed, config.token_lifetime_secs
)));
}
}
let messages = memory.history().to_vec();
if let Some(ref budget) = config.token_budget {
let budget = budget.lock().unwrap();
if budget.remaining() < 100 {
warn!(
iteration = iteration,
remaining = budget.remaining(),
"Token budget exhausted"
);
let _ = audit_log.append(
AuditEventType::SecurityViolation,
"token_budget",
&format!("Token budget exhausted (remaining: {})", budget.remaining()),
Some(serde_json::json!({
"remaining": budget.remaining(),
"used": budget.used_tokens,
"iteration": iteration,
})),
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Err(crate::error::RavenClawsError::SecurityViolation(
"Token budget exhausted".to_string(),
));
}
}
if let Some(ref load_manager) = config.load_manager {
let admission = load_manager.check_admission();
if !admission.is_allowed() {
warn!(
?admission,
iteration = iteration,
"Admission denied before LLM call"
);
let _ = audit_log.append(
AuditEventType::Error,
"load_manager",
&format!("Admission denied: {:?}", admission),
None,
);
load_manager.record_outcome(crate::load::RequestOutcome::Failure);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Err(crate::error::RavenClawsError::SecurityViolation(format!(
"Admission denied: {:?}",
admission
)));
}
}
let response = match llm.chat(messages.clone()).await {
Ok(r) => {
if let Some(ref load_manager) = config.load_manager {
load_manager.record_outcome(crate::load::RequestOutcome::Success);
}
r
}
Err(e) => {
if let Some(ref load_manager) = config.load_manager {
load_manager.record_outcome(crate::load::RequestOutcome::Failure);
}
if let Some(ref chain) = config.fallback_chain {
warn!(error = %e, "Primary LLM failed, trying fallback chain");
let _ = audit_log.append(
AuditEventType::Error,
"llm",
&format!("Primary LLM failed, trying fallback: {}", e),
None,
);
let configs = {
let c = chain.lock().unwrap();
c.configs.clone()
};
let mut temp_chain = ProviderFallbackChain::new(configs);
match temp_chain.chat_with_fallback(messages).await {
Ok(r) => {
info!("Fallback chain succeeded");
if let Some(ref budget) = config.token_budget {
if let Some(usage) = &r.usage {
let mut b = budget.lock().unwrap();
b.record_usage(usage.total_tokens);
}
}
r
}
Err(fallback_e) => {
warn!(error = %fallback_e, "Fallback chain also failed");
let _ = audit_log.append(
AuditEventType::Error,
"llm",
&format!("All providers failed: {}", fallback_e),
None,
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Err(crate::error::RavenClawsError::Llm(fallback_e));
}
}
} else {
warn!(error = %e, "LLM request failed");
let _ = audit_log.append(
AuditEventType::Error,
"llm",
&format!("LLM request failed: {}", e),
None,
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Err(crate::error::RavenClawsError::Llm(e));
}
}
};
let mut iteration_tokens: u64 = 0;
if let Some(ref budget) = config.token_budget {
if let Some(usage) = &response.usage {
let mut b = budget.lock().unwrap();
b.record_usage(usage.total_tokens);
iteration_tokens = usage.total_tokens as u64;
debug!(
iteration = iteration,
tokens_used = usage.total_tokens,
total_used = b.used_tokens,
remaining = b.remaining(),
"Token usage recorded"
);
}
} else if let Some(usage) = &response.usage {
iteration_tokens = usage.total_tokens as u64;
}
if let Some(ref cb) = config.metrics_callback {
cb(iteration_tokens, 0);
}
if let Some(ref rf) = config.ravenfabric {
if rf.is_enabled() {
let _ = rf.health().await;
info!(
iteration = iteration,
ravenfabric = true,
"RavenFabric health check completed"
);
}
}
let first_choice = response.choices.first();
let content = first_choice
.map(|c| c.message.content.clone())
.unwrap_or_default();
debug!(
iteration = iteration,
response_length = content.len(),
response_preview = %content[..content.len().min(500)],
"LLM response received"
);
if let Some(ref detector) = injection_detector {
match detector.check(&content) {
crate::policy::InjectionVerdict::Suspicious(reason) => {
warn!(
iteration = iteration,
reason = %reason,
"Prompt-injection detected in LLM response"
);
let _ = audit_log.append(
AuditEventType::SecurityViolation,
"injection_detector",
&format!("Prompt-injection detected: {}", reason),
Some(serde_json::json!({
"reason": reason,
"iteration": iteration,
"content_preview": &content[..content.len().min(200)],
})),
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Err(crate::error::RavenClawsError::SecurityViolation(format!(
"LLM response blocked: potential prompt injection ({})",
reason
)));
}
crate::policy::InjectionVerdict::Clean => {}
}
}
if config.enable_tools {
if let Some((tool_name, args)) = first_choice.and_then(parse_structured_tool_call) {
info!(tool = %tool_name, "Structured tool call detected");
if let Some(tool_result) = execute_parsed_tool_call(
tool_name,
args,
®istry,
&policy_engine,
&sandbox,
&audit_log,
config.require_approval,
)
.await
{
let observation = if tool_result.success {
format!("OBSERVATION: {}", tool_result.output)
} else {
format!(
"OBSERVATION: Tool failed with error: {}",
tool_result.error.as_deref().unwrap_or("unknown error")
)
};
memory.add_user_message(&observation);
if let Some(ref cb) = config.metrics_callback {
cb(0, 1);
}
info!(
iteration = iteration,
tool = %tool_result.tool_name,
success = tool_result.success,
"Structured tool executed"
);
continue;
}
}
}
if content.contains("FINAL:") {
let final_response = content
.split("FINAL:")
.nth(1)
.unwrap_or("")
.trim()
.to_string();
memory.add_assistant_message(&content);
let _ = audit_log.append(
AuditEventType::AgentFinish,
"agent",
"Agent loop completed successfully",
Some(serde_json::json!({
"iterations": iteration + 1,
"final_response_length": final_response.len(),
})),
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Ok(final_response);
}
if config.enable_tools {
if let Some(tool_result) = execute_tool_call_with_security(
&content,
®istry,
&policy_engine,
&sandbox,
&audit_log,
)
.await
{
let observation = if tool_result.success {
format!("OBSERVATION: {}", tool_result.output)
} else {
format!(
"OBSERVATION: Tool failed with error: {}",
tool_result.error.as_deref().unwrap_or("unknown error")
)
};
memory.add_assistant_message(&content);
memory.add_user_message(&observation);
if let Some(ref cb) = config.metrics_callback {
cb(0, 1);
}
info!(
iteration = iteration,
tool = %tool_result.tool_name,
success = tool_result.success,
"Tool executed"
);
continue;
}
}
memory.add_assistant_message(&content);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
let checkpoint = CheckpointState::new(
session_id.clone(),
iteration,
config.max_iterations,
memory.history().to_vec(),
initial_prompt,
system_prompt,
llm.provider_name(),
llm.model(),
config.enable_tools,
);
if let Err(e) = save_checkpoint(checkpoint_dir, &checkpoint) {
warn!(
session_id = %session_id,
iteration = iteration,
error = %e,
"Failed to save checkpoint"
);
} else {
debug!(
session_id = %session_id,
iteration = iteration,
"Checkpoint saved"
);
}
}
if config.no_final_required {
info!(
iteration = iteration,
response_length = content.len(),
"no_final_required: treating response as completion"
);
let _ = audit_log.append(
AuditEventType::AgentFinish,
"agent",
"Agent loop completed (no_final_required)",
Some(serde_json::json!({
"iterations": iteration + 1,
"final_response_length": content.len(),
})),
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
return Ok(content);
}
info!(
iteration = iteration,
thought = %content.lines().find(|l| l.starts_with("THOUGHT:")).unwrap_or("<no thought>"),
"Agent loop progress"
);
}
warn!(
max_iterations = config.max_iterations,
"Agent loop reached max iterations"
);
let _ = audit_log.append(
AuditEventType::Error,
"agent",
"Agent loop reached max iterations without completing",
Some(serde_json::json!({
"max_iterations": config.max_iterations,
})),
);
if let Some(ref checkpoint_dir) = config.checkpoint_dir {
delete_checkpoint(checkpoint_dir, &session_id);
}
let history = memory.history();
if history.len() > 1 {
if let Some(last) = history.last() {
return Ok(last.content.clone());
}
}
Err(crate::error::RavenClawsError::CommandExecution(
"Agent loop reached max iterations without completing the task".to_string(),
))
}
#[allow(dead_code)]
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model()))]
pub async fn run_agent_loop_with_mcp(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
mcp_client: Option<Arc<RwLock<McpClient>>>,
) -> Result<String> {
run_agent_loop_with_mcp_and_registry(
llm,
initial_prompt,
system_prompt,
config,
mcp_client,
None,
)
.await
}
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model()))]
pub async fn run_agent_loop_with_mcp_and_registry(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
mcp_client: Option<Arc<RwLock<McpClient>>>,
tool_registry: Option<ToolRegistry>,
) -> Result<String> {
let mut registry = tool_registry.unwrap_or_else(ToolRegistry::with_default_tools);
if let Some(client) = &mcp_client {
match crate::mcp::register_mcp_tools(&mut registry, client.clone()).await {
Ok(count) => {
info!(count, "MCP tools registered");
}
Err(e) => {
warn!(error = %e, "Failed to register MCP tools");
}
}
}
let mcp_enabled = mcp_client.is_some();
run_agent_loop_inner(
llm,
initial_prompt,
system_prompt,
config,
registry,
"MCP integration",
mcp_enabled,
Vec::new(),
)
.await
}
#[instrument(skip_all, fields(provider = %llm.provider_name(), model = %llm.model(), image_count = image_data_uris.len()))]
pub async fn run_agent_loop_with_mcp_and_images(
llm: Arc<dyn LLMProviderTrait>,
initial_prompt: &str,
system_prompt: &str,
config: AgentLoopConfig,
mcp_client: Option<Arc<RwLock<McpClient>>>,
tool_registry: Option<ToolRegistry>,
image_data_uris: Vec<String>,
) -> Result<String> {
let mut registry = tool_registry.unwrap_or_else(ToolRegistry::with_default_tools);
if let Some(client) = &mcp_client {
match crate::mcp::register_mcp_tools(&mut registry, client.clone()).await {
Ok(count) => {
info!(count, "MCP tools registered");
}
Err(e) => {
warn!(error = %e, "Failed to register MCP tools");
}
}
}
let mcp_enabled = mcp_client.is_some();
run_agent_loop_inner(
llm,
initial_prompt,
system_prompt,
config,
registry,
"MCP integration",
mcp_enabled,
image_data_uris,
)
.await
}
async fn prompt_for_approval(tool_name: &str, args: &serde_json::Value) -> bool {
use std::io::{IsTerminal, Write};
let args_str = serde_json::to_string_pretty(args).unwrap_or_default();
if !std::io::stdin().is_terminal() {
warn!(
tool = %tool_name,
"stdin is not a TTY — auto-approving tool call (use --require-approval only in interactive mode)"
);
return true;
}
eprintln!("\n⚠️ Tool requires approval:");
eprintln!(" Tool: {}", tool_name);
for line in args_str.lines() {
eprintln!(" {}", line);
}
eprint!(" Approve? [y/N] ");
std::io::stderr().flush().ok();
let mut input = String::new();
match std::io::stdin().read_line(&mut input) {
Ok(_) => {
let trimmed = input.trim().to_lowercase();
trimmed == "y" || trimmed == "yes"
}
Err(e) => {
warn!(error = %e, "Failed to read approval input — denying by default");
false
}
}
}
#[cfg(test)]
async fn prompt_for_approval_with_input(
tool_name: &str,
args: &serde_json::Value,
input: &str,
) -> bool {
use std::io::Write;
let args_str = serde_json::to_string_pretty(args).unwrap_or_default();
eprintln!("\n⚠️ Tool requires approval:");
eprintln!(" Tool: {}", tool_name);
for line in args_str.lines() {
eprintln!(" {}", line);
}
eprint!(" Approve? [y/N] ");
std::io::stderr().flush().ok();
let trimmed = input.trim().to_lowercase();
trimmed == "y" || trimmed == "yes"
}
async fn execute_parsed_tool_call(
tool_name: String,
args: serde_json::Value,
registry: &ToolRegistry,
policy_engine: &PolicyEngine,
_sandbox: &Sandbox,
audit_log: &AuditLog,
require_approval: bool,
) -> Option<ToolResult> {
info!(tool = %tool_name, "Executing parsed tool call");
let _ = audit_log.tool_call(&tool_name, &args);
if require_approval && policy_engine.requires_approval(&tool_name) {
let _ = audit_log.append(
AuditEventType::ApprovalRequested,
"approval",
&format!("Approval required for tool: {}", tool_name),
Some(serde_json::json!({"tool": tool_name, "args": args})),
);
let granted = prompt_for_approval(&tool_name, &args).await;
if !granted {
let _ = audit_log.approval(&tool_name, false, Some("Denied by user"));
warn!(tool = %tool_name, "Tool call denied by user");
return Some(ToolResult {
tool_name: tool_name.clone(),
success: false,
output: String::new(),
error: Some(format!("Approval denied by user for tool: {}", tool_name)),
exit_code: Some(-1),
duration_ms: None,
});
}
let _ = audit_log.approval(&tool_name, true, Some("Approved by user"));
info!(tool = %tool_name, "Tool call approved by user");
}
let policy_decision = policy_engine.check_tool_call(&tool_name, &args);
match &policy_decision {
Decision::Allow => {
let _ = audit_log.policy_decision(&tool_name, true, None);
}
Decision::Deny(reason) => {
let _ = audit_log.policy_decision(&tool_name, false, Some(reason));
warn!(tool = %tool_name, reason = %reason, "Tool call denied by policy");
return Some(ToolResult {
tool_name: tool_name.clone(),
success: false,
output: String::new(),
error: Some(format!("Policy denied: {}", reason)),
exit_code: Some(-1),
duration_ms: None,
});
}
}
let tool_name_clone = tool_name.clone();
let call = ToolCall {
name: tool_name.clone(),
arguments: args,
id: None,
};
let result = match registry.execute(call).await {
Ok(result) => {
let _ = audit_log.append(
AuditEventType::ToolResult,
&tool_name_clone,
&format!(
"Tool executed: {} (success: {})",
tool_name_clone, result.success
),
Some(serde_json::json!({
"success": result.success,
"exit_code": result.exit_code,
"duration_ms": result.duration_ms,
})),
);
result
}
Err(e) => {
let _ = audit_log.append(
AuditEventType::Error,
&tool_name_clone,
&format!("Tool execution failed: {}", e),
None,
);
ToolResult {
tool_name: tool_name_clone,
success: false,
output: String::new(),
error: Some(e.to_string()),
exit_code: Some(-1),
duration_ms: None,
}
}
};
Some(result)
}
async fn execute_tool_call_with_security(
content: &str,
registry: &ToolRegistry,
policy_engine: &PolicyEngine,
_sandbox: &Sandbox,
audit_log: &AuditLog,
) -> Option<ToolResult> {
let (tool_name, args) = parse_tool_call(content)?;
execute_parsed_tool_call(
tool_name,
args,
registry,
policy_engine,
_sandbox,
audit_log,
false, )
.await
}
fn parse_structured_tool_call(choice: &Choice) -> Option<(String, serde_json::Value)> {
let tool_calls = choice.tool_calls.as_ref()?;
let first_call = tool_calls.first()?;
let tool_name = first_call.function.name.clone();
let args: serde_json::Value = serde_json::from_str(&first_call.function.arguments).ok()?;
Some((tool_name, args))
}
fn parse_tool_call(content: &str) -> Option<(String, serde_json::Value)> {
let mut lines = content.lines();
let tool_call_line = lines.find(|l| l.trim().starts_with("TOOL_CALL:"))?;
let tool_name = tool_call_line
.trim()
.strip_prefix("TOOL_CALL:")
.map(|s| s.trim())
.filter(|s| !s.is_empty())?
.to_string();
let args_line = lines.find(|l| l.trim().starts_with("ARGS:"))?;
let args_str = args_line.trim().strip_prefix("ARGS:").map(|s| s.trim())?;
let args: serde_json::Value = serde_json::from_str(args_str).ok()?;
Some((tool_name, args))
}
pub async fn run_single(
llm: Arc<dyn LLMProviderTrait>,
config: Config,
ravenfabric: Option<RavenFabricClient>,
) -> Result<()> {
info!(
"Starting single agent mode with provider: {}",
llm.provider_name()
);
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
info!("RavenFabric remote execution available");
match rf.health().await {
Ok(true) => info!("RavenFabric mesh is healthy"),
Ok(false) => warn!("RavenFabric mesh returned unhealthy status"),
Err(e) => warn!(error = %e, "RavenFabric health check failed"),
}
}
}
let system_prompt = &config.llm.system_prompt;
let messages = vec![
ChatMessage::new("system", system_prompt.to_string()),
ChatMessage::new("user", "Ready. Awaiting instructions."),
];
match llm.chat(messages).await {
Ok(response) => {
if let Some(choice) = response.choices.first() {
info!(provider = llm.provider_name(), model = llm.model(), response = %choice.message.content, "Agent response received");
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
let preview = choice.message.content.chars().take(500).collect::<String>();
let _ = rf.broadcast(&preview, 30).await;
info!("Agent result broadcast to RavenFabric mesh");
}
}
}
}
Err(e) => {
warn!(error = %e, provider = llm.provider_name(), "LLM request failed");
}
}
Ok(())
}
pub async fn run_swarm(
llm: Arc<dyn LLMProviderTrait>,
config: Config,
ravenfabric: Option<RavenFabricClient>,
) -> Result<()> {
info!("Starting swarm mode (single-provider) — 3 parallel agents");
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
info!("RavenFabric remote execution available for swarm coordination");
match rf.health().await {
Ok(true) => info!("RavenFabric mesh is healthy"),
Ok(false) => warn!("RavenFabric mesh returned unhealthy status"),
Err(e) => warn!(error = %e, "RavenFabric health check failed"),
}
}
}
let _system_prompt = &config.llm.system_prompt;
let num_agents = 3;
let mut handles = Vec::new();
let personas = [
"You are an analytical agent. Focus on logic, structure, and precision.",
"You are a creative agent. Focus on innovation, alternatives, and possibilities.",
"You are a pragmatic agent. Focus on simplicity, efficiency, and practicality.",
];
for (i, persona) in personas.iter().enumerate().take(num_agents) {
let llm_clone = llm.clone();
let persona = persona.to_string();
let task = "Analyze the given task and provide your solution.".to_string();
let handle = tokio::spawn(async move {
let mut memory = ConversationMemory::new(&persona, 10);
memory.add_user_message(&task);
let messages = memory.history().to_vec();
match llm_clone.chat(messages).await {
Ok(response) => {
let content = response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
Ok((i, content))
}
Err(e) => Err(format!("Agent {} failed: {}", i, e)),
}
});
handles.push(handle);
}
let mut results: Vec<(usize, String)> = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok((idx, result))) => {
info!("Agent {} completed: {} chars", idx, result.len());
results.push((idx, result));
}
Ok(Err(e)) => warn!("Agent failed: {}", e),
Err(e) => warn!("Agent join failed: {}", e),
}
}
println!("\n🐦⬛ Swarm Results ({} agents):", results.len());
for (idx, result) in &results {
println!(
"\n── Agent {} ({}) ──",
idx + 1,
personas[*idx].split('.').next().unwrap_or("Unknown")
);
println!("{}", result);
}
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
let summary = format!(
"Swarm completed: {} agents, results: {}",
results.len(),
results
.iter()
.map(|(i, r)| format!("Agent {}: {} chars", i, r.len()))
.collect::<Vec<_>>()
.join(", ")
);
let _ = rf.broadcast(&summary, 30).await;
info!("Swarm results broadcast to RavenFabric mesh");
}
}
Ok(())
}
pub async fn run_supervisor(
llm: Arc<dyn LLMProviderTrait>,
config: Config,
ravenfabric: Option<RavenFabricClient>,
) -> Result<()> {
info!("Starting supervisor mode (single-provider)");
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
info!("RavenFabric remote execution available for supervisor coordination");
match rf.health().await {
Ok(true) => info!("RavenFabric mesh is healthy"),
Ok(false) => warn!("RavenFabric mesh returned unhealthy status"),
Err(e) => warn!(error = %e, "RavenFabric health check failed"),
}
}
}
let system_prompt = &config.llm.system_prompt;
let policy_engine = PolicyEngine::default_secure();
let mut sandbox = Sandbox::default();
sandbox.init().await.map_err(|e| {
crate::error::RavenClawsError::CommandExecution(format!("Sandbox init failed: {}", e))
})?;
let audit_log = AuditLog::new(format!("supervisor-{}", std::process::id()));
let registry = ToolRegistry::with_default_tools();
let supervisor_prompt = format!(
"You are a supervisor agent. Your task is to decompose complex tasks into subtasks \
and coordinate sub-agents to complete them. \
\n\nFor each subtask, respond with:\n\
SUBTASK: <description>\n\
AGENT: <agent_number>\n\
\nWhen all subtasks are complete, respond with:\n\
FINAL: <aggregated result>\n\
\nTask: {}",
"Coordinate the completion of the assigned task."
);
let mut memory = ConversationMemory::new(system_prompt, 20);
memory.add_user_message(&supervisor_prompt);
let mut subtask_results: Vec<String> = Vec::new();
let mut iteration = 0;
let max_iterations = 15;
loop {
iteration += 1;
if iteration > max_iterations {
warn!("Supervisor reached max iterations");
break;
}
let messages = memory.history().to_vec();
let response = match llm.chat(messages).await {
Ok(r) => r,
Err(e) => {
warn!(error = %e, "Supervisor LLM request failed");
continue;
}
};
let content = response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
if content.contains("FINAL:") {
let final_response = content
.split("FINAL:")
.nth(1)
.unwrap_or("")
.trim()
.to_string();
info!("Supervisor completed task: {} chars", final_response.len());
let _ = audit_log.append(
AuditEventType::AgentFinish,
"supervisor",
"Supervisor completed task coordination",
Some(serde_json::json!({
"iterations": iteration,
"subtasks_completed": subtask_results.len(),
})),
);
println!("\n🐦⬛ Supervisor Result:\n{}", final_response);
return Ok(());
}
if content.contains("SUBTASK:") {
let subtask_block = content.split("SUBTASK:").nth(1).unwrap_or("");
let subtask_lines: Vec<&str> = subtask_block.lines().take(3).collect();
let subtask_desc = subtask_lines.first().unwrap_or(&"").trim();
let agent_num = subtask_lines
.iter()
.find(|l| l.starts_with("AGENT:"))
.and_then(|l| l.split(':').nth(1))
.unwrap_or("1")
.trim();
if !subtask_desc.is_empty() {
info!("Subtask {}: {}", agent_num, subtask_desc);
let subtask_result = run_subtask_agent(
llm.clone(),
subtask_desc,
system_prompt,
&policy_engine,
&sandbox,
&audit_log,
®istry,
)
.await;
match subtask_result {
Ok(result) => {
info!("Subtask {} completed: {} chars", agent_num, result.len());
subtask_results.push(format!("Agent {} result: {}", agent_num, result));
memory.add_assistant_message(&format!(
"Decomposed subtask {}: {}",
agent_num, subtask_desc
));
memory
.add_user_message(&format!("Subtask {} result: {}", agent_num, result));
}
Err(e) => {
warn!("Subtask {} failed: {}", agent_num, e);
memory
.add_assistant_message(&format!("Subtask {} failed: {}", agent_num, e));
}
}
}
} else {
memory.add_assistant_message(&content);
}
}
if !subtask_results.is_empty() {
let aggregated = subtask_results.join("\n\n");
info!(
"Supervisor aggregated {} subtask results",
subtask_results.len()
);
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
let summary = format!(
"Supervisor completed: {} subtasks, result: {} chars",
subtask_results.len(),
aggregated.len()
);
let _ = rf.broadcast(&summary, 30).await;
info!("Supervisor result broadcast to RavenFabric mesh");
}
}
println!("\n🐦⬛ Supervisor Aggregated Result:\n{}", aggregated);
return Ok(());
}
Err(crate::error::RavenClawsError::CommandExecution(
"Supervisor mode completed without results".to_string(),
))
}
async fn run_subtask_agent(
llm: Arc<dyn LLMProviderTrait>,
subtask: &str,
system_prompt: &str,
policy_engine: &PolicyEngine,
sandbox: &Sandbox,
audit_log: &AuditLog,
registry: &ToolRegistry,
) -> Result<String> {
let mut memory = ConversationMemory::new(system_prompt, 10);
memory.add_user_message(&format!("Execute this subtask: {}", subtask));
for i in 0..5 {
let messages = memory.history().to_vec();
let response = match llm.chat(messages).await {
Ok(r) => r,
Err(e) => {
warn!(error = %e, iteration = i, "Subtask agent LLM failed");
continue;
}
};
let content = response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
if content.contains("FINAL:") || content.contains("DONE:") {
return Ok(content
.replace("FINAL:", "")
.replace("DONE:", "")
.trim()
.to_string());
}
if let Some(tool_result) =
execute_tool_call_with_security(&content, registry, policy_engine, sandbox, audit_log)
.await
{
memory.add_assistant_message(&content);
memory.add_user_message(&format!("Tool result: {}", tool_result.output));
} else {
memory.add_assistant_message(&content);
memory.add_user_message("Continue with next step.");
}
}
Ok("Subtask completed".to_string())
}
pub async fn run_single_multi(
multi_llm: MultiModelManager,
config: Config,
ravenfabric: Option<RavenFabricClient>,
) -> Result<()> {
info!(
"Starting single agent mode (multi-model) with {} providers",
multi_llm.client_count()
);
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
info!("RavenFabric remote execution available");
match rf.health().await {
Ok(true) => info!("RavenFabric mesh is healthy"),
Ok(false) => warn!("RavenFabric mesh returned unhealthy status"),
Err(e) => warn!(error = %e, "RavenFabric health check failed"),
}
}
}
let system_prompt = &config.llm.system_prompt;
let messages = vec![
ChatMessage::new("system", system_prompt.to_string()),
ChatMessage::new("user", "Ready. Awaiting instructions."),
];
let mut last_index = 0;
for i in 0..multi_llm.client_count() {
let client = if i == 0 {
multi_llm.get_client(0)
} else {
multi_llm.next_client(last_index)
};
if let Some(client) = client {
match client.chat(messages.clone()).await {
Ok(response) => {
if let Some(choice) = response.choices.first() {
info!(provider = client.provider_name(), model = client.model(), response = %choice.message.content, "Provider response received");
}
}
Err(e) => {
warn!(error = %e, provider = client.provider_name(), model = client.model(), "Provider request failed");
}
}
last_index = i;
}
}
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
let _ = rf
.broadcast("Single agent (multi-model) completed", 30)
.await;
info!("Multi-model result broadcast to RavenFabric mesh");
}
}
Ok(())
}
pub async fn run_swarm_multi(
multi_llm: MultiModelManager,
config: Config,
ravenfabric: Option<RavenFabricClient>,
) -> Result<()> {
info!(
"Starting swarm mode (multi-model) — {} parallel agents",
multi_llm.client_count()
);
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
info!("RavenFabric remote execution available for swarm coordination");
match rf.health().await {
Ok(true) => info!("RavenFabric mesh is healthy"),
Ok(false) => warn!("RavenFabric mesh returned unhealthy status"),
Err(e) => warn!(error = %e, "RavenFabric health check failed"),
}
}
}
let _system_prompt = &config.llm.system_prompt;
let num_agents = multi_llm.client_count().min(3); let mut handles = Vec::new();
let personas = [
"You are an analytical agent. Focus on logic, structure, and precision.",
"You are a creative agent. Focus on innovation, alternatives, and possibilities.",
"You are a pragmatic agent. Focus on simplicity, efficiency, and practicality.",
];
for i in 0..num_agents {
let client = multi_llm.get_client(i).unwrap().clone();
let persona = personas.get(i).unwrap_or(&personas[0]).to_string();
let task = "Analyze the given task and provide your solution.".to_string();
let handle = tokio::spawn(async move {
let mut memory = ConversationMemory::new(&persona, 10);
memory.add_user_message(&task);
let messages = memory.history().to_vec();
match client.chat(messages).await {
Ok(response) => {
let content = response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
Ok((
i,
client.provider_name().to_string(),
client.model().to_string(),
content,
))
}
Err(e) => Err(format!("Agent {} failed: {}", i, e)),
}
});
handles.push(handle);
}
let mut results: Vec<(usize, String, String, String)> = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok((idx, provider, model, result))) => {
info!(
"Agent {} ({}:{}) completed: {} chars",
idx,
provider,
model,
result.len()
);
results.push((idx, provider, model, result));
}
Ok(Err(e)) => warn!("Agent failed: {}", e),
Err(e) => warn!("Agent join failed: {}", e),
}
}
println!(
"\n🐦⬛ Swarm Results ({} agents, multi-model):",
results.len()
);
for (idx, provider, model, result) in &results {
println!("\n── Agent {} ({}:{}) ──", idx + 1, provider, model);
println!("{}", result);
}
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
let summary = format!("Multi-model swarm completed: {} agents", results.len());
let _ = rf.broadcast(&summary, 30).await;
info!("Multi-model swarm results broadcast to RavenFabric mesh");
}
}
Ok(())
}
pub async fn run_supervisor_multi(
multi_llm: MultiModelManager,
config: Config,
ravenfabric: Option<RavenFabricClient>,
) -> Result<()> {
info!(
"Starting supervisor mode (multi-model) with {} providers",
multi_llm.client_count()
);
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
info!("RavenFabric remote execution available for supervisor coordination");
match rf.health().await {
Ok(true) => info!("RavenFabric mesh is healthy"),
Ok(false) => warn!("RavenFabric mesh returned unhealthy status"),
Err(e) => warn!(error = %e, "RavenFabric health check failed"),
}
}
}
let system_prompt = &config.llm.system_prompt;
let policy_engine = PolicyEngine::default_secure();
let mut sandbox = Sandbox::default();
sandbox.init().await.map_err(|e| {
crate::error::RavenClawsError::CommandExecution(format!("Sandbox init failed: {}", e))
})?;
let audit_log = AuditLog::new(format!("supervisor-multi-{}", std::process::id()));
let registry = ToolRegistry::with_default_tools();
let supervisor_prompt = format!(
"You are a supervisor agent coordinating multiple LLM providers. \
Decompose tasks and assign them to appropriate providers based on their strengths. \
\n\nFor each subtask, respond with:\n\
SUBTASK: <description>\n\
PROVIDER: <provider_index 0-{}>\n\
\nWhen complete, respond with:\n\
FINAL: <aggregated result>\n\
\nTask: {}",
multi_llm.client_count() - 1,
"Coordinate the completion of the assigned task using available providers."
);
let mut memory = ConversationMemory::new(system_prompt, 20);
memory.add_user_message(&supervisor_prompt);
let mut subtask_results: Vec<String> = Vec::new();
let mut iteration = 0;
let max_iterations = 15;
loop {
iteration += 1;
if iteration > max_iterations {
warn!("Supervisor reached max iterations");
break;
}
let supervisor_client = multi_llm
.get_client(iteration % multi_llm.client_count())
.or_else(|| multi_llm.get_client(0))
.cloned();
let messages = memory.history().to_vec();
let response =
match supervisor_client.map(|c| tokio::spawn(async move { c.chat(messages).await })) {
Some(handle) => match handle.await {
Ok(Ok(r)) => r,
Ok(Err(e)) => {
warn!(error = %e, "Supervisor LLM request failed");
continue;
}
Err(e) => {
warn!(error = %e, "Supervisor task join failed");
continue;
}
},
None => {
warn!("No LLM clients available");
break;
}
};
let content = response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
if content.contains("FINAL:") {
let final_response = content
.split("FINAL:")
.nth(1)
.unwrap_or("")
.trim()
.to_string();
info!("Supervisor completed task: {} chars", final_response.len());
let _ = audit_log.append(
AuditEventType::AgentFinish,
"supervisor",
"Supervisor completed task coordination",
Some(serde_json::json!({
"iterations": iteration,
"subtasks_completed": subtask_results.len(),
"providers_used": multi_llm.client_count(),
})),
);
println!("\n🐦⬛ Supervisor Result (multi-model):\n{}", final_response);
return Ok(());
}
if content.contains("SUBTASK:") && content.contains("PROVIDER:") {
let subtask_block = content.split("SUBTASK:").nth(1).unwrap_or("");
let subtask_lines: Vec<&str> = subtask_block.lines().take(4).collect();
let subtask_desc = subtask_lines.first().unwrap_or(&"").trim();
let provider_idx = subtask_lines
.iter()
.find(|l| l.starts_with("PROVIDER:"))
.and_then(|l| l.split(':').nth(1))
.and_then(|s| s.trim().parse::<usize>().ok())
.unwrap_or(0);
if !subtask_desc.is_empty() {
info!("Subtask for provider {}: {}", provider_idx, subtask_desc);
let client = multi_llm
.get_client(provider_idx)
.or_else(|| multi_llm.get_client(0));
if let Some(client) = client {
let subtask_result = run_subtask_agent(
client.clone(),
subtask_desc,
system_prompt,
&policy_engine,
&sandbox,
&audit_log,
®istry,
)
.await;
match subtask_result {
Ok(result) => {
info!("Subtask {} completed: {} chars", provider_idx, result.len());
subtask_results.push(format!(
"Provider {} ({}): {}",
provider_idx,
client.provider_name(),
result
));
memory.add_assistant_message(&format!(
"Assigned subtask to provider {}: {}",
provider_idx, subtask_desc
));
memory.add_user_message(&format!(
"Provider {} result: {}",
provider_idx, result
));
}
Err(e) => {
warn!("Subtask {} failed: {}", provider_idx, e);
memory.add_assistant_message(&format!(
"Provider {} subtask failed: {}",
provider_idx, e
));
}
}
}
}
} else {
memory.add_assistant_message(&content);
}
}
if !subtask_results.is_empty() {
let aggregated = subtask_results.join("\n\n");
info!(
"Supervisor aggregated {} subtask results",
subtask_results.len()
);
if let Some(ref rf) = ravenfabric {
if rf.is_enabled() {
let summary = format!(
"Multi-model supervisor completed: {} subtasks, result: {} chars",
subtask_results.len(),
aggregated.len()
);
let _ = rf.broadcast(&summary, 30).await;
info!("Multi-model supervisor result broadcast to RavenFabric mesh");
}
}
println!(
"\n🐦⬛ Supervisor Aggregated Result (multi-model):\n{}",
aggregated
);
return Ok(());
}
Err(crate::error::RavenClawsError::CommandExecution(
"Supervisor mode completed without results".to_string(),
))
}
pub async fn run_repl(llm: Arc<dyn LLMProviderTrait>, config: Config) -> Result<()> {
use tokio::io::{AsyncBufReadExt, BufReader};
info!("Starting interactive REPL mode");
let system_prompt = &config.llm.system_prompt;
let mut memory = ConversationMemory::new(system_prompt, 0);
let stdin = BufReader::new(tokio::io::stdin());
let mut lines = stdin.lines();
println!("RavenClaws REPL — type /exit to quit, /reset to clear history");
loop {
print!("\n> ");
use tokio::io::AsyncWriteExt;
tokio::io::stdout().flush().await?;
let line = match lines.next_line().await {
Ok(Some(l)) => l,
Ok(None) => break, Err(e) => {
warn!(error = %e, "REPL read error");
break;
}
};
let input = line.trim();
if input.is_empty() {
continue;
}
match input {
"/exit" | "/quit" => {
println!("Exiting REPL.");
break;
}
"/reset" => {
memory = ConversationMemory::new(system_prompt, 0);
println!("Conversation history reset.");
continue;
}
_ => {}
}
memory.add_user_message(input);
let messages = memory.history().to_vec();
match llm.chat(messages).await {
Ok(response) => {
if let Some(choice) = response.choices.first() {
let content = &choice.message.content;
println!("{}", content);
memory.add_assistant_message(content);
}
}
Err(e) => {
warn!(error = %e, "LLM request failed");
println!("Error: {}", e);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_swarm_function_exists() {
let _fn_ptr: fn(Arc<dyn LLMProviderTrait>, Config, Option<RavenFabricClient>) -> _ =
run_swarm;
}
#[test]
fn test_supervisor_function_exists() {
let _fn_ptr: fn(Arc<dyn LLMProviderTrait>, Config, Option<RavenFabricClient>) -> _ =
run_supervisor;
}
#[test]
fn test_conversation_memory_new() {
let mem = ConversationMemory::new("system prompt", 10);
assert_eq!(mem.messages.len(), 1);
assert_eq!(mem.messages[0].role, "system");
assert_eq!(mem.messages[0].content, "system prompt");
}
#[test]
fn test_conversation_memory_add_user() {
let mut mem = ConversationMemory::new("system", 10);
mem.add_user_message("hello");
assert_eq!(mem.messages.len(), 2);
assert_eq!(mem.messages[1].role, "user");
assert_eq!(mem.messages[1].content, "hello");
}
#[test]
fn test_conversation_memory_trim() {
let mut mem = ConversationMemory::new("system", 3);
mem.add_user_message("msg1");
mem.add_assistant_message("resp1");
mem.add_user_message("msg2");
mem.add_assistant_message("resp2");
assert!(mem.messages.len() <= 3);
}
#[test]
fn test_parse_tool_call_valid() {
let content = "THOUGHT: I need to run a command\nTOOL_CALL: shell_exec\nARGS: {\"command\": \"echo hello\"}";
let (name, args) = parse_tool_call(content).unwrap();
assert_eq!(name, "shell_exec");
assert_eq!(args["command"], "echo hello");
}
#[test]
fn test_parse_tool_call_missing_tool() {
let content = "THOUGHT: no tool here";
assert!(parse_tool_call(content).is_none());
}
#[test]
fn test_parse_tool_call_missing_args() {
let content = "TOOL_CALL: shell_exec\nNo args line";
assert!(parse_tool_call(content).is_none());
}
#[test]
fn test_parse_tool_call_invalid_json() {
let content = "TOOL_CALL: shell_exec\nARGS: not valid json";
assert!(parse_tool_call(content).is_none());
}
#[test]
fn test_agent_loop_config_default() {
let config = AgentLoopConfig::default();
assert_eq!(config.max_iterations, 10);
assert!(!config.enable_tools);
assert!(!config.require_approval);
}
#[test]
fn test_agent_loop_config_require_approval() {
let config = AgentLoopConfig {
max_iterations: 5,
enable_tools: true,
require_approval: true,
prompt_injection_protection: true,
token_lifetime_secs: 0,
no_final_required: false,
fallback_chain: None,
token_budget: None,
ravenfabric: None,
checkpoint_dir: None,
session_id: None,
metrics_callback: None,
load_manager: None,
};
assert_eq!(config.max_iterations, 5);
assert!(config.enable_tools);
assert!(config.require_approval);
assert!(config.prompt_injection_protection);
assert_eq!(config.token_lifetime_secs, 0);
}
#[test]
fn test_prompt_for_approval_yes() {
let args = serde_json::json!({"command": "echo hello"});
let result = tokio_test::block_on(prompt_for_approval_with_input("shell_exec", &args, "y"));
assert!(result, "Should approve for 'y'");
}
#[test]
fn test_prompt_for_approval_yes_full() {
let args = serde_json::json!({"command": "echo hello"});
let result =
tokio_test::block_on(prompt_for_approval_with_input("shell_exec", &args, "yes"));
assert!(result, "Should approve for 'yes'");
}
#[test]
fn test_prompt_for_approval_no() {
let args = serde_json::json!({"command": "echo hello"});
let result = tokio_test::block_on(prompt_for_approval_with_input("shell_exec", &args, "n"));
assert!(!result, "Should deny for 'n'");
}
#[test]
fn test_prompt_for_approval_no_full() {
let args = serde_json::json!({"command": "echo hello"});
let result =
tokio_test::block_on(prompt_for_approval_with_input("shell_exec", &args, "no"));
assert!(!result, "Should deny for 'no'");
}
#[test]
fn test_prompt_for_approval_empty() {
let args = serde_json::json!({"command": "echo hello"});
let result = tokio_test::block_on(prompt_for_approval_with_input("shell_exec", &args, ""));
assert!(!result, "Should deny for empty input (default N)");
}
#[test]
fn test_prompt_for_approval_uppercase() {
let args = serde_json::json!({"command": "echo hello"});
let result = tokio_test::block_on(prompt_for_approval_with_input("shell_exec", &args, "Y"));
assert!(result, "Should approve for uppercase 'Y'");
}
#[test]
fn test_prompt_for_approval_auto_approves_non_tty() {
#[allow(clippy::let_underscore_future)]
let _ = prompt_for_approval_with_input("test", &serde_json::json!({}), "y");
}
#[test]
fn test_execute_parsed_tool_call_skips_approval_when_not_required() {
let registry = ToolRegistry::with_default_tools();
let policy_engine = PolicyEngine::default_secure();
let sandbox = Sandbox::default();
let audit_log = AuditLog::new("test-session".to_string());
let args = serde_json::json!({"command": "echo hello"});
let result = tokio_test::block_on(execute_parsed_tool_call(
"shell_exec".to_string(),
args,
®istry,
&policy_engine,
&sandbox,
&audit_log,
false, ));
assert!(result.is_some());
let tool_result = result.unwrap();
assert_eq!(tool_result.tool_name, "shell_exec");
}
#[test]
fn test_execute_parsed_tool_call_approval_not_needed_for_read_only_tools() {
let registry = ToolRegistry::with_default_tools();
let policy_engine = PolicyEngine::default_secure();
let sandbox = Sandbox::default();
let audit_log = AuditLog::new("test-session".to_string());
let args = serde_json::json!({"path": "/tmp/test.txt"});
let result = tokio_test::block_on(execute_parsed_tool_call(
"read_file".to_string(),
args,
®istry,
&policy_engine,
&sandbox,
&audit_log,
true, ));
assert!(result.is_some());
let tool_result = result.unwrap();
assert_eq!(tool_result.tool_name, "read_file");
}
#[test]
fn test_agent_loop_config_token_lifetime_zero_disabled() {
let config = AgentLoopConfig {
max_iterations: 10,
enable_tools: false,
require_approval: false,
prompt_injection_protection: false,
token_lifetime_secs: 0,
no_final_required: false,
fallback_chain: None,
token_budget: None,
ravenfabric: None,
checkpoint_dir: None,
session_id: None,
metrics_callback: None,
load_manager: None,
};
assert_eq!(config.token_lifetime_secs, 0);
}
#[test]
fn test_agent_loop_config_token_lifetime_nonzero() {
let config = AgentLoopConfig {
max_iterations: 10,
enable_tools: false,
require_approval: false,
prompt_injection_protection: false,
token_lifetime_secs: 3600,
no_final_required: false,
fallback_chain: None,
token_budget: None,
ravenfabric: None,
checkpoint_dir: None,
session_id: None,
metrics_callback: None,
load_manager: None,
};
assert_eq!(config.token_lifetime_secs, 3600);
}
#[test]
fn test_agent_loop_config_default_includes_token_lifetime() {
let config = AgentLoopConfig::default();
assert_eq!(config.token_lifetime_secs, 0);
}
}