use crate::cli::{default_model_for_provider, known_models_for_provider, provider_api_key_env};
use crate::format::*;
use std::io::{self, BufRead, Write};
pub const WIZARD_PROVIDERS: &[(&str, &str)] = &[
("anthropic", "Anthropic (Claude)"),
("openai", "OpenAI (GPT-4o)"),
("google", "Google (Gemini)"),
("ollama", "Ollama (local, no API key needed)"),
("openrouter", "OpenRouter (multi-provider gateway)"),
("deepseek", "DeepSeek"),
("groq", "Groq"),
("xai", "xAI (Grok)"),
("mistral", "Mistral"),
("cerebras", "Cerebras"),
("minimax", "MiniMax"),
(
"bedrock",
"AWS Bedrock (Claude, Nova — uses AWS credentials)",
),
("custom", "Custom (self-hosted OpenAI-compatible)"),
];
#[derive(Debug, Clone, PartialEq)]
pub struct WizardResult {
pub provider: String,
pub api_key: String,
pub model: String,
pub base_url: Option<String>,
}
pub fn generate_config_contents(provider: &str, model: &str, base_url: Option<&str>) -> String {
let mut config = String::new();
config.push_str("# yoyo configuration — generated by setup wizard\n");
config.push_str(&format!("provider = \"{provider}\"\n"));
config.push_str(&format!("model = \"{model}\"\n"));
if let Some(url) = base_url {
config.push_str(&format!("base_url = \"{url}\"\n"));
}
if provider == "bedrock" {
config.push_str("# For Bedrock, set: AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY\n");
config.push_str("# Or pass --api-key \"access_key:secret_key\"\n");
}
config
}
pub fn save_config_to_file(
provider: &str,
model: &str,
base_url: Option<&str>,
) -> io::Result<String> {
let path = ".yoyo.toml";
let contents = generate_config_contents(provider, model, base_url);
std::fs::write(path, contents)?;
Ok(path.to_string())
}
pub fn save_config_to_user_file(
provider: &str,
model: &str,
base_url: Option<&str>,
) -> io::Result<String> {
let path = crate::cli::user_config_path().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Could not determine user config directory (no HOME or XDG_CONFIG_HOME set)",
)
})?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let contents = generate_config_contents(provider, model, base_url);
std::fs::write(&path, contents)?;
Ok(path.display().to_string())
}
pub fn parse_provider_choice(input: &str) -> Option<&'static str> {
let trimmed = input.trim();
for &(slug, _) in WIZARD_PROVIDERS {
if trimmed.eq_ignore_ascii_case(slug) {
return Some(slug);
}
}
if let Ok(n) = trimmed.parse::<usize>() {
if n >= 1 && n <= WIZARD_PROVIDERS.len() {
return Some(WIZARD_PROVIDERS[n - 1].0);
}
}
None
}
#[derive(Debug, Clone, PartialEq)]
pub enum SaveLocation {
Project,
User,
Skip,
}
pub fn parse_save_choice(input: &str) -> SaveLocation {
let trimmed = input.trim().to_lowercase();
match trimmed.as_str() {
"" | "1" | "p" | "project" => SaveLocation::Project,
"2" | "u" | "user" | "global" => SaveLocation::User,
"3" | "n" | "no" | "none" | "s" | "skip" => SaveLocation::Skip,
_ => SaveLocation::Project, }
}
pub fn user_config_display_path() -> String {
crate::cli::user_config_path()
.map(|p| p.display().to_string())
.unwrap_or_else(|| "~/.config/yoyo/config.toml".to_string())
}
pub fn run_wizard_interactive<R: BufRead, W: Write>(
reader: &mut R,
writer: &mut W,
) -> Option<WizardResult> {
writeln!(writer).ok();
writeln!(writer, " {BOLD}Welcome to yoyo! 🐙{RESET}").ok();
writeln!(writer).ok();
writeln!(
writer,
" Let's get you set up. This will only take a moment."
)
.ok();
writeln!(writer).ok();
writeln!(writer, " {BOLD}Step 1:{RESET} Choose your AI provider:").ok();
writeln!(writer).ok();
for (i, &(_, label)) in WIZARD_PROVIDERS.iter().enumerate() {
writeln!(writer, " {BOLD}{}{RESET}. {label}", i + 1).ok();
}
writeln!(writer).ok();
write!(writer, " Enter number or name [1]: ").ok();
writer.flush().ok();
let mut choice = String::new();
if reader.read_line(&mut choice).is_err() || choice.trim().is_empty() {
choice = "1".to_string();
}
let provider = match parse_provider_choice(&choice) {
Some(p) => p,
None => {
writeln!(writer, " {DIM}(defaulting to Anthropic){RESET}").ok();
"anthropic"
}
};
let provider_label = WIZARD_PROVIDERS
.iter()
.find(|&&(slug, _)| slug == provider)
.map(|&(_, label)| label)
.unwrap_or(provider);
writeln!(writer).ok();
writeln!(
writer,
" {GREEN}✓{RESET} Provider: {BOLD}{provider_label}{RESET}"
)
.ok();
let (api_key, base_url_from_step2) = if provider == "ollama" {
writeln!(writer).ok();
writeln!(
writer,
" {DIM}No API key needed for {provider} — nice!{RESET}"
)
.ok();
("not-needed".to_string(), None)
} else if provider == "bedrock" {
writeln!(writer).ok();
writeln!(writer, " {BOLD}Step 2:{RESET} Enter your AWS credentials").ok();
writeln!(
writer,
" {DIM}(or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY in your shell){RESET}"
)
.ok();
writeln!(writer).ok();
write!(writer, " AWS Access Key ID: ").ok();
writer.flush().ok();
let mut access_key_input = String::new();
if reader.read_line(&mut access_key_input).is_err() {
return None;
}
let access_key = access_key_input.trim().to_string();
write!(writer, " AWS Secret Access Key: ").ok();
writer.flush().ok();
let mut secret_key_input = String::new();
if reader.read_line(&mut secret_key_input).is_err() {
return None;
}
let secret_key = secret_key_input.trim().to_string();
write!(writer, " AWS Region [us-east-1]: ").ok();
writer.flush().ok();
let mut region_input = String::new();
if reader.read_line(&mut region_input).is_err() {
return None;
}
let region = region_input.trim();
let region = if region.is_empty() {
"us-east-1"
} else {
region
};
let combined_key = if access_key.is_empty() && secret_key.is_empty() {
let env_access = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_default();
let env_secret = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
if !env_access.is_empty() && !env_secret.is_empty() {
writeln!(
writer,
" {GREEN}✓{RESET} Using credentials from {DIM}AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY{RESET}"
)
.ok();
format!("{env_access}:{env_secret}")
} else {
writeln!(
writer,
" {YELLOW}No AWS credentials provided.{RESET} Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY or re-run the wizard."
)
.ok();
return None;
}
} else {
writeln!(writer, " {GREEN}✓{RESET} AWS credentials received").ok();
format!("{access_key}:{secret_key}")
};
let bedrock_url = format!("https://bedrock-runtime.{region}.amazonaws.com");
writeln!(
writer,
" {GREEN}✓{RESET} Region: {BOLD}{region}{RESET} → {DIM}{bedrock_url}{RESET}"
)
.ok();
(combined_key, Some(bedrock_url))
} else {
let env_var = provider_api_key_env(provider).unwrap_or("ANTHROPIC_API_KEY");
writeln!(writer).ok();
writeln!(writer, " {BOLD}Step 2:{RESET} Enter your API key").ok();
writeln!(
writer,
" {DIM}(or set {env_var} in your shell and press Enter to skip){RESET}"
)
.ok();
writeln!(writer).ok();
write!(writer, " API key: ").ok();
writer.flush().ok();
let mut key_input = String::new();
if reader.read_line(&mut key_input).is_err() {
return None;
}
let key = key_input.trim().to_string();
if key.is_empty() {
if let Some(env_key) = provider_api_key_env(provider) {
if let Ok(val) = std::env::var(env_key) {
if !val.is_empty() {
writeln!(
writer,
" {GREEN}✓{RESET} Using key from {DIM}{env_key}{RESET}"
)
.ok();
(val, None)
} else {
writeln!(
writer,
" {YELLOW}No API key provided.{RESET} Set {env_var} or re-run the wizard."
)
.ok();
return None;
}
} else {
writeln!(
writer,
" {YELLOW}No API key provided.{RESET} Set {env_var} or re-run the wizard."
)
.ok();
return None;
}
} else {
writeln!(
writer,
" {YELLOW}No API key provided.{RESET} Set {env_var} or re-run the wizard."
)
.ok();
return None;
}
} else {
writeln!(writer, " {GREEN}✓{RESET} API key received").ok();
(key, None)
}
};
let base_url = if base_url_from_step2.is_some() {
base_url_from_step2
} else if provider == "custom" {
writeln!(writer).ok();
writeln!(
writer,
" {BOLD}Base URL:{RESET} Enter the URL of your OpenAI-compatible API"
)
.ok();
writeln!(writer, " {DIM}(e.g. http://localhost:8080/v1){RESET}").ok();
writeln!(writer).ok();
write!(writer, " Base URL: ").ok();
writer.flush().ok();
let mut url_input = String::new();
if reader.read_line(&mut url_input).is_err() {
return None;
}
let url = url_input.trim().to_string();
if url.is_empty() {
writeln!(
writer,
" {YELLOW}No base URL provided.{RESET} A base URL is required for custom providers."
)
.ok();
return None;
}
writeln!(writer, " {GREEN}✓{RESET} Base URL: {BOLD}{url}{RESET}").ok();
Some(url)
} else {
None
};
let default_model = default_model_for_provider(provider);
let known_models = known_models_for_provider(provider);
writeln!(writer).ok();
writeln!(
writer,
" {BOLD}Step 3:{RESET} Choose a model {DIM}(press Enter for default){RESET}"
)
.ok();
if !known_models.is_empty() {
writeln!(writer, " {DIM}Popular models for {provider}:{RESET}").ok();
for m in known_models {
if *m == default_model {
writeln!(writer, " • {m} {DIM}(default){RESET}").ok();
} else {
writeln!(writer, " • {m}").ok();
}
}
}
writeln!(writer).ok();
write!(writer, " Model [{default_model}]: ").ok();
writer.flush().ok();
let mut model_input = String::new();
if reader.read_line(&mut model_input).is_err() {
return None;
}
let model = model_input.trim();
let model = if model.is_empty() {
default_model.clone()
} else {
model.to_string()
};
writeln!(writer, " {GREEN}✓{RESET} Model: {BOLD}{model}{RESET}").ok();
let xdg_display = user_config_display_path();
writeln!(writer).ok();
writeln!(writer, " {BOLD}Step 4:{RESET} Save configuration?").ok();
writeln!(
writer,
" {DIM}This saves your provider and model so you don't need flags next time.{RESET}"
)
.ok();
writeln!(writer).ok();
writeln!(
writer,
" {BOLD}1{RESET}. Save to {CYAN}.yoyo.toml{RESET} (current project only)"
)
.ok();
writeln!(
writer,
" {BOLD}2{RESET}. Save to {CYAN}{xdg_display}{RESET} (user-level, applies everywhere)"
)
.ok();
writeln!(writer, " {BOLD}3{RESET}. Don't save").ok();
writeln!(writer).ok();
write!(writer, " Choice [1]: ").ok();
writer.flush().ok();
let mut save_input = String::new();
if reader.read_line(&mut save_input).is_err() {
save_input = "1".to_string();
}
let save_location = parse_save_choice(&save_input);
match save_location {
SaveLocation::Project => match save_config_to_file(provider, &model, base_url.as_deref()) {
Ok(path) => {
writeln!(writer, " {GREEN}✓{RESET} Saved to {CYAN}{path}{RESET}").ok();
}
Err(e) => {
writeln!(writer, " {YELLOW}Could not save config: {e}{RESET}").ok();
}
},
SaveLocation::User => {
match save_config_to_user_file(provider, &model, base_url.as_deref()) {
Ok(path) => {
writeln!(writer, " {GREEN}✓{RESET} Saved to {CYAN}{path}{RESET}").ok();
}
Err(e) => {
writeln!(writer, " {YELLOW}Could not save config: {e}{RESET}").ok();
}
}
}
SaveLocation::Skip => {
writeln!(
writer,
" {DIM}Skipped — you can create .yoyo.toml or {xdg_display} manually later.{RESET}"
)
.ok();
}
}
writeln!(writer).ok();
writeln!(writer, " {GREEN}{BOLD}All set! Starting yoyo...{RESET}").ok();
writeln!(writer).ok();
Some(WizardResult {
provider: provider.to_string(),
api_key,
model,
base_url,
})
}
pub fn run_setup_wizard() -> Option<WizardResult> {
let stdin = io::stdin();
let mut reader = stdin.lock();
let mut writer = io::stdout();
run_wizard_interactive(&mut reader, &mut writer)
}
pub fn needs_setup(provider: &str) -> bool {
if std::path::Path::new(".yoyo.toml").exists() {
return false;
}
if let Some(user_path) = crate::cli::user_config_path() {
if user_path.exists() {
return false;
}
}
if provider == "ollama" || provider == "custom" {
return false;
}
if let Some(env_var) = provider_api_key_env(provider) {
if std::env::var(env_var)
.ok()
.filter(|k| !k.is_empty())
.is_some()
{
return false;
}
}
if std::env::var("ANTHROPIC_API_KEY")
.ok()
.filter(|k| !k.is_empty())
.is_some()
{
return false;
}
if std::env::var("API_KEY")
.ok()
.filter(|k| !k.is_empty())
.is_some()
{
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cli::KNOWN_PROVIDERS;
#[test]
fn test_parse_provider_choice_by_number() {
assert_eq!(parse_provider_choice("1"), Some("anthropic"));
assert_eq!(parse_provider_choice("2"), Some("openai"));
assert_eq!(parse_provider_choice("3"), Some("google"));
assert_eq!(parse_provider_choice("4"), Some("ollama"));
assert_eq!(parse_provider_choice("5"), Some("openrouter"));
assert_eq!(parse_provider_choice("6"), Some("deepseek"));
assert_eq!(parse_provider_choice("7"), Some("groq"));
assert_eq!(parse_provider_choice("8"), Some("xai"));
assert_eq!(parse_provider_choice("9"), Some("mistral"));
assert_eq!(parse_provider_choice("10"), Some("cerebras"));
assert_eq!(parse_provider_choice("11"), Some("minimax"));
assert_eq!(parse_provider_choice("12"), Some("bedrock"));
assert_eq!(parse_provider_choice("13"), Some("custom"));
}
#[test]
fn test_parse_provider_choice_by_name() {
assert_eq!(parse_provider_choice("anthropic"), Some("anthropic"));
assert_eq!(parse_provider_choice("OpenAI"), Some("openai"));
assert_eq!(parse_provider_choice("GOOGLE"), Some("google"));
assert_eq!(parse_provider_choice("ollama"), Some("ollama"));
assert_eq!(parse_provider_choice("cerebras"), Some("cerebras"));
assert_eq!(parse_provider_choice("Cerebras"), Some("cerebras"));
assert_eq!(parse_provider_choice("minimax"), Some("minimax"));
assert_eq!(parse_provider_choice("MiniMax"), Some("minimax"));
assert_eq!(parse_provider_choice("bedrock"), Some("bedrock"));
assert_eq!(parse_provider_choice("Bedrock"), Some("bedrock"));
assert_eq!(parse_provider_choice("custom"), Some("custom"));
assert_eq!(parse_provider_choice("CUSTOM"), Some("custom"));
}
#[test]
fn test_parse_provider_choice_invalid() {
assert_eq!(parse_provider_choice("0"), None);
assert_eq!(parse_provider_choice("99"), None);
assert_eq!(parse_provider_choice("banana"), None);
assert_eq!(parse_provider_choice(""), None);
}
#[test]
fn test_parse_provider_choice_whitespace() {
assert_eq!(parse_provider_choice(" 1 "), Some("anthropic"));
assert_eq!(parse_provider_choice(" openai "), Some("openai"));
}
#[test]
fn test_generate_config_contents() {
let config = generate_config_contents("anthropic", "claude-opus-4-6", None);
assert!(config.contains("provider = \"anthropic\""));
assert!(config.contains("model = \"claude-opus-4-6\""));
assert!(config.starts_with("# yoyo configuration"));
assert!(!config.contains("base_url"));
}
#[test]
fn test_generate_config_openai() {
let config = generate_config_contents("openai", "gpt-4o", None);
assert!(config.contains("provider = \"openai\""));
assert!(config.contains("model = \"gpt-4o\""));
}
#[test]
fn test_generate_config_custom_with_base_url() {
let config =
generate_config_contents("custom", "my-model", Some("http://localhost:8080/v1"));
assert!(config.contains("provider = \"custom\""));
assert!(config.contains("model = \"my-model\""));
assert!(config.contains("base_url = \"http://localhost:8080/v1\""));
}
#[test]
fn test_wizard_providers_are_known() {
for &(slug, _) in WIZARD_PROVIDERS {
assert!(
KNOWN_PROVIDERS.contains(&slug),
"Wizard provider '{slug}' not in KNOWN_PROVIDERS"
);
}
}
#[test]
fn test_wizard_anthropic_with_key() {
let input = "1\nsk-test-key-123\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "anthropic");
assert_eq!(r.api_key, "sk-test-key-123");
assert_eq!(r.model, "claude-opus-4-6");
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("Step 1"));
assert!(output_str.contains("Step 2"));
assert!(output_str.contains("Step 3"));
assert!(output_str.contains("Step 4"));
assert!(output_str.contains("All set!"));
}
#[test]
fn test_wizard_ollama_skips_api_key() {
let input = "4\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "ollama");
assert_eq!(r.api_key, "not-needed");
assert_eq!(r.model, "llama3.2");
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("No API key needed"));
}
#[test]
fn test_wizard_custom_model() {
let input = "2\nsk-openai-key\ngpt-4.1-mini\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "openai");
assert_eq!(r.api_key, "sk-openai-key");
assert_eq!(r.model, "gpt-4.1-mini");
}
#[test]
fn test_wizard_provider_by_name() {
let input = "google\ntest-key\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "google");
assert_eq!(r.api_key, "test-key");
}
#[test]
fn test_wizard_default_provider_on_enter() {
let input = "\nmy-api-key\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "anthropic");
assert_eq!(r.api_key, "my-api-key");
}
#[test]
fn test_wizard_no_key_no_env_returns_none() {
let input = "1\n\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let prev = std::env::var("ANTHROPIC_API_KEY").ok();
std::env::remove_var("ANTHROPIC_API_KEY");
let result = run_wizard_interactive(&mut reader, &mut output);
if let Some(val) = prev {
std::env::set_var("ANTHROPIC_API_KEY", val);
}
assert!(result.is_none());
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("No API key provided"));
}
#[test]
fn test_save_config_to_file() {
let dir = std::env::temp_dir().join("yoyo_test_wizard");
let _ = std::fs::create_dir_all(&dir);
let prev_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(&dir).unwrap();
let result = save_config_to_file("openai", "gpt-4o", None);
assert!(result.is_ok());
let content = std::fs::read_to_string(".yoyo.toml").unwrap();
assert!(content.contains("provider = \"openai\""));
assert!(content.contains("model = \"gpt-4o\""));
std::env::set_current_dir(prev_dir).unwrap();
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_wizard_result_fields() {
let result = WizardResult {
provider: "anthropic".to_string(),
api_key: "sk-test".to_string(),
model: "claude-opus-4-6".to_string(),
base_url: None,
};
assert_eq!(result.provider, "anthropic");
assert_eq!(result.api_key, "sk-test");
assert_eq!(result.model, "claude-opus-4-6");
assert_eq!(result.base_url, None);
}
#[test]
fn test_wizard_cerebras_flow() {
let input = "10\nsk-cerebras-key\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "cerebras");
assert_eq!(r.api_key, "sk-cerebras-key");
assert_eq!(r.model, "llama-3.3-70b"); assert_eq!(r.base_url, None);
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("Cerebras"));
}
#[test]
fn test_wizard_minimax_flow() {
let input = "11\nsk-minimax-key\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "minimax");
assert_eq!(r.api_key, "sk-minimax-key");
assert_eq!(r.model, "MiniMax-M2.7"); assert_eq!(r.base_url, None);
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("MiniMax"));
}
#[test]
fn test_wizard_custom_provider_flow() {
let input = "13\nmy-custom-key\nhttp://localhost:8080/v1\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "custom");
assert_eq!(r.api_key, "my-custom-key");
assert_eq!(r.base_url, Some("http://localhost:8080/v1".to_string()));
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("Base URL"));
assert!(output_str.contains("Custom (self-hosted OpenAI-compatible)"));
}
#[test]
fn test_wizard_custom_provider_no_base_url_returns_none() {
let input = "13\nmy-custom-key\n\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_none());
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("No base URL provided"));
}
#[test]
fn test_parse_save_choice_defaults_to_project() {
assert_eq!(parse_save_choice(""), SaveLocation::Project);
assert_eq!(parse_save_choice("1"), SaveLocation::Project);
assert_eq!(parse_save_choice("p"), SaveLocation::Project);
assert_eq!(parse_save_choice("project"), SaveLocation::Project);
assert_eq!(parse_save_choice(" 1 "), SaveLocation::Project);
}
#[test]
fn test_parse_save_choice_user() {
assert_eq!(parse_save_choice("2"), SaveLocation::User);
assert_eq!(parse_save_choice("u"), SaveLocation::User);
assert_eq!(parse_save_choice("user"), SaveLocation::User);
assert_eq!(parse_save_choice("global"), SaveLocation::User);
assert_eq!(parse_save_choice(" 2 "), SaveLocation::User);
}
#[test]
fn test_parse_save_choice_skip() {
assert_eq!(parse_save_choice("3"), SaveLocation::Skip);
assert_eq!(parse_save_choice("n"), SaveLocation::Skip);
assert_eq!(parse_save_choice("no"), SaveLocation::Skip);
assert_eq!(parse_save_choice("none"), SaveLocation::Skip);
assert_eq!(parse_save_choice("s"), SaveLocation::Skip);
assert_eq!(parse_save_choice("skip"), SaveLocation::Skip);
}
#[test]
fn test_parse_save_choice_unknown_defaults_to_project() {
assert_eq!(parse_save_choice("banana"), SaveLocation::Project);
assert_eq!(parse_save_choice("yes"), SaveLocation::Project);
}
#[test]
fn test_save_config_to_user_file() {
let dir = std::env::temp_dir().join("yoyo_test_xdg_save");
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok();
std::env::set_var("XDG_CONFIG_HOME", &dir);
let result = save_config_to_user_file("google", "gemini-2.0-flash", None);
assert!(result.is_ok(), "save_config_to_user_file should succeed");
let path_str = result.unwrap();
assert!(
path_str.contains("yoyo"),
"path should contain yoyo directory"
);
assert!(
path_str.contains("config.toml"),
"path should end with config.toml"
);
let content = std::fs::read_to_string(&path_str).unwrap();
assert!(content.contains("provider = \"google\""));
assert!(content.contains("model = \"gemini-2.0-flash\""));
if let Some(val) = prev_xdg {
std::env::set_var("XDG_CONFIG_HOME", val);
} else {
std::env::remove_var("XDG_CONFIG_HOME");
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_save_config_to_user_file_creates_parent_dirs() {
let dir = std::env::temp_dir().join("yoyo_test_xdg_nested");
let _ = std::fs::remove_dir_all(&dir);
let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok();
std::env::set_var("XDG_CONFIG_HOME", &dir);
let result = save_config_to_user_file("openai", "gpt-4o", None);
assert!(
result.is_ok(),
"should create parent dirs: {:?}",
result.err()
);
let expected_path = dir.join("yoyo").join("config.toml");
assert!(expected_path.exists(), "config file should exist");
if let Some(val) = prev_xdg {
std::env::set_var("XDG_CONFIG_HOME", val);
} else {
std::env::remove_var("XDG_CONFIG_HOME");
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_wizard_step4_shows_three_choices() {
let input = "4\n\n3\n"; let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let output_str = String::from_utf8(output).unwrap();
assert!(
output_str.contains(".yoyo.toml"),
"should show project-level option"
);
assert!(
output_str.contains("user-level"),
"should show user-level option"
);
assert!(output_str.contains("Don't save"), "should show skip option");
assert!(
output_str.contains("Choice [1]"),
"should show choice prompt with default"
);
}
#[test]
fn test_wizard_save_to_user_level() {
let dir = std::env::temp_dir().join("yoyo_test_wizard_user_save");
let _ = std::fs::remove_dir_all(&dir);
let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok();
std::env::set_var("XDG_CONFIG_HOME", &dir);
let input = "4\n\n2\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let output_str = String::from_utf8(output).unwrap();
assert!(
output_str.contains("Saved to"),
"should confirm save: {output_str}"
);
let expected_path = dir.join("yoyo").join("config.toml");
assert!(
expected_path.exists(),
"user-level config should be created"
);
let content = std::fs::read_to_string(&expected_path).unwrap();
assert!(content.contains("provider = \"ollama\""));
if let Some(val) = prev_xdg {
std::env::set_var("XDG_CONFIG_HOME", val);
} else {
std::env::remove_var("XDG_CONFIG_HOME");
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_user_config_display_path() {
let display = user_config_display_path();
assert!(
display.contains("yoyo") || display.contains("config"),
"display path should mention yoyo or config: {display}"
);
}
#[test]
fn test_bedrock_in_wizard_providers() {
let slugs: Vec<&str> = WIZARD_PROVIDERS.iter().map(|&(s, _)| s).collect();
assert!(
slugs.contains(&"bedrock"),
"bedrock should be in WIZARD_PROVIDERS"
);
}
#[test]
fn test_generate_config_bedrock() {
let config = generate_config_contents(
"bedrock",
"anthropic.claude-sonnet-4-20250514-v1:0",
Some("https://bedrock-runtime.us-east-1.amazonaws.com"),
);
assert!(config.contains("provider = \"bedrock\""));
assert!(config.contains("model = \"anthropic.claude-sonnet-4-20250514-v1:0\""));
assert!(config.contains("base_url = \"https://bedrock-runtime.us-east-1.amazonaws.com\""));
assert!(config.contains("AWS_ACCESS_KEY_ID"));
assert!(config.contains("AWS_SECRET_ACCESS_KEY"));
for line in config.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
assert!(
trimmed.contains('='),
"non-comment line should be key=value: {trimmed}"
);
}
}
#[test]
fn test_wizard_bedrock_with_credentials() {
let input = "12\nAKIATEST123\nwJalrXUtnFEMI/test\n\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some(), "wizard should succeed for bedrock");
let r = result.unwrap();
assert_eq!(r.provider, "bedrock");
assert_eq!(r.api_key, "AKIATEST123:wJalrXUtnFEMI/test");
assert_eq!(r.model, "anthropic.claude-sonnet-4-20250514-v1:0"); assert_eq!(
r.base_url.as_deref(),
Some("https://bedrock-runtime.us-east-1.amazonaws.com")
);
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("AWS credentials received"));
assert!(output_str.contains("us-east-1"));
}
#[test]
fn test_wizard_bedrock_custom_region() {
let input = "12\nAKIATEST123\nsecretkey\neu-west-1\n\nn\n";
let mut reader = io::Cursor::new(input.as_bytes());
let mut output = Vec::new();
let result = run_wizard_interactive(&mut reader, &mut output);
assert!(result.is_some());
let r = result.unwrap();
assert_eq!(r.provider, "bedrock");
assert_eq!(
r.base_url.as_deref(),
Some("https://bedrock-runtime.eu-west-1.amazonaws.com")
);
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.contains("eu-west-1"));
}
}