pub mod backend;
mod preflight;
pub mod events;
pub mod session;
pub mod git_session;
pub use events::{
AgentEvent, FinishReason, ThinkingDeltaEvent, ToolApprovalEvent,
ToolCompleteEvent, ToolStartEvent, TokenUsageInfo, TurnEndEvent,
TurnStartEvent, SessionEndEvent,
};
use crate::config::{LlmProvider, PawanConfig};
use crate::coordinator::{CoordinatorResult, ToolCallingConfig, ToolCoordinator};
use crate::credentials;
use crate::tools::{ToolDefinition, ToolRegistry};
use crate::{PawanError, Result};
use backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
use backend::LlmBackend;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(default)]
pub tool_calls: Vec<ToolCallRequest>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_result: Option<ToolResultMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallRequest {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolResultMessage {
pub tool_call_id: String,
pub content: Value,
pub success: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub id: String,
pub name: String,
pub arguments: Value,
pub result: Value,
pub success: bool,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
pub reasoning_tokens: u64,
pub action_tokens: u64,
}
#[derive(Debug, Clone)]
pub struct LLMResponse {
pub content: String,
pub reasoning: Option<String>,
pub tool_calls: Vec<ToolCallRequest>,
pub finish_reason: String,
pub usage: Option<TokenUsage>,
}
#[derive(Debug)]
pub struct AgentResponse {
pub content: String,
pub tool_calls: Vec<ToolCallRecord>,
pub iterations: usize,
pub usage: TokenUsage,
}
pub type TokenCallback = Box<dyn Fn(&str) + Send + Sync>;
pub type ToolCallback = Box<dyn Fn(&ToolCallRecord) + Send + Sync>;
pub type ToolStartCallback = Box<dyn Fn(&str) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct PermissionRequest {
pub tool_name: String,
pub args_summary: String,
}
pub type PermissionCallback =
Box<dyn Fn(PermissionRequest) -> tokio::sync::oneshot::Receiver<bool> + Send + Sync>;
pub struct PawanAgent {
config: PawanConfig,
tools: ToolRegistry,
history: Vec<Message>,
workspace_root: PathBuf,
backend: Box<dyn LlmBackend>,
context_tokens_estimate: usize,
eruka: Option<crate::eruka_bridge::ErukaClient>,
session_id: String,
arch_context: Option<String>,
last_tool_call_time: Option<Instant>,
}
fn probe_local_endpoint(url: &str) -> bool {
use std::net::TcpStream;
use std::time::Duration;
let hostport = url
.trim_start_matches("http://")
.trim_start_matches("https://")
.split('/')
.next()
.unwrap_or("");
let addr = if hostport.contains(':') {
hostport.to_string()
} else if url.starts_with("https://") {
format!("{hostport}:443")
} else {
format!("{hostport}:80")
};
let addr = addr.replace("localhost", "127.0.0.1");
let socket_addr = match addr.parse() {
Ok(a) => a,
Err(_) => return false,
};
TcpStream::connect_timeout(&socket_addr, Duration::from_millis(100)).is_ok()
}
fn get_api_key_with_secure_fallback(env_var: &str, key_name: &str) -> Option<String> {
if let Ok(key) = std::env::var(env_var) {
return Some(key);
}
match credentials::get_api_key(key_name) {
Ok(Some(key)) => {
std::env::set_var(env_var, &key);
Some(key)
}
Ok(None) => None,
Err(e) => {
tracing::warn!("Failed to retrieve {} from secure store: {}", key_name, e);
None
}
}
}
fn prompt_and_store_api_key(env_var: &str, key_name: &str, provider: &str) -> Option<String> {
eprintln!("\n🔑 {} API key not found.", provider);
eprintln!("You can set it via:");
eprintln!(" - Environment variable: export {}=<your-key>", env_var);
eprintln!(" - Interactive entry (recommended for security)");
eprintln!("\nEnter your {} API key:", provider);
eprintln!(" (Your key will be stored securely in the OS credential store)\n");
#[cfg(unix)]
let key = {
use std::io::{self, Write};
let mut stdout = io::stdout();
stdout.flush().ok();
rpassword::prompt_password("> ").ok()
};
#[cfg(windows)]
let key = {
use std::io::{self, Write};
let mut stdout = io::stdout();
stdout.flush().ok();
rpassword::prompt_password("> ").ok()
};
#[cfg(not(any(unix, windows)))]
let key = {
use std::io::{self, Write, BufRead};
let mut stdout = io::stdout();
let mut stdin = io::stdin();
stdout.flush().ok();
print!("> ");
stdout.flush().ok();
let mut input = String::new();
stdin.lock().read_line(&mut input).ok();
Some(input.trim().to_string())
};
match key {
Some(k) if !k.trim().is_empty() => {
let key = k.trim().to_string();
match credentials::store_api_key(key_name, &key) {
Ok(()) => {
tracing::info!("{} API key stored securely", provider);
std::env::set_var(env_var, &key);
Some(key)
}
Err(e) => {
tracing::warn!("Failed to store key securely: {}. Using session-only.", e);
std::env::set_var(env_var, &key);
Some(key)
}
}
}
_ => {
eprintln!("\n⚠️ No key entered. {} will not work until a key is set.", provider);
None
}
}
}
fn load_arch_context(workspace_root: &std::path::Path) -> Option<String> {
let path = workspace_root.join(".pawan").join("arch.md");
if !path.exists() {
return None;
}
match std::fs::read_to_string(&path) {
Ok(content) if !content.trim().is_empty() => {
const MAX_CHARS: usize = 2_000;
if content.len() > MAX_CHARS {
let boundary = content
.char_indices()
.map(|(i, _)| i)
.nth(MAX_CHARS)
.unwrap_or(content.len());
Some(format!("{}…(truncated)", &content[..boundary]))
} else {
Some(content)
}
}
_ => None,
}
}
impl PawanAgent {
pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
let tools = ToolRegistry::with_defaults(workspace_root.clone());
let system_prompt = config.get_system_prompt();
let backend = Self::create_backend(&config, &system_prompt);
let eruka = if config.eruka.enabled {
Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
} else {
None
};
let arch_context = load_arch_context(&workspace_root);
Self {
config,
tools,
history: Vec::new(),
workspace_root,
backend,
context_tokens_estimate: 0,
eruka,
session_id: uuid::Uuid::new_v4().to_string(),
arch_context,
last_tool_call_time: None,
}
}
fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
if config.local_first {
let local_url = config
.local_endpoint
.clone()
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
if probe_local_endpoint(&local_url) {
tracing::info!(
url = %local_url,
model = %config.model,
"local_first: local server reachable, using local inference"
);
return Box::new(OpenAiCompatBackend::new(
backend::openai_compat::OpenAiCompatConfig {
api_url: local_url,
api_key: None,
model: config.model.clone(),
temperature: config.temperature,
top_p: config.top_p,
max_tokens: config.max_tokens,
system_prompt: system_prompt.to_string(),
use_thinking: false,
max_retries: config.max_retries,
fallback_models: Vec::new(),
cloud: None,
},
));
}
tracing::info!(
url = %local_url,
"local_first: local server unreachable, falling back to cloud provider"
);
}
if config.use_ares_backend {
if let Some(backend) = Self::try_create_ares_backend(config, system_prompt) {
return backend;
}
tracing::warn!(
"use_ares_backend=true but ares backend creation failed; \
falling back to pawan's native backend"
);
}
match config.provider {
LlmProvider::Nvidia | LlmProvider::OpenAI | LlmProvider::Mlx => {
let (api_url, api_key) = match config.provider {
LlmProvider::Nvidia => {
let url = std::env::var("NVIDIA_API_URL")
.unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
let key = get_api_key_with_secure_fallback("NVIDIA_API_KEY", "nvidia_api_key");
let key = if key.is_none() {
prompt_and_store_api_key("NVIDIA_API_KEY", "nvidia_api_key", "NVIDIA")
} else {
key
};
if key.is_none() {
tracing::warn!("NVIDIA_API_KEY not set. Model calls will fail until a key is provided.");
}
(url, key)
},
LlmProvider::OpenAI => {
let url = config.base_url.clone()
.or_else(|| std::env::var("OPENAI_API_URL").ok())
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
let key = get_api_key_with_secure_fallback("OPENAI_API_KEY", "openai_api_key");
let key = if key.is_none() {
prompt_and_store_api_key("OPENAI_API_KEY", "openai_api_key", "OpenAI")
} else {
key
};
(url, key)
},
LlmProvider::Mlx => {
let url = config.base_url.clone()
.unwrap_or_else(|| "http://localhost:8080/v1".to_string());
tracing::info!(url = %url, "Using MLX LM server (Apple Silicon native)");
(url, None) },
_ => unreachable!(),
};
let cloud = config.cloud.as_ref().map(|c| {
let (cloud_url, cloud_key) = match c.provider {
LlmProvider::Nvidia => {
let url = std::env::var("NVIDIA_API_URL")
.unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
let key = get_api_key_with_secure_fallback("NVIDIA_API_KEY", "nvidia_api_key");
(url, key)
},
LlmProvider::OpenAI => {
let url = std::env::var("OPENAI_API_URL")
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let key = get_api_key_with_secure_fallback("OPENAI_API_KEY", "openai_api_key");
(url, key)
},
LlmProvider::Mlx => {
("http://localhost:8080/v1".to_string(), None)
},
_ => {
tracing::warn!("Cloud fallback only supports nvidia/openai/mlx providers");
("https://integrate.api.nvidia.com/v1".to_string(), None)
}
};
backend::openai_compat::CloudFallback {
api_url: cloud_url,
api_key: cloud_key,
model: c.model.clone(),
fallback_models: c.fallback_models.clone(),
}
});
Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
api_url,
api_key,
model: config.model.clone(),
temperature: config.temperature,
top_p: config.top_p,
max_tokens: config.max_tokens,
system_prompt: system_prompt.to_string(),
use_thinking: config.thinking_budget == 0 && config.use_thinking_mode(),
max_retries: config.max_retries,
fallback_models: config.fallback_models.clone(),
cloud,
}))
}
LlmProvider::Ollama => {
let url = std::env::var("OLLAMA_URL")
.unwrap_or_else(|_| "http://localhost:11434".to_string());
Box::new(backend::ollama::OllamaBackend::new(
url,
config.model.clone(),
config.temperature,
system_prompt.to_string(),
))
}
}
}
fn try_create_ares_backend(
config: &PawanConfig,
system_prompt: &str,
) -> Option<Box<dyn LlmBackend>> {
use ares::llm::client::{ModelParams, Provider};
let params = ModelParams {
temperature: Some(config.temperature),
max_tokens: Some(config.max_tokens as u32),
top_p: Some(config.top_p),
frequency_penalty: None,
presence_penalty: None,
};
let provider = match config.provider {
LlmProvider::Nvidia => {
let api_base = std::env::var("NVIDIA_API_URL")
.unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
let api_key = std::env::var("NVIDIA_API_KEY").ok()?;
Provider::OpenAI {
api_key,
api_base,
model: config.model.clone(),
params,
}
}
LlmProvider::OpenAI => {
let api_base = config
.base_url
.clone()
.or_else(|| std::env::var("OPENAI_API_URL").ok())
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
Provider::OpenAI {
api_key,
api_base,
model: config.model.clone(),
params,
}
}
LlmProvider::Mlx => {
let api_base = config
.base_url
.clone()
.unwrap_or_else(|| "http://localhost:8080/v1".to_string());
Provider::OpenAI {
api_key: String::new(),
api_base,
model: config.model.clone(),
params,
}
}
LlmProvider::Ollama => {
return None;
}
};
let client: Box<dyn ares::llm::LLMClient> = match provider {
Provider::OpenAI {
api_key,
api_base,
model,
params,
} => Box::new(ares::llm::openai::OpenAIClient::with_params(
api_key, api_base, model, params,
)),
_ => return None,
};
tracing::info!(
provider = ?config.provider,
model = %config.model,
"Using ares-backed LLM backend"
);
Some(Box::new(backend::ares_backend::AresBackend::new(
client,
system_prompt.to_string(),
)))
}
pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
self.tools = tools;
self
}
pub fn tools_mut(&mut self) -> &mut ToolRegistry {
&mut self.tools
}
pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
self.backend = backend;
self
}
pub fn history(&self) -> &[Message] {
&self.history
}
pub fn save_session(&self) -> Result<String> {
let mut session = session::Session::new(&self.config.model);
session.messages = self.history.clone();
session.total_tokens = self.context_tokens_estimate as u64;
session.save()?;
Ok(session.id)
}
pub fn resume_session(&mut self, session_id: &str) -> Result<()> {
let session = session::Session::load(session_id)?;
self.history = session.messages;
self.context_tokens_estimate = session.total_tokens as usize;
self.session_id = session_id.to_string();
Ok(())
}
pub async fn archive_to_eruka(&self) -> Result<()> {
let Some(eruka) = &self.eruka else {
return Ok(());
};
let mut session = session::Session::new(&self.config.model);
session.id = self.session_id.clone();
session.messages = self.history.clone();
session.total_tokens = self.context_tokens_estimate as u64;
eruka.archive_session(&session).await
}
fn history_snapshot_for_eruka(history: &[Message]) -> String {
let mut out = String::with_capacity(2048);
for msg in history {
let prefix = match msg.role {
Role::User => "U: ",
Role::Assistant => "A: ",
Role::Tool => "T: ",
Role::System => "S: ",
};
let body: String = msg.content.chars().take(200).collect();
out.push_str(prefix);
out.push_str(&body);
out.push('\n');
if out.len() > 4000 {
break;
}
}
out
}
pub fn config(&self) -> &PawanConfig {
&self.config
}
pub fn clear_history(&mut self) {
self.history.clear();
}
fn prune_history(&mut self) {
let len = self.history.len();
if len <= 5 {
return; }
let keep_end = 4;
let start = 1; let end = len - keep_end;
let pruned_count = end - start;
let mut scored: Vec<(f32, &Message)> = self.history[start..end]
.iter()
.map(|msg| {
let score = Self::message_importance(msg);
(score, msg)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut summary = String::with_capacity(2048);
for (score, msg) in &scored {
let prefix = match msg.role {
Role::User => "User: ",
Role::Assistant => "Assistant: ",
Role::Tool => if *score > 0.7 { "Tool error: " } else { "Tool: " },
Role::System => "System: ",
};
let chunk: String = msg.content.chars().take(200).collect();
summary.push_str(prefix);
summary.push_str(&chunk);
summary.push('\n');
if summary.len() > 2000 {
let safe_end = summary.char_indices()
.take_while(|(i, _)| *i <= 2000)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0);
summary.truncate(safe_end);
break;
}
}
let summary_msg = Message {
role: Role::System,
content: format!("Previous conversation summary (pruned {} messages, importance-ranked): {}", pruned_count, summary),
tool_calls: vec![],
tool_result: None,
};
self.history.drain(start..end);
self.history.insert(start, summary_msg);
tracing::info!(pruned = pruned_count, context_estimate = self.context_tokens_estimate, "Pruned messages from history (importance-ranked)");
}
fn message_importance(msg: &Message) -> f32 {
match msg.role {
Role::User => 0.6, Role::System => 0.3, Role::Assistant => {
if msg.content.contains("error") || msg.content.contains("Error") { 0.8 }
else { 0.4 }
}
Role::Tool => {
if let Some(ref result) = msg.tool_result {
if !result.success { 0.9 } else { 0.2 } } else {
0.3
}
}
}
}
pub fn add_message(&mut self, message: Message) {
self.history.push(message);
}
pub fn switch_model(&mut self, model: &str) {
self.config.model = model.to_string();
let system_prompt = self.config.get_system_prompt();
self.backend = Self::create_backend(&self.config, &system_prompt);
tracing::info!(model = model, "Model switched at runtime");
}
pub fn model_name(&self) -> &str {
&self.config.model
}
pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools.get_definitions()
}
pub async fn execute(&mut self, user_prompt: &str) -> Result<AgentResponse> {
self.execute_with_callbacks(user_prompt, None, None, None)
.await
}
pub async fn execute_with_callbacks(
&mut self,
user_prompt: &str,
on_token: Option<TokenCallback>,
on_tool: Option<ToolCallback>,
on_tool_start: Option<ToolStartCallback>,
) -> Result<AgentResponse> {
self.execute_with_all_callbacks(user_prompt, on_token, on_tool, on_tool_start, None)
.await
}
pub async fn execute_with_all_callbacks(
&mut self,
user_prompt: &str,
on_token: Option<TokenCallback>,
on_tool: Option<ToolCallback>,
on_tool_start: Option<ToolStartCallback>,
on_permission: Option<PermissionCallback>,
) -> Result<AgentResponse> {
if self.config.use_coordinator {
if on_token.is_some() || on_tool.is_some() || on_tool_start.is_some() || on_permission.is_some() {
tracing::warn!(
"Callbacks and permission prompts are not supported in coordinator mode; ignoring them"
);
}
return self.execute_with_coordinator(user_prompt).await;
}
self.last_tool_call_time = None;
if let Some(eruka) = &self.eruka {
if let Err(e) = eruka.inject_core_memory(&mut self.history).await {
tracing::warn!("Eruka memory injection failed (non-fatal): {}", e);
}
match eruka.prefetch(user_prompt, 2000).await {
Ok(Some(ctx)) => {
self.history.push(Message {
role: Role::System,
content: ctx,
tool_calls: vec![],
tool_result: None,
});
}
Ok(None) => {}
Err(e) => tracing::warn!("Eruka prefetch failed (non-fatal): {}", e),
}
}
let effective_prompt = match &self.arch_context {
Some(ctx) => format!(
"[Workspace Architecture]\n{ctx}\n[/Workspace Architecture]\n\n{user_prompt}"
),
None => user_prompt.to_string(),
};
self.history.push(Message {
role: Role::User,
content: effective_prompt,
tool_calls: vec![],
tool_result: None,
});
let mut all_tool_calls = Vec::new();
let mut total_usage = TokenUsage::default();
let mut iterations = 0;
let max_iterations = self.config.max_tool_iterations;
loop {
if let Some(last_time) = self.last_tool_call_time {
let elapsed = last_time.elapsed().as_secs();
if elapsed > self.config.tool_call_idle_timeout_secs {
return Err(PawanError::Agent(format!(
"Tool idle timeout exceeded ({}s > {}s)",
elapsed, self.config.tool_call_idle_timeout_secs
)));
}
}
iterations += 1;
if iterations > max_iterations {
return Err(PawanError::Agent(format!(
"Max tool iterations ({}) exceeded",
max_iterations
)));
}
let remaining = max_iterations.saturating_sub(iterations);
if remaining == 3 && iterations > 1 {
self.history.push(Message {
role: Role::User,
content: format!(
"[SYSTEM] You have {} tool iterations remaining. \
Stop exploring and write the most important output now. \
If you have code to write, write it immediately.",
remaining
),
tool_calls: vec![],
tool_result: None,
});
}
self.context_tokens_estimate = self.history.iter().map(|m| m.content.len()).sum::<usize>() / 4;
if self.context_tokens_estimate > self.config.max_context_tokens {
if let Some(eruka) = &self.eruka {
let snapshot = Self::history_snapshot_for_eruka(&self.history);
if let Err(e) = eruka.on_pre_compress(&snapshot, &self.session_id).await {
tracing::warn!("Eruka on_pre_compress failed (non-fatal): {}", e);
}
}
self.prune_history();
}
let latest_query = self.history.iter().rev()
.find(|m| m.role == Role::User)
.map(|m| m.content.as_str())
.unwrap_or("");
let tool_defs = self.tools.select_for_query(latest_query, 12);
if iterations == 1 {
let tool_names: Vec<&str> = tool_defs.iter().map(|t| t.name.as_str()).collect();
tracing::info!(tools = ?tool_names, count = tool_defs.len(), "Selected tools for query");
}
self.last_tool_call_time = Some(Instant::now());
let response = {
#[allow(unused_assignments)]
let mut last_err = None;
let max_llm_retries = 3;
let mut attempt = 0;
loop {
attempt += 1;
match self.backend.generate(&self.history, &tool_defs, on_token.as_ref()).await {
Ok(resp) => break resp,
Err(e) => {
let err_str = e.to_string();
let is_transient = err_str.contains("timeout")
|| err_str.contains("connection")
|| err_str.contains("429")
|| err_str.contains("500")
|| err_str.contains("502")
|| err_str.contains("503")
|| err_str.contains("504")
|| err_str.contains("reset")
|| err_str.contains("broken pipe");
if is_transient && attempt <= max_llm_retries {
let delay = std::time::Duration::from_secs(2u64.pow(attempt as u32));
tracing::warn!(
attempt = attempt,
delay_secs = delay.as_secs(),
error = err_str.as_str(),
"LLM call failed (transient) — retrying"
);
tokio::time::sleep(delay).await;
if err_str.contains("context") || err_str.contains("token") {
tracing::info!("Pruning history before retry (possible context overflow)");
if let Some(eruka) = &self.eruka {
let snapshot = Self::history_snapshot_for_eruka(&self.history);
if let Err(e) = eruka.on_pre_compress(&snapshot, &self.session_id).await {
tracing::warn!("Eruka on_pre_compress failed (non-fatal): {}", e);
}
}
self.prune_history();
}
continue;
}
last_err = Some(e);
break {
tracing::error!(
attempt = attempt,
error = last_err.as_ref().map(|e| e.to_string()).unwrap_or_default().as_str(),
"LLM call failed permanently — returning error as content"
);
LLMResponse {
content: format!(
"LLM error after {} attempts: {}. The task could not be completed.",
attempt,
last_err.as_ref().map(|e| e.to_string()).unwrap_or_default()
),
reasoning: None,
tool_calls: vec![],
finish_reason: "error".to_string(),
usage: None,
}
};
}
}
}
};
if let Some(ref usage) = response.usage {
total_usage.prompt_tokens += usage.prompt_tokens;
total_usage.completion_tokens += usage.completion_tokens;
total_usage.total_tokens += usage.total_tokens;
total_usage.reasoning_tokens += usage.reasoning_tokens;
total_usage.action_tokens += usage.action_tokens;
if usage.reasoning_tokens > 0 {
tracing::info!(
iteration = iterations,
think = usage.reasoning_tokens,
act = usage.action_tokens,
total = usage.completion_tokens,
"Token budget: think:{} act:{} (total:{})",
usage.reasoning_tokens, usage.action_tokens, usage.completion_tokens
);
}
let thinking_budget = self.config.thinking_budget;
if thinking_budget > 0 && usage.reasoning_tokens > thinking_budget as u64 {
tracing::warn!(
budget = thinking_budget,
actual = usage.reasoning_tokens,
"Thinking budget exceeded ({}/{} tokens)",
usage.reasoning_tokens, thinking_budget
);
}
}
let clean_content = {
let mut s = response.content.clone();
loop {
let lower = s.to_lowercase();
let open = lower.find("<think>");
let close = lower.find("</think>");
match (open, close) {
(Some(i), Some(j)) if j > i => {
let before = s[..i].trim_end().to_string();
let after = if s.len() > j + 8 { s[j + 8..].trim_start().to_string() } else { String::new() };
s = if before.is_empty() { after } else if after.is_empty() { before } else { format!("{}\n{}", before, after) };
}
_ => break,
}
}
s
};
if response.tool_calls.is_empty() {
let has_tools = !tool_defs.is_empty();
let lower = clean_content.to_lowercase();
let planning_prefix = lower.starts_with("let me")
|| lower.starts_with("i'll help")
|| lower.starts_with("i will help")
|| lower.starts_with("sure, i")
|| lower.starts_with("okay, i");
let looks_like_planning = clean_content.len() > 200 || (planning_prefix && clean_content.len() > 50);
if has_tools && looks_like_planning && iterations == 1 && iterations < max_iterations && response.finish_reason != "error" {
tracing::warn!(
"No tool calls at iteration {} (content: {}B) — nudging model to use tools",
iterations, clean_content.len()
);
self.history.push(Message {
role: Role::Assistant,
content: clean_content.clone(),
tool_calls: vec![],
tool_result: None,
});
self.history.push(Message {
role: Role::User,
content: "You must use tools to complete this task. Do NOT just describe what you would do — actually call the tools. Start with bash or read_file.".to_string(),
tool_calls: vec![],
tool_result: None,
});
continue;
}
if iterations > 1 {
let prev_assistant = self.history.iter().rev()
.find(|m| m.role == Role::Assistant && !m.content.is_empty());
if let Some(prev) = prev_assistant {
if prev.content.trim() == clean_content.trim() && iterations < max_iterations {
tracing::warn!("Repeated response detected at iteration {} — injecting correction", iterations);
self.history.push(Message {
role: Role::Assistant,
content: clean_content.clone(),
tool_calls: vec![],
tool_result: None,
});
self.history.push(Message {
role: Role::User,
content: "You gave the same response as before. Try a different approach. Use anchor_text in edit_file_lines, or use insert_after, or use bash with sed.".to_string(),
tool_calls: vec![],
tool_result: None,
});
continue;
}
}
}
self.history.push(Message {
role: Role::Assistant,
content: clean_content.clone(),
tool_calls: vec![],
tool_result: None,
});
if let Some(eruka) = &self.eruka {
if let Err(e) = eruka
.sync_turn(user_prompt, &clean_content, &self.session_id)
.await
{
tracing::warn!("Eruka sync_turn failed (non-fatal): {}", e);
}
}
return Ok(AgentResponse {
content: clean_content,
tool_calls: all_tool_calls,
iterations,
usage: total_usage,
});
}
self.history.push(Message {
role: Role::Assistant,
content: response.content.clone(),
tool_calls: response.tool_calls.clone(),
tool_result: None,
});
for tool_call in &response.tool_calls {
self.tools.activate(&tool_call.name);
let perm = crate::config::ToolPermission::resolve(
&tool_call.name, &self.config.permissions
);
let denied = match perm {
crate::config::ToolPermission::Deny => Some("Tool denied by permission policy"),
crate::config::ToolPermission::Prompt => {
if tool_call.name == "bash" {
if let Some(cmd) = tool_call.arguments.get("command").and_then(|v| v.as_str()) {
if crate::tools::bash::is_read_only(cmd) {
tracing::debug!(command = cmd, "Auto-allowing read-only bash command under Prompt permission");
None
} else if let Some(ref perm_cb) = on_permission {
let args_summary = cmd.chars().take(120).collect::<String>();
let rx = perm_cb(PermissionRequest {
tool_name: tool_call.name.clone(),
args_summary,
});
match rx.await {
Ok(true) => None,
_ => Some("User denied tool execution"),
}
} else {
Some("Bash command requires user approval (read-only commands auto-allowed)")
}
} else {
Some("Tool requires user approval")
}
} else if let Some(ref perm_cb) = on_permission {
let args_summary = tool_call.arguments.to_string().chars().take(120).collect::<String>();
let rx = perm_cb(PermissionRequest {
tool_name: tool_call.name.clone(),
args_summary,
});
match rx.await {
Ok(true) => None,
_ => Some("User denied tool execution"),
}
} else {
Some("Tool requires user approval (set permission to 'allow' or use TUI mode)")
}
}
crate::config::ToolPermission::Allow => None,
};
if let Some(reason) = denied {
let record = ToolCallRecord {
id: tool_call.id.clone(),
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
result: json!({"error": reason}),
success: false,
duration_ms: 0,
};
if let Some(ref callback) = on_tool {
callback(&record);
}
all_tool_calls.push(record);
self.history.push(Message {
role: Role::Tool,
content: format!("{{\"error\": \"{}\"}}", reason),
tool_calls: vec![],
tool_result: Some(ToolResultMessage {
tool_call_id: tool_call.id.clone(),
content: json!({"error": reason}),
success: false,
}),
});
continue;
}
if let Some(ref callback) = on_tool_start {
callback(&tool_call.name);
}
tracing::debug!(
tool = tool_call.name.as_str(),
args_len = serde_json::to_string(&tool_call.arguments).unwrap_or_default().len(),
"Tool call: {}({})",
tool_call.name,
serde_json::to_string(&tool_call.arguments)
.unwrap_or_default()
.chars()
.take(200)
.collect::<String>()
);
if let Some(tool) = self.tools.get(&tool_call.name) {
let schema = tool.parameters_schema();
if let Ok(params) = thulp_core::ToolDefinition::parse_mcp_input_schema(&schema) {
let thulp_def = thulp_core::ToolDefinition {
name: tool_call.name.clone(),
description: String::new(),
parameters: params,
};
if let Err(e) = thulp_def.validate_args(&tool_call.arguments) {
tracing::warn!(
tool = tool_call.name.as_str(),
error = %e,
"Tool argument validation failed (continuing anyway)"
);
}
}
}
let start = std::time::Instant::now();
let tool = self.tools.get(&tool_call.name);
let is_mutating = tool.map(|t| t.mutating()).unwrap_or(false);
if is_mutating {
if let Some(ref callback) = on_permission {
let args_summary = summarize_args(&tool_call.arguments);
let request = PermissionRequest {
tool_name: tool_call.name.clone(),
args_summary,
};
let permission_rx = (callback)(request);
match permission_rx.await {
Ok(true) => {
}
Ok(false) => {
tracing::info!(tool = tool_call.name.as_str(), "Tool execution denied by user");
let record = ToolCallRecord {
id: tool_call.id.clone(),
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
result: json!({"error": "Tool execution denied by user", "tool": tool_call.name}),
success: false,
duration_ms: 0,
};
if let Some(ref callback) = on_tool {
callback(&record);
}
continue;
}
Err(_) => {
let record = ToolCallRecord {
id: tool_call.id.clone(),
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
result: json!({"error": "Permission channel closed", "tool": tool_call.name}),
success: false,
duration_ms: 0,
};
if let Some(ref callback) = on_tool {
callback(&record);
}
continue;
}
}
} else {
tracing::warn!(tool = tool_call.name.as_str(), "No permission callback, auto-approving mutating tool");
}
}
let result = {
let tool_future = self.tools.execute(&tool_call.name, tool_call.arguments.clone());
let timeout_dur = if tool_call.name == "bash" {
std::time::Duration::from_secs(self.config.bash_timeout_secs)
} else {
std::time::Duration::from_secs(30)
};
match tokio::time::timeout(timeout_dur, tool_future).await {
Ok(inner) => inner,
Err(_) => Err(PawanError::Tool(format!(
"Tool '{}' timed out after {}s", tool_call.name, timeout_dur.as_secs()
))),
}
};
let duration_ms = start.elapsed().as_millis() as u64;
let (result_value, success) = match result {
Ok(v) => (v, true),
Err(e) => {
tracing::warn!(tool = tool_call.name.as_str(), error = %e, "Tool execution failed");
(json!({"error": e.to_string(), "tool": tool_call.name, "hint": "Try a different approach or tool"}), false)
}
};
let max_result_chars = self.config.max_result_chars;
let result_value = truncate_tool_result(result_value, max_result_chars);
let record = ToolCallRecord {
id: tool_call.id.clone(),
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
result: result_value.clone(),
success,
duration_ms,
};
if let Some(ref callback) = on_tool {
callback(&record);
}
all_tool_calls.push(record);
self.history.push(Message {
role: Role::Tool,
content: serde_json::to_string(&result_value).unwrap_or_default(),
tool_calls: vec![],
tool_result: Some(ToolResultMessage {
tool_call_id: tool_call.id.clone(),
content: result_value,
success,
}),
});
if success && tool_call.name == "write_file" {
let wrote_rs = tool_call.arguments.get("path")
.and_then(|p| p.as_str())
.map(|p| p.ends_with(".rs"))
.unwrap_or(false);
if wrote_rs {
let ws = self.workspace_root.clone();
let check_result = tokio::process::Command::new("cargo")
.arg("check")
.arg("--message-format=short")
.current_dir(&ws)
.output()
.await;
match check_result {
Ok(output) if !output.status.success() => {
let stderr = String::from_utf8_lossy(&output.stderr);
let err_msg: String = stderr.chars().take(1500).collect();
tracing::info!("Compile-gate: cargo check failed after write_file, injecting errors");
self.history.push(Message {
role: Role::User,
content: format!(
"[SYSTEM] cargo check failed after your write_file. Fix the errors:\n```\n{}\n```",
err_msg
),
tool_calls: vec![],
tool_result: None,
});
}
Ok(_) => {
tracing::debug!("Compile-gate: cargo check passed");
}
Err(e) => {
tracing::warn!("Compile-gate: cargo check failed to run: {}", e);
}
}
}
}
}
}
}
async fn execute_with_coordinator(&mut self, user_prompt: &str) -> Result<AgentResponse> {
self.last_tool_call_time = None;
if let Some(eruka) = &self.eruka {
if let Err(e) = eruka.inject_core_memory(&mut self.history).await {
tracing::warn!("Eruka memory injection failed (non-fatal): {}", e);
}
match eruka.prefetch(user_prompt, 2000).await {
Ok(Some(ctx)) => {
self.history.push(Message {
role: Role::System,
content: ctx,
tool_calls: vec![],
tool_result: None,
});
}
Ok(None) => {}
Err(e) => tracing::warn!("Eruka prefetch failed (non-fatal): {}", e),
}
}
let effective_prompt = match &self.arch_context {
Some(ctx) => format!(
"[Workspace Architecture]\n{ctx}\n[/Workspace Architecture]\n\n{user_prompt}"
),
None => user_prompt.to_string(),
};
let coordinator_config = ToolCallingConfig {
max_iterations: self.config.max_tool_iterations,
parallel_execution: true,
tool_timeout: std::time::Duration::from_secs(self.config.bash_timeout_secs),
stop_on_error: false,
};
let system_prompt = self.config.get_system_prompt();
let backend = Self::create_backend(&self.config, &system_prompt);
let backend = Arc::from(backend);
let registry = Arc::new(ToolRegistry::with_defaults(self.workspace_root.clone()));
let coordinator = ToolCoordinator::new(backend, registry, coordinator_config);
let result: CoordinatorResult = coordinator
.execute(Some(&system_prompt), &effective_prompt)
.await
.map_err(|e| PawanError::Agent(format!("Coordinator execution failed: {}", e)))?;
let content = result.content.clone();
let agent_response = AgentResponse {
content: result.content,
tool_calls: result.tool_calls,
iterations: result.iterations,
usage: result.total_usage,
};
if let Some(eruka) = &self.eruka {
if let Err(e) = eruka
.sync_turn(user_prompt, &content, &self.session_id)
.await
{
tracing::warn!("Eruka sync_turn failed (non-fatal): {}", e);
}
}
Ok(agent_response)
}
pub async fn heal(&mut self) -> Result<AgentResponse> {
let healer = crate::healing::Healer::new(
self.workspace_root.clone(),
self.config.healing.clone(),
);
let diagnostics = healer.get_diagnostics().await?;
let failed_tests = healer.get_failed_tests().await?;
let mut prompt = format!(
"I need you to heal this Rust project at: {}
",
self.workspace_root.display()
);
if !diagnostics.is_empty() {
prompt.push_str(&format!(
"## Compilation Issues ({} found)
{}
",
diagnostics.len(),
healer.format_diagnostics_for_prompt(&diagnostics)
));
}
if !failed_tests.is_empty() {
prompt.push_str(&format!(
"## Failed Tests ({} found)
{}
",
failed_tests.len(),
healer.format_tests_for_prompt(&failed_tests)
));
}
if diagnostics.is_empty() && failed_tests.is_empty() {
prompt.push_str("No issues found! Run cargo check and cargo test to verify.
");
}
prompt.push_str("
Fix each issue one at a time. Verify with cargo check after each fix.");
self.execute(&prompt).await
}
pub async fn heal_with_retries(&mut self, max_attempts: usize) -> Result<AgentResponse> {
use std::collections::{HashMap, HashSet};
let mut last_response = self.heal().await?;
let mut stuck_counts: HashMap<u64, usize> = HashMap::new();
for attempt in 1..max_attempts {
let fixer = crate::healing::CompilerFixer::new(self.workspace_root.clone());
let remaining = fixer.check().await?;
let errors: Vec<_> = remaining
.iter()
.filter(|d| d.kind == crate::healing::DiagnosticKind::Error)
.collect();
if !errors.is_empty() {
let current_fps: HashSet<u64> = errors.iter().map(|d| d.fingerprint()).collect();
stuck_counts.retain(|fp, _| current_fps.contains(fp));
for fp in ¤t_fps {
*stuck_counts.entry(*fp).or_insert(0) += 1;
}
let thrashing: Vec<u64> = stuck_counts
.iter()
.filter_map(|(&fp, &count)| if count >= max_attempts { Some(fp) } else { None })
.collect();
if !thrashing.is_empty() {
tracing::warn!(
stuck_fingerprints = thrashing.len(),
attempt,
"Anti-thrash: {} error(s) unchanged after {} attempts, halting heal loop",
thrashing.len(),
max_attempts
);
return Ok(last_response);
}
tracing::warn!(
errors = errors.len(),
attempt,
"Stage 1 (cargo check): errors remain, retrying"
);
last_response = self.heal().await?;
continue;
}
stuck_counts.clear();
let verify_cmd = self.config.healing.verify_cmd.clone();
if let Some(ref cmd) = verify_cmd {
match crate::healing::run_verify_cmd(&self.workspace_root, cmd).await {
Ok(None) => {
tracing::info!(attempts = attempt, "Stage 2 (verify_cmd) passed, healing complete");
return Ok(last_response);
}
Ok(Some(diag)) => {
tracing::warn!(
attempt,
cmd,
output = diag.raw,
"Stage 2 (verify_cmd) failed, retrying"
);
last_response = self.heal().await?;
continue;
}
Err(e) => {
tracing::warn!(cmd, error = %e, "verify_cmd could not be run, skipping stage 2");
return Ok(last_response);
}
}
} else {
tracing::info!(attempts = attempt, "Stage 1 (cargo check) passed, healing complete");
return Ok(last_response);
}
}
tracing::info!(attempts = max_attempts, "Healing finished (may still have errors)");
Ok(last_response)
}
pub async fn task(&mut self, task_description: &str) -> Result<AgentResponse> {
let prompt = format!(
r#"I need you to complete the following coding task:
{}
The workspace is at: {}
Please:
1. First explore the codebase to understand the relevant code
2. Make the necessary changes
3. Verify the changes compile with `cargo check`
4. Run relevant tests if applicable
Explain your changes as you go."#,
task_description,
self.workspace_root.display()
);
self.execute(&prompt).await
}
pub async fn generate_commit_message(&mut self) -> Result<String> {
let prompt = r#"Please:
1. Run `git status` to see what files are changed
2. Run `git diff --cached` to see staged changes (or `git diff` for unstaged)
3. Generate a concise, descriptive commit message following conventional commits format
Only output the suggested commit message, nothing else."#;
let response = self.execute(prompt).await?;
Ok(response.content)
}
}
fn truncate_tool_result(value: Value, max_chars: usize) -> Value {
let serialized = serde_json::to_string(&value).unwrap_or_default();
if serialized.len() <= max_chars {
return value;
}
match value {
Value::Object(map) => {
let mut result = serde_json::Map::new();
let total = serialized.len();
for (k, v) in map {
if let Value::String(s) = &v {
if s.len() > 500 {
let target = s.len() * max_chars / total;
let target = target.max(200); let truncated: String = s.chars().take(target).collect();
result.insert(k, json!(format!("{}...[truncated from {} chars]", truncated, s.len())));
continue;
}
}
result.insert(k, truncate_tool_result(v, max_chars));
}
Value::Object(result)
}
Value::String(s) if s.len() > max_chars => {
let truncated: String = s.chars().take(max_chars).collect();
json!(format!("{}...[truncated from {} chars]", truncated, s.len()))
}
Value::Array(arr) if serialized.len() > max_chars => {
let mut result = Vec::new();
let mut running_len = 2; for item in arr {
let item_str = serde_json::to_string(&item).unwrap_or_default();
running_len += item_str.len() + 1; if running_len > max_chars {
result.push(json!(format!("...[{} more items truncated]", 0)));
break;
}
result.push(item);
}
Value::Array(result)
}
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::agent::backend::mock::{MockBackend, MockResponse};
#[test]
fn test_message_serialization() {
let msg = Message {
role: Role::User,
content: "Hello".to_string(),
tool_calls: vec![],
tool_result: None,
};
let json = serde_json::to_string(&msg).expect("Serialization failed");
assert!(json.contains("user"));
assert!(json.contains("Hello"));
}
#[test]
fn test_tool_call_request() {
let tc = ToolCallRequest {
id: "123".to_string(),
name: "read_file".to_string(),
arguments: json!({"path": "test.txt"}),
};
let json = serde_json::to_string(&tc).expect("Serialization failed");
assert!(json.contains("read_file"));
assert!(json.contains("test.txt"));
}
fn agent_with_messages(n: usize) -> PawanAgent {
let config = PawanConfig::default();
let mut agent = PawanAgent::new(config, PathBuf::from("."));
agent.add_message(Message {
role: Role::System,
content: "System prompt".to_string(),
tool_calls: vec![],
tool_result: None,
});
for i in 1..n {
agent.add_message(Message {
role: if i % 2 == 1 { Role::User } else { Role::Assistant },
content: format!("Message {}", i),
tool_calls: vec![],
tool_result: None,
});
}
assert_eq!(agent.history().len(), n);
agent
}
#[test]
fn test_prune_history_no_op_when_small() {
let mut agent = agent_with_messages(5);
agent.prune_history();
assert_eq!(agent.history().len(), 5, "Should not prune <= 5 messages");
}
#[test]
fn test_prune_history_reduces_messages() {
let mut agent = agent_with_messages(12);
assert_eq!(agent.history().len(), 12);
agent.prune_history();
assert_eq!(agent.history().len(), 6);
}
#[test]
fn test_prune_history_preserves_system_prompt() {
let mut agent = agent_with_messages(10);
let original_system = agent.history()[0].content.clone();
agent.prune_history();
assert_eq!(agent.history()[0].content, original_system, "System prompt must survive pruning");
}
#[test]
fn test_prune_history_preserves_last_messages() {
let mut agent = agent_with_messages(10);
let last4: Vec<String> = agent.history()[6..10].iter().map(|m| m.content.clone()).collect();
agent.prune_history();
let after_last4: Vec<String> = agent.history()[2..6].iter().map(|m| m.content.clone()).collect();
assert_eq!(last4, after_last4, "Last 4 messages must be preserved after pruning");
}
#[test]
fn test_prune_history_inserts_summary() {
let mut agent = agent_with_messages(10);
agent.prune_history();
assert_eq!(agent.history()[1].role, Role::System);
assert!(agent.history()[1].content.contains("summary"), "Summary message should contain 'summary'");
}
#[test]
fn test_prune_history_utf8_safe() {
let config = PawanConfig::default();
let mut agent = PawanAgent::new(config, PathBuf::from("."));
agent.add_message(Message {
role: Role::System, content: "sys".into(), tool_calls: vec![], tool_result: None,
});
for _ in 0..10 {
agent.add_message(Message {
role: Role::User,
content: "こんにちは世界 🌍 ".repeat(50),
tool_calls: vec![],
tool_result: None,
});
}
agent.prune_history();
assert!(agent.history().len() < 11, "Should have pruned");
let summary = &agent.history()[1].content;
assert!(summary.is_char_boundary(0));
}
#[test]
fn test_prune_history_exactly_6_messages() {
let mut agent = agent_with_messages(6);
agent.prune_history();
assert_eq!(agent.history().len(), 6);
}
#[test]
fn test_message_role_roundtrip() {
for role in [Role::User, Role::Assistant, Role::System, Role::Tool] {
let json = serde_json::to_string(&role).unwrap();
let back: Role = serde_json::from_str(&json).unwrap();
assert_eq!(role, back);
}
}
#[test]
fn test_agent_response_construction() {
let resp = AgentResponse {
content: String::new(),
tool_calls: vec![],
iterations: 3,
usage: TokenUsage::default(),
};
assert!(resp.content.is_empty());
assert!(resp.tool_calls.is_empty());
assert_eq!(resp.iterations, 3);
}
#[test]
fn test_truncate_small_result_unchanged() {
let val = json!({"success": true, "output": "hello"});
let result = truncate_tool_result(val.clone(), 8000);
assert_eq!(result, val);
}
#[test]
fn test_truncate_large_string_value() {
let big = "x".repeat(10000);
let val = json!({"stdout": big, "success": true});
let result = truncate_tool_result(val, 2000);
let stdout = result["stdout"].as_str().unwrap();
assert!(stdout.len() < 10000, "Should be truncated");
assert!(stdout.contains("truncated"), "Should indicate truncation");
}
#[test]
fn test_truncate_preserves_valid_json() {
let big = "x".repeat(20000);
let val = json!({"data": big, "meta": "keep"});
let result = truncate_tool_result(val, 5000);
let serialized = serde_json::to_string(&result).unwrap();
let _reparsed: Value = serde_json::from_str(&serialized).unwrap();
assert_eq!(result["meta"], "keep");
}
#[test]
fn test_truncate_bare_string() {
let big = json!("x".repeat(10000));
let result = truncate_tool_result(big, 500);
let s = result.as_str().unwrap();
assert!(s.len() <= 600); assert!(s.contains("truncated"));
}
#[test]
fn test_truncate_array() {
let items: Vec<Value> = (0..1000).map(|i| json!(format!("item_{}", i))).collect();
let val = Value::Array(items);
let result = truncate_tool_result(val, 500);
let arr = result.as_array().unwrap();
assert!(arr.len() < 1000, "Array should be truncated");
}
#[test]
fn test_importance_failed_tool_highest() {
let msg = Message {
role: Role::Tool,
content: "error".into(),
tool_calls: vec![],
tool_result: Some(ToolResultMessage {
tool_call_id: "1".into(),
content: json!({"error": "failed"}),
success: false,
}),
};
assert!(PawanAgent::message_importance(&msg) > 0.8, "Failed tools should be high importance");
}
#[test]
fn test_importance_successful_tool_lowest() {
let msg = Message {
role: Role::Tool,
content: "ok".into(),
tool_calls: vec![],
tool_result: Some(ToolResultMessage {
tool_call_id: "1".into(),
content: json!({"success": true}),
success: true,
}),
};
assert!(PawanAgent::message_importance(&msg) < 0.3, "Successful tools should be low importance");
}
#[test]
fn test_importance_user_medium() {
let msg = Message { role: Role::User, content: "hello".into(), tool_calls: vec![], tool_result: None };
let score = PawanAgent::message_importance(&msg);
assert!(score > 0.4 && score < 0.8, "User messages should be medium: {}", score);
}
#[test]
fn test_importance_error_assistant_high() {
let msg = Message { role: Role::Assistant, content: "Error: something failed".into(), tool_calls: vec![], tool_result: None };
assert!(PawanAgent::message_importance(&msg) > 0.7, "Error assistant messages should be high importance");
}
#[test]
fn test_importance_ordering() {
let failed_tool = Message { role: Role::Tool, content: "err".into(), tool_calls: vec![], tool_result: Some(ToolResultMessage { tool_call_id: "1".into(), content: json!({}), success: false }) };
let user = Message { role: Role::User, content: "hi".into(), tool_calls: vec![], tool_result: None };
let ok_tool = Message { role: Role::Tool, content: "ok".into(), tool_calls: vec![], tool_result: Some(ToolResultMessage { tool_call_id: "2".into(), content: json!({}), success: true }) };
let f = PawanAgent::message_importance(&failed_tool);
let u = PawanAgent::message_importance(&user);
let s = PawanAgent::message_importance(&ok_tool);
assert!(f > u && u > s, "Ordering should be: failed({}) > user({}) > success({})", f, u, s);
}
#[test]
fn test_agent_clear_history_removes_all() {
let mut agent = agent_with_messages(8);
assert_eq!(agent.history().len(), 8);
agent.clear_history();
assert_eq!(agent.history().len(), 0, "clear_history should drop every message");
}
#[test]
fn test_agent_add_message_appends_in_order() {
let config = PawanConfig::default();
let mut agent = PawanAgent::new(config, PathBuf::from("."));
assert_eq!(agent.history().len(), 0);
let first = Message {
role: Role::User,
content: "first".into(),
tool_calls: vec![],
tool_result: None,
};
let second = Message {
role: Role::Assistant,
content: "second".into(),
tool_calls: vec![],
tool_result: None,
};
agent.add_message(first);
agent.add_message(second);
assert_eq!(agent.history().len(), 2);
assert_eq!(agent.history()[0].content, "first");
assert_eq!(agent.history()[1].content, "second");
assert_eq!(agent.history()[0].role, Role::User);
assert_eq!(agent.history()[1].role, Role::Assistant);
}
#[test]
fn test_agent_switch_model_updates_name() {
let config = PawanConfig::default();
let mut agent = PawanAgent::new(config, PathBuf::from("."));
let original = agent.model_name().to_string();
agent.switch_model("gpt-oss-120b");
assert_eq!(agent.model_name(), "gpt-oss-120b");
assert_ne!(
agent.model_name(),
original,
"switch_model should change model_name"
);
}
#[test]
fn test_agent_with_tools_replaces_registry() {
let config = PawanConfig::default();
let agent = PawanAgent::new(config, PathBuf::from("."));
let original_tool_count = agent.get_tool_definitions().len();
let empty = ToolRegistry::new();
let agent = agent.with_tools(empty);
assert_eq!(
agent.get_tool_definitions().len(),
0,
"with_tools(empty) should drop default registry (had {} tools)",
original_tool_count
);
}
#[test]
fn test_agent_get_tool_definitions_returns_deterministic_set() {
let config = PawanConfig::default();
let agent_a = PawanAgent::new(config.clone(), PathBuf::from("."));
let agent_b = PawanAgent::new(config, PathBuf::from("."));
let defs_a: Vec<String> = agent_a.get_tool_definitions().iter().map(|d| d.name.clone()).collect();
let defs_b: Vec<String> = agent_b.get_tool_definitions().iter().map(|d| d.name.clone()).collect();
assert!(!defs_a.is_empty(), "default agent should have tools");
assert_eq!(defs_a.len(), defs_b.len(), "two default agents must have same tool count");
let names: Vec<&str> = defs_a.iter().map(|s| s.as_str()).collect();
assert!(names.contains(&"read_file"), "should have read_file in defaults");
assert!(names.contains(&"bash"), "should have bash in defaults");
}
#[test]
fn test_truncate_empty_object_unchanged() {
let val = json!({});
let result = truncate_tool_result(val.clone(), 10);
assert_eq!(result, val);
}
#[test]
fn test_truncate_null_value_unchanged() {
let val = Value::Null;
let result = truncate_tool_result(val.clone(), 10);
assert_eq!(result, val);
}
#[test]
fn test_truncate_numeric_values_pass_through() {
let val = json!({"count": 42, "ratio": 2.5, "enabled": true});
let result = truncate_tool_result(val.clone(), 8000);
assert_eq!(result, val);
}
#[test]
fn test_truncate_large_string_is_utf8_safe() {
let emoji_heavy = "🦀".repeat(3000);
let val = json!({"crabs": emoji_heavy});
let result = truncate_tool_result(val, 1000);
let out = result["crabs"].as_str().unwrap();
assert!(out.contains("truncated"), "truncation marker must be present");
assert!(out.starts_with('🦀'), "must preserve char boundary");
}
#[test]
fn test_truncate_nested_object_remains_valid_json() {
let inner_big = "y".repeat(5000);
let val = json!({
"meta": "small",
"nested": { "inner": inner_big }
});
let result = truncate_tool_result(val, 1500);
assert_eq!(result["meta"], "small");
let serialized = serde_json::to_string(&result).unwrap();
let _reparsed: Value = serde_json::from_str(&serialized)
.expect("truncated result must be valid JSON");
}
#[test]
fn test_truncate_short_bare_string_unchanged() {
let val = json!("short string");
let result = truncate_tool_result(val.clone(), 1000);
assert_eq!(result, val);
}
#[test]
fn test_session_id_is_unique_per_agent() {
let a1 = PawanAgent::new(PawanConfig::default(), PathBuf::from("."));
let a2 = PawanAgent::new(PawanConfig::default(), PathBuf::from("."));
assert_ne!(a1.session_id, a2.session_id);
assert!(!a1.session_id.is_empty());
assert_eq!(a1.session_id.len(), 36);
}
#[test]
fn test_resume_session_adopts_loaded_id() {
use std::io::Write;
let tmp = tempfile::TempDir::new().unwrap();
let sess_dir = tmp.path().join(".pawan").join("sessions");
std::fs::create_dir_all(&sess_dir).unwrap();
let sess_id = "resume-test-xyz";
let sess_path = sess_dir.join(format!("{}.json", sess_id));
let sess_json = serde_json::json!({
"id": sess_id,
"model": "test-model",
"created_at": "2026-04-11T00:00:00Z",
"updated_at": "2026-04-11T00:00:00Z",
"messages": [],
"total_tokens": 0,
"iteration_count": 0
});
let mut f = std::fs::File::create(&sess_path).unwrap();
f.write_all(sess_json.to_string().as_bytes()).unwrap();
let prev_home = std::env::var("HOME").ok();
std::env::set_var("HOME", tmp.path());
let mut agent = PawanAgent::new(PawanConfig::default(), PathBuf::from("."));
let orig_id = agent.session_id.clone();
agent.resume_session(sess_id).expect("resume should succeed");
assert_eq!(agent.session_id, sess_id);
assert_ne!(agent.session_id, orig_id);
if let Some(h) = prev_home {
std::env::set_var("HOME", h);
} else {
std::env::remove_var("HOME");
}
}
#[test]
fn test_history_snapshot_for_eruka_bounded() {
let mut history = Vec::new();
for i in 0..100 {
history.push(Message {
role: if i % 2 == 0 { Role::User } else { Role::Assistant },
content: "x".repeat(500),
tool_calls: vec![],
tool_result: None,
});
}
let snapshot = PawanAgent::history_snapshot_for_eruka(&history);
assert!(snapshot.len() <= 4400, "snapshot too long: {} chars", snapshot.len());
assert!(snapshot.len() > 200, "snapshot too short: {} chars", snapshot.len());
}
#[test]
fn test_history_snapshot_for_eruka_includes_role_prefixes() {
let history = vec![
Message { role: Role::User, content: "hi".into(), tool_calls: vec![], tool_result: None },
Message { role: Role::Assistant, content: "hello".into(), tool_calls: vec![], tool_result: None },
Message { role: Role::Tool, content: "ok".into(), tool_calls: vec![], tool_result: None },
Message { role: Role::System, content: "sys".into(), tool_calls: vec![], tool_result: None },
];
let snapshot = PawanAgent::history_snapshot_for_eruka(&history);
assert!(snapshot.contains("U: hi"));
assert!(snapshot.contains("A: hello"));
assert!(snapshot.contains("T: ok"));
assert!(snapshot.contains("S: sys"));
}
#[tokio::test]
async fn test_archive_to_eruka_ok_when_disabled() {
let agent = PawanAgent::new(PawanConfig::default(), PathBuf::from("."));
assert!(agent.eruka.is_none(), "default config should disable eruka");
let result = agent.archive_to_eruka().await;
assert!(result.is_ok(), "archive_to_eruka should be non-fatal when disabled");
}
#[test]
fn test_probe_local_endpoint_closed_port_returns_false() {
assert!(
!probe_local_endpoint("http://localhost:1999/v1"),
"closed port should return false"
);
}
#[test]
fn test_probe_local_endpoint_open_port_returns_true() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed");
let port = listener.local_addr().unwrap().port();
let url = format!("http://localhost:{port}/v1");
assert!(probe_local_endpoint(&url), "open port should return true");
}
#[test]
fn test_probe_local_endpoint_url_without_explicit_port() {
let _ = probe_local_endpoint("http://localhost/v1");
}
#[test]
fn test_load_arch_context_absent_returns_none() {
let dir = tempfile::TempDir::new().unwrap();
assert!(load_arch_context(dir.path()).is_none());
}
#[test]
fn test_load_arch_context_reads_file_content() {
let dir = tempfile::TempDir::new().unwrap();
let pawan_dir = dir.path().join(".pawan");
std::fs::create_dir_all(&pawan_dir).unwrap();
std::fs::write(pawan_dir.join("arch.md"), "## Architecture\nUse tokio.\n").unwrap();
let result = load_arch_context(dir.path());
assert!(result.is_some());
assert!(result.unwrap().contains("Use tokio"));
}
#[test]
fn test_load_arch_context_empty_file_returns_none() {
let dir = tempfile::TempDir::new().unwrap();
let pawan_dir = dir.path().join(".pawan");
std::fs::create_dir_all(&pawan_dir).unwrap();
std::fs::write(pawan_dir.join("arch.md"), " \n").unwrap();
assert!(load_arch_context(dir.path()).is_none(), "whitespace-only file should be None");
}
#[test]
fn test_load_arch_context_truncates_at_2000_chars() {
let dir = tempfile::TempDir::new().unwrap();
let pawan_dir = dir.path().join(".pawan");
std::fs::create_dir_all(&pawan_dir).unwrap();
let content = "x".repeat(2_500);
std::fs::write(pawan_dir.join("arch.md"), &content).unwrap();
let result = load_arch_context(dir.path()).unwrap();
assert!(
result.len() < 2_100,
"truncated result should be close to 2000 chars, got {}",
result.len()
);
assert!(result.ends_with("(truncated)"), "truncated output must end with marker");
}
#[tokio::test]
async fn test_tool_idle_timeout_triggered() {
use std::time::Duration;
use tokio::time::sleep;
let mut config = PawanConfig::default();
config.tool_call_idle_timeout_secs = 0;
struct SlowBackend {
index: Arc<std::sync::atomic::AtomicUsize>,
}
#[async_trait::async_trait]
impl LlmBackend for SlowBackend {
async fn generate(&self, _m: &[Message], _t: &[ToolDefinition], _o: Option<&TokenCallback>) -> Result<LLMResponse> {
let idx = self.index.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if idx == 0 {
Ok(LLMResponse {
content: String::new(),
reasoning: None,
tool_calls: vec![ToolCallRequest {
id: "1".to_string(),
name: "read_file".to_string(),
arguments: json!({"path": "foo"}),
}],
finish_reason: "tool_calls".to_string(),
usage: None,
})
} else if idx == 1 {
sleep(Duration::from_millis(1100)).await;
Ok(LLMResponse {
content: String::new(),
reasoning: None,
tool_calls: vec![ToolCallRequest {
id: "2".to_string(),
name: "read_file".to_string(),
arguments: json!({"path": "bar"}),
}],
finish_reason: "tool_calls".to_string(),
usage: None,
})
} else {
Ok(LLMResponse {
content: "Done".to_string(),
reasoning: None,
tool_calls: vec![],
finish_reason: "stop".to_string(),
usage: None,
})
}
}
}
let mut agent = PawanAgent::new(config, PathBuf::from("."));
agent.backend = Box::new(SlowBackend { index: Arc::new(std::sync::atomic::AtomicUsize::new(0)) });
let result = agent.execute_with_all_callbacks("test", None, None, None, None).await;
match result {
Err(PawanError::Agent(msg)) => {
assert!(msg.contains("Tool idle timeout exceeded"), "Error message should contain timeout: {}", msg);
}
Ok(_) => panic!("Expected timeout error, but it succeeded. This means the timeout check didn't catch the delay."),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
#[tokio::test]
async fn test_tool_idle_timeout_not_triggered() {
let mut config = PawanConfig::default();
config.tool_call_idle_timeout_secs = 10;
let backend = MockBackend::new(vec![
MockResponse::text("Done"),
]);
let mut agent = PawanAgent::new(config, PathBuf::from("."));
agent.backend = Box::new(backend);
let result = agent.execute_with_all_callbacks("test", None, None, None, None).await;
assert!(result.is_ok());
}
}
fn summarize_args(args: &serde_json::Value) -> String {
match args {
serde_json::Value::Object(map) => {
let mut parts = Vec::new();
for (key, value) in map {
let value_str = match value {
serde_json::Value::String(s) if s.len() > 50 => {
format!("\"{}...\"", &s[..47])
}
serde_json::Value::String(s) => format!("\"{}\"", s),
serde_json::Value::Array(arr) if arr.len() > 3 => {
format!("[... {} items]", arr.len())
}
serde_json::Value::Array(arr) => {
let items: Vec<String> = arr.iter().take(3).map(|v| {
match v {
serde_json::Value::String(s) => {
if s.len() > 20 {
format!("\"{}...\"", &s[..17])
} else {
format!("\"{}\"", s)
}
}
_ => v.to_string(),
}
}).collect();
format!("[{}]", items.join(", "))
}
_ => value.to_string(),
};
parts.push(format!("{}: {}", key, value_str));
}
parts.join(", ")
}
serde_json::Value::String(s) => {
if s.len() > 100 {
format!("\"{}...\"", &s[..97])
} else {
format!("\"{}\"", s)
}
}
serde_json::Value::Array(arr) => {
format!("[{} items]", arr.len())
}
_ => args.to_string(),
}
}
#[cfg(test)]
mod coordinator_tests {
use super::*;
use crate::agent::backend::mock::{MockBackend, MockResponse};
use crate::coordinator::{FinishReason, ToolCallingConfig};
use std::sync::Arc;
#[test]
fn test_config_default_use_coordinator_false() {
let config = PawanConfig::default();
assert!(!config.use_coordinator);
}
#[test]
fn test_config_use_coordinator_true() {
let config = PawanConfig {
use_coordinator: true,
..Default::default()
};
assert!(config.use_coordinator);
}
#[tokio::test]
async fn test_execute_with_coordinator_flag_enabled() {
let config = PawanConfig {
use_coordinator: true,
model: "test-model".to_string(),
..Default::default()
};
let agent = PawanAgent::new(config, PathBuf::from("."));
assert!(agent.config().use_coordinator);
}
#[tokio::test]
async fn test_execute_with_coordinator_produces_response() {
let config = PawanConfig {
use_coordinator: true,
max_tool_iterations: 1,
model: "test-model".to_string(),
..Default::default()
};
let agent = PawanAgent::new(config, PathBuf::from("."));
let backend = MockBackend::with_text("Hello from coordinator!");
let mut agent = agent.with_backend(Box::new(backend));
assert!(agent.config().use_coordinator);
}
#[test]
fn test_tool_calling_config_defaults() {
let cfg = ToolCallingConfig::default();
assert_eq!(cfg.max_iterations, 10);
assert!(cfg.parallel_execution);
assert_eq!(cfg.tool_timeout.as_secs(), 30);
assert!(!cfg.stop_on_error);
}
#[test]
fn test_tool_calling_config_custom() {
let cfg = ToolCallingConfig {
max_iterations: 5,
parallel_execution: false,
tool_timeout: std::time::Duration::from_secs(60),
stop_on_error: true,
};
assert_eq!(cfg.max_iterations, 5);
assert!(!cfg.parallel_execution);
assert_eq!(cfg.tool_timeout.as_secs(), 60);
assert!(cfg.stop_on_error);
}
#[tokio::test]
async fn test_coordinator_dispatch_when_flag_is_false() {
let config = PawanConfig::default();
assert!(!config.use_coordinator);
}
#[tokio::test]
async fn test_coordinator_error_handling_unknown_tool() {
use crate::coordinator::ToolCoordinator;
let mock_backend = Arc::new(MockBackend::with_tool_call(
"call_1",
"nonexistent_tool",
json!({}),
"Trying to call unknown tool",
));
let registry = Arc::new(ToolRegistry::new());
let config = ToolCallingConfig::default();
let coordinator = ToolCoordinator::new(mock_backend, registry, config);
let result = coordinator.execute(None, "Use a tool").await.unwrap();
assert!(matches!(result.finish_reason, FinishReason::UnknownTool(_)));
}
#[tokio::test]
async fn test_coordinator_max_iterations_limit() {
use crate::coordinator::ToolCoordinator;
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
struct DummyTool;
#[async_trait]
impl Tool for DummyTool {
fn name(&self) -> &str { "test_tool" }
fn description(&self) -> &str { "Dummy tool for testing" }
fn parameters_schema(&self) -> serde_json::Value { json!({}) }
async fn execute(&self, _args: serde_json::Value) -> crate::Result<serde_json::Value> {
Ok(json!({ "status": "ok" }))
}
}
let mock_backend = Arc::new(MockBackend::with_repeated_tool_call("test_tool"));
let mut registry = ToolRegistry::new();
registry.register(Arc::new(DummyTool));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
max_iterations: 3,
..Default::default()
};
let coordinator = ToolCoordinator::new(mock_backend, registry, config);
let result = coordinator.execute(None, "Use tools").await.unwrap();
assert_eq!(result.iterations, 3);
assert!(matches!(result.finish_reason, FinishReason::MaxIterations));
}
#[tokio::test]
async fn test_coordinator_timeout_handling() {
use crate::coordinator::ToolCoordinator;
let mock_backend = Arc::new(MockBackend::with_tool_call(
"call_1",
"bash",
json!({"command": "sleep 10"}),
"Run slow command",
));
let registry = Arc::new(ToolRegistry::with_defaults(PathBuf::from(".")));
let config = ToolCallingConfig {
tool_timeout: std::time::Duration::from_millis(1),
..Default::default()
};
let coordinator = ToolCoordinator::new(mock_backend, registry, config);
let result = coordinator.execute(None, "Run a command").await.unwrap();
assert!(!result.tool_calls.is_empty());
let first_call = &result.tool_calls[0];
assert!(!first_call.success);
assert!(first_call.result.get("error").is_some());
}
#[tokio::test]
async fn test_coordinator_token_usage_accumulation() {
use crate::coordinator::ToolCoordinator;
let mock_backend = Arc::new(MockBackend::with_text_and_usage(
"Response",
100,
50,
));
let registry = Arc::new(ToolRegistry::new());
let config = ToolCallingConfig::default();
let coordinator = ToolCoordinator::new(mock_backend, registry, config);
let result = coordinator.execute(None, "Hello").await.unwrap();
assert_eq!(result.total_usage.prompt_tokens, 100);
assert_eq!(result.total_usage.completion_tokens, 50);
assert_eq!(result.total_usage.total_tokens, 150);
}
#[tokio::test]
async fn test_coordinator_parallel_execution() {
use crate::coordinator::ToolCoordinator;
let mock_backend = Arc::new(MockBackend::with_multiple_tool_calls(vec![
("call_1", "bash", json!({"command": "echo 1"})),
("call_2", "bash", json!({"command": "echo 2"})),
("call_3", "read_file", json!({"path": "test.txt"})),
]));
let registry = Arc::new(ToolRegistry::with_defaults(PathBuf::from(".")));
let config = ToolCallingConfig {
parallel_execution: true,
..Default::default()
};
let coordinator = ToolCoordinator::new(mock_backend, registry, config);
let result = coordinator.execute(None, "Run multiple commands").await.unwrap();
assert!(result.tool_calls.len() >= 3);
}
}