1use anyhow::{bail, Context, Result};
8
9use crate::claude::model_config::get_model_registry;
10
11#[derive(Debug)]
13pub struct AiCredentialInfo {
14 pub provider: AiProvider,
16 pub model: String,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AiProvider {
23 Claude,
25 Bedrock,
27 OpenAi,
29 Ollama,
31}
32
33impl std::fmt::Display for AiProvider {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::Claude => write!(f, "Claude API"),
37 Self::Bedrock => write!(f, "AWS Bedrock"),
38 Self::OpenAi => write!(f, "OpenAI API"),
39 Self::Ollama => write!(f, "Ollama"),
40 }
41 }
42}
43
44pub fn check_ai_credentials(model_override: Option<&str>) -> Result<AiCredentialInfo> {
50 use crate::utils::settings::{get_env_var, get_env_vars};
51
52 let use_openai = get_env_var("USE_OPENAI")
54 .map(|val| val == "true")
55 .unwrap_or(false);
56
57 let use_ollama = get_env_var("USE_OLLAMA")
58 .map(|val| val == "true")
59 .unwrap_or(false);
60
61 let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK")
62 .map(|val| val == "true")
63 .unwrap_or(false);
64
65 if use_ollama {
67 let model = model_override
68 .map(String::from)
69 .or_else(|| get_env_var("OLLAMA_MODEL").ok())
70 .unwrap_or_else(|| "llama2".to_string());
71
72 return Ok(AiCredentialInfo {
73 provider: AiProvider::Ollama,
74 model,
75 });
76 }
77
78 if use_openai {
80 let registry = get_model_registry();
81 let model = model_override
82 .map(String::from)
83 .or_else(|| get_env_var("OPENAI_MODEL").ok())
84 .unwrap_or_else(|| {
85 registry
86 .get_default_model("openai")
87 .unwrap_or("gpt-5")
88 .to_string()
89 });
90
91 get_env_vars(&["OPENAI_API_KEY", "OPENAI_AUTH_TOKEN"]).map_err(|_| {
93 anyhow::anyhow!(
94 "OpenAI API key not found.\n\
95 Set one of these environment variables:\n\
96 - OPENAI_API_KEY\n\
97 - OPENAI_AUTH_TOKEN"
98 )
99 })?;
100
101 return Ok(AiCredentialInfo {
102 provider: AiProvider::OpenAi,
103 model,
104 });
105 }
106
107 if use_bedrock {
109 let registry = get_model_registry();
110 let model = model_override
111 .map(String::from)
112 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
113 .unwrap_or_else(|| {
114 registry
115 .get_default_model("claude")
116 .unwrap_or("claude-sonnet-4-6")
117 .to_string()
118 });
119
120 get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| {
122 anyhow::anyhow!(
123 "AWS Bedrock authentication not configured.\n\
124 Set ANTHROPIC_AUTH_TOKEN environment variable."
125 )
126 })?;
127
128 get_env_var("ANTHROPIC_BEDROCK_BASE_URL").map_err(|_| {
129 anyhow::anyhow!(
130 "AWS Bedrock base URL not configured.\n\
131 Set ANTHROPIC_BEDROCK_BASE_URL environment variable."
132 )
133 })?;
134
135 return Ok(AiCredentialInfo {
136 provider: AiProvider::Bedrock,
137 model,
138 });
139 }
140
141 let registry = get_model_registry();
143 let model = model_override
144 .map(String::from)
145 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
146 .unwrap_or_else(|| {
147 registry
148 .get_default_model("claude")
149 .unwrap_or("claude-sonnet-4-6")
150 .to_string()
151 });
152
153 get_env_vars(&[
155 "CLAUDE_API_KEY",
156 "ANTHROPIC_API_KEY",
157 "ANTHROPIC_AUTH_TOKEN",
158 ])
159 .map_err(|_| {
160 anyhow::anyhow!(
161 "Claude API key not found.\n\
162 Set one of these environment variables:\n\
163 - CLAUDE_API_KEY\n\
164 - ANTHROPIC_API_KEY\n\
165 - ANTHROPIC_AUTH_TOKEN"
166 )
167 })?;
168
169 Ok(AiCredentialInfo {
170 provider: AiProvider::Claude,
171 model,
172 })
173}
174
175pub fn check_github_cli() -> Result<()> {
183 let gh_check = std::process::Command::new("gh")
185 .args(["--version"])
186 .output();
187
188 match gh_check {
189 Ok(output) if output.status.success() => {
190 let repo_check = std::process::Command::new("gh")
192 .args(["repo", "view", "--json", "name"])
193 .output();
194
195 match repo_check {
196 Ok(repo_output) if repo_output.status.success() => Ok(()),
197 Ok(repo_output) => {
198 let error_details = String::from_utf8_lossy(&repo_output.stderr);
199 if error_details.contains("authentication") || error_details.contains("login") {
200 bail!(
201 "GitHub CLI authentication failed.\n\
202 Please run 'gh auth login' or set GITHUB_TOKEN environment variable."
203 )
204 }
205 bail!(
206 "GitHub CLI cannot access this repository.\n\
207 Error: {}",
208 error_details.trim()
209 )
210 }
211 Err(e) => bail!("Failed to test GitHub CLI access: {e}"),
212 }
213 }
214 _ => bail!(
215 "GitHub CLI (gh) is not installed or not in PATH.\n\
216 Please install it from https://cli.github.com/"
217 ),
218 }
219}
220
221pub fn check_git_repository() -> Result<()> {
226 crate::git::GitRepository::open().context(
227 "Not in a git repository. Please run this command from within a git repository.",
228 )?;
229 Ok(())
230}
231
232pub fn check_working_directory_clean() -> Result<()> {
242 let repo = crate::git::GitRepository::open().context("Failed to open git repository")?;
243
244 let status = repo
245 .get_working_directory_status()
246 .context("Failed to get working directory status")?;
247
248 if !status.clean {
249 let mut message = String::from("Working directory has uncommitted changes:\n");
250 for change in &status.untracked_changes {
251 message.push_str(&format!(" {} {}\n", change.status, change.file));
252 }
253 message.push_str("\nPlease commit or stash your changes before proceeding.");
254 bail!(message);
255 }
256
257 Ok(())
258}
259
260pub fn check_ai_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
268 check_git_repository()?;
269 check_ai_credentials(model_override)
270}
271
272pub fn check_pr_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
281 check_git_repository()?;
282 let ai_info = check_ai_credentials(model_override)?;
283 check_github_cli()?;
284 Ok(ai_info)
285}
286
287#[cfg(test)]
288#[allow(clippy::unwrap_used, clippy::expect_used)]
289mod tests {
290 use super::*;
291
292 use std::env;
293 use std::sync::Mutex;
294 use std::sync::OnceLock;
295
296 static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
298
299 struct EnvGuard {
301 _lock: std::sync::MutexGuard<'static, ()>,
302 vars: Vec<(String, Option<String>)>,
303 }
304
305 impl EnvGuard {
306 fn new() -> Self {
307 let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
308 Self {
309 _lock: lock,
310 vars: Vec::new(),
311 }
312 }
313
314 fn set(&mut self, key: &str, value: &str) {
315 let original = env::var(key).ok();
316 self.vars.push((key.to_string(), original));
317 env::set_var(key, value);
318 }
319
320 fn remove(&mut self, key: &str) {
321 let original = env::var(key).ok();
322 self.vars.push((key.to_string(), original));
323 env::remove_var(key);
324 }
325 }
326
327 impl Drop for EnvGuard {
328 fn drop(&mut self) {
329 for (key, original_value) in self.vars.drain(..).rev() {
330 match original_value {
331 Some(value) => env::set_var(&key, value),
332 None => env::remove_var(&key),
333 }
334 }
335 }
336 }
337
338 #[test]
339 fn ai_provider_display() {
340 assert_eq!(format!("{}", AiProvider::Claude), "Claude API");
341 assert_eq!(format!("{}", AiProvider::Bedrock), "AWS Bedrock");
342 assert_eq!(format!("{}", AiProvider::OpenAi), "OpenAI API");
343 assert_eq!(format!("{}", AiProvider::Ollama), "Ollama");
344 }
345
346 #[test]
347 fn ai_provider_equality() {
348 assert_eq!(AiProvider::Claude, AiProvider::Claude);
349 assert_ne!(AiProvider::Claude, AiProvider::OpenAi);
350 assert_ne!(AiProvider::Bedrock, AiProvider::Ollama);
351 }
352
353 #[test]
354 fn ai_provider_clone() {
355 let provider = AiProvider::Bedrock;
356 let cloned = provider;
357 assert_eq!(provider, cloned);
358 }
359
360 #[test]
361 fn ai_provider_debug() {
362 let debug_str = format!("{:?}", AiProvider::Claude);
363 assert_eq!(debug_str, "Claude");
364 }
365
366 #[test]
367 fn ai_credential_info_debug() {
368 let info = AiCredentialInfo {
369 provider: AiProvider::Ollama,
370 model: "llama2".to_string(),
371 };
372 let debug_str = format!("{info:?}");
373 assert!(debug_str.contains("Ollama"));
374 assert!(debug_str.contains("llama2"));
375 }
376
377 #[test]
378 fn claude_default_model_from_registry() {
379 let mut guard = EnvGuard::new();
380 guard.remove("USE_OPENAI");
382 guard.remove("USE_OLLAMA");
383 guard.remove("CLAUDE_CODE_USE_BEDROCK");
384 guard.remove("ANTHROPIC_MODEL");
385 guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
386
387 let info = check_ai_credentials(None).unwrap();
388 assert_eq!(info.provider, AiProvider::Claude);
389 assert_eq!(info.model, "claude-sonnet-4-6");
390 }
391
392 #[test]
393 fn openai_default_model_from_registry() {
394 let mut guard = EnvGuard::new();
395 guard.set("USE_OPENAI", "true");
396 guard.remove("USE_OLLAMA");
397 guard.remove("OPENAI_MODEL");
398 guard.set("OPENAI_API_KEY", "sk-test-dummy");
399
400 let info = check_ai_credentials(None).unwrap();
401 assert_eq!(info.provider, AiProvider::OpenAi);
402 assert_eq!(info.model, "gpt-5-mini");
403 }
404
405 #[test]
406 fn bedrock_default_model_from_registry() {
407 let mut guard = EnvGuard::new();
408 guard.remove("USE_OPENAI");
409 guard.remove("USE_OLLAMA");
410 guard.set("CLAUDE_CODE_USE_BEDROCK", "true");
411 guard.remove("ANTHROPIC_MODEL");
412 guard.set("ANTHROPIC_AUTH_TOKEN", "test-token");
413 guard.set("ANTHROPIC_BEDROCK_BASE_URL", "https://bedrock.example.com");
414
415 let info = check_ai_credentials(None).unwrap();
416 assert_eq!(info.provider, AiProvider::Bedrock);
417 assert_eq!(info.model, "claude-sonnet-4-6");
418 }
419
420 #[test]
421 fn model_override_takes_precedence() {
422 let mut guard = EnvGuard::new();
423 guard.remove("USE_OPENAI");
424 guard.remove("USE_OLLAMA");
425 guard.remove("CLAUDE_CODE_USE_BEDROCK");
426 guard.remove("ANTHROPIC_MODEL");
427 guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
428
429 let info = check_ai_credentials(Some("claude-opus-4-6")).unwrap();
430 assert_eq!(info.model, "claude-opus-4-6");
431 }
432}