use crate::cli::is_verbose;
use crate::format::*;
use std::collections::HashMap;
use std::io::{self, IsTerminal, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use yoagent::agent::Agent;
use yoagent::*;
static WATCH_COMMAND: RwLock<Option<String>> = RwLock::new(None);
pub fn set_watch_command(cmd: &str) {
let mut guard = WATCH_COMMAND.write().unwrap();
*guard = Some(cmd.to_string());
}
pub fn get_watch_command() -> Option<String> {
let guard = WATCH_COMMAND.read().unwrap();
guard.clone()
}
pub fn clear_watch_command() {
let mut guard = WATCH_COMMAND.write().unwrap();
*guard = None;
}
static AUDIT_ENABLED: AtomicBool = AtomicBool::new(false);
#[allow(dead_code)]
pub fn enable_audit_log() {
AUDIT_ENABLED.store(true, Ordering::Relaxed);
}
pub fn is_audit_enabled() -> bool {
AUDIT_ENABLED.load(Ordering::Relaxed)
}
pub fn audit_log_tool_call(
tool_name: &str,
args: &serde_json::Value,
duration_ms: u64,
success: bool,
) {
if !is_audit_enabled() {
return;
}
let _ = write_audit_entry(tool_name, args, duration_ms, success);
}
fn write_audit_entry(
tool_name: &str,
args: &serde_json::Value,
duration_ms: u64,
success: bool,
) -> std::io::Result<()> {
let dir = std::path::Path::new(".yoyo");
std::fs::create_dir_all(dir)?;
let path = dir.join("audit.jsonl");
let mut file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)?;
let ts = std::process::Command::new("date")
.arg("+%Y-%m-%dT%H:%M:%S")
.output()
.ok()
.and_then(|o| String::from_utf8(o.stdout).ok())
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string());
let truncated_args = truncate_audit_args(args);
let entry = serde_json::json!({
"ts": ts,
"tool": tool_name,
"args": truncated_args,
"duration_ms": duration_ms,
"success": success,
});
writeln!(file, "{}", entry)?;
Ok(())
}
pub fn truncate_audit_args(args: &serde_json::Value) -> serde_json::Value {
match args {
serde_json::Value::Object(map) => {
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(k.clone(), truncate_audit_value(v));
}
serde_json::Value::Object(new_map)
}
other => other.clone(),
}
}
fn truncate_audit_value(v: &serde_json::Value) -> serde_json::Value {
match v {
serde_json::Value::String(s) if s.len() > 200 => serde_json::Value::String(format!(
"{}... [truncated, {} chars total]",
&s[..200],
s.len()
)),
other => other.clone(),
}
}
#[allow(dead_code)]
pub fn read_audit_log(n: usize) -> Vec<String> {
let path = std::path::Path::new(".yoyo").join("audit.jsonl");
match std::fs::read_to_string(&path) {
Ok(content) => {
let lines: Vec<&str> = content.lines().collect();
let start = lines.len().saturating_sub(n);
lines[start..].iter().map(|s| s.to_string()).collect()
}
Err(_) => Vec::new(),
}
}
#[allow(dead_code)]
pub fn clear_audit_log() -> bool {
let path = std::path::Path::new(".yoyo").join("audit.jsonl");
if path.exists() {
std::fs::write(&path, "").is_ok()
} else {
false
}
}
#[derive(Debug, Clone)]
pub struct SessionChanges {
inner: Arc<Mutex<Vec<FileChange>>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FileChange {
pub path: String,
pub kind: ChangeKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChangeKind {
Write,
Edit,
}
impl std::fmt::Display for ChangeKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChangeKind::Write => write!(f, "write"),
ChangeKind::Edit => write!(f, "edit"),
}
}
}
impl SessionChanges {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn record(&self, path: &str, kind: ChangeKind) {
let mut changes = self.inner.lock().unwrap();
if let Some(existing) = changes.iter_mut().find(|c| c.path == path) {
existing.kind = kind;
} else {
changes.push(FileChange {
path: path.to_string(),
kind,
});
}
}
pub fn snapshot(&self) -> Vec<FileChange> {
self.inner.lock().unwrap().clone()
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.inner.lock().unwrap().is_empty()
}
pub fn clear(&self) {
self.inner.lock().unwrap().clear();
}
}
#[derive(Debug, Clone)]
pub struct TurnSnapshot {
pub originals: HashMap<String, String>,
pub created: Vec<String>,
}
impl TurnSnapshot {
pub fn new() -> Self {
Self {
originals: HashMap::new(),
created: Vec::new(),
}
}
pub fn snapshot_file(&mut self, path: &str) {
if self.originals.contains_key(path) {
return; }
if let Ok(content) = std::fs::read_to_string(path) {
self.originals.insert(path.to_string(), content);
}
}
pub fn record_created(&mut self, path: &str) {
if !self.originals.contains_key(path) && !self.created.contains(&path.to_string()) {
self.created.push(path.to_string());
}
}
#[allow(dead_code)]
pub fn file_count(&self) -> usize {
self.originals.len() + self.created.len()
}
pub fn is_empty(&self) -> bool {
self.originals.is_empty() && self.created.is_empty()
}
pub fn restore(&self) -> Vec<String> {
let mut actions = Vec::new();
for (path, content) in &self.originals {
if std::fs::write(path, content).is_ok() {
actions.push(format!("restored {path}"));
} else {
actions.push(format!("failed to restore {path}"));
}
}
for path in &self.created {
if std::fs::remove_file(path).is_ok() {
actions.push(format!("deleted {path}"));
} else {
actions.push(format!("failed to delete {path}"));
}
}
actions
}
}
#[derive(Debug, Clone)]
pub struct TurnHistory {
turns: Vec<TurnSnapshot>,
}
impl TurnHistory {
pub fn new() -> Self {
Self { turns: Vec::new() }
}
pub fn push(&mut self, snapshot: TurnSnapshot) {
if !snapshot.is_empty() {
self.turns.push(snapshot);
}
}
#[allow(dead_code)]
pub fn pop(&mut self) -> Option<TurnSnapshot> {
self.turns.pop()
}
pub fn len(&self) -> usize {
self.turns.len()
}
pub fn is_empty(&self) -> bool {
self.turns.is_empty()
}
pub fn undo_last(&mut self, n: usize) -> Vec<String> {
let mut all_actions = Vec::new();
let count = n.min(self.turns.len());
for _ in 0..count {
if let Some(snapshot) = self.turns.pop() {
all_actions.extend(snapshot.restore());
}
}
all_actions
}
pub fn clear(&mut self) {
self.turns.clear();
}
}
#[derive(Debug, Clone, Default)]
pub struct PromptOutcome {
pub text: String,
pub last_tool_error: Option<String>,
#[allow(dead_code)]
pub was_overflow: bool,
}
pub fn build_retry_prompt(input: &str, last_error: &Option<String>) -> String {
match last_error {
Some(err) => {
let summary = if err.len() > 200 {
format!("{}…", &err[..200])
} else {
err.clone()
};
format!("[Previous attempt failed: {summary}. Try a different approach.]\n\n{input}")
}
None => input.to_string(),
}
}
const MAX_RETRIES: u32 = 3;
pub const MAX_AUTO_RETRIES: u32 = 2;
pub fn build_auto_retry_prompt(original_input: &str, tool_error: &str, attempt: u32) -> String {
let summary = if tool_error.len() > 300 {
format!("{}…", &tool_error[..300])
} else {
tool_error.to_string()
};
format!(
"[Auto-retry {attempt}/{MAX_AUTO_RETRIES}: a tool failed with: {summary}. \
Try a different approach or fix the error.]\n\n{original_input}"
)
}
const OVERFLOW_PHRASES: &[&str] = &[
"prompt is too long",
"input is too long",
"exceeds the context window",
"exceeds the maximum",
"maximum prompt length",
"reduce the length of the messages",
"maximum context length",
"exceeds the limit of",
"exceeds the available context size",
"greater than the context length",
"context window exceeds limit",
"exceeded model token limit",
"context length exceeded",
"context_length_exceeded",
"too many tokens",
"token limit exceeded",
];
pub fn is_overflow_error(msg: &str) -> bool {
if msg.is_empty() {
return false;
}
let lower = msg.to_lowercase();
OVERFLOW_PHRASES.iter().any(|phrase| lower.contains(phrase))
}
pub fn build_overflow_retry_prompt(original_input: &str) -> String {
format!(
"[Context was auto-compacted because the conversation exceeded the model's token limit. \
Earlier messages have been summarized. Please continue with the task.]\n\n{original_input}"
)
}
pub fn retry_delay(attempt: u32) -> Duration {
Duration::from_secs(1 << (attempt.saturating_sub(1)))
}
pub fn is_retriable_error(error_msg: &str) -> bool {
let lower = error_msg.to_lowercase();
let non_retriable = [
"401",
"403",
"400",
"authentication",
"unauthorized",
"forbidden",
"invalid api key",
"invalid request",
"permission denied",
"invalid_api_key",
"not_found",
"404",
];
for keyword in &non_retriable {
if lower.contains(keyword) {
return false;
}
}
let retriable = [
"429",
"rate limit",
"rate_limit",
"too many requests",
"500",
"502",
"503",
"504",
"internal server error",
"bad gateway",
"service unavailable",
"gateway timeout",
"overloaded",
"connection",
"timeout",
"timed out",
"network",
"temporarily",
"retry",
"capacity",
"server error",
"stream ended",
"stream closed",
"unexpected eof",
"broken pipe",
"reset by peer",
"incomplete",
];
for keyword in &retriable {
if lower.contains(keyword) {
return true;
}
}
false
}
pub fn diagnose_api_error(error: &str, model: &str) -> Option<String> {
let lower = error.to_lowercase();
let provider = infer_provider_from_model(model);
if lower.contains("401")
|| lower.contains("unauthorized")
|| lower.contains("invalid api key")
|| lower.contains("invalid_api_key")
|| lower.contains("invalid x-api-key")
|| lower.contains("authentication")
{
let env_var = crate::cli::provider_api_key_env(&provider).unwrap_or("ANTHROPIC_API_KEY");
let config_hint = "Or add api_key to .yoyo.toml, or use --api-key <key>.";
let key_set = std::env::var(env_var).is_ok();
let status = if key_set {
format!(" {env_var} is set but the API rejected it — check the key value.")
} else {
format!(" {env_var} is not set.")
};
return Some(format!(
"Authentication failed for provider '{provider}'.\n\
{status}\n\
Set it with: export {env_var}=<your-key>\n\
{config_hint}"
));
}
if lower.contains("not_found")
|| lower.contains("model not found")
|| lower.contains("404")
|| lower.contains("does not exist")
|| lower.contains("unknown model")
|| lower.contains("invalid model")
|| lower.contains("no such model")
{
let known = crate::cli::known_models_for_provider(&provider);
let mut msg = format!("Model '{model}' was not found by provider '{provider}'.");
if !known.is_empty() {
msg.push_str("\nAvailable models for this provider:");
for m in known {
msg.push_str(&format!("\n • {m}"));
}
msg.push_str(&format!(
"\nSwitch with: /model {} or --model {}",
known[0], known[0]
));
}
return Some(msg);
}
if lower.contains("connection refused")
|| lower.contains("connection reset")
|| lower.contains("dns")
|| lower.contains("resolve")
|| lower.contains("name or service not known")
|| lower.contains("network is unreachable")
|| lower.contains("no route to host")
{
let mut msg = String::from("Network error — could not reach the API.\n");
if provider == "ollama" {
msg.push_str(" Is Ollama running? Try: ollama serve\n");
} else if provider == "custom" {
msg.push_str(" Check your --base-url value.\n");
} else {
msg.push_str(&format!(
" Check your internet connection and that {provider}'s API is reachable.\n"
));
}
msg.push_str(" You can retry with /retry.");
return Some(msg);
}
if lower.contains("403") || lower.contains("forbidden") || lower.contains("permission denied") {
return Some(format!(
"Access forbidden (403) from provider '{provider}'.\n\
This usually means your API key doesn't have access to model '{model}'.\n\
Check your plan/tier with {provider}, or try a different model."
));
}
if lower.contains("stream ended")
|| lower.contains("stream closed")
|| lower.contains("unexpected eof")
|| lower.contains("broken pipe")
|| lower.contains("incomplete")
{
return Some(
"The API stream was interrupted before the response completed.\n\
This is usually a transient network issue — yoyo will auto-retry.\n\
If it persists, check your internet connection or try a different model."
.to_string(),
);
}
None
}
fn infer_provider_from_model(model: &str) -> String {
let m = model.to_lowercase();
if m.contains("claude") || m.contains("opus") || m.contains("sonnet") || m.contains("haiku") {
"anthropic".into()
} else if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o4") {
"openai".into()
} else if m.contains("gemini") {
"google".into()
} else if m.contains("grok") {
"xai".into()
} else if m.contains("deepseek") {
"deepseek".into()
} else if m.contains("mistral") || m.contains("codestral") {
"mistral".into()
} else if m.contains("llama") || m.contains("mixtral") || m.contains("gemma") {
"groq".into()
} else if m.contains("glm") {
"zai".into()
} else {
"anthropic".into() }
}
fn tool_result_preview(result: &ToolResult, max_chars: usize) -> String {
let text: String = result
.content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ");
let text = text.trim();
if text.is_empty() {
return String::new();
}
let first_line = text.lines().next().unwrap_or("");
truncate_with_ellipsis(first_line, max_chars)
}
pub fn write_output_file(path: &Option<String>, text: &str) {
if let Some(path) = path {
match std::fs::write(path, text) {
Ok(_) => eprintln!("{DIM} wrote response to {path}{RESET}"),
Err(e) => eprintln!("{RED} error writing to {path}: {e}{RESET}"),
}
}
}
fn message_text(msg: &AgentMessage) -> String {
match msg {
AgentMessage::Llm(Message::User { content, .. }) => content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" "),
AgentMessage::Llm(Message::Assistant { content, .. }) => {
let mut parts = Vec::new();
for c in content {
match c {
Content::Text { text } if !text.is_empty() => parts.push(text.as_str()),
Content::ToolCall { name, .. } => parts.push(name.as_str()),
_ => {}
}
}
parts.join(" ")
}
AgentMessage::Llm(Message::ToolResult {
tool_name, content, ..
}) => {
let text: String = content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ");
format!("{tool_name} {text}")
}
AgentMessage::Extension(ext) => ext.role.clone(),
}
}
pub fn highlight_matches(text: &str, query: &str) -> String {
if query.is_empty() {
return text.to_string();
}
let lower_text = text.to_lowercase();
let lower_query = query.to_lowercase();
let mut result = String::with_capacity(text.len() + 32);
let mut last_end = 0;
for (match_start, _) in lower_text.match_indices(&lower_query) {
let match_end = match_start + query.len();
result.push_str(&text[last_end..match_start]);
result.push_str(&format!("{BOLD}{}{RESET}", &text[match_start..match_end]));
last_end = match_end;
}
result.push_str(&text[last_end..]);
result
}
pub fn search_messages(messages: &[AgentMessage], query: &str) -> Vec<(usize, String, String)> {
let query_lower = query.to_lowercase();
let mut results = Vec::new();
for (i, msg) in messages.iter().enumerate() {
let text = message_text(msg);
if text.to_lowercase().contains(&query_lower) {
let (role, _) = summarize_message(msg);
let lower = text.to_lowercase();
let match_pos = lower.find(&query_lower).unwrap_or(0);
let start = match_pos.saturating_sub(20);
let start = text[..start]
.char_indices()
.last()
.map(|(idx, _)| idx)
.unwrap_or(0);
let end = text
.char_indices()
.map(|(idx, ch)| idx + ch.len_utf8())
.find(|&idx| idx >= match_pos + query.len() + 20)
.unwrap_or(text.len());
let snippet = &text[start..end];
let prefix = if start > 0 { "…" } else { "" };
let suffix = if end < text.len() { "…" } else { "" };
let preview = format!("{prefix}{snippet}{suffix}");
let highlighted = highlight_matches(&preview, query);
results.push((i + 1, role.to_string(), highlighted));
}
}
results
}
pub fn summarize_message(msg: &AgentMessage) -> (&str, String) {
match msg {
AgentMessage::Llm(Message::User { content, .. }) => {
let text = content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ");
("user", truncate_with_ellipsis(&text, 80))
}
AgentMessage::Llm(Message::Assistant { content, .. }) => {
let mut parts = Vec::new();
let mut tool_calls = 0;
for c in content {
match c {
Content::Text { text } if !text.is_empty() => {
parts.push(truncate_with_ellipsis(text, 60));
}
Content::ToolCall { name, .. } => {
tool_calls += 1;
if tool_calls <= 3 {
parts.push(format!("→{name}"));
}
}
_ => {}
}
}
if tool_calls > 3 {
parts.push(format!("(+{} more tools)", tool_calls - 3));
}
let preview = if parts.is_empty() {
"(empty)".to_string()
} else {
parts.join(" ")
};
("assistant", preview)
}
AgentMessage::Llm(Message::ToolResult {
tool_name,
is_error,
..
}) => {
let status = if *is_error { "✗" } else { "✓" };
("tool", format!("{tool_name} {status}"))
}
AgentMessage::Extension(ext) => ("ext", truncate_with_ellipsis(&ext.role, 60)),
}
}
enum PromptResult {
Done {
collected_text: String,
usage: Usage,
last_tool_error: Option<String>,
},
RetriableError { error_msg: String, usage: Usage },
ContextOverflow { error_msg: String, usage: Usage },
}
async fn run_prompt_once(
agent: &mut Agent,
input: &str,
changes: &SessionChanges,
model: &str,
) -> PromptResult {
let rx = agent.prompt(input).await;
handle_prompt_events(agent, rx, changes, model).await
}
async fn run_prompt_once_with_messages(
agent: &mut Agent,
messages: Vec<AgentMessage>,
changes: &SessionChanges,
model: &str,
) -> PromptResult {
let rx = agent.prompt_messages(messages).await;
handle_prompt_events(agent, rx, changes, model).await
}
async fn handle_prompt_events(
agent: &mut Agent,
mut rx: tokio::sync::mpsc::UnboundedReceiver<AgentEvent>,
changes: &SessionChanges,
model: &str,
) -> PromptResult {
let mut usage = Usage::default();
let mut in_text = false;
let mut in_thinking = false;
let mut tool_timers: HashMap<String, Instant> = HashMap::new();
let mut collected_text = String::new();
let mut retriable_error: Option<String> = None;
let mut overflow_error: Option<String> = None;
let mut last_tool_error: Option<String> = None;
let mut md_renderer = MarkdownRenderer::new();
let mut spinner: Option<Spinner> = Some(Spinner::start());
let mut think_filter = ThinkBlockFilter::new();
let mut audit_inflight: HashMap<String, (String, serde_json::Value)> = HashMap::new();
let mut tool_progress_timers: HashMap<String, ToolProgressTimer> = HashMap::new();
let mut batch_count: usize = 0;
let mut batch_succeeded: usize = 0;
let mut batch_failed: usize = 0;
let mut batch_start: Option<Instant> = None;
let mut turn_number: usize = 0;
let mut had_text = false;
loop {
tokio::select! {
event = rx.recv() => {
let Some(event) = event else { break };
match event {
AgentEvent::ToolExecutionStart {
tool_call_id, tool_name, args, ..
} => {
match tool_name.as_str() {
"write_file" => {
if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
changes.record(path, ChangeKind::Write);
}
}
"edit_file" => {
if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
changes.record(path, ChangeKind::Edit);
}
}
_ => {}
}
if let Some(s) = spinner.take() { s.stop(); }
if in_text {
println!();
in_text = false;
}
if batch_count == 0 {
if batch_start.is_none() {
batch_start = Some(Instant::now());
}
if had_text {
turn_number += 1;
println!("{}", turn_boundary(turn_number));
}
}
batch_count += 1;
tool_timers.insert(tool_call_id.clone(), Instant::now());
audit_inflight.insert(
tool_call_id.clone(),
(tool_name.clone(), args.clone()),
);
let summary = format_tool_summary(&tool_name, &args);
if tool_name == "sub_agent" {
eprintln!("\n{DIM} 🐙 Delegating to sub-agent...{RESET}");
}
print!("{YELLOW} ▶ {summary}{RESET}");
if is_verbose() {
println!();
let args_str = serde_json::to_string_pretty(&args).unwrap_or_default();
for line in args_str.lines() {
println!("{DIM} │ {line}{RESET}");
}
} else if tool_name == "edit_file" {
let old_text = args.get("old_text").and_then(|v| v.as_str()).unwrap_or("");
let new_text = args.get("new_text").and_then(|v| v.as_str()).unwrap_or("");
let diff = format_edit_diff(old_text, new_text);
if !diff.is_empty() {
println!();
println!("{diff}");
}
}
io::stdout().flush().ok();
if tool_name == "bash" {
let timer = ToolProgressTimer::start(tool_name.clone());
tool_progress_timers.insert(tool_call_id.clone(), timer);
}
}
AgentEvent::ToolExecutionEnd { tool_call_id, is_error, result, .. } => {
if let Some(timer) = tool_progress_timers.remove(&tool_call_id) {
timer.stop();
}
let elapsed = tool_timers
.remove(&tool_call_id)
.map(|start| start.elapsed());
let dur_str = elapsed
.map(|d| format!(" {DIM}({}){RESET}", format_duration(d)))
.unwrap_or_default();
if let Some((audit_tool, audit_args)) = audit_inflight.remove(&tool_call_id) {
let duration_ms = elapsed.map(|d| d.as_millis() as u64).unwrap_or(0);
audit_log_tool_call(&audit_tool, &audit_args, duration_ms, !is_error);
}
if is_error {
batch_failed += 1;
println!(" {RED}✗{RESET}{dur_str}");
let preview = tool_result_preview(&result, 200);
if !preview.is_empty() {
println!("{}", indent_tool_output(&preview));
}
let error_text = tool_result_preview(&result, 200);
if !error_text.is_empty() {
last_tool_error = Some(error_text);
} else {
last_tool_error = Some("tool execution failed".to_string());
}
} else {
batch_succeeded += 1;
last_tool_error = None;
println!(" {GREEN}✓{RESET}{dur_str}");
if is_verbose() {
let preview = tool_result_preview(&result, 200);
if !preview.is_empty() {
println!("{}", indent_tool_output(&preview));
}
}
}
}
AgentEvent::ToolExecutionUpdate { tool_call_id, partial_result, .. } => {
let line_count = count_result_lines(&partial_result);
if let Some(timer) = tool_progress_timers.get(&tool_call_id) {
timer.set_line_count(line_count);
}
if io::stdout().is_terminal() {
let text = extract_result_text(&partial_result);
if !text.is_empty() {
let tail = format_partial_tail(&text, 3);
if !tail.is_empty() {
println!();
println!("{tail}");
io::stdout().flush().ok();
}
}
}
}
AgentEvent::MessageUpdate {
delta: StreamDelta::Text { delta },
..
} => {
if let Some(s) = spinner.take() { s.stop(); }
if in_thinking {
eprintln!();
eprintln!("{}", section_divider());
let _ = io::stderr().flush();
in_thinking = false;
}
if batch_count > 0 {
let batch_duration = batch_start
.map(|s| s.elapsed())
.unwrap_or_default();
let summary = format_tool_batch_summary(
batch_count, batch_succeeded, batch_failed, batch_duration,
);
if !summary.is_empty() {
println!("{summary}");
}
batch_count = 0;
batch_succeeded = 0;
batch_failed = 0;
batch_start = None;
}
if !in_text {
println!();
in_text = true;
had_text = true;
}
let filtered = if is_verbose() {
delta.clone()
} else {
think_filter.filter(&delta)
};
if filtered.is_empty() {
io::stdout().flush().ok();
continue;
}
let rendered = md_renderer.render_delta(&filtered);
if !rendered.is_empty() {
print!("{}", rendered);
}
io::stdout().flush().ok();
collected_text.push_str(&filtered);
}
AgentEvent::MessageUpdate {
delta: StreamDelta::Thinking { delta },
..
} => {
if let Some(s) = spinner.take() { s.stop(); }
if !in_thinking {
eprintln!("\n{}", section_header("Thinking"));
in_thinking = true;
}
eprint!("{DIM}{delta}{RESET}");
let _ = io::stderr().flush();
}
AgentEvent::AgentEnd { messages } => {
if let Some(s) = spinner.take() { s.stop(); }
let remaining = think_filter.flush();
if !remaining.is_empty() {
let rendered = md_renderer.render_delta(&remaining);
if !rendered.is_empty() {
print!("{rendered}");
io::stdout().flush().ok();
}
collected_text.push_str(&remaining);
}
if batch_count > 0 {
let batch_duration = batch_start
.map(|s| s.elapsed())
.unwrap_or_default();
let summary = format_tool_batch_summary(
batch_count, batch_succeeded, batch_failed, batch_duration,
);
if !summary.is_empty() {
println!("{summary}");
}
batch_count = 0;
batch_succeeded = 0;
batch_failed = 0;
batch_start = None;
}
for msg in &messages {
if let AgentMessage::Llm(Message::Assistant { usage: msg_usage, stop_reason, error_message, .. }) = msg {
usage.input += msg_usage.input;
usage.output += msg_usage.output;
usage.cache_read += msg_usage.cache_read;
usage.cache_write += msg_usage.cache_write;
if *stop_reason == StopReason::Error {
if let Some(err_msg) = error_message {
if in_text {
println!();
in_text = false;
}
if is_overflow_error(err_msg) {
overflow_error = Some(err_msg.clone());
} else if is_retriable_error(err_msg) {
retriable_error = Some(err_msg.clone());
} else {
eprintln!("\n{RED} error: {err_msg}{RESET}");
if let Some(diagnostic) = diagnose_api_error(err_msg, model) {
eprintln!("{YELLOW} 💡 {}{RESET}", diagnostic.replace('\n', &format!("\n{YELLOW} {RESET}")));
}
}
}
}
}
}
}
AgentEvent::InputRejected { reason } => {
if let Some(s) = spinner.take() { s.stop(); }
eprintln!("{RED} input rejected: {reason}{RESET}");
if let Some(diagnostic) = diagnose_api_error(&reason, model) {
eprintln!("{YELLOW} 💡 {}{RESET}", diagnostic.replace('\n', &format!("\n{YELLOW} {RESET}")));
}
}
AgentEvent::ProgressMessage { text, .. } => {
if let Some(s) = spinner.take() { s.stop(); }
if in_text {
println!();
in_text = false;
}
println!("{DIM} {text}{RESET}");
}
AgentEvent::MessageStart { .. } => {
if let Some(s) = spinner.take() { s.stop(); }
}
AgentEvent::MessageEnd { .. } => {
if in_text {
let remaining = md_renderer.flush();
if !remaining.is_empty() {
print!("{remaining}");
}
println!();
in_text = false;
}
}
_ => {}
}
}
_ = tokio::signal::ctrl_c() => {
if let Some(s) = spinner.take() { s.stop(); }
agent.abort();
if in_text {
println!();
}
println!("\n{DIM} (interrupted — press Ctrl+C again to exit){RESET}");
return PromptResult::Done {
collected_text,
usage,
last_tool_error,
};
}
}
}
if let Some(s) = spinner.take() {
s.stop();
}
let remaining = md_renderer.flush();
if !remaining.is_empty() {
print!("{}", remaining);
io::stdout().flush().ok();
}
if in_text {
println!();
}
if let Some(err_msg) = overflow_error {
PromptResult::ContextOverflow {
error_msg: err_msg,
usage,
}
} else if let Some(err_msg) = retriable_error {
PromptResult::RetriableError {
error_msg: err_msg,
usage,
}
} else {
PromptResult::Done {
collected_text,
usage,
last_tool_error,
}
}
}
pub async fn run_prompt(
agent: &mut Agent,
input: &str,
session_total: &mut Usage,
model: &str,
) -> PromptOutcome {
let changes = SessionChanges::new();
run_prompt_with_changes(agent, input, session_total, model, &changes).await
}
pub async fn run_prompt_with_changes(
agent: &mut Agent,
input: &str,
session_total: &mut Usage,
model: &str,
changes: &SessionChanges,
) -> PromptOutcome {
crate::commands_session::proactive_compact_if_needed(agent);
let prompt_start = Instant::now();
let mut total_usage = Usage::default();
let mut collected_text = String::new();
let mut last_tool_error: Option<String> = None;
let mut did_overflow_compact = false;
let saved_state = agent.save_messages().ok();
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
if let Some(ref json) = saved_state {
let _ = agent.restore_messages(json);
}
}
match run_prompt_once(agent, input, changes, model).await {
PromptResult::Done {
collected_text: text,
usage,
last_tool_error: tool_err,
} => {
total_usage.input += usage.input;
total_usage.output += usage.output;
total_usage.cache_read += usage.cache_read;
total_usage.cache_write += usage.cache_write;
collected_text = text;
last_tool_error = tool_err;
break;
}
PromptResult::RetriableError { error_msg, usage } => {
total_usage.input += usage.input;
total_usage.output += usage.output;
total_usage.cache_read += usage.cache_read;
total_usage.cache_write += usage.cache_write;
if attempt < MAX_RETRIES {
let delay = retry_delay(attempt + 1);
let delay_secs = delay.as_secs();
let next = attempt + 2; eprintln!(
"{DIM} ⚡ retrying (attempt {next}/{}, waiting {delay_secs}s)...{RESET}",
MAX_RETRIES + 1
);
tokio::time::sleep(delay).await;
} else {
eprintln!("\n{RED} error: {error_msg}{RESET}");
eprintln!("{DIM} (failed after {} attempts){RESET}", MAX_RETRIES + 1);
if let Some(diagnostic) = diagnose_api_error(&error_msg, model) {
eprintln!(
"{YELLOW} 💡 {}{RESET}",
diagnostic.replace('\n', &format!("\n{YELLOW} {RESET}"))
);
}
}
}
PromptResult::ContextOverflow { error_msg, usage } => {
total_usage.input += usage.input;
total_usage.output += usage.output;
total_usage.cache_read += usage.cache_read;
total_usage.cache_write += usage.cache_write;
eprintln!(
"\n{YELLOW} ⚡ context overflow detected — auto-compacting and retrying...{RESET}"
);
eprintln!("{DIM} ({error_msg}){RESET}");
if let Some(ref json) = saved_state {
let _ = agent.restore_messages(json);
}
if let Some((before_count, before_tokens, after_count, after_tokens)) =
crate::commands_session::compact_agent(agent)
{
eprintln!(
"{DIM} compacted: {before_count} → {after_count} messages, ~{} → ~{} tokens{RESET}",
crate::format::format_token_count(before_tokens),
crate::format::format_token_count(after_tokens)
);
}
did_overflow_compact = true;
let retry_input = build_overflow_retry_prompt(input);
match run_prompt_once(agent, &retry_input, changes, model).await {
PromptResult::Done {
collected_text: text,
usage: retry_usage,
last_tool_error: tool_err,
} => {
total_usage.input += retry_usage.input;
total_usage.output += retry_usage.output;
total_usage.cache_read += retry_usage.cache_read;
total_usage.cache_write += retry_usage.cache_write;
collected_text = text;
last_tool_error = tool_err;
}
PromptResult::RetriableError {
error_msg: retry_err,
usage: retry_usage,
}
| PromptResult::ContextOverflow {
error_msg: retry_err,
usage: retry_usage,
} => {
total_usage.input += retry_usage.input;
total_usage.output += retry_usage.output;
total_usage.cache_read += retry_usage.cache_read;
total_usage.cache_write += retry_usage.cache_write;
eprintln!("\n{RED} error: {retry_err}{RESET}");
eprintln!(
"{DIM} (overflow retry also failed — try /compact manually){RESET}"
);
}
}
break;
}
}
}
session_total.input += total_usage.input;
session_total.output += total_usage.output;
session_total.cache_read += total_usage.cache_read;
session_total.cache_write += total_usage.cache_write;
print_usage(&total_usage, session_total, model, prompt_start.elapsed());
maybe_ring_bell(prompt_start.elapsed());
println!();
PromptOutcome {
text: collected_text,
last_tool_error,
was_overflow: did_overflow_compact,
}
}
pub async fn run_prompt_auto_retry(
agent: &mut Agent,
input: &str,
session_total: &mut Usage,
model: &str,
changes: &SessionChanges,
) -> PromptOutcome {
let mut outcome = run_prompt_with_changes(agent, input, session_total, model, changes).await;
for attempt in 1..=MAX_AUTO_RETRIES {
match outcome.last_tool_error {
Some(ref err) => {
let retry_prompt = build_auto_retry_prompt(input, err, attempt);
eprintln!(
"{DIM} ⚡ auto-retrying after tool error (attempt {attempt}/{MAX_AUTO_RETRIES})...{RESET}"
);
outcome =
run_prompt_with_changes(agent, &retry_prompt, session_total, model, changes)
.await;
}
None => break,
}
}
outcome
}
pub async fn run_prompt_with_content(
agent: &mut Agent,
content_blocks: Vec<Content>,
session_total: &mut Usage,
model: &str,
) -> PromptOutcome {
let changes = SessionChanges::new();
run_prompt_with_content_and_changes(agent, content_blocks, session_total, model, &changes).await
}
pub async fn run_prompt_auto_retry_with_content(
agent: &mut Agent,
content_blocks: Vec<Content>,
session_total: &mut Usage,
model: &str,
changes: &SessionChanges,
original_text: &str,
) -> PromptOutcome {
let mut outcome =
run_prompt_with_content_and_changes(agent, content_blocks, session_total, model, changes)
.await;
for attempt in 1..=MAX_AUTO_RETRIES {
match outcome.last_tool_error {
Some(ref err) => {
let retry_prompt = build_auto_retry_prompt(original_text, err, attempt);
eprintln!(
"{DIM} ⚡ auto-retrying after tool error (attempt {attempt}/{MAX_AUTO_RETRIES})...{RESET}"
);
outcome =
run_prompt_with_changes(agent, &retry_prompt, session_total, model, changes)
.await;
}
None => break,
}
}
outcome
}
pub async fn run_prompt_with_content_and_changes(
agent: &mut Agent,
content_blocks: Vec<Content>,
session_total: &mut Usage,
model: &str,
changes: &SessionChanges,
) -> PromptOutcome {
crate::commands_session::proactive_compact_if_needed(agent);
let prompt_start = Instant::now();
let mut total_usage = Usage::default();
let mut collected_text = String::new();
let mut last_tool_error: Option<String> = None;
let user_msg = AgentMessage::Llm(Message::User {
content: content_blocks,
timestamp: now_ms(),
});
let saved_state = agent.save_messages().ok();
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
if let Some(ref json) = saved_state {
let _ = agent.restore_messages(json);
}
}
match run_prompt_once_with_messages(agent, vec![user_msg.clone()], changes, model).await {
PromptResult::Done {
collected_text: text,
usage,
last_tool_error: tool_err,
} => {
total_usage.input += usage.input;
total_usage.output += usage.output;
total_usage.cache_read += usage.cache_read;
total_usage.cache_write += usage.cache_write;
collected_text = text;
last_tool_error = tool_err;
break;
}
PromptResult::RetriableError { error_msg, usage } => {
total_usage.input += usage.input;
total_usage.output += usage.output;
total_usage.cache_read += usage.cache_read;
total_usage.cache_write += usage.cache_write;
if attempt < MAX_RETRIES {
let delay = retry_delay(attempt + 1);
let delay_secs = delay.as_secs();
let next = attempt + 2;
eprintln!(
"{DIM} ⚡ retrying (attempt {next}/{}, waiting {delay_secs}s)...{RESET}",
MAX_RETRIES + 1
);
tokio::time::sleep(delay).await;
} else {
eprintln!("\n{RED} error: {error_msg}{RESET}");
eprintln!("{DIM} (failed after {} attempts){RESET}", MAX_RETRIES + 1);
if let Some(diagnostic) = diagnose_api_error(&error_msg, model) {
eprintln!(
"{YELLOW} 💡 {}{RESET}",
diagnostic.replace('\n', &format!("\n{YELLOW} {RESET}"))
);
}
}
}
PromptResult::ContextOverflow { error_msg, usage } => {
total_usage.input += usage.input;
total_usage.output += usage.output;
total_usage.cache_read += usage.cache_read;
total_usage.cache_write += usage.cache_write;
eprintln!(
"\n{YELLOW} ⚡ context overflow detected — cannot retry with image content{RESET}"
);
eprintln!("{DIM} ({error_msg}){RESET}");
break;
}
}
}
session_total.input += total_usage.input;
session_total.output += total_usage.output;
session_total.cache_read += total_usage.cache_read;
session_total.cache_write += total_usage.cache_write;
print_usage(&total_usage, session_total, model, prompt_start.elapsed());
maybe_ring_bell(prompt_start.elapsed());
println!();
PromptOutcome {
text: collected_text,
last_tool_error,
was_overflow: false,
}
}
pub fn format_changes(changes: &SessionChanges) -> String {
let snapshot = changes.snapshot();
if snapshot.is_empty() {
return String::new();
}
let mut out = String::new();
out.push_str(&format!(
" {} {} modified this session:\n",
snapshot.len(),
pluralize(snapshot.len(), "file", "files")
));
for change in &snapshot {
let icon = match change.kind {
ChangeKind::Write => "✏",
ChangeKind::Edit => "🔧",
};
out.push_str(&format!(" {icon} {} ({})\n", change.path, change.kind));
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_delay_exponential_backoff() {
assert_eq!(retry_delay(1), Duration::from_secs(1));
assert_eq!(retry_delay(2), Duration::from_secs(2));
assert_eq!(retry_delay(3), Duration::from_secs(4));
}
#[test]
fn test_retry_delay_zero_attempt() {
assert_eq!(retry_delay(0), Duration::from_secs(1));
}
#[test]
fn test_is_retriable_rate_limit() {
assert!(is_retriable_error("429 Too Many Requests"));
assert!(is_retriable_error("rate limit exceeded"));
assert!(is_retriable_error("Rate_limit_error: too many requests"));
assert!(is_retriable_error("too many requests, please slow down"));
}
#[test]
fn test_is_retriable_server_errors() {
assert!(is_retriable_error("500 Internal Server Error"));
assert!(is_retriable_error("502 Bad Gateway"));
assert!(is_retriable_error("503 Service Unavailable"));
assert!(is_retriable_error("504 Gateway Timeout"));
assert!(is_retriable_error("the server is overloaded"));
assert!(is_retriable_error("Server error occurred"));
}
#[test]
fn test_is_retriable_network_errors() {
assert!(is_retriable_error("connection reset by peer"));
assert!(is_retriable_error("network error: connection refused"));
assert!(is_retriable_error("request timed out"));
assert!(is_retriable_error("timeout waiting for response"));
}
#[test]
fn test_is_not_retriable_auth_errors() {
assert!(!is_retriable_error("401 Unauthorized"));
assert!(!is_retriable_error("403 Forbidden"));
assert!(!is_retriable_error("authentication failed"));
assert!(!is_retriable_error("invalid api key"));
assert!(!is_retriable_error("Invalid_api_key: check your key"));
assert!(!is_retriable_error("permission denied"));
}
#[test]
fn test_is_not_retriable_client_errors() {
assert!(!is_retriable_error("400 Bad Request"));
assert!(!is_retriable_error("invalid request body"));
assert!(!is_retriable_error("404 not_found"));
}
#[test]
fn test_is_not_retriable_unknown_error() {
assert!(!is_retriable_error("something went wrong"));
assert!(!is_retriable_error("unexpected error"));
}
#[test]
fn test_is_retriable_stream_errors() {
assert!(is_retriable_error("Stream ended"));
assert!(is_retriable_error("stream closed unexpectedly"));
assert!(is_retriable_error("unexpected eof while reading"));
assert!(is_retriable_error("broken pipe"));
assert!(is_retriable_error("connection reset by peer"));
assert!(is_retriable_error("incomplete response from server"));
}
#[test]
fn test_diagnose_stream_ended() {
let diag = diagnose_api_error("error: Stream ended", "claude-sonnet-4-20250514");
assert!(diag.is_some());
let msg = diag.unwrap();
assert!(msg.contains("interrupted"));
assert!(msg.contains("auto-retry"));
}
#[test]
fn test_diagnose_stream_closed() {
let diag = diagnose_api_error("stream closed unexpectedly", "gpt-4o");
assert!(diag.is_some());
assert!(diag.unwrap().contains("interrupted"));
}
#[test]
fn test_diagnose_unexpected_eof() {
let diag = diagnose_api_error("unexpected eof", "claude-sonnet-4-20250514");
assert!(diag.is_some());
assert!(diag.unwrap().contains("interrupted"));
}
#[test]
fn test_diagnose_broken_pipe() {
let diag = diagnose_api_error("broken pipe while writing", "claude-sonnet-4-20250514");
assert!(diag.is_some());
assert!(diag.unwrap().contains("interrupted"));
}
#[test]
fn test_diagnose_incomplete() {
let diag = diagnose_api_error("incomplete response", "claude-sonnet-4-20250514");
assert!(diag.is_some());
assert!(diag.unwrap().contains("interrupted"));
}
#[test]
fn test_summarize_message_user() {
let msg = AgentMessage::Llm(Message::user("hello world, this is a test"));
let (role, preview) = summarize_message(&msg);
assert_eq!(role, "user");
assert!(preview.contains("hello world"));
}
#[test]
fn test_summarize_message_tool_result() {
let msg = AgentMessage::Llm(Message::ToolResult {
tool_call_id: "tc_1".into(),
tool_name: "bash".into(),
content: vec![Content::Text {
text: "output".into(),
}],
is_error: false,
timestamp: 0,
});
let (role, preview) = summarize_message(&msg);
assert_eq!(role, "tool");
assert!(preview.contains("bash"));
assert!(preview.contains("✓"));
}
#[test]
fn test_summarize_message_tool_result_error() {
let msg = AgentMessage::Llm(Message::ToolResult {
tool_call_id: "tc_2".into(),
tool_name: "bash".into(),
content: vec![Content::Text {
text: "error".into(),
}],
is_error: true,
timestamp: 0,
});
let (role, preview) = summarize_message(&msg);
assert_eq!(role, "tool");
assert!(preview.contains("✗"));
}
#[test]
fn test_write_output_file_none() {
write_output_file(&None, "test content");
}
#[test]
fn test_write_output_file_some() {
let dir = std::env::temp_dir().join("yoyo_test_output");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("test_output.txt");
let path_str = path.to_string_lossy().to_string();
write_output_file(&Some(path_str), "hello from yoyo");
let content = std::fs::read_to_string(&path).unwrap();
assert_eq!(content, "hello from yoyo");
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_tool_result_preview_empty() {
let result = ToolResult {
content: vec![],
details: serde_json::json!(null),
};
assert_eq!(tool_result_preview(&result, 100), "");
}
#[test]
fn test_tool_result_preview_text() {
let result = ToolResult {
content: vec![Content::Text {
text: "error: file not found".into(),
}],
details: serde_json::json!(null),
};
assert_eq!(tool_result_preview(&result, 100), "error: file not found");
}
#[test]
fn test_tool_result_preview_truncated() {
let result = ToolResult {
content: vec![Content::Text {
text: "a".repeat(200),
}],
details: serde_json::json!(null),
};
let preview = tool_result_preview(&result, 50);
assert!(preview.len() < 100);
assert!(preview.ends_with('…'));
}
#[test]
fn test_tool_result_preview_multiline() {
let result = ToolResult {
content: vec![Content::Text {
text: "first line\nsecond line\nthird line".into(),
}],
details: serde_json::json!(null),
};
assert_eq!(tool_result_preview(&result, 100), "first line");
}
#[test]
fn test_search_messages_basic_match() {
let messages = vec![
AgentMessage::Llm(Message::user("hello world")),
AgentMessage::Llm(Message::user("goodbye world")),
];
let results = search_messages(&messages, "hello");
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 1); assert_eq!(results[0].1, "user");
assert!(results[0].2.contains("hello"));
}
#[test]
fn test_search_messages_case_insensitive() {
let messages = vec![AgentMessage::Llm(Message::user("Hello World"))];
let results = search_messages(&messages, "hello");
assert_eq!(results.len(), 1);
let results2 = search_messages(&messages, "HELLO");
assert_eq!(results2.len(), 1);
}
#[test]
fn test_search_messages_no_match() {
let messages = vec![AgentMessage::Llm(Message::user("hello world"))];
let results = search_messages(&messages, "foobar");
assert!(results.is_empty());
}
#[test]
fn test_search_messages_empty_messages() {
let messages: Vec<AgentMessage> = vec![];
let results = search_messages(&messages, "anything");
assert!(results.is_empty());
}
#[test]
fn test_search_messages_multiple_matches() {
let messages = vec![
AgentMessage::Llm(Message::user("the rust language")),
AgentMessage::Llm(Message::user("python is great")),
AgentMessage::Llm(Message::user("rust is fast")),
];
let results = search_messages(&messages, "rust");
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1);
assert_eq!(results[1].0, 3);
}
#[test]
fn test_search_messages_tool_result() {
let messages = vec![AgentMessage::Llm(Message::ToolResult {
tool_call_id: "tc_1".into(),
tool_name: "bash".into(),
content: vec![Content::Text {
text: "cargo build succeeded".into(),
}],
is_error: false,
timestamp: 0,
})];
let results = search_messages(&messages, "cargo");
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, "tool");
}
#[test]
fn test_message_text_user() {
let msg = AgentMessage::Llm(Message::user("test input"));
let text = message_text(&msg);
assert_eq!(text, "test input");
}
#[test]
fn test_message_text_tool_result() {
let msg = AgentMessage::Llm(Message::ToolResult {
tool_call_id: "tc_1".into(),
tool_name: "bash".into(),
content: vec![Content::Text {
text: "output text".into(),
}],
is_error: false,
timestamp: 0,
});
let text = message_text(&msg);
assert!(text.contains("bash"));
assert!(text.contains("output text"));
}
#[test]
fn test_highlight_matches_basic() {
let result = highlight_matches("hello world", "world");
assert!(result.contains(&format!("{BOLD}world{RESET}")));
assert!(result.contains("hello "));
}
#[test]
fn test_highlight_matches_case_insensitive() {
let result = highlight_matches("Hello World", "hello");
assert!(result.contains(&format!("{BOLD}Hello{RESET}")));
}
#[test]
fn test_highlight_matches_multiple_occurrences() {
let result = highlight_matches("rust is fast, rust is safe", "rust");
let bold_rust = format!("{BOLD}rust{RESET}");
let count = result.matches(&bold_rust.to_string()).count();
assert_eq!(count, 2);
}
#[test]
fn test_highlight_matches_no_match() {
let result = highlight_matches("hello world", "foobar");
assert_eq!(result, "hello world");
}
#[test]
fn test_highlight_matches_empty_query() {
let result = highlight_matches("hello world", "");
assert_eq!(result, "hello world");
}
#[test]
fn test_highlight_matches_empty_text() {
let result = highlight_matches("", "query");
assert_eq!(result, "");
}
#[test]
fn test_highlight_matches_preserves_original_case() {
let result = highlight_matches("The Rust Language", "rust");
assert!(result.contains(&format!("{BOLD}Rust{RESET}")));
}
#[test]
fn test_highlight_matches_entire_string() {
let result = highlight_matches("hello", "hello");
assert_eq!(result, format!("{BOLD}hello{RESET}"));
}
#[test]
fn test_search_messages_results_are_highlighted() {
let messages = vec![AgentMessage::Llm(Message::user("hello world"))];
let results = search_messages(&messages, "hello");
assert_eq!(results.len(), 1);
assert!(results[0].2.contains(&format!("{BOLD}hello{RESET}")));
}
#[test]
fn test_session_changes_new_is_empty() {
let changes = SessionChanges::new();
assert!(changes.is_empty());
assert_eq!(changes.len(), 0);
assert!(changes.snapshot().is_empty());
}
#[test]
fn test_session_changes_record_write() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
assert_eq!(changes.len(), 1);
assert!(!changes.is_empty());
let snapshot = changes.snapshot();
assert_eq!(snapshot[0].path, "src/main.rs");
assert_eq!(snapshot[0].kind, ChangeKind::Write);
}
#[test]
fn test_session_changes_record_edit() {
let changes = SessionChanges::new();
changes.record("src/cli.rs", ChangeKind::Edit);
assert_eq!(changes.len(), 1);
let snapshot = changes.snapshot();
assert_eq!(snapshot[0].path, "src/cli.rs");
assert_eq!(snapshot[0].kind, ChangeKind::Edit);
}
#[test]
fn test_session_changes_deduplicates_same_path() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
changes.record("src/main.rs", ChangeKind::Edit);
assert_eq!(changes.len(), 1);
let snapshot = changes.snapshot();
assert_eq!(snapshot[0].kind, ChangeKind::Edit);
}
#[test]
fn test_session_changes_multiple_files() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
changes.record("src/cli.rs", ChangeKind::Edit);
changes.record("README.md", ChangeKind::Write);
assert_eq!(changes.len(), 3);
let snapshot = changes.snapshot();
assert_eq!(snapshot[0].path, "src/main.rs");
assert_eq!(snapshot[1].path, "src/cli.rs");
assert_eq!(snapshot[2].path, "README.md");
}
#[test]
fn test_session_changes_clear() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
changes.record("src/cli.rs", ChangeKind::Edit);
assert_eq!(changes.len(), 2);
changes.clear();
assert!(changes.is_empty());
assert_eq!(changes.len(), 0);
}
#[test]
fn test_session_changes_clone_is_independent() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
let cloned = changes.clone();
changes.record("src/cli.rs", ChangeKind::Edit);
assert_eq!(cloned.len(), 2);
}
#[test]
fn test_change_kind_display() {
assert_eq!(format!("{}", ChangeKind::Write), "write");
assert_eq!(format!("{}", ChangeKind::Edit), "edit");
}
#[test]
fn test_format_changes_empty() {
let changes = SessionChanges::new();
let output = format_changes(&changes);
assert!(output.is_empty());
}
#[test]
fn test_format_changes_single_write() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
let output = format_changes(&changes);
assert!(output.contains("1 file modified"));
assert!(output.contains("src/main.rs"));
assert!(output.contains("write"));
assert!(output.contains("✏"));
}
#[test]
fn test_format_changes_multiple_files() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
changes.record("src/cli.rs", ChangeKind::Edit);
let output = format_changes(&changes);
assert!(output.contains("2 files modified"));
assert!(output.contains("src/main.rs"));
assert!(output.contains("src/cli.rs"));
assert!(output.contains("write"));
assert!(output.contains("edit"));
assert!(output.contains("🔧"));
}
#[test]
fn test_build_auto_retry_prompt_includes_error_and_input() {
let prompt = build_auto_retry_prompt("fix the bug", "file not found: foo.rs", 1);
assert!(prompt.contains("fix the bug"));
assert!(prompt.contains("file not found: foo.rs"));
assert!(prompt.contains("Auto-retry 1/2"));
assert!(prompt.contains("Try a different approach"));
}
#[test]
fn test_build_auto_retry_prompt_attempt_number() {
let prompt1 = build_auto_retry_prompt("do something", "error", 1);
let prompt2 = build_auto_retry_prompt("do something", "error", 2);
assert!(prompt1.contains("Auto-retry 1/2"));
assert!(prompt2.contains("Auto-retry 2/2"));
}
#[test]
fn test_build_auto_retry_prompt_truncates_long_errors() {
let long_error = "x".repeat(500);
let prompt = build_auto_retry_prompt("input", &long_error, 1);
assert!(prompt.contains(&"x".repeat(300)));
assert!(prompt.contains('…'));
assert!(!prompt.contains(&"x".repeat(500)));
}
#[test]
fn test_build_auto_retry_prompt_preserves_short_errors() {
let short_error = "command exited with code 1";
let prompt = build_auto_retry_prompt("input", short_error, 1);
assert!(prompt.contains(short_error));
let error_portion = &prompt[..prompt.find("Try a").unwrap()];
assert!(!error_portion.ends_with("…. "));
}
#[test]
fn test_auto_retry_prompt_with_file_mention_context() {
let original_text = "explain the bug in this file";
let error = "bash: command not found: cargo";
let prompt = build_auto_retry_prompt(original_text, error, 1);
assert!(
prompt.contains(original_text),
"Retry should include original user text: {prompt}"
);
assert!(
prompt.contains(error),
"Retry should include the tool error: {prompt}"
);
}
#[test]
fn test_session_changes_shared_across_content_prompts() {
let changes = SessionChanges::new();
changes.record("src/main.rs", ChangeKind::Write);
changes.record("src/cli.rs", ChangeKind::Edit);
let snapshot = changes.snapshot();
assert_eq!(snapshot.len(), 2);
assert_eq!(snapshot[0].path, "src/main.rs");
assert_eq!(snapshot[0].kind, ChangeKind::Write);
assert_eq!(snapshot[1].path, "src/cli.rs");
assert_eq!(snapshot[1].kind, ChangeKind::Edit);
let output = format_changes(&changes);
assert!(output.contains("2 files"));
assert!(output.contains("src/main.rs"));
assert!(output.contains("src/cli.rs"));
}
#[test]
fn test_max_auto_retries_constant() {
assert_eq!(MAX_AUTO_RETRIES, 2);
}
#[test]
fn test_is_overflow_error_anthropic() {
assert!(is_overflow_error(
"prompt is too long: 213462 tokens > 200000 maximum"
));
}
#[test]
fn test_is_overflow_error_openai() {
assert!(is_overflow_error(
"Your input exceeds the context window of this model"
));
}
#[test]
fn test_is_overflow_error_google() {
assert!(is_overflow_error(
"The input token count (1196265) exceeds the maximum number of tokens allowed"
));
}
#[test]
fn test_is_overflow_error_generic_too_many_tokens() {
assert!(is_overflow_error("too many tokens in request"));
}
#[test]
fn test_is_overflow_error_context_length_exceeded() {
assert!(is_overflow_error("context length exceeded"));
assert!(is_overflow_error("context_length_exceeded"));
}
#[test]
fn test_is_overflow_error_max_token_exceeded() {
assert!(is_overflow_error(
"exceeded model token limit for this request"
));
assert!(is_overflow_error("token limit exceeded"));
}
#[test]
fn test_is_overflow_error_case_insensitive() {
assert!(is_overflow_error("PROMPT IS TOO LONG"));
assert!(is_overflow_error("Too Many Tokens"));
assert!(is_overflow_error("CONTEXT LENGTH EXCEEDED"));
}
#[test]
fn test_is_overflow_error_bedrock() {
assert!(is_overflow_error("input is too long for requested model"));
}
#[test]
fn test_is_overflow_error_groq() {
assert!(is_overflow_error(
"Please reduce the length of the messages or completion"
));
}
#[test]
fn test_is_overflow_error_xai() {
assert!(is_overflow_error(
"This model's maximum prompt length is 131072 but request contains 537812 tokens"
));
}
#[test]
fn test_is_not_overflow_error() {
assert!(!is_overflow_error("invalid api key"));
assert!(!is_overflow_error("rate limit exceeded"));
assert!(!is_overflow_error("500 Internal Server Error"));
assert!(!is_overflow_error("connection reset"));
assert!(!is_overflow_error("bad request"));
assert!(!is_overflow_error(""));
}
#[test]
fn test_build_overflow_retry_prompt() {
let prompt = build_overflow_retry_prompt("explain the code");
assert!(prompt.contains("explain the code"));
assert!(prompt.contains("auto-compacted"));
}
#[test]
fn test_image_content_block_construction() {
let data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string();
let mime_type = "image/png".to_string();
let content_blocks = [
Content::Text {
text: "describe this image".to_string(),
},
Content::Image {
data: data.clone(),
mime_type: mime_type.clone(),
},
];
assert_eq!(content_blocks.len(), 2);
match &content_blocks[0] {
Content::Text { text } => assert_eq!(text, "describe this image"),
_ => panic!("expected Text content"),
}
match &content_blocks[1] {
Content::Image {
data: d,
mime_type: m,
} => {
assert_eq!(d, &data);
assert_eq!(m, &mime_type);
}
_ => panic!("expected Image content"),
}
}
#[test]
fn test_user_message_with_image_content() {
let content_blocks = vec![
Content::Text {
text: "what is this?".to_string(),
},
Content::Image {
data: "base64data".to_string(),
mime_type: "image/jpeg".to_string(),
},
];
let user_msg = AgentMessage::Llm(Message::User {
content: content_blocks,
timestamp: now_ms(),
});
assert_eq!(user_msg.role(), "user");
if let AgentMessage::Llm(Message::User { content, .. }) = &user_msg {
assert_eq!(content.len(), 2);
} else {
panic!("expected Llm(User) message");
}
}
#[test]
fn test_turn_snapshot_new_is_empty() {
let snap = TurnSnapshot::new();
assert!(snap.is_empty());
assert_eq!(snap.file_count(), 0);
}
#[test]
fn test_turn_snapshot_save_and_restore() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::write(&path, "original content").unwrap();
let path_str = path.to_str().unwrap();
let mut snap = TurnSnapshot::new();
snap.snapshot_file(path_str);
assert!(!snap.is_empty());
assert_eq!(snap.file_count(), 1);
assert_eq!(snap.originals.get(path_str).unwrap(), "original content");
std::fs::write(&path, "modified content").unwrap();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "modified content");
let actions = snap.restore();
assert_eq!(actions.len(), 1);
assert!(actions[0].contains("restored"));
assert_eq!(std::fs::read_to_string(&path).unwrap(), "original content");
}
#[test]
fn test_turn_snapshot_created_files_deleted() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("new_file.txt");
let path_str = path.to_str().unwrap();
let mut snap = TurnSnapshot::new();
snap.record_created(path_str);
assert!(!snap.is_empty());
assert_eq!(snap.file_count(), 1);
std::fs::write(&path, "new content").unwrap();
assert!(path.exists());
let actions = snap.restore();
assert_eq!(actions.len(), 1);
assert!(actions[0].contains("deleted"));
assert!(!path.exists());
}
#[test]
fn test_turn_snapshot_no_duplicate_snapshots() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::write(&path, "v1").unwrap();
let path_str = path.to_str().unwrap();
let mut snap = TurnSnapshot::new();
snap.snapshot_file(path_str);
std::fs::write(&path, "v2").unwrap();
snap.snapshot_file(path_str);
assert_eq!(snap.originals.get(path_str).unwrap(), "v1");
}
#[test]
fn test_turn_snapshot_nonexistent_file() {
let mut snap = TurnSnapshot::new();
snap.snapshot_file("/nonexistent/path/to/file.txt");
assert!(snap.originals.is_empty());
}
#[test]
fn test_turn_snapshot_created_not_duplicated() {
let mut snap = TurnSnapshot::new();
snap.record_created("new.txt");
snap.record_created("new.txt");
assert_eq!(snap.created.len(), 1);
}
#[test]
fn test_turn_snapshot_created_ignores_existing() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::write(&path, "content").unwrap();
let path_str = path.to_str().unwrap();
let mut snap = TurnSnapshot::new();
snap.snapshot_file(path_str);
snap.record_created(path_str);
assert!(snap.created.is_empty());
}
#[test]
fn test_turn_history_new_is_empty() {
let hist = TurnHistory::new();
assert!(hist.is_empty());
assert_eq!(hist.len(), 0);
}
#[test]
fn test_turn_history_push_pop() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("a.txt");
std::fs::write(&path, "original").unwrap();
let mut hist = TurnHistory::new();
let mut snap = TurnSnapshot::new();
snap.snapshot_file(path.to_str().unwrap());
hist.push(snap);
assert_eq!(hist.len(), 1);
let popped = hist.pop();
assert!(popped.is_some());
assert_eq!(hist.len(), 0);
}
#[test]
fn test_turn_history_skips_empty_snapshots() {
let mut hist = TurnHistory::new();
hist.push(TurnSnapshot::new()); assert!(hist.is_empty());
}
#[test]
fn test_turn_history_undo_last_n() {
let dir = tempfile::tempdir().unwrap();
let path_a = dir.path().join("a.txt");
std::fs::write(&path_a, "a_original").unwrap();
let mut snap1 = TurnSnapshot::new();
snap1.snapshot_file(path_a.to_str().unwrap());
let path_b = dir.path().join("b.txt");
std::fs::write(&path_b, "b_original").unwrap();
let mut snap2 = TurnSnapshot::new();
snap2.snapshot_file(path_b.to_str().unwrap());
let mut hist = TurnHistory::new();
hist.push(snap1);
hist.push(snap2);
assert_eq!(hist.len(), 2);
std::fs::write(&path_a, "a_modified").unwrap();
std::fs::write(&path_b, "b_modified").unwrap();
let actions = hist.undo_last(1);
assert!(!actions.is_empty());
assert_eq!(std::fs::read_to_string(&path_b).unwrap(), "b_original");
assert_eq!(std::fs::read_to_string(&path_a).unwrap(), "a_modified");
assert_eq!(hist.len(), 1);
let actions = hist.undo_last(1);
assert!(!actions.is_empty());
assert_eq!(std::fs::read_to_string(&path_a).unwrap(), "a_original");
assert!(hist.is_empty());
}
#[test]
fn test_turn_history_undo_more_than_available() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("x.txt");
std::fs::write(&path, "orig").unwrap();
let mut snap = TurnSnapshot::new();
snap.snapshot_file(path.to_str().unwrap());
let mut hist = TurnHistory::new();
hist.push(snap);
std::fs::write(&path, "changed").unwrap();
let actions = hist.undo_last(5);
assert!(!actions.is_empty());
assert_eq!(std::fs::read_to_string(&path).unwrap(), "orig");
assert!(hist.is_empty());
}
#[test]
fn test_turn_history_clear() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("c.txt");
std::fs::write(&path, "content").unwrap();
let mut snap = TurnSnapshot::new();
snap.snapshot_file(path.to_str().unwrap());
let mut hist = TurnHistory::new();
hist.push(snap);
assert_eq!(hist.len(), 1);
hist.clear();
assert!(hist.is_empty());
}
#[test]
fn test_truncate_audit_args_short_values() {
let args = serde_json::json!({"path": "src/main.rs", "command": "cargo test"});
let truncated = truncate_audit_args(&args);
assert_eq!(
truncated, args,
"Short strings should pass through unchanged"
);
}
#[test]
fn test_truncate_audit_args_long_values() {
let long_content = "x".repeat(500);
let args = serde_json::json!({"path": "test.txt", "content": long_content});
let truncated = truncate_audit_args(&args);
let content_val = truncated.get("content").unwrap().as_str().unwrap();
assert!(content_val.len() < 500, "Long content should be truncated");
assert!(
content_val.contains("... [truncated, 500 chars total]"),
"Should include truncation marker"
);
assert_eq!(truncated.get("path").unwrap().as_str().unwrap(), "test.txt");
}
#[test]
fn test_truncate_audit_args_non_string() {
let args = serde_json::json!({"count": 42, "flag": true, "ratio": 3.15});
let truncated = truncate_audit_args(&args);
assert_eq!(truncated, args, "Non-string values should pass through");
}
#[test]
fn test_truncate_audit_args_nested_object() {
let args = serde_json::json!({"meta": {"key": "value"}, "name": "test"});
let truncated = truncate_audit_args(&args);
assert_eq!(
truncated.get("meta").unwrap(),
&serde_json::json!({"key": "value"})
);
}
#[test]
fn test_audit_enabled_default_false() {
let fresh = AtomicBool::new(false);
assert!(!fresh.load(Ordering::Relaxed));
}
#[test]
fn test_read_audit_log_missing_file() {
let entries = read_audit_log(10);
let _ = entries;
}
#[test]
fn test_truncate_audit_args_exactly_200() {
let exact = "y".repeat(200);
let args = serde_json::json!({"content": exact});
let truncated = truncate_audit_args(&args);
assert_eq!(
truncated.get("content").unwrap().as_str().unwrap(),
exact,
"Exactly 200-char string should not be truncated"
);
}
#[test]
fn test_truncate_audit_args_201() {
let over = "z".repeat(201);
let args = serde_json::json!({"content": over});
let truncated = truncate_audit_args(&args);
let val = truncated.get("content").unwrap().as_str().unwrap();
assert!(
val.contains("... [truncated, 201 chars total]"),
"201-char string should be truncated"
);
}
}