1use anyhow::{bail, Context, Result};
8
9use crate::claude::model_config::get_model_registry;
10
11#[derive(Debug)]
13pub struct AiCredentialInfo {
14 pub provider: AiProvider,
16 pub model: String,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AiProvider {
23 Claude,
25 Bedrock,
27 OpenAi,
29 Ollama,
31}
32
33impl std::fmt::Display for AiProvider {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::Claude => write!(f, "Claude API"),
37 Self::Bedrock => write!(f, "AWS Bedrock"),
38 Self::OpenAi => write!(f, "OpenAI API"),
39 Self::Ollama => write!(f, "Ollama"),
40 }
41 }
42}
43
44pub fn check_ai_credentials(model_override: Option<&str>) -> Result<AiCredentialInfo> {
50 use crate::utils::settings::{get_env_var, get_env_vars};
51
52 let use_openai = get_env_var("USE_OPENAI").is_ok_and(|val| val == "true");
54
55 let use_ollama = get_env_var("USE_OLLAMA").is_ok_and(|val| val == "true");
56
57 let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK").is_ok_and(|val| val == "true");
58
59 if use_ollama {
61 let model = model_override
62 .map(String::from)
63 .or_else(|| get_env_var("OLLAMA_MODEL").ok())
64 .unwrap_or_else(|| "llama2".to_string());
65
66 return Ok(AiCredentialInfo {
67 provider: AiProvider::Ollama,
68 model,
69 });
70 }
71
72 if use_openai {
74 let registry = get_model_registry();
75 let model = model_override
76 .map(String::from)
77 .or_else(|| get_env_var("OPENAI_MODEL").ok())
78 .unwrap_or_else(|| {
79 registry
80 .get_default_model("openai")
81 .unwrap_or("gpt-5")
82 .to_string()
83 });
84
85 get_env_vars(&["OPENAI_API_KEY", "OPENAI_AUTH_TOKEN"]).map_err(|_| {
87 anyhow::anyhow!(
88 "OpenAI API key not found.\n\
89 Set one of these environment variables:\n\
90 - OPENAI_API_KEY\n\
91 - OPENAI_AUTH_TOKEN"
92 )
93 })?;
94
95 return Ok(AiCredentialInfo {
96 provider: AiProvider::OpenAi,
97 model,
98 });
99 }
100
101 if use_bedrock {
103 let registry = get_model_registry();
104 let model = model_override
105 .map(String::from)
106 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
107 .unwrap_or_else(|| {
108 registry
109 .get_default_model("claude")
110 .unwrap_or("claude-sonnet-4-6")
111 .to_string()
112 });
113
114 get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| {
116 anyhow::anyhow!(
117 "AWS Bedrock authentication not configured.\n\
118 Set ANTHROPIC_AUTH_TOKEN environment variable."
119 )
120 })?;
121
122 get_env_var("ANTHROPIC_BEDROCK_BASE_URL").map_err(|_| {
123 anyhow::anyhow!(
124 "AWS Bedrock base URL not configured.\n\
125 Set ANTHROPIC_BEDROCK_BASE_URL environment variable."
126 )
127 })?;
128
129 return Ok(AiCredentialInfo {
130 provider: AiProvider::Bedrock,
131 model,
132 });
133 }
134
135 let registry = get_model_registry();
137 let model = model_override
138 .map(String::from)
139 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
140 .unwrap_or_else(|| {
141 registry
142 .get_default_model("claude")
143 .unwrap_or("claude-sonnet-4-6")
144 .to_string()
145 });
146
147 get_env_vars(&[
149 "CLAUDE_API_KEY",
150 "ANTHROPIC_API_KEY",
151 "ANTHROPIC_AUTH_TOKEN",
152 ])
153 .map_err(|_| {
154 anyhow::anyhow!(
155 "Claude API key not found.\n\
156 Set one of these environment variables:\n\
157 - CLAUDE_API_KEY\n\
158 - ANTHROPIC_API_KEY\n\
159 - ANTHROPIC_AUTH_TOKEN"
160 )
161 })?;
162
163 Ok(AiCredentialInfo {
164 provider: AiProvider::Claude,
165 model,
166 })
167}
168
169pub fn check_github_cli() -> Result<()> {
177 let gh_check = std::process::Command::new("gh")
179 .args(["--version"])
180 .output();
181
182 match gh_check {
183 Ok(output) if output.status.success() => {
184 let repo_check = std::process::Command::new("gh")
186 .args(["repo", "view", "--json", "name"])
187 .output();
188
189 match repo_check {
190 Ok(repo_output) if repo_output.status.success() => Ok(()),
191 Ok(repo_output) => {
192 let error_details = String::from_utf8_lossy(&repo_output.stderr);
193 if error_details.contains("authentication") || error_details.contains("login") {
194 bail!(
195 "GitHub CLI authentication failed.\n\
196 Please run 'gh auth login' or set GITHUB_TOKEN environment variable."
197 )
198 }
199 bail!(
200 "GitHub CLI cannot access this repository.\n\
201 Error: {}",
202 error_details.trim()
203 )
204 }
205 Err(e) => bail!("Failed to test GitHub CLI access: {e}"),
206 }
207 }
208 _ => bail!(
209 "GitHub CLI (gh) is not installed or not in PATH.\n\
210 Please install it from https://cli.github.com/"
211 ),
212 }
213}
214
215pub fn check_git_repository() -> Result<()> {
220 crate::git::GitRepository::open().context(
221 "Not in a git repository. Please run this command from within a git repository.",
222 )?;
223 Ok(())
224}
225
226pub fn check_working_directory_clean() -> Result<()> {
236 let repo = crate::git::GitRepository::open().context("Failed to open git repository")?;
237
238 let status = repo
239 .get_working_directory_status()
240 .context("Failed to get working directory status")?;
241
242 if !status.clean {
243 let mut message = String::from("Working directory has uncommitted changes:\n");
244 for change in &status.untracked_changes {
245 message.push_str(&format!(" {} {}\n", change.status, change.file));
246 }
247 message.push_str("\nPlease commit or stash your changes before proceeding.");
248 bail!(message);
249 }
250
251 Ok(())
252}
253
254pub fn check_ai_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
262 check_git_repository()?;
263 check_ai_credentials(model_override)
264}
265
266pub fn check_pr_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
275 check_git_repository()?;
276 let ai_info = check_ai_credentials(model_override)?;
277 check_github_cli()?;
278 Ok(ai_info)
279}
280
281#[cfg(test)]
282#[allow(clippy::unwrap_used, clippy::expect_used)]
283mod tests {
284 use super::*;
285
286 use std::env;
287 use std::sync::Mutex;
288 use std::sync::OnceLock;
289
290 static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
292
293 struct EnvGuard {
295 _lock: std::sync::MutexGuard<'static, ()>,
296 vars: Vec<(String, Option<String>)>,
297 }
298
299 impl EnvGuard {
300 fn new() -> Self {
301 let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
302 Self {
303 _lock: lock,
304 vars: Vec::new(),
305 }
306 }
307
308 fn set(&mut self, key: &str, value: &str) {
309 let original = env::var(key).ok();
310 self.vars.push((key.to_string(), original));
311 env::set_var(key, value);
312 }
313
314 fn remove(&mut self, key: &str) {
315 let original = env::var(key).ok();
316 self.vars.push((key.to_string(), original));
317 env::remove_var(key);
318 }
319 }
320
321 impl Drop for EnvGuard {
322 fn drop(&mut self) {
323 for (key, original_value) in self.vars.drain(..).rev() {
324 match original_value {
325 Some(value) => env::set_var(&key, value),
326 None => env::remove_var(&key),
327 }
328 }
329 }
330 }
331
332 #[test]
333 fn ai_provider_display() {
334 assert_eq!(format!("{}", AiProvider::Claude), "Claude API");
335 assert_eq!(format!("{}", AiProvider::Bedrock), "AWS Bedrock");
336 assert_eq!(format!("{}", AiProvider::OpenAi), "OpenAI API");
337 assert_eq!(format!("{}", AiProvider::Ollama), "Ollama");
338 }
339
340 #[test]
341 fn ai_provider_equality() {
342 assert_eq!(AiProvider::Claude, AiProvider::Claude);
343 assert_ne!(AiProvider::Claude, AiProvider::OpenAi);
344 assert_ne!(AiProvider::Bedrock, AiProvider::Ollama);
345 }
346
347 #[test]
348 fn ai_provider_clone() {
349 let provider = AiProvider::Bedrock;
350 let cloned = provider;
351 assert_eq!(provider, cloned);
352 }
353
354 #[test]
355 fn ai_provider_debug() {
356 let debug_str = format!("{:?}", AiProvider::Claude);
357 assert_eq!(debug_str, "Claude");
358 }
359
360 #[test]
361 fn ai_credential_info_debug() {
362 let info = AiCredentialInfo {
363 provider: AiProvider::Ollama,
364 model: "llama2".to_string(),
365 };
366 let debug_str = format!("{info:?}");
367 assert!(debug_str.contains("Ollama"));
368 assert!(debug_str.contains("llama2"));
369 }
370
371 #[test]
372 fn claude_default_model_from_registry() {
373 let mut guard = EnvGuard::new();
374 guard.remove("USE_OPENAI");
376 guard.remove("USE_OLLAMA");
377 guard.remove("CLAUDE_CODE_USE_BEDROCK");
378 guard.remove("ANTHROPIC_MODEL");
379 guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
380
381 let info = check_ai_credentials(None).unwrap();
382 assert_eq!(info.provider, AiProvider::Claude);
383 assert_eq!(info.model, "claude-sonnet-4-6");
384 }
385
386 #[test]
387 fn openai_default_model_from_registry() {
388 let mut guard = EnvGuard::new();
389 guard.set("USE_OPENAI", "true");
390 guard.remove("USE_OLLAMA");
391 guard.remove("OPENAI_MODEL");
392 guard.set("OPENAI_API_KEY", "sk-test-dummy");
393
394 let info = check_ai_credentials(None).unwrap();
395 assert_eq!(info.provider, AiProvider::OpenAi);
396 assert_eq!(info.model, "gpt-5-mini");
397 }
398
399 #[test]
400 fn bedrock_default_model_from_registry() {
401 let mut guard = EnvGuard::new();
402 guard.remove("USE_OPENAI");
403 guard.remove("USE_OLLAMA");
404 guard.set("CLAUDE_CODE_USE_BEDROCK", "true");
405 guard.remove("ANTHROPIC_MODEL");
406 guard.set("ANTHROPIC_AUTH_TOKEN", "test-token");
407 guard.set("ANTHROPIC_BEDROCK_BASE_URL", "https://bedrock.example.com");
408
409 let info = check_ai_credentials(None).unwrap();
410 assert_eq!(info.provider, AiProvider::Bedrock);
411 assert_eq!(info.model, "claude-sonnet-4-6");
412 }
413
414 #[test]
415 fn model_override_takes_precedence() {
416 let mut guard = EnvGuard::new();
417 guard.remove("USE_OPENAI");
418 guard.remove("USE_OLLAMA");
419 guard.remove("CLAUDE_CODE_USE_BEDROCK");
420 guard.remove("ANTHROPIC_MODEL");
421 guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
422
423 let info = check_ai_credentials(Some("claude-opus-4-6")).unwrap();
424 assert_eq!(info.model, "claude-opus-4-6");
425 }
426}