use crate::format::*;
use std::collections::HashMap;
use yoagent::skills::SkillSet;
use yoagent::ThinkingLevel;
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const MAX_CONTEXT_TOKENS: u64 = 200_000;
pub const AUTO_COMPACT_THRESHOLD: f64 = 0.80;
pub const DEFAULT_SESSION_PATH: &str = "yoyo-session.json";
pub const AUTO_SAVE_SESSION_PATH: &str = ".yoyo/last-session.json";
pub const SYSTEM_PROMPT: &str = r#"You are a coding assistant working in the user's terminal.
You have access to the filesystem and shell. Be direct and concise.
When the user asks you to do something, do it — don't just explain how.
Use tools proactively: read files to understand context, run commands to verify your work.
After making changes, run tests or verify the result when appropriate."#;
pub const KNOWN_PROVIDERS: &[&str] = &[
"anthropic",
"openai",
"google",
"openrouter",
"ollama",
"xai",
"groq",
"deepseek",
"mistral",
"cerebras",
"zai",
"custom",
];
#[derive(Debug, Clone, Default)]
pub struct PermissionConfig {
pub allow: Vec<String>,
pub deny: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct DirectoryRestrictions {
pub allow: Vec<String>,
pub deny: Vec<String>,
}
impl DirectoryRestrictions {
pub fn is_empty(&self) -> bool {
self.allow.is_empty() && self.deny.is_empty()
}
pub fn check_path(&self, path: &str) -> Result<(), String> {
if self.is_empty() {
return Ok(());
}
let resolved = resolve_path(path);
for denied in &self.deny {
let denied_resolved = resolve_path(denied);
if path_is_under(&resolved, &denied_resolved) {
return Err(format!(
"Access denied: '{}' is under restricted directory '{}'",
path, denied
));
}
}
if !self.allow.is_empty() {
let allowed = self.allow.iter().any(|a| {
let a_resolved = resolve_path(a);
path_is_under(&resolved, &a_resolved)
});
if !allowed {
return Err(format!(
"Access denied: '{}' is not under any allowed directory",
path
));
}
}
Ok(())
}
}
fn resolve_path(path: &str) -> String {
if let Ok(canonical) = std::fs::canonicalize(path) {
return canonical.to_string_lossy().to_string();
}
let p = std::path::Path::new(path);
let absolute = if p.is_absolute() {
p.to_path_buf()
} else {
std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("/"))
.join(p)
};
let mut components = Vec::new();
for component in absolute.components() {
match component {
std::path::Component::ParentDir => {
components.pop();
}
std::path::Component::CurDir => {}
other => components.push(other),
}
}
let normalized: std::path::PathBuf = components.iter().collect();
normalized.to_string_lossy().to_string()
}
fn path_is_under(path: &str, dir: &str) -> bool {
let dir_with_sep = if dir.ends_with('/') {
dir.to_string()
} else {
format!("{}/", dir)
};
path == dir || path.starts_with(&dir_with_sep)
}
impl PermissionConfig {
pub fn check(&self, command: &str) -> Option<bool> {
for pattern in &self.deny {
if glob_match(pattern, command) {
return Some(false);
}
}
for pattern in &self.allow {
if glob_match(pattern, command) {
return Some(true);
}
}
None
}
pub fn is_empty(&self) -> bool {
self.allow.is_empty() && self.deny.is_empty()
}
}
pub fn glob_match(pattern: &str, text: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 1 {
return pattern == text;
}
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 {
if !text.starts_with(part) {
return false;
}
pos = part.len();
} else if i == parts.len() - 1 {
if !text[pos..].ends_with(part) {
return false;
}
pos = text.len();
} else {
match text[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
true
}
pub fn parse_toml_array(value: &str) -> Vec<String> {
let trimmed = value.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return Vec::new();
}
let inner = &trimmed[1..trimmed.len() - 1];
inner
.split(',')
.map(|s| {
let s = s.trim();
if (s.starts_with('"') && s.ends_with('"'))
|| (s.starts_with('\'') && s.ends_with('\''))
{
s[1..s.len() - 1].to_string()
} else {
s.to_string()
}
})
.filter(|s| !s.is_empty())
.collect()
}
pub fn parse_permissions_from_config(content: &str) -> PermissionConfig {
let mut config = PermissionConfig::default();
let mut in_permissions = false;
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
if trimmed.starts_with('[') && trimmed.ends_with(']') {
in_permissions = trimmed == "[permissions]";
continue;
}
if !in_permissions {
continue;
}
if let Some((key, value)) = trimmed.split_once('=') {
let key = key.trim();
let value = value.trim();
match key {
"allow" => config.allow = parse_toml_array(value),
"deny" => config.deny = parse_toml_array(value),
_ => {}
}
}
}
config
}
pub fn parse_directories_from_config(content: &str) -> DirectoryRestrictions {
let mut config = DirectoryRestrictions::default();
let mut in_directories = false;
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
if trimmed.starts_with('[') && trimmed.ends_with(']') {
in_directories = trimmed == "[directories]";
continue;
}
if !in_directories {
continue;
}
if let Some((key, value)) = trimmed.split_once('=') {
let key = key.trim();
let value = value.trim();
match key {
"allow" => config.allow = parse_toml_array(value),
"deny" => config.deny = parse_toml_array(value),
_ => {}
}
}
}
config
}
pub struct Config {
pub model: String,
pub api_key: String,
pub provider: String,
pub base_url: Option<String>,
pub skills: SkillSet,
pub system_prompt: String,
pub thinking: ThinkingLevel,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub max_turns: Option<usize>,
pub continue_session: bool,
pub output_path: Option<String>,
pub prompt_arg: Option<String>,
pub verbose: bool,
pub mcp_servers: Vec<String>,
pub openapi_specs: Vec<String>,
pub auto_approve: bool,
pub permissions: PermissionConfig,
pub dir_restrictions: DirectoryRestrictions,
}
static VERBOSE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
pub fn enable_verbose() {
let _ = VERBOSE.set(true);
}
pub fn is_verbose() -> bool {
*VERBOSE.get_or_init(|| false)
}
pub const PROJECT_CONTEXT_FILES: &[&str] = &["YOYO.md", "CLAUDE.md", ".yoyo/instructions.md"];
pub fn print_help() {
println!("yoyo v{VERSION} — a coding agent growing up in public");
println!();
println!("Usage: yoyo [OPTIONS]");
println!();
println!("Options:");
println!(" --model <name> Model to use (default: claude-opus-4-6)");
println!(" --provider <name> Provider: anthropic (default), openai, google, openrouter,");
println!(" ollama, xai, groq, deepseek, mistral, cerebras, zai, custom");
println!(" --base-url <url> Custom API endpoint (e.g., http://localhost:11434/v1)");
println!(" --thinking <lvl> Enable extended thinking (off, minimal, low, medium, high)");
println!(" --max-tokens <n> Maximum output tokens per response (default: 8192)");
println!(" --max-turns <n> Maximum agent turns per prompt (default: 50)");
println!(" --temperature <f> Sampling temperature (0.0-1.0, default: model default)");
println!(" --skills <dir> Directory containing skill files");
println!(" --system <text> Custom system prompt (overrides default)");
println!(" --system-file <f> Read system prompt from file");
println!(" --prompt, -p <t> Run a single prompt and exit (no REPL)");
println!(" --output, -o <f> Write final response text to a file");
println!(" --api-key <key> API key (overrides provider-specific env var)");
println!(" --mcp <cmd> Connect to an MCP server via stdio (repeatable)");
println!(" --openapi <spec> Load OpenAPI spec file and register API tools (repeatable)");
println!(" --no-color Disable colored output (also respects NO_COLOR env)");
println!(" --verbose, -v Show debug info (API errors, request details)");
println!(" --yes, -y Auto-approve all tool executions (skip confirmation prompts)");
println!(" --allow <pat> Auto-approve bash commands matching glob pattern (repeatable)");
println!(" --deny <pat> Auto-deny bash commands matching glob pattern (repeatable)");
println!(" --allow-dir <d> Restrict file access to this directory (repeatable)");
println!(" --deny-dir <d> Block file access to this directory (repeatable)");
println!(" --continue, -c Resume last saved session");
println!(" --help, -h Show this help message");
println!(" --version, -V Show version");
println!();
println!("Commands (in REPL):");
println!(" /quit, /exit Exit the agent");
println!(" /clear Clear conversation history");
println!(" /compact Compact conversation to save context space");
println!(" /commit [msg] Commit staged changes (AI-generates message if no msg)");
println!(" /config Show all current settings");
println!(" /context Show loaded project context files (YOYO.md)");
println!(" /cost Show estimated session cost");
println!(" /diff Show git diff summary of uncommitted changes");
println!(" /docs <crate> Look up docs.rs documentation for a Rust crate");
println!(" /find <pattern> Fuzzy-search project files by name");
println!(" /fix Auto-fix build/lint errors (runs checks, sends failures to AI)");
println!(" /forget <n> Remove a project memory by index");
println!(" /git <subcmd> Quick git: status, log [n], add <path>, stash, stash pop");
println!(" /health Run project health checks (auto-detects project type)");
println!(" /pr [number] List open PRs, or view details of a specific PR");
println!(" /history Show summary of conversation messages");
println!(" /search <query> Search conversation history for matching messages");
println!(" /init Create a starter YOYO.md project context file");
println!(" /lint Auto-detect and run project linter");
println!(" /load [path] Load session from file");
println!(" /memories List project-specific memories");
println!(" /model <name> Switch model mid-session");
println!(" /retry Re-send the last user input");
println!(" /remember <note> Save a project-specific memory (persists across sessions)");
println!(" /review [path] AI code review: staged changes (default) or a specific file");
println!(" /run <cmd> Run a shell command directly (no AI, no tokens)");
println!(" /save [path] Save session to file");
println!(" /spawn <task> Spawn a subagent with fresh context to handle a task");
println!(" /status Show session info");
println!(" /test Auto-detect and run project tests");
println!(" /think [level] Show or change thinking level (off/low/medium/high)");
println!(" /tokens Show token usage and context window");
println!(" /tree [depth] Show project directory tree (default depth: 3)");
println!(" /undo Revert all uncommitted changes (git checkout)");
println!(" /version Show yoyo version");
println!();
println!("Environment:");
println!(" ANTHROPIC_API_KEY API key for Anthropic (default provider)");
println!(" OPENAI_API_KEY API key for OpenAI");
println!(" GOOGLE_API_KEY API key for Google/Gemini");
println!(" GROQ_API_KEY API key for Groq");
println!(" XAI_API_KEY API key for xAI");
println!(" DEEPSEEK_API_KEY API key for DeepSeek");
println!(" OPENROUTER_API_KEY API key for OpenRouter");
println!(" ZAI_API_KEY API key for ZAI (Zhipu AI / z.ai)");
println!(" API_KEY Fallback API key (any provider)");
println!();
println!("Config files (searched in order, first found wins):");
println!(" .yoyo.toml Project-level config (current directory)");
println!(" ~/.config/yoyo/config.toml User-level config");
println!();
println!("Config file format (key = value):");
println!(" model = \"claude-sonnet-4-20250514\"");
println!(" provider = \"openai\"");
println!(" base_url = \"http://localhost:11434/v1\"");
println!(" thinking = \"medium\"");
println!(" max_tokens = 4096");
println!(" max_turns = 20");
println!(" api_key = \"sk-ant-...\"");
println!();
println!(" [permissions]");
println!(" allow = [\"git *\", \"cargo *\"]");
println!(" deny = [\"rm -rf *\"]");
println!();
println!(" [directories]");
println!(" allow = [\"./src\", \"./tests\"]");
println!(" deny = [\"~/.ssh\", \"/etc\"]");
println!();
println!("CLI flags override config file values.");
}
pub fn print_banner() {
println!(
"\n{BOLD}{CYAN} yoyo{RESET} v{VERSION} {DIM}— a coding agent growing up in public{RESET}"
);
println!("{DIM} Type /help for commands, /quit to exit{RESET}\n");
}
pub fn parse_thinking_level(s: &str) -> ThinkingLevel {
match s.to_lowercase().as_str() {
"off" | "none" => ThinkingLevel::Off,
"minimal" | "min" => ThinkingLevel::Minimal,
"low" => ThinkingLevel::Low,
"medium" | "med" => ThinkingLevel::Medium,
"high" | "max" => ThinkingLevel::High,
_ => {
eprintln!(
"{YELLOW}warning:{RESET} Unknown thinking level '{s}', using 'medium'. \
Valid: off, minimal, low, medium, high"
);
ThinkingLevel::Medium
}
}
}
pub fn clamp_temperature(t: f32) -> f32 {
if t < 0.0 {
eprintln!("{YELLOW}warning:{RESET} Temperature {t} is below 0.0, clamping to 0.0");
0.0
} else if t > 1.0 {
eprintln!("{YELLOW}warning:{RESET} Temperature {t} is above 1.0, clamping to 1.0");
1.0
} else {
t
}
}
const KNOWN_FLAGS: &[&str] = &[
"--model",
"--provider",
"--base-url",
"--thinking",
"--max-tokens",
"--max-turns",
"--temperature",
"--skills",
"--system",
"--system-file",
"--prompt",
"-p",
"--output",
"-o",
"--api-key",
"--mcp",
"--openapi",
"--allow",
"--deny",
"--allow-dir",
"--deny-dir",
"--no-color",
"--verbose",
"-v",
"--yes",
"-y",
"--continue",
"-c",
"--help",
"-h",
"--version",
"-V",
];
pub fn warn_unknown_flags(args: &[String], flags_needing_values: &[&str]) {
let mut skip_next = false;
for arg in args.iter().skip(1) {
if skip_next {
skip_next = false;
continue;
}
if arg.starts_with('-') {
if flags_needing_values.contains(&arg.as_str()) {
skip_next = true; } else if !KNOWN_FLAGS.contains(&arg.as_str()) {
eprintln!(
"{YELLOW}warning:{RESET} Unknown flag '{arg}' — ignored. Run --help for usage."
);
}
}
}
}
pub const MAX_PROJECT_FILES: usize = 200;
pub fn get_project_file_listing() -> Option<String> {
let output = std::process::Command::new("git")
.args(["ls-files"])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8_lossy(&output.stdout);
let files: Vec<&str> = stdout.lines().filter(|l| !l.is_empty()).collect();
if files.is_empty() {
return None;
}
let total = files.len();
let capped: Vec<&str> = files.into_iter().take(MAX_PROJECT_FILES).collect();
let mut listing = capped.join("\n");
if total > MAX_PROJECT_FILES {
listing.push_str(&format!(
"\n... and {} more files",
total - MAX_PROJECT_FILES
));
}
Some(listing)
}
pub fn get_recently_changed_files(max_files: usize) -> Option<Vec<String>> {
let output = std::process::Command::new("git")
.args([
"log",
"--diff-filter=M",
"--name-only",
"--pretty=format:",
"-n",
"20",
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut seen = std::collections::HashSet::new();
let files: Vec<String> = stdout
.lines()
.filter(|l| !l.is_empty())
.filter(|l| seen.insert(l.to_string()))
.take(max_files)
.map(|l| l.to_string())
.collect();
if files.is_empty() {
None
} else {
Some(files)
}
}
pub const MAX_RECENT_FILES: usize = 20;
pub fn load_project_context() -> Option<String> {
let mut context = String::new();
let mut found = Vec::new();
for name in PROJECT_CONTEXT_FILES {
if let Ok(content) = std::fs::read_to_string(name) {
let content = content.trim();
if !content.is_empty() {
if !context.is_empty() {
context.push_str("\n\n");
}
context.push_str(content);
found.push(*name);
}
}
}
if let Some(file_listing) = get_project_file_listing() {
if !context.is_empty() {
context.push_str("\n\n");
}
context.push_str("## Project Files\n\n");
context.push_str(&file_listing);
if found.is_empty() {
eprintln!("{DIM} context: project file listing{RESET}");
}
}
if let Some(recent_files) = get_recently_changed_files(MAX_RECENT_FILES) {
if !context.is_empty() {
context.push_str("\n\n");
}
context.push_str("## Recently Changed Files\n\n");
context.push_str(&recent_files.join("\n"));
}
let memory = crate::memory::load_memories();
if let Some(memories_section) = crate::memory::format_memories_for_prompt(&memory) {
if !context.is_empty() {
context.push_str("\n\n");
}
context.push_str(&memories_section);
}
if found.is_empty() && context.is_empty() {
None
} else {
for name in &found {
eprintln!("{DIM} context: {name}{RESET}");
}
if context.contains("## Recently Changed Files") {
eprintln!("{DIM} context: recently changed files{RESET}");
}
if !memory.entries.is_empty() {
eprintln!(
"{DIM} context: {} project memories{RESET}",
memory.entries.len()
);
}
Some(context)
}
}
pub fn list_project_context_files() -> Vec<(&'static str, usize)> {
let mut result = Vec::new();
for name in PROJECT_CONTEXT_FILES {
if let Ok(content) = std::fs::read_to_string(name) {
let content = content.trim();
if !content.is_empty() {
let lines = content.lines().count();
result.push((*name, lines));
}
}
}
result
}
const CONFIG_FILE_NAMES: &[&str] = &[".yoyo.toml"];
fn user_config_path() -> Option<std::path::PathBuf> {
dirs_hint().map(|dir| dir.join("yoyo").join("config.toml"))
}
fn dirs_hint() -> Option<std::path::PathBuf> {
std::env::var("XDG_CONFIG_HOME")
.ok()
.map(std::path::PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".config"))
})
}
fn data_dir_hint() -> Option<std::path::PathBuf> {
std::env::var("XDG_DATA_HOME")
.ok()
.map(std::path::PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".local").join("share"))
})
}
pub fn history_file_path() -> Option<std::path::PathBuf> {
if let Some(data_dir) = data_dir_hint() {
let yoyo_dir = data_dir.join("yoyo");
if std::fs::create_dir_all(&yoyo_dir).is_ok() {
return Some(yoyo_dir.join("history"));
}
}
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".yoyo_history"))
}
pub fn parse_config_file(content: &str) -> HashMap<String, String> {
let mut map = HashMap::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
let key = key.trim().to_string();
let value = value.trim();
let value = if (value.starts_with('"') && value.ends_with('"'))
|| (value.starts_with('\'') && value.ends_with('\''))
{
value[1..value.len() - 1].to_string()
} else {
value.to_string()
};
map.insert(key, value);
}
}
map
}
fn load_config_file() -> HashMap<String, String> {
for name in CONFIG_FILE_NAMES {
if let Ok(content) = std::fs::read_to_string(name) {
eprintln!("{DIM} config: {name}{RESET}");
return parse_config_file(&content);
}
}
if let Some(path) = user_config_path() {
if let Ok(content) = std::fs::read_to_string(&path) {
eprintln!("{DIM} config: {}{RESET}", path.display());
return parse_config_file(&content);
}
}
HashMap::new()
}
fn load_permissions_from_config_file() -> PermissionConfig {
for name in CONFIG_FILE_NAMES {
if let Ok(content) = std::fs::read_to_string(name) {
return parse_permissions_from_config(&content);
}
}
if let Some(path) = user_config_path() {
if let Ok(content) = std::fs::read_to_string(&path) {
return parse_permissions_from_config(&content);
}
}
PermissionConfig::default()
}
fn load_directories_from_config_file() -> DirectoryRestrictions {
for name in CONFIG_FILE_NAMES {
if let Ok(content) = std::fs::read_to_string(name) {
return parse_directories_from_config(&content);
}
}
if let Some(path) = user_config_path() {
if let Ok(content) = std::fs::read_to_string(&path) {
return parse_directories_from_config(&content);
}
}
DirectoryRestrictions::default()
}
pub fn parse_args(args: &[String]) -> Option<Config> {
if args.iter().any(|a| a == "--help" || a == "-h") {
print_help();
return None;
}
if args.iter().any(|a| a == "--version" || a == "-V") {
println!("yoyo v{VERSION}");
return None;
}
let file_config = load_config_file();
let flags_needing_values = [
"--model",
"--provider",
"--base-url",
"--thinking",
"--max-tokens",
"--max-turns",
"--temperature",
"--skills",
"--system",
"--system-file",
"--prompt",
"-p",
"--output",
"-o",
"--api-key",
"--mcp",
"--openapi",
"--allow",
"--deny",
"--allow-dir",
"--deny-dir",
];
for flag in &flags_needing_values {
if let Some(pos) = args.iter().position(|a| a == flag) {
match args.get(pos + 1) {
None => {
eprintln!("{RED}error:{RESET} {flag} requires a value");
eprintln!("Run with --help for usage information.");
std::process::exit(1);
}
Some(next)
if next.starts_with('-')
&& !next.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) =>
{
eprintln!(
"{YELLOW}warning:{RESET} {flag} value looks like another flag: '{next}'"
);
}
_ => {}
}
}
}
warn_unknown_flags(args, &flags_needing_values);
let provider = args
.iter()
.position(|a| a == "--provider")
.and_then(|i| args.get(i + 1))
.cloned()
.or_else(|| file_config.get("provider").cloned())
.unwrap_or_else(|| "anthropic".into())
.to_lowercase();
if !KNOWN_PROVIDERS.contains(&provider.as_str()) {
eprintln!(
"{YELLOW}warning:{RESET} Unknown provider '{provider}'. Known providers: {}",
KNOWN_PROVIDERS.join(", ")
);
}
let base_url = args
.iter()
.position(|a| a == "--base-url")
.and_then(|i| args.get(i + 1))
.cloned()
.or_else(|| file_config.get("base_url").cloned());
let api_key_from_flag = args
.iter()
.position(|a| a == "--api-key")
.and_then(|i| args.get(i + 1))
.cloned();
let provider_env_var = provider_api_key_env(&provider);
let api_key = match api_key_from_flag {
Some(key) if !key.is_empty() => key,
_ => {
let from_provider_env = provider_env_var
.and_then(|var| std::env::var(var).ok())
.filter(|k| !k.is_empty());
match from_provider_env {
Some(key) => key,
None => {
match std::env::var("ANTHROPIC_API_KEY").or_else(|_| std::env::var("API_KEY")) {
Ok(key) if !key.is_empty() => key,
_ => match file_config.get("api_key").cloned() {
Some(key) if !key.is_empty() => key,
_ => {
if provider == "ollama" || provider == "custom" {
"not-needed".to_string()
} else {
let env_hint = provider_env_var.unwrap_or("ANTHROPIC_API_KEY");
eprintln!("{RED}error:{RESET} No API key found.");
eprintln!(
"Set {env_hint} env var, use --api-key <key>, or add api_key to .yoyo.toml."
);
std::process::exit(1);
}
}
},
}
}
}
}
};
let model = args
.iter()
.position(|a| a == "--model")
.and_then(|i| args.get(i + 1))
.cloned()
.or_else(|| file_config.get("model").cloned())
.unwrap_or_else(|| default_model_for_provider(&provider));
let skill_dirs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--skills")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let skills = if skill_dirs.is_empty() {
SkillSet::empty()
} else {
match SkillSet::load(&skill_dirs) {
Ok(s) => s,
Err(e) => {
eprintln!("{YELLOW}warning:{RESET} Failed to load skills: {e}");
SkillSet::empty()
}
}
};
let custom_system = args
.iter()
.position(|a| a == "--system")
.and_then(|i| args.get(i + 1))
.cloned();
let system_from_file = args
.iter()
.position(|a| a == "--system-file")
.and_then(|i| args.get(i + 1))
.map(|path| {
std::fs::read_to_string(path).unwrap_or_else(|e| {
eprintln!("{RED}error:{RESET} Failed to read system prompt file '{path}': {e}");
std::process::exit(1);
})
});
let mut system_prompt = system_from_file
.or(custom_system)
.unwrap_or_else(|| SYSTEM_PROMPT.to_string());
if let Some(project_context) = load_project_context() {
system_prompt.push_str("\n\n# Project Instructions\n\n");
system_prompt.push_str(&project_context);
}
let thinking = args
.iter()
.position(|a| a == "--thinking")
.and_then(|i| args.get(i + 1))
.map(|s| parse_thinking_level(s))
.or_else(|| file_config.get("thinking").map(|s| parse_thinking_level(s)))
.unwrap_or(ThinkingLevel::Off);
let continue_session = args.iter().any(|a| a == "--continue" || a == "-c");
let max_tokens = args
.iter()
.position(|a| a == "--max-tokens")
.and_then(|i| args.get(i + 1))
.and_then(|s| {
s.parse::<u32>().ok().or_else(|| {
eprintln!(
"{YELLOW}warning:{RESET} Invalid --max-tokens value '{s}', using default"
);
None
})
})
.or_else(|| {
file_config
.get("max_tokens")
.and_then(|s| s.parse::<u32>().ok())
});
let temperature = args
.iter()
.position(|a| a == "--temperature")
.and_then(|i| args.get(i + 1))
.and_then(|s| {
s.parse::<f32>().ok().or_else(|| {
eprintln!(
"{YELLOW}warning:{RESET} Invalid --temperature value '{s}', using default"
);
None
})
})
.or_else(|| {
file_config
.get("temperature")
.and_then(|s| s.parse::<f32>().ok())
})
.map(clamp_temperature);
let max_turns = args
.iter()
.position(|a| a == "--max-turns")
.and_then(|i| args.get(i + 1))
.and_then(|s| {
s.parse::<usize>().ok().or_else(|| {
eprintln!("{YELLOW}warning:{RESET} Invalid --max-turns value '{s}', using default");
None
})
})
.or_else(|| {
file_config
.get("max_turns")
.and_then(|s| s.parse::<usize>().ok())
});
let output_path = args
.iter()
.position(|a| a == "--output" || a == "-o")
.and_then(|i| args.get(i + 1))
.cloned();
let prompt_arg = args
.iter()
.position(|a| a == "--prompt" || a == "-p")
.and_then(|i| args.get(i + 1))
.cloned();
let verbose = args.iter().any(|a| a == "--verbose" || a == "-v");
let auto_approve = args.iter().any(|a| a == "--yes" || a == "-y");
let cli_allow: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--allow")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let cli_deny: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--deny")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let permissions = if cli_allow.is_empty() && cli_deny.is_empty() {
load_permissions_from_config_file()
} else {
PermissionConfig {
allow: cli_allow,
deny: cli_deny,
}
};
let cli_allow_dirs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--allow-dir")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let cli_deny_dirs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--deny-dir")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let dir_restrictions = if cli_allow_dirs.is_empty() && cli_deny_dirs.is_empty() {
load_directories_from_config_file()
} else {
DirectoryRestrictions {
allow: cli_allow_dirs,
deny: cli_deny_dirs,
}
};
let mcp_servers: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--mcp")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let openapi_specs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--openapi")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
Some(Config {
model,
api_key,
provider,
base_url,
skills,
system_prompt,
thinking,
max_tokens,
temperature,
max_turns,
continue_session,
output_path,
prompt_arg,
verbose,
mcp_servers,
openapi_specs,
auto_approve,
permissions,
dir_restrictions,
})
}
pub fn provider_api_key_env(provider: &str) -> Option<&'static str> {
match provider {
"openai" => Some("OPENAI_API_KEY"),
"google" => Some("GOOGLE_API_KEY"),
"groq" => Some("GROQ_API_KEY"),
"xai" => Some("XAI_API_KEY"),
"deepseek" => Some("DEEPSEEK_API_KEY"),
"openrouter" => Some("OPENROUTER_API_KEY"),
"mistral" => Some("MISTRAL_API_KEY"),
"cerebras" => Some("CEREBRAS_API_KEY"),
"zai" => Some("ZAI_API_KEY"),
"anthropic" => Some("ANTHROPIC_API_KEY"),
_ => None,
}
}
pub fn default_model_for_provider(provider: &str) -> String {
match provider {
"openai" => "gpt-4o".into(),
"google" => "gemini-2.0-flash".into(),
"openrouter" => "anthropic/claude-sonnet-4-20250514".into(),
"ollama" => "llama3.2".into(),
"xai" => "grok-3".into(),
"groq" => "llama-3.3-70b-versatile".into(),
"deepseek" => "deepseek-chat".into(),
"mistral" => "mistral-large-latest".into(),
"cerebras" => "llama-3.3-70b".into(),
"zai" => "glm-4-plus".into(),
_ => "claude-opus-4-6".into(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_constant_exists() {
assert!(
VERSION.contains('.'),
"Version should contain a dot: {VERSION}"
);
}
#[test]
fn test_parse_thinking_level() {
assert_eq!(parse_thinking_level("off"), ThinkingLevel::Off);
assert_eq!(parse_thinking_level("none"), ThinkingLevel::Off);
assert_eq!(parse_thinking_level("minimal"), ThinkingLevel::Minimal);
assert_eq!(parse_thinking_level("min"), ThinkingLevel::Minimal);
assert_eq!(parse_thinking_level("low"), ThinkingLevel::Low);
assert_eq!(parse_thinking_level("medium"), ThinkingLevel::Medium);
assert_eq!(parse_thinking_level("med"), ThinkingLevel::Medium);
assert_eq!(parse_thinking_level("high"), ThinkingLevel::High);
assert_eq!(parse_thinking_level("max"), ThinkingLevel::High);
assert_eq!(parse_thinking_level("HIGH"), ThinkingLevel::High);
assert_eq!(parse_thinking_level("Medium"), ThinkingLevel::Medium);
assert_eq!(parse_thinking_level("unknown"), ThinkingLevel::Medium);
}
#[test]
fn test_system_flag_parsing() {
let args = [
"yoyo".to_string(),
"--system".to_string(),
"You are a Rust expert.".to_string(),
];
let system = args
.iter()
.position(|a| a == "--system")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(system, Some("You are a Rust expert.".to_string()));
}
#[test]
fn test_system_flag_missing() {
let args = ["yoyo".to_string()];
let system = args
.iter()
.position(|a| a == "--system")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(system, None);
}
#[test]
fn test_system_file_flag() {
let args = [
"yoyo".to_string(),
"--system-file".to_string(),
"prompt.txt".to_string(),
];
let system_file = args
.iter()
.position(|a| a == "--system-file")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(system_file, Some("prompt.txt".to_string()));
}
#[test]
fn test_continue_flag_parsing() {
let args_short = ["yoyo".to_string(), "-c".to_string()];
assert!(args_short.iter().any(|a| a == "--continue" || a == "-c"));
let args_long = ["yoyo".to_string(), "--continue".to_string()];
assert!(args_long.iter().any(|a| a == "--continue" || a == "-c"));
let args_none = ["yoyo".to_string()];
assert!(!args_none.iter().any(|a| a == "--continue" || a == "-c"));
}
#[test]
fn test_prompt_flag_parsing() {
let args = [
"yoyo".to_string(),
"-p".to_string(),
"explain this code".to_string(),
];
let prompt = args
.iter()
.position(|a| a == "--prompt" || a == "-p")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(prompt, Some("explain this code".to_string()));
let args_long = [
"yoyo".to_string(),
"--prompt".to_string(),
"what does this do?".to_string(),
];
let prompt_long = args_long
.iter()
.position(|a| a == "--prompt" || a == "-p")
.and_then(|i| args_long.get(i + 1))
.cloned();
assert_eq!(prompt_long, Some("what does this do?".to_string()));
let args_none = ["yoyo".to_string()];
let prompt_none = args_none
.iter()
.position(|a| a == "--prompt" || a == "-p")
.and_then(|i| args_none.get(i + 1))
.cloned();
assert_eq!(prompt_none, None);
}
#[test]
fn test_output_flag_parsing() {
let args = [
"yoyo".to_string(),
"-o".to_string(),
"output.md".to_string(),
];
let output = args
.iter()
.position(|a| a == "--output" || a == "-o")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(output, Some("output.md".to_string()));
let args_long = [
"yoyo".to_string(),
"--output".to_string(),
"result.txt".to_string(),
];
let output_long = args_long
.iter()
.position(|a| a == "--output" || a == "-o")
.and_then(|i| args_long.get(i + 1))
.cloned();
assert_eq!(output_long, Some("result.txt".to_string()));
let args_none = ["yoyo".to_string()];
let output_none = args_none
.iter()
.position(|a| a == "--output" || a == "-o")
.and_then(|i| args_none.get(i + 1))
.cloned();
assert_eq!(output_none, None);
}
#[test]
fn test_default_session_path() {
assert_eq!(DEFAULT_SESSION_PATH, "yoyo-session.json");
}
#[test]
fn test_auto_compact_threshold_constants() {
assert_eq!(MAX_CONTEXT_TOKENS, 200_000);
assert!((AUTO_COMPACT_THRESHOLD - 0.80).abs() < f64::EPSILON);
}
#[test]
fn test_max_tokens_flag_parsing() {
let args = [
"yoyo".to_string(),
"--max-tokens".to_string(),
"4096".to_string(),
];
let max_tokens = args
.iter()
.position(|a| a == "--max-tokens")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse::<u32>().ok());
assert_eq!(max_tokens, Some(4096));
}
#[test]
fn test_max_tokens_flag_missing() {
let args = ["yoyo".to_string()];
let max_tokens = args
.iter()
.position(|a| a == "--max-tokens")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse::<u32>().ok());
assert_eq!(max_tokens, None);
}
#[test]
fn test_max_tokens_flag_invalid() {
let args = [
"yoyo".to_string(),
"--max-tokens".to_string(),
"not_a_number".to_string(),
];
let max_tokens = args
.iter()
.position(|a| a == "--max-tokens")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse::<u32>().ok());
assert_eq!(max_tokens, None);
}
#[test]
fn test_no_color_flag_recognized() {
let args = ["yoyo".to_string(), "--no-color".to_string()];
assert!(args.iter().any(|a| a == "--no-color"));
}
#[test]
fn test_parse_config_file_basic() {
let content = r#"
model = "claude-sonnet-4-20250514"
thinking = "medium"
max_tokens = 4096
"#;
let config = parse_config_file(content);
assert_eq!(config.get("model").unwrap(), "claude-sonnet-4-20250514");
assert_eq!(config.get("thinking").unwrap(), "medium");
assert_eq!(config.get("max_tokens").unwrap(), "4096");
}
#[test]
fn test_parse_config_file_comments_and_blanks() {
let content = r#"
# This is a comment
model = "claude-opus-4-6"
# Another comment
thinking = "high"
"#;
let config = parse_config_file(content);
assert_eq!(config.get("model").unwrap(), "claude-opus-4-6");
assert_eq!(config.get("thinking").unwrap(), "high");
assert_eq!(config.len(), 2);
}
#[test]
fn test_parse_config_file_no_quotes() {
let content = "model = claude-haiku-35\nmax_tokens = 2048";
let config = parse_config_file(content);
assert_eq!(config.get("model").unwrap(), "claude-haiku-35");
assert_eq!(config.get("max_tokens").unwrap(), "2048");
}
#[test]
fn test_parse_config_file_single_quotes() {
let content = "model = 'claude-opus-4-6'";
let config = parse_config_file(content);
assert_eq!(config.get("model").unwrap(), "claude-opus-4-6");
}
#[test]
fn test_parse_config_file_empty() {
let config = parse_config_file("");
assert!(config.is_empty());
}
#[test]
fn test_parse_config_file_whitespace_handling() {
let content = " model = claude-opus-4-6 ";
let config = parse_config_file(content);
assert_eq!(config.get("model").unwrap(), "claude-opus-4-6");
}
#[test]
fn test_list_project_context_files_returns_vec() {
let files = list_project_context_files();
for (name, lines) in &files {
assert!(!name.is_empty());
assert!(*lines > 0);
}
}
#[test]
fn test_project_context_file_names_not_empty() {
assert_eq!(PROJECT_CONTEXT_FILES.len(), 3);
assert_eq!(PROJECT_CONTEXT_FILES[0], "YOYO.md");
assert_eq!(PROJECT_CONTEXT_FILES[1], "CLAUDE.md");
assert_eq!(PROJECT_CONTEXT_FILES[2], ".yoyo/instructions.md");
for name in PROJECT_CONTEXT_FILES {
assert!(!name.is_empty());
}
}
#[test]
fn test_temperature_flag_parsing() {
let args = [
"yoyo".to_string(),
"--temperature".to_string(),
"0.7".to_string(),
];
let temp = args
.iter()
.position(|a| a == "--temperature")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse::<f32>().ok());
assert_eq!(temp, Some(0.7));
}
#[test]
fn test_temperature_flag_missing() {
let args = ["yoyo".to_string()];
let temp = args
.iter()
.position(|a| a == "--temperature")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse::<f32>().ok());
assert_eq!(temp, None);
}
#[test]
fn test_temperature_flag_invalid() {
let args = [
"yoyo".to_string(),
"--temperature".to_string(),
"not_a_number".to_string(),
];
let temp = args
.iter()
.position(|a| a == "--temperature")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse::<f32>().ok());
assert_eq!(temp, None);
}
#[test]
fn test_verbose_flag_parsing() {
let args_short = ["yoyo".to_string(), "-v".to_string()];
assert!(args_short.iter().any(|a| a == "--verbose" || a == "-v"));
let args_long = ["yoyo".to_string(), "--verbose".to_string()];
assert!(args_long.iter().any(|a| a == "--verbose" || a == "-v"));
let args_none = ["yoyo".to_string()];
assert!(!args_none.iter().any(|a| a == "--verbose" || a == "-v"));
}
#[test]
fn test_clamp_temperature_in_range() {
assert_eq!(clamp_temperature(0.0), 0.0);
assert_eq!(clamp_temperature(0.5), 0.5);
assert_eq!(clamp_temperature(1.0), 1.0);
}
#[test]
fn test_clamp_temperature_below_zero() {
assert_eq!(clamp_temperature(-0.5), 0.0);
assert_eq!(clamp_temperature(-100.0), 0.0);
}
#[test]
fn test_clamp_temperature_above_one() {
assert_eq!(clamp_temperature(1.5), 1.0);
assert_eq!(clamp_temperature(99.0), 1.0);
}
#[test]
fn test_known_flags_contains_all_flags() {
let flags_with_values = [
"--model",
"--thinking",
"--max-tokens",
"--max-turns",
"--temperature",
"--skills",
"--system",
"--system-file",
"--prompt",
"-p",
"--output",
"-o",
"--api-key",
"--openapi",
"--allow",
"--deny",
"--allow-dir",
"--deny-dir",
];
for flag in &flags_with_values {
assert!(
KNOWN_FLAGS.contains(flag),
"Flag {flag} should be in KNOWN_FLAGS"
);
}
}
#[test]
fn test_warn_unknown_flags_no_panic() {
let flags_needing_values = ["--model", "--thinking"];
warn_unknown_flags(
&["yoyo".to_string(), "--unknown".to_string()],
&flags_needing_values,
);
warn_unknown_flags(
&[
"yoyo".to_string(),
"--model".to_string(),
"test".to_string(),
],
&flags_needing_values,
);
warn_unknown_flags(&["yoyo".to_string()], &flags_needing_values);
}
#[test]
fn test_api_key_flag_parsing() {
let args = [
"yoyo".to_string(),
"--api-key".to_string(),
"sk-test-key".to_string(),
];
let api_key = args
.iter()
.position(|a| a == "--api-key")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(api_key, Some("sk-test-key".to_string()));
}
#[test]
fn test_api_key_flag_missing() {
let args = ["yoyo".to_string()];
let api_key = args
.iter()
.position(|a| a == "--api-key")
.and_then(|i| args.get(i + 1))
.cloned();
assert_eq!(api_key, None);
}
#[test]
fn test_api_key_flag_in_known_flags() {
assert!(
KNOWN_FLAGS.contains(&"--api-key"),
"--api-key should be in KNOWN_FLAGS"
);
}
#[test]
fn test_api_key_from_config_file() {
let content = "api_key = \"sk-ant-test-from-config\"";
let config = parse_config_file(content);
assert_eq!(config.get("api_key").unwrap(), "sk-ant-test-from-config");
}
#[test]
fn test_get_project_file_listing_no_panic() {
let result = get_project_file_listing();
if let Some(listing) = &result {
assert!(!listing.is_empty(), "File listing should not be empty");
let lines: Vec<&str> = listing.lines().collect();
assert!(
lines.len() <= MAX_PROJECT_FILES + 1, "File listing should be capped at {} files",
MAX_PROJECT_FILES
);
assert!(
listing.contains("Cargo.toml"),
"File listing should contain Cargo.toml"
);
}
}
#[test]
fn test_max_project_files_constant() {
assert_eq!(MAX_PROJECT_FILES, 200);
}
#[test]
fn test_load_project_context_includes_file_listing() {
let result = load_project_context();
if let Some(context) = &result {
if get_project_file_listing().is_some() {
assert!(
context.contains("## Project Files"),
"Context should contain Project Files section"
);
}
}
}
#[test]
fn test_history_file_path_returns_some() {
let path = history_file_path();
if std::env::var("HOME").is_ok() {
assert!(path.is_some(), "Should return a path when HOME is set");
let p = path.unwrap();
let p_str = p.to_string_lossy();
assert!(
p_str.contains("yoyo"),
"History path should contain 'yoyo': {p_str}"
);
assert!(
p_str.ends_with("history") || p_str.ends_with(".yoyo_history"),
"History path should end with 'history' or '.yoyo_history': {p_str}"
);
}
}
#[test]
fn test_history_file_path_prefers_xdg() {
let dir = std::env::temp_dir().join("yoyo_test_xdg_data");
let _ = std::fs::create_dir_all(&dir);
let path = history_file_path();
if std::env::var("HOME").is_ok() || std::env::var("XDG_DATA_HOME").is_ok() {
assert!(path.is_some());
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_yoyo_md_is_primary_context_file() {
assert_eq!(
PROJECT_CONTEXT_FILES[0], "YOYO.md",
"YOYO.md must be the primary context file"
);
assert!(
PROJECT_CONTEXT_FILES.contains(&"CLAUDE.md"),
"CLAUDE.md should still be supported for compatibility"
);
assert_ne!(
PROJECT_CONTEXT_FILES[0], "CLAUDE.md",
"CLAUDE.md should not be the primary context file"
);
}
#[test]
fn test_data_dir_hint_returns_path() {
if std::env::var("HOME").is_ok() || std::env::var("XDG_DATA_HOME").is_ok() {
let dir = data_dir_hint();
assert!(dir.is_some(), "Should return a data dir path");
}
}
#[test]
fn test_glob_match_exact() {
assert!(glob_match("ls", "ls"));
assert!(!glob_match("ls", "ls -la"));
assert!(!glob_match("ls -la", "ls"));
}
#[test]
fn test_glob_match_wildcard_suffix() {
assert!(glob_match("git *", "git status"));
assert!(glob_match("git *", "git commit -m 'hello'"));
assert!(!glob_match("git *", "echo git"));
assert!(!glob_match("git *", "gitignore"));
}
#[test]
fn test_glob_match_wildcard_prefix() {
assert!(glob_match("*.rs", "main.rs"));
assert!(glob_match("*.rs", "src/main.rs"));
assert!(!glob_match("*.rs", "main.py"));
}
#[test]
fn test_glob_match_wildcard_middle() {
assert!(glob_match("cargo * --release", "cargo build --release"));
assert!(glob_match("cargo * --release", "cargo test --release"));
assert!(!glob_match("cargo * --release", "cargo build --debug"));
}
#[test]
fn test_glob_match_multiple_wildcards() {
assert!(glob_match("*git*", "git status"));
assert!(glob_match("*git*", "echo git hello"));
assert!(glob_match("*git*", "something git something"));
assert!(!glob_match("*git*", "echo hello"));
}
#[test]
fn test_glob_match_star_only() {
assert!(glob_match("*", "anything"));
assert!(glob_match("*", ""));
assert!(glob_match("*", "ls -la /tmp"));
}
#[test]
fn test_glob_match_empty_pattern() {
assert!(glob_match("", ""));
assert!(!glob_match("", "something"));
}
#[test]
fn test_glob_match_rm_rf() {
assert!(glob_match("rm -rf *", "rm -rf /"));
assert!(glob_match("rm -rf *", "rm -rf /tmp"));
assert!(!glob_match("rm -rf *", "rm file.txt"));
assert!(!glob_match("rm -rf *", "rm -r dir"));
}
#[test]
fn test_permission_config_check_allow() {
let config = PermissionConfig {
allow: vec!["git *".to_string(), "cargo *".to_string()],
deny: vec![],
};
assert_eq!(config.check("git status"), Some(true));
assert_eq!(config.check("cargo build"), Some(true));
assert_eq!(config.check("rm -rf /"), None);
}
#[test]
fn test_permission_config_check_deny() {
let config = PermissionConfig {
allow: vec![],
deny: vec!["rm -rf *".to_string(), "sudo *".to_string()],
};
assert_eq!(config.check("rm -rf /tmp"), Some(false));
assert_eq!(config.check("sudo apt install"), Some(false));
assert_eq!(config.check("ls"), None);
}
#[test]
fn test_permission_config_deny_overrides_allow() {
let config = PermissionConfig {
allow: vec!["*".to_string()],
deny: vec!["rm -rf *".to_string()],
};
assert_eq!(config.check("rm -rf /"), Some(false));
assert_eq!(config.check("ls"), Some(true));
assert_eq!(config.check("git status"), Some(true));
}
#[test]
fn test_permission_config_empty() {
let config = PermissionConfig::default();
assert!(config.is_empty());
assert_eq!(config.check("anything"), None);
}
#[test]
fn test_parse_toml_array_basic() {
let arr = parse_toml_array(r#"["git *", "cargo *"]"#);
assert_eq!(arr, vec!["git *", "cargo *"]);
}
#[test]
fn test_parse_toml_array_single() {
let arr = parse_toml_array(r#"["rm -rf *"]"#);
assert_eq!(arr, vec!["rm -rf *"]);
}
#[test]
fn test_parse_toml_array_empty() {
let arr = parse_toml_array("[]");
assert!(arr.is_empty());
}
#[test]
fn test_parse_toml_array_single_quotes() {
let arr = parse_toml_array("['git *', 'ls']");
assert_eq!(arr, vec!["git *", "ls"]);
}
#[test]
fn test_parse_toml_array_not_array() {
let arr = parse_toml_array("not an array");
assert!(arr.is_empty());
}
#[test]
fn test_parse_permissions_from_config() {
let content = r#"
model = "claude-opus-4-6"
thinking = "medium"
[permissions]
allow = ["git *", "cargo *", "echo *"]
deny = ["rm -rf *", "sudo *"]
"#;
let perms = parse_permissions_from_config(content);
assert_eq!(perms.allow, vec!["git *", "cargo *", "echo *"]);
assert_eq!(perms.deny, vec!["rm -rf *", "sudo *"]);
}
#[test]
fn test_parse_permissions_from_config_no_section() {
let content = r#"
model = "claude-opus-4-6"
thinking = "medium"
"#;
let perms = parse_permissions_from_config(content);
assert!(perms.is_empty());
}
#[test]
fn test_parse_permissions_from_config_empty_section() {
let content = r#"
[permissions]
"#;
let perms = parse_permissions_from_config(content);
assert!(perms.is_empty());
}
#[test]
fn test_parse_permissions_from_config_only_allow() {
let content = r#"
[permissions]
allow = ["git *"]
"#;
let perms = parse_permissions_from_config(content);
assert_eq!(perms.allow, vec!["git *"]);
assert!(perms.deny.is_empty());
}
#[test]
fn test_parse_permissions_from_config_other_section_after() {
let content = r#"
[permissions]
allow = ["git *"]
[other]
key = "value"
"#;
let perms = parse_permissions_from_config(content);
assert_eq!(perms.allow, vec!["git *"]);
assert!(perms.deny.is_empty());
}
#[test]
fn test_permission_config_realistic_scenario() {
let config = PermissionConfig {
allow: vec![
"git *".to_string(),
"cargo *".to_string(),
"cat *".to_string(),
"ls *".to_string(),
"echo *".to_string(),
],
deny: vec![
"rm -rf *".to_string(),
"sudo *".to_string(),
"curl * | sh".to_string(),
],
};
assert_eq!(config.check("git status"), Some(true));
assert_eq!(config.check("cargo test"), Some(true));
assert_eq!(config.check("cat Cargo.toml"), Some(true));
assert_eq!(config.check("rm -rf /"), Some(false));
assert_eq!(config.check("sudo rm -rf /"), Some(false));
assert_eq!(config.check("python script.py"), None);
assert_eq!(config.check("npm install"), None);
}
#[test]
fn test_allow_deny_flags_parsing() {
let args = [
"yoyo".to_string(),
"--allow".to_string(),
"git *".to_string(),
"--allow".to_string(),
"cargo *".to_string(),
"--deny".to_string(),
"rm -rf *".to_string(),
];
let allow: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--allow")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let deny: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--deny")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
assert_eq!(allow, vec!["git *", "cargo *"]);
assert_eq!(deny, vec!["rm -rf *"]);
}
#[test]
fn test_openapi_flag_parsing_single() {
let args = [
"yoyo".to_string(),
"--openapi".to_string(),
"petstore.yaml".to_string(),
];
let specs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--openapi")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
assert_eq!(specs, vec!["petstore.yaml"]);
}
#[test]
fn test_openapi_flag_parsing_multiple() {
let args = [
"yoyo".to_string(),
"--openapi".to_string(),
"api1.yaml".to_string(),
"--openapi".to_string(),
"api2.json".to_string(),
"--model".to_string(),
"claude-opus-4-6".to_string(),
];
let specs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--openapi")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
assert_eq!(specs, vec!["api1.yaml", "api2.json"]);
}
#[test]
fn test_openapi_flag_in_known_flags() {
assert!(
KNOWN_FLAGS.contains(&"--openapi"),
"--openapi should be in KNOWN_FLAGS"
);
}
#[test]
fn test_get_recently_changed_files_in_git_repo() {
let result = get_recently_changed_files(20);
if let Some(files) = &result {
assert!(!files.is_empty(), "Should have recently changed files");
let unique: std::collections::HashSet<&String> = files.iter().collect();
assert_eq!(
files.len(),
unique.len(),
"Recently changed files should be deduplicated"
);
assert!(files.len() <= 20, "Should not exceed max_files limit");
}
}
#[test]
fn test_get_recently_changed_files_respects_limit() {
let result = get_recently_changed_files(2);
if let Some(files) = &result {
assert!(
files.len() <= 2,
"Should respect max_files=2, got {}",
files.len()
);
}
}
#[test]
fn test_get_recently_changed_files_no_duplicates() {
let result = get_recently_changed_files(50);
if let Some(files) = &result {
let unique: std::collections::HashSet<&String> = files.iter().collect();
assert_eq!(files.len(), unique.len(), "Files should be deduplicated");
}
}
#[test]
fn test_max_recent_files_constant() {
assert_eq!(MAX_RECENT_FILES, 20);
}
#[test]
fn test_load_project_context_includes_recently_changed() {
let result = load_project_context();
if let Some(context) = &result {
if get_recently_changed_files(MAX_RECENT_FILES).is_some() {
assert!(
context.contains("## Recently Changed Files"),
"Context should contain Recently Changed Files section"
);
}
}
}
#[test]
fn test_directory_restrictions_empty_allows_everything() {
let restrictions = DirectoryRestrictions::default();
assert!(restrictions.is_empty());
assert!(restrictions.check_path("/etc/passwd").is_ok());
assert!(restrictions.check_path("src/main.rs").is_ok());
}
#[test]
fn test_directory_restrictions_deny_blocks_path() {
let restrictions = DirectoryRestrictions {
allow: vec![],
deny: vec!["/etc".to_string()],
};
assert!(restrictions.check_path("/etc/passwd").is_err());
assert!(restrictions.check_path("/etc/shadow").is_err());
assert!(restrictions.check_path("/tmp/file.txt").is_ok());
}
#[test]
fn test_directory_restrictions_allow_restricts_to_listed() {
let cwd = std::env::current_dir()
.unwrap()
.to_string_lossy()
.to_string();
let restrictions = DirectoryRestrictions {
allow: vec![format!("{}/src", cwd)],
deny: vec![],
};
assert!(restrictions
.check_path(&format!("{}/src/main.rs", cwd))
.is_ok());
assert!(restrictions.check_path("/tmp/file.txt").is_err());
}
#[test]
fn test_directory_restrictions_deny_overrides_allow() {
let cwd = std::env::current_dir()
.unwrap()
.to_string_lossy()
.to_string();
let restrictions = DirectoryRestrictions {
allow: vec![cwd.clone()],
deny: vec![format!("{}/secrets", cwd)],
};
assert!(restrictions
.check_path(&format!("{}/src/main.rs", cwd))
.is_ok());
assert!(restrictions
.check_path(&format!("{}/secrets/key.pem", cwd))
.is_err());
}
#[test]
fn test_directory_restrictions_parent_dir_escape_blocked() {
let cwd = std::env::current_dir()
.unwrap()
.to_string_lossy()
.to_string();
let restrictions = DirectoryRestrictions {
allow: vec![format!("{}/src", cwd)],
deny: vec![],
};
assert!(restrictions
.check_path(&format!("{}/src/../secrets/key.pem", cwd))
.is_err());
}
#[test]
fn test_directory_restrictions_relative_paths() {
let cwd = std::env::current_dir()
.unwrap()
.to_string_lossy()
.to_string();
let restrictions = DirectoryRestrictions {
allow: vec![],
deny: vec![format!("{}/secrets", cwd)],
};
assert!(restrictions.check_path("secrets/file.txt").is_err());
assert!(restrictions.check_path("src/main.rs").is_ok());
}
#[test]
fn test_directory_restrictions_exact_dir_match() {
let restrictions = DirectoryRestrictions {
allow: vec![],
deny: vec!["/etc".to_string()],
};
assert!(restrictions.check_path("/etc").is_err());
assert!(restrictions.check_path("/etc/passwd").is_err());
assert!(restrictions.check_path("/etcetc/file").is_ok());
}
#[test]
fn test_resolve_path_normalizes_parent_dir() {
let resolved = resolve_path("/tmp/a/../b");
assert_eq!(resolved, "/tmp/b");
}
#[test]
fn test_resolve_path_absolute() {
let resolved = resolve_path("/usr/bin/env");
assert!(resolved.starts_with('/'));
assert!(resolved.contains("usr"));
}
#[test]
fn test_path_is_under_basic() {
assert!(path_is_under("/etc/passwd", "/etc"));
assert!(path_is_under("/etc", "/etc"));
assert!(!path_is_under("/etcetc", "/etc"));
assert!(!path_is_under("/tmp/file", "/etc"));
}
#[test]
fn test_parse_directories_from_config() {
let content = r#"
model = "claude-opus-4-6"
[directories]
allow = ["./src", "./tests"]
deny = ["~/.ssh", "/etc"]
"#;
let dirs = parse_directories_from_config(content);
assert_eq!(dirs.allow, vec!["./src", "./tests"]);
assert_eq!(dirs.deny, vec!["~/.ssh", "/etc"]);
}
#[test]
fn test_parse_directories_from_config_no_section() {
let content = r#"
model = "claude-opus-4-6"
"#;
let dirs = parse_directories_from_config(content);
assert!(dirs.is_empty());
}
#[test]
fn test_parse_directories_from_config_does_not_interfere_with_permissions() {
let content = r#"
[permissions]
allow = ["git *"]
deny = ["rm -rf *"]
[directories]
deny = ["/etc"]
"#;
let perms = parse_permissions_from_config(content);
assert_eq!(perms.allow, vec!["git *"]);
assert_eq!(perms.deny, vec!["rm -rf *"]);
let dirs = parse_directories_from_config(content);
assert!(dirs.allow.is_empty());
assert_eq!(dirs.deny, vec!["/etc"]);
}
#[test]
fn test_allow_dir_deny_dir_flags_parsing() {
let args = [
"yoyo".to_string(),
"--allow-dir".to_string(),
"./src".to_string(),
"--allow-dir".to_string(),
"./tests".to_string(),
"--deny-dir".to_string(),
"/etc".to_string(),
];
let allow_dirs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--allow-dir")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
let deny_dirs: Vec<String> = args
.iter()
.enumerate()
.filter(|(_, a)| a.as_str() == "--deny-dir")
.filter_map(|(i, _)| args.get(i + 1).cloned())
.collect();
assert_eq!(allow_dirs, vec!["./src", "./tests"]);
assert_eq!(deny_dirs, vec!["/etc"]);
}
#[test]
fn test_allow_dir_deny_dir_in_known_flags() {
assert!(
KNOWN_FLAGS.contains(&"--allow-dir"),
"--allow-dir should be in KNOWN_FLAGS"
);
assert!(
KNOWN_FLAGS.contains(&"--deny-dir"),
"--deny-dir should be in KNOWN_FLAGS"
);
}
}