use anyhow::{bail, Context, Result};
use crate::claude::model_config::get_model_registry;
#[derive(Debug)]
pub struct AiCredentialInfo {
pub provider: AiProvider,
pub model: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AiProvider {
Claude,
Bedrock,
OpenAi,
Ollama,
}
impl std::fmt::Display for AiProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Claude => write!(f, "Claude API"),
Self::Bedrock => write!(f, "AWS Bedrock"),
Self::OpenAi => write!(f, "OpenAI API"),
Self::Ollama => write!(f, "Ollama"),
}
}
}
pub fn check_ai_credentials(model_override: Option<&str>) -> Result<AiCredentialInfo> {
use crate::utils::settings::{get_env_var, get_env_vars};
let use_openai = get_env_var("USE_OPENAI")
.map(|val| val == "true")
.unwrap_or(false);
let use_ollama = get_env_var("USE_OLLAMA")
.map(|val| val == "true")
.unwrap_or(false);
let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK")
.map(|val| val == "true")
.unwrap_or(false);
if use_ollama {
let model = model_override
.map(String::from)
.or_else(|| get_env_var("OLLAMA_MODEL").ok())
.unwrap_or_else(|| "llama2".to_string());
return Ok(AiCredentialInfo {
provider: AiProvider::Ollama,
model,
});
}
if use_openai {
let registry = get_model_registry();
let model = model_override
.map(String::from)
.or_else(|| get_env_var("OPENAI_MODEL").ok())
.unwrap_or_else(|| {
registry
.get_default_model("openai")
.unwrap_or("gpt-5")
.to_string()
});
get_env_vars(&["OPENAI_API_KEY", "OPENAI_AUTH_TOKEN"]).map_err(|_| {
anyhow::anyhow!(
"OpenAI API key not found.\n\
Set one of these environment variables:\n\
- OPENAI_API_KEY\n\
- OPENAI_AUTH_TOKEN"
)
})?;
return Ok(AiCredentialInfo {
provider: AiProvider::OpenAi,
model,
});
}
if use_bedrock {
let registry = get_model_registry();
let model = model_override
.map(String::from)
.or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
.unwrap_or_else(|| {
registry
.get_default_model("claude")
.unwrap_or("claude-sonnet-4-6")
.to_string()
});
get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| {
anyhow::anyhow!(
"AWS Bedrock authentication not configured.\n\
Set ANTHROPIC_AUTH_TOKEN environment variable."
)
})?;
get_env_var("ANTHROPIC_BEDROCK_BASE_URL").map_err(|_| {
anyhow::anyhow!(
"AWS Bedrock base URL not configured.\n\
Set ANTHROPIC_BEDROCK_BASE_URL environment variable."
)
})?;
return Ok(AiCredentialInfo {
provider: AiProvider::Bedrock,
model,
});
}
let registry = get_model_registry();
let model = model_override
.map(String::from)
.or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
.unwrap_or_else(|| {
registry
.get_default_model("claude")
.unwrap_or("claude-sonnet-4-6")
.to_string()
});
get_env_vars(&[
"CLAUDE_API_KEY",
"ANTHROPIC_API_KEY",
"ANTHROPIC_AUTH_TOKEN",
])
.map_err(|_| {
anyhow::anyhow!(
"Claude API key not found.\n\
Set one of these environment variables:\n\
- CLAUDE_API_KEY\n\
- ANTHROPIC_API_KEY\n\
- ANTHROPIC_AUTH_TOKEN"
)
})?;
Ok(AiCredentialInfo {
provider: AiProvider::Claude,
model,
})
}
pub fn check_github_cli() -> Result<()> {
let gh_check = std::process::Command::new("gh")
.args(["--version"])
.output();
match gh_check {
Ok(output) if output.status.success() => {
let repo_check = std::process::Command::new("gh")
.args(["repo", "view", "--json", "name"])
.output();
match repo_check {
Ok(repo_output) if repo_output.status.success() => Ok(()),
Ok(repo_output) => {
let error_details = String::from_utf8_lossy(&repo_output.stderr);
if error_details.contains("authentication") || error_details.contains("login") {
bail!(
"GitHub CLI authentication failed.\n\
Please run 'gh auth login' or set GITHUB_TOKEN environment variable."
)
}
bail!(
"GitHub CLI cannot access this repository.\n\
Error: {}",
error_details.trim()
)
}
Err(e) => bail!("Failed to test GitHub CLI access: {e}"),
}
}
_ => bail!(
"GitHub CLI (gh) is not installed or not in PATH.\n\
Please install it from https://cli.github.com/"
),
}
}
pub fn check_git_repository() -> Result<()> {
crate::git::GitRepository::open().context(
"Not in a git repository. Please run this command from within a git repository.",
)?;
Ok(())
}
pub fn check_working_directory_clean() -> Result<()> {
let repo = crate::git::GitRepository::open().context("Failed to open git repository")?;
let status = repo
.get_working_directory_status()
.context("Failed to get working directory status")?;
if !status.clean {
let mut message = String::from("Working directory has uncommitted changes:\n");
for change in &status.untracked_changes {
message.push_str(&format!(" {} {}\n", change.status, change.file));
}
message.push_str("\nPlease commit or stash your changes before proceeding.");
bail!(message);
}
Ok(())
}
pub fn check_ai_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
check_git_repository()?;
check_ai_credentials(model_override)
}
pub fn check_pr_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
check_git_repository()?;
let ai_info = check_ai_credentials(model_override)?;
check_github_cli()?;
Ok(ai_info)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::env;
use std::sync::Mutex;
use std::sync::OnceLock;
static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
struct EnvGuard {
_lock: std::sync::MutexGuard<'static, ()>,
vars: Vec<(String, Option<String>)>,
}
impl EnvGuard {
fn new() -> Self {
let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
Self {
_lock: lock,
vars: Vec::new(),
}
}
fn set(&mut self, key: &str, value: &str) {
let original = env::var(key).ok();
self.vars.push((key.to_string(), original));
env::set_var(key, value);
}
fn remove(&mut self, key: &str) {
let original = env::var(key).ok();
self.vars.push((key.to_string(), original));
env::remove_var(key);
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
for (key, original_value) in self.vars.drain(..).rev() {
match original_value {
Some(value) => env::set_var(&key, value),
None => env::remove_var(&key),
}
}
}
}
#[test]
fn ai_provider_display() {
assert_eq!(format!("{}", AiProvider::Claude), "Claude API");
assert_eq!(format!("{}", AiProvider::Bedrock), "AWS Bedrock");
assert_eq!(format!("{}", AiProvider::OpenAi), "OpenAI API");
assert_eq!(format!("{}", AiProvider::Ollama), "Ollama");
}
#[test]
fn ai_provider_equality() {
assert_eq!(AiProvider::Claude, AiProvider::Claude);
assert_ne!(AiProvider::Claude, AiProvider::OpenAi);
assert_ne!(AiProvider::Bedrock, AiProvider::Ollama);
}
#[test]
fn ai_provider_clone() {
let provider = AiProvider::Bedrock;
let cloned = provider;
assert_eq!(provider, cloned);
}
#[test]
fn ai_provider_debug() {
let debug_str = format!("{:?}", AiProvider::Claude);
assert_eq!(debug_str, "Claude");
}
#[test]
fn ai_credential_info_debug() {
let info = AiCredentialInfo {
provider: AiProvider::Ollama,
model: "llama2".to_string(),
};
let debug_str = format!("{info:?}");
assert!(debug_str.contains("Ollama"));
assert!(debug_str.contains("llama2"));
}
#[test]
fn claude_default_model_from_registry() {
let mut guard = EnvGuard::new();
guard.remove("USE_OPENAI");
guard.remove("USE_OLLAMA");
guard.remove("CLAUDE_CODE_USE_BEDROCK");
guard.remove("ANTHROPIC_MODEL");
guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
let info = check_ai_credentials(None).unwrap();
assert_eq!(info.provider, AiProvider::Claude);
assert_eq!(info.model, "claude-sonnet-4-6");
}
#[test]
fn openai_default_model_from_registry() {
let mut guard = EnvGuard::new();
guard.set("USE_OPENAI", "true");
guard.remove("USE_OLLAMA");
guard.remove("OPENAI_MODEL");
guard.set("OPENAI_API_KEY", "sk-test-dummy");
let info = check_ai_credentials(None).unwrap();
assert_eq!(info.provider, AiProvider::OpenAi);
assert_eq!(info.model, "gpt-5-mini");
}
#[test]
fn bedrock_default_model_from_registry() {
let mut guard = EnvGuard::new();
guard.remove("USE_OPENAI");
guard.remove("USE_OLLAMA");
guard.set("CLAUDE_CODE_USE_BEDROCK", "true");
guard.remove("ANTHROPIC_MODEL");
guard.set("ANTHROPIC_AUTH_TOKEN", "test-token");
guard.set("ANTHROPIC_BEDROCK_BASE_URL", "https://bedrock.example.com");
let info = check_ai_credentials(None).unwrap();
assert_eq!(info.provider, AiProvider::Bedrock);
assert_eq!(info.model, "claude-sonnet-4-6");
}
#[test]
fn model_override_takes_precedence() {
let mut guard = EnvGuard::new();
guard.remove("USE_OPENAI");
guard.remove("USE_OLLAMA");
guard.remove("CLAUDE_CODE_USE_BEDROCK");
guard.remove("ANTHROPIC_MODEL");
guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
let info = check_ai_credentials(Some("claude-opus-4-6")).unwrap();
assert_eq!(info.model, "claude-opus-4-6");
}
}