Skip to main content

omni_dev/utils/
preflight.rs

1//! Preflight validation checks for early failure detection.
2//!
3//! This module provides functions to validate required services and credentials
4//! before starting expensive operations. Commands should call these checks early
5//! to fail fast with clear error messages.
6
7use anyhow::{bail, Context, Result};
8
9use crate::claude::model_config::get_model_registry;
10
11/// Result of AI credential validation.
12#[derive(Debug)]
13pub struct AiCredentialInfo {
14    /// The AI provider that will be used.
15    pub provider: AiProvider,
16    /// The model that will be used.
17    pub model: String,
18}
19
20/// AI provider types.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AiProvider {
23    /// Anthropic Claude API.
24    Claude,
25    /// AWS Bedrock with Claude.
26    Bedrock,
27    /// OpenAI API.
28    OpenAi,
29    /// Local Ollama.
30    Ollama,
31}
32
33impl std::fmt::Display for AiProvider {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::Claude => write!(f, "Claude API"),
37            Self::Bedrock => write!(f, "AWS Bedrock"),
38            Self::OpenAi => write!(f, "OpenAI API"),
39            Self::Ollama => write!(f, "Ollama"),
40        }
41    }
42}
43
44/// Validates that AI credentials are available before processing.
45///
46/// This performs a lightweight check of environment variables without
47/// creating a full AI client. Use this at the start of commands that
48/// require AI to fail fast if credentials are missing.
49pub fn check_ai_credentials(model_override: Option<&str>) -> Result<AiCredentialInfo> {
50    use crate::utils::settings::{get_env_var, get_env_vars};
51
52    // Check provider selection flags
53    let use_openai = get_env_var("USE_OPENAI")
54        .map(|val| val == "true")
55        .unwrap_or(false);
56
57    let use_ollama = get_env_var("USE_OLLAMA")
58        .map(|val| val == "true")
59        .unwrap_or(false);
60
61    let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK")
62        .map(|val| val == "true")
63        .unwrap_or(false);
64
65    // Check Ollama (no credentials required, just model)
66    if use_ollama {
67        let model = model_override
68            .map(String::from)
69            .or_else(|| get_env_var("OLLAMA_MODEL").ok())
70            .unwrap_or_else(|| "llama2".to_string());
71
72        return Ok(AiCredentialInfo {
73            provider: AiProvider::Ollama,
74            model,
75        });
76    }
77
78    // Check OpenAI
79    if use_openai {
80        let registry = get_model_registry();
81        let model = model_override
82            .map(String::from)
83            .or_else(|| get_env_var("OPENAI_MODEL").ok())
84            .unwrap_or_else(|| {
85                registry
86                    .get_default_model("openai")
87                    .unwrap_or("gpt-5")
88                    .to_string()
89            });
90
91        // Verify API key exists
92        get_env_vars(&["OPENAI_API_KEY", "OPENAI_AUTH_TOKEN"]).map_err(|_| {
93            anyhow::anyhow!(
94                "OpenAI API key not found.\n\
95                 Set one of these environment variables:\n\
96                 - OPENAI_API_KEY\n\
97                 - OPENAI_AUTH_TOKEN"
98            )
99        })?;
100
101        return Ok(AiCredentialInfo {
102            provider: AiProvider::OpenAi,
103            model,
104        });
105    }
106
107    // Check Bedrock
108    if use_bedrock {
109        let registry = get_model_registry();
110        let model = model_override
111            .map(String::from)
112            .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
113            .unwrap_or_else(|| {
114                registry
115                    .get_default_model("claude")
116                    .unwrap_or("claude-sonnet-4-6")
117                    .to_string()
118            });
119
120        // Verify Bedrock configuration
121        get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| {
122            anyhow::anyhow!(
123                "AWS Bedrock authentication not configured.\n\
124                 Set ANTHROPIC_AUTH_TOKEN environment variable."
125            )
126        })?;
127
128        get_env_var("ANTHROPIC_BEDROCK_BASE_URL").map_err(|_| {
129            anyhow::anyhow!(
130                "AWS Bedrock base URL not configured.\n\
131                 Set ANTHROPIC_BEDROCK_BASE_URL environment variable."
132            )
133        })?;
134
135        return Ok(AiCredentialInfo {
136            provider: AiProvider::Bedrock,
137            model,
138        });
139    }
140
141    // Default: Claude API
142    let registry = get_model_registry();
143    let model = model_override
144        .map(String::from)
145        .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
146        .unwrap_or_else(|| {
147            registry
148                .get_default_model("claude")
149                .unwrap_or("claude-sonnet-4-6")
150                .to_string()
151        });
152
153    // Verify API key exists
154    get_env_vars(&[
155        "CLAUDE_API_KEY",
156        "ANTHROPIC_API_KEY",
157        "ANTHROPIC_AUTH_TOKEN",
158    ])
159    .map_err(|_| {
160        anyhow::anyhow!(
161            "Claude API key not found.\n\
162                 Set one of these environment variables:\n\
163                 - CLAUDE_API_KEY\n\
164                 - ANTHROPIC_API_KEY\n\
165                 - ANTHROPIC_AUTH_TOKEN"
166        )
167    })?;
168
169    Ok(AiCredentialInfo {
170        provider: AiProvider::Claude,
171        model,
172    })
173}
174
175/// Validates that GitHub CLI is available and authenticated.
176///
177/// This checks:
178/// 1. `gh` CLI is installed and in PATH
179/// 2. User is authenticated (can access the current repo)
180///
181/// Use this at the start of commands that require GitHub API access.
182pub fn check_github_cli() -> Result<()> {
183    // Check if gh CLI is available
184    let gh_check = std::process::Command::new("gh")
185        .args(["--version"])
186        .output();
187
188    match gh_check {
189        Ok(output) if output.status.success() => {
190            // Test if gh can access the current repo
191            let repo_check = std::process::Command::new("gh")
192                .args(["repo", "view", "--json", "name"])
193                .output();
194
195            match repo_check {
196                Ok(repo_output) if repo_output.status.success() => Ok(()),
197                Ok(repo_output) => {
198                    let error_details = String::from_utf8_lossy(&repo_output.stderr);
199                    if error_details.contains("authentication") || error_details.contains("login") {
200                        bail!(
201                            "GitHub CLI authentication failed.\n\
202                             Please run 'gh auth login' or set GITHUB_TOKEN environment variable."
203                        )
204                    }
205                    bail!(
206                        "GitHub CLI cannot access this repository.\n\
207                         Error: {}",
208                        error_details.trim()
209                    )
210                }
211                Err(e) => bail!("Failed to test GitHub CLI access: {e}"),
212            }
213        }
214        _ => bail!(
215            "GitHub CLI (gh) is not installed or not in PATH.\n\
216             Please install it from https://cli.github.com/"
217        ),
218    }
219}
220
221/// Validates that the current directory is in a valid git repository.
222///
223/// This is a lightweight check that opens the repository without
224/// loading any commit data.
225pub fn check_git_repository() -> Result<()> {
226    crate::git::GitRepository::open().context(
227        "Not in a git repository. Please run this command from within a git repository.",
228    )?;
229    Ok(())
230}
231
232/// Validates that the working directory is clean (no uncommitted changes).
233///
234/// This checks for:
235/// - Staged changes
236/// - Unstaged modifications
237/// - Untracked files (excluding ignored files)
238///
239/// Use this before operations that require a clean working directory,
240/// like amending commits.
241pub fn check_working_directory_clean() -> Result<()> {
242    let repo = crate::git::GitRepository::open().context("Failed to open git repository")?;
243
244    let status = repo
245        .get_working_directory_status()
246        .context("Failed to get working directory status")?;
247
248    if !status.clean {
249        let mut message = String::from("Working directory has uncommitted changes:\n");
250        for change in &status.untracked_changes {
251            message.push_str(&format!("  {} {}\n", change.status, change.file));
252        }
253        message.push_str("\nPlease commit or stash your changes before proceeding.");
254        bail!(message);
255    }
256
257    Ok(())
258}
259
260/// Performs combined preflight check for AI commands.
261///
262/// Validates:
263/// - Git repository access
264/// - AI credentials
265///
266/// Returns information about the AI provider that will be used.
267pub fn check_ai_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
268    check_git_repository()?;
269    check_ai_credentials(model_override)
270}
271
272/// Performs combined preflight check for PR creation.
273///
274/// Validates:
275/// - Git repository access
276/// - AI credentials
277/// - GitHub CLI availability and authentication
278///
279/// Returns information about the AI provider that will be used.
280pub fn check_pr_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
281    check_git_repository()?;
282    let ai_info = check_ai_credentials(model_override)?;
283    check_github_cli()?;
284    Ok(ai_info)
285}
286
287#[cfg(test)]
288#[allow(clippy::unwrap_used, clippy::expect_used)]
289mod tests {
290    use super::*;
291
292    use std::env;
293    use std::sync::Mutex;
294    use std::sync::OnceLock;
295
296    /// Global lock to ensure environment variable tests don't interfere with each other.
297    static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
298
299    /// Manages environment variables in tests to avoid interference.
300    struct EnvGuard {
301        _lock: std::sync::MutexGuard<'static, ()>,
302        vars: Vec<(String, Option<String>)>,
303    }
304
305    impl EnvGuard {
306        fn new() -> Self {
307            let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
308            Self {
309                _lock: lock,
310                vars: Vec::new(),
311            }
312        }
313
314        fn set(&mut self, key: &str, value: &str) {
315            let original = env::var(key).ok();
316            self.vars.push((key.to_string(), original));
317            env::set_var(key, value);
318        }
319
320        fn remove(&mut self, key: &str) {
321            let original = env::var(key).ok();
322            self.vars.push((key.to_string(), original));
323            env::remove_var(key);
324        }
325    }
326
327    impl Drop for EnvGuard {
328        fn drop(&mut self) {
329            for (key, original_value) in self.vars.drain(..).rev() {
330                match original_value {
331                    Some(value) => env::set_var(&key, value),
332                    None => env::remove_var(&key),
333                }
334            }
335        }
336    }
337
338    #[test]
339    fn ai_provider_display() {
340        assert_eq!(format!("{}", AiProvider::Claude), "Claude API");
341        assert_eq!(format!("{}", AiProvider::Bedrock), "AWS Bedrock");
342        assert_eq!(format!("{}", AiProvider::OpenAi), "OpenAI API");
343        assert_eq!(format!("{}", AiProvider::Ollama), "Ollama");
344    }
345
346    #[test]
347    fn ai_provider_equality() {
348        assert_eq!(AiProvider::Claude, AiProvider::Claude);
349        assert_ne!(AiProvider::Claude, AiProvider::OpenAi);
350        assert_ne!(AiProvider::Bedrock, AiProvider::Ollama);
351    }
352
353    #[test]
354    fn ai_provider_clone() {
355        let provider = AiProvider::Bedrock;
356        let cloned = provider;
357        assert_eq!(provider, cloned);
358    }
359
360    #[test]
361    fn ai_provider_debug() {
362        let debug_str = format!("{:?}", AiProvider::Claude);
363        assert_eq!(debug_str, "Claude");
364    }
365
366    #[test]
367    fn ai_credential_info_debug() {
368        let info = AiCredentialInfo {
369            provider: AiProvider::Ollama,
370            model: "llama2".to_string(),
371        };
372        let debug_str = format!("{info:?}");
373        assert!(debug_str.contains("Ollama"));
374        assert!(debug_str.contains("llama2"));
375    }
376
377    #[test]
378    fn claude_default_model_from_registry() {
379        let mut guard = EnvGuard::new();
380        // Enable Claude API path with a dummy key, no model override
381        guard.remove("USE_OPENAI");
382        guard.remove("USE_OLLAMA");
383        guard.remove("CLAUDE_CODE_USE_BEDROCK");
384        guard.remove("ANTHROPIC_MODEL");
385        guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
386
387        let info = check_ai_credentials(None).unwrap();
388        assert_eq!(info.provider, AiProvider::Claude);
389        assert_eq!(info.model, "claude-sonnet-4-6");
390    }
391
392    #[test]
393    fn openai_default_model_from_registry() {
394        let mut guard = EnvGuard::new();
395        guard.set("USE_OPENAI", "true");
396        guard.remove("USE_OLLAMA");
397        guard.remove("OPENAI_MODEL");
398        guard.set("OPENAI_API_KEY", "sk-test-dummy");
399
400        let info = check_ai_credentials(None).unwrap();
401        assert_eq!(info.provider, AiProvider::OpenAi);
402        assert_eq!(info.model, "gpt-5-mini");
403    }
404
405    #[test]
406    fn bedrock_default_model_from_registry() {
407        let mut guard = EnvGuard::new();
408        guard.remove("USE_OPENAI");
409        guard.remove("USE_OLLAMA");
410        guard.set("CLAUDE_CODE_USE_BEDROCK", "true");
411        guard.remove("ANTHROPIC_MODEL");
412        guard.set("ANTHROPIC_AUTH_TOKEN", "test-token");
413        guard.set("ANTHROPIC_BEDROCK_BASE_URL", "https://bedrock.example.com");
414
415        let info = check_ai_credentials(None).unwrap();
416        assert_eq!(info.provider, AiProvider::Bedrock);
417        assert_eq!(info.model, "claude-sonnet-4-6");
418    }
419
420    #[test]
421    fn model_override_takes_precedence() {
422        let mut guard = EnvGuard::new();
423        guard.remove("USE_OPENAI");
424        guard.remove("USE_OLLAMA");
425        guard.remove("CLAUDE_CODE_USE_BEDROCK");
426        guard.remove("ANTHROPIC_MODEL");
427        guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
428
429        let info = check_ai_credentials(Some("claude-opus-4-6")).unwrap();
430        assert_eq!(info.model, "claude-opus-4-6");
431    }
432}