Skip to main content

omni_dev/utils/
preflight.rs

1//! Preflight validation checks for early failure detection.
2//!
3//! This module provides functions to validate required services and credentials
4//! before starting expensive operations. Commands should call these checks early
5//! to fail fast with clear error messages.
6
7use anyhow::{bail, Context, Result};
8
9use crate::claude::model_config::get_model_registry;
10
11/// Result of AI credential validation.
12#[derive(Debug)]
13pub struct AiCredentialInfo {
14    /// The AI provider that will be used.
15    pub provider: AiProvider,
16    /// The model that will be used.
17    pub model: String,
18}
19
20/// AI provider types.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AiProvider {
23    /// Anthropic Claude API.
24    Claude,
25    /// AWS Bedrock with Claude.
26    Bedrock,
27    /// OpenAI API.
28    OpenAi,
29    /// Local Ollama.
30    Ollama,
31}
32
33impl std::fmt::Display for AiProvider {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::Claude => write!(f, "Claude API"),
37            Self::Bedrock => write!(f, "AWS Bedrock"),
38            Self::OpenAi => write!(f, "OpenAI API"),
39            Self::Ollama => write!(f, "Ollama"),
40        }
41    }
42}
43
44/// Validates that AI credentials are available before processing.
45///
46/// This performs a lightweight check of environment variables without
47/// creating a full AI client. Use this at the start of commands that
48/// require AI to fail fast if credentials are missing.
49pub fn check_ai_credentials(model_override: Option<&str>) -> Result<AiCredentialInfo> {
50    use crate::utils::settings::{get_env_var, get_env_vars};
51
52    // Check provider selection flags
53    let use_openai = get_env_var("USE_OPENAI").is_ok_and(|val| val == "true");
54
55    let use_ollama = get_env_var("USE_OLLAMA").is_ok_and(|val| val == "true");
56
57    let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK").is_ok_and(|val| val == "true");
58
59    // Check Ollama (no credentials required, just model)
60    if use_ollama {
61        let model = model_override
62            .map(String::from)
63            .or_else(|| get_env_var("OLLAMA_MODEL").ok())
64            .unwrap_or_else(|| "llama2".to_string());
65
66        return Ok(AiCredentialInfo {
67            provider: AiProvider::Ollama,
68            model,
69        });
70    }
71
72    // Check OpenAI
73    if use_openai {
74        let registry = get_model_registry();
75        let model = model_override
76            .map(String::from)
77            .or_else(|| get_env_var("OPENAI_MODEL").ok())
78            .unwrap_or_else(|| {
79                registry
80                    .get_default_model("openai")
81                    .unwrap_or("gpt-5")
82                    .to_string()
83            });
84
85        // Verify API key exists
86        get_env_vars(&["OPENAI_API_KEY", "OPENAI_AUTH_TOKEN"]).map_err(|_| {
87            anyhow::anyhow!(
88                "OpenAI API key not found.\n\
89                 Set one of these environment variables:\n\
90                 - OPENAI_API_KEY\n\
91                 - OPENAI_AUTH_TOKEN"
92            )
93        })?;
94
95        return Ok(AiCredentialInfo {
96            provider: AiProvider::OpenAi,
97            model,
98        });
99    }
100
101    // Check Bedrock
102    if use_bedrock {
103        let registry = get_model_registry();
104        let model = model_override
105            .map(String::from)
106            .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
107            .unwrap_or_else(|| {
108                registry
109                    .get_default_model("claude")
110                    .unwrap_or("claude-sonnet-4-6")
111                    .to_string()
112            });
113
114        // Verify Bedrock configuration
115        get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| {
116            anyhow::anyhow!(
117                "AWS Bedrock authentication not configured.\n\
118                 Set ANTHROPIC_AUTH_TOKEN environment variable."
119            )
120        })?;
121
122        get_env_var("ANTHROPIC_BEDROCK_BASE_URL").map_err(|_| {
123            anyhow::anyhow!(
124                "AWS Bedrock base URL not configured.\n\
125                 Set ANTHROPIC_BEDROCK_BASE_URL environment variable."
126            )
127        })?;
128
129        return Ok(AiCredentialInfo {
130            provider: AiProvider::Bedrock,
131            model,
132        });
133    }
134
135    // Default: Claude API
136    let registry = get_model_registry();
137    let model = model_override
138        .map(String::from)
139        .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
140        .unwrap_or_else(|| {
141            registry
142                .get_default_model("claude")
143                .unwrap_or("claude-sonnet-4-6")
144                .to_string()
145        });
146
147    // Verify API key exists
148    get_env_vars(&[
149        "CLAUDE_API_KEY",
150        "ANTHROPIC_API_KEY",
151        "ANTHROPIC_AUTH_TOKEN",
152    ])
153    .map_err(|_| {
154        anyhow::anyhow!(
155            "Claude API key not found.\n\
156                 Set one of these environment variables:\n\
157                 - CLAUDE_API_KEY\n\
158                 - ANTHROPIC_API_KEY\n\
159                 - ANTHROPIC_AUTH_TOKEN"
160        )
161    })?;
162
163    Ok(AiCredentialInfo {
164        provider: AiProvider::Claude,
165        model,
166    })
167}
168
169/// Validates that GitHub CLI is available and authenticated.
170///
171/// This checks:
172/// 1. `gh` CLI is installed and in PATH
173/// 2. User is authenticated (can access the current repo)
174///
175/// Use this at the start of commands that require GitHub API access.
176pub fn check_github_cli() -> Result<()> {
177    // Check if gh CLI is available
178    let gh_check = std::process::Command::new("gh")
179        .args(["--version"])
180        .output();
181
182    match gh_check {
183        Ok(output) if output.status.success() => {
184            // Test if gh can access the current repo
185            let repo_check = std::process::Command::new("gh")
186                .args(["repo", "view", "--json", "name"])
187                .output();
188
189            match repo_check {
190                Ok(repo_output) if repo_output.status.success() => Ok(()),
191                Ok(repo_output) => {
192                    let error_details = String::from_utf8_lossy(&repo_output.stderr);
193                    if error_details.contains("authentication") || error_details.contains("login") {
194                        bail!(
195                            "GitHub CLI authentication failed.\n\
196                             Please run 'gh auth login' or set GITHUB_TOKEN environment variable."
197                        )
198                    }
199                    bail!(
200                        "GitHub CLI cannot access this repository.\n\
201                         Error: {}",
202                        error_details.trim()
203                    )
204                }
205                Err(e) => bail!("Failed to test GitHub CLI access: {e}"),
206            }
207        }
208        _ => bail!(
209            "GitHub CLI (gh) is not installed or not in PATH.\n\
210             Please install it from https://cli.github.com/"
211        ),
212    }
213}
214
215/// Validates that the current directory is in a valid git repository.
216///
217/// This is a lightweight check that opens the repository without
218/// loading any commit data.
219pub fn check_git_repository() -> Result<()> {
220    crate::git::GitRepository::open().context(
221        "Not in a git repository. Please run this command from within a git repository.",
222    )?;
223    Ok(())
224}
225
226/// Validates that the working directory is clean (no uncommitted changes).
227///
228/// This checks for:
229/// - Staged changes
230/// - Unstaged modifications
231/// - Untracked files (excluding ignored files)
232///
233/// Use this before operations that require a clean working directory,
234/// like amending commits.
235pub fn check_working_directory_clean() -> Result<()> {
236    let repo = crate::git::GitRepository::open().context("Failed to open git repository")?;
237
238    let status = repo
239        .get_working_directory_status()
240        .context("Failed to get working directory status")?;
241
242    if !status.clean {
243        let mut message = String::from("Working directory has uncommitted changes:\n");
244        for change in &status.untracked_changes {
245            message.push_str(&format!("  {} {}\n", change.status, change.file));
246        }
247        message.push_str("\nPlease commit or stash your changes before proceeding.");
248        bail!(message);
249    }
250
251    Ok(())
252}
253
254/// Performs combined preflight check for AI commands.
255///
256/// Validates:
257/// - Git repository access
258/// - AI credentials
259///
260/// Returns information about the AI provider that will be used.
261pub fn check_ai_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
262    check_git_repository()?;
263    check_ai_credentials(model_override)
264}
265
266/// Performs combined preflight check for PR creation.
267///
268/// Validates:
269/// - Git repository access
270/// - AI credentials
271/// - GitHub CLI availability and authentication
272///
273/// Returns information about the AI provider that will be used.
274pub fn check_pr_command_prerequisites(model_override: Option<&str>) -> Result<AiCredentialInfo> {
275    check_git_repository()?;
276    let ai_info = check_ai_credentials(model_override)?;
277    check_github_cli()?;
278    Ok(ai_info)
279}
280
281#[cfg(test)]
282#[allow(clippy::unwrap_used, clippy::expect_used)]
283mod tests {
284    use super::*;
285
286    use std::env;
287    use std::sync::Mutex;
288    use std::sync::OnceLock;
289
290    /// Global lock to ensure environment variable tests don't interfere with each other.
291    static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
292
293    /// Manages environment variables in tests to avoid interference.
294    struct EnvGuard {
295        _lock: std::sync::MutexGuard<'static, ()>,
296        vars: Vec<(String, Option<String>)>,
297    }
298
299    impl EnvGuard {
300        fn new() -> Self {
301            let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
302            Self {
303                _lock: lock,
304                vars: Vec::new(),
305            }
306        }
307
308        fn set(&mut self, key: &str, value: &str) {
309            let original = env::var(key).ok();
310            self.vars.push((key.to_string(), original));
311            env::set_var(key, value);
312        }
313
314        fn remove(&mut self, key: &str) {
315            let original = env::var(key).ok();
316            self.vars.push((key.to_string(), original));
317            env::remove_var(key);
318        }
319    }
320
321    impl Drop for EnvGuard {
322        fn drop(&mut self) {
323            for (key, original_value) in self.vars.drain(..).rev() {
324                match original_value {
325                    Some(value) => env::set_var(&key, value),
326                    None => env::remove_var(&key),
327                }
328            }
329        }
330    }
331
332    #[test]
333    fn ai_provider_display() {
334        assert_eq!(format!("{}", AiProvider::Claude), "Claude API");
335        assert_eq!(format!("{}", AiProvider::Bedrock), "AWS Bedrock");
336        assert_eq!(format!("{}", AiProvider::OpenAi), "OpenAI API");
337        assert_eq!(format!("{}", AiProvider::Ollama), "Ollama");
338    }
339
340    #[test]
341    fn ai_provider_equality() {
342        assert_eq!(AiProvider::Claude, AiProvider::Claude);
343        assert_ne!(AiProvider::Claude, AiProvider::OpenAi);
344        assert_ne!(AiProvider::Bedrock, AiProvider::Ollama);
345    }
346
347    #[test]
348    fn ai_provider_clone() {
349        let provider = AiProvider::Bedrock;
350        let cloned = provider;
351        assert_eq!(provider, cloned);
352    }
353
354    #[test]
355    fn ai_provider_debug() {
356        let debug_str = format!("{:?}", AiProvider::Claude);
357        assert_eq!(debug_str, "Claude");
358    }
359
360    #[test]
361    fn ai_credential_info_debug() {
362        let info = AiCredentialInfo {
363            provider: AiProvider::Ollama,
364            model: "llama2".to_string(),
365        };
366        let debug_str = format!("{info:?}");
367        assert!(debug_str.contains("Ollama"));
368        assert!(debug_str.contains("llama2"));
369    }
370
371    #[test]
372    fn claude_default_model_from_registry() {
373        let mut guard = EnvGuard::new();
374        // Enable Claude API path with a dummy key, no model override
375        guard.remove("USE_OPENAI");
376        guard.remove("USE_OLLAMA");
377        guard.remove("CLAUDE_CODE_USE_BEDROCK");
378        guard.remove("ANTHROPIC_MODEL");
379        guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
380
381        let info = check_ai_credentials(None).unwrap();
382        assert_eq!(info.provider, AiProvider::Claude);
383        assert_eq!(info.model, "claude-sonnet-4-6");
384    }
385
386    #[test]
387    fn openai_default_model_from_registry() {
388        let mut guard = EnvGuard::new();
389        guard.set("USE_OPENAI", "true");
390        guard.remove("USE_OLLAMA");
391        guard.remove("OPENAI_MODEL");
392        guard.set("OPENAI_API_KEY", "sk-test-dummy");
393
394        let info = check_ai_credentials(None).unwrap();
395        assert_eq!(info.provider, AiProvider::OpenAi);
396        assert_eq!(info.model, "gpt-5-mini");
397    }
398
399    #[test]
400    fn bedrock_default_model_from_registry() {
401        let mut guard = EnvGuard::new();
402        guard.remove("USE_OPENAI");
403        guard.remove("USE_OLLAMA");
404        guard.set("CLAUDE_CODE_USE_BEDROCK", "true");
405        guard.remove("ANTHROPIC_MODEL");
406        guard.set("ANTHROPIC_AUTH_TOKEN", "test-token");
407        guard.set("ANTHROPIC_BEDROCK_BASE_URL", "https://bedrock.example.com");
408
409        let info = check_ai_credentials(None).unwrap();
410        assert_eq!(info.provider, AiProvider::Bedrock);
411        assert_eq!(info.model, "claude-sonnet-4-6");
412    }
413
414    #[test]
415    fn model_override_takes_precedence() {
416        let mut guard = EnvGuard::new();
417        guard.remove("USE_OPENAI");
418        guard.remove("USE_OLLAMA");
419        guard.remove("CLAUDE_CODE_USE_BEDROCK");
420        guard.remove("ANTHROPIC_MODEL");
421        guard.set("ANTHROPIC_API_KEY", "sk-test-dummy");
422
423        let info = check_ai_credentials(Some("claude-opus-4-6")).unwrap();
424        assert_eq!(info.model, "claude-opus-4-6");
425    }
426}