Skip to main content

nika_cli/
verbs.rs

1//! Direct verb execution from CLI — no YAML needed.
2//!
3//! Handles `nika infer`, `nika fetch`, `nika invoke`, `nika agent` commands.
4//! Each verb creates a one-shot TaskExecutor and dispatches a single TaskAction.
5
6use std::io::{IsTerminal, Read};
7use std::sync::Arc;
8use std::time::Instant;
9
10use colored::Colorize;
11
12use nika_engine::ast::output::{OutputFormat, OutputPolicy};
13use nika_engine::ast::{
14    AgentParams, FetchParams, InferParams, InvokeParams, ResponseFormat, TaskAction,
15};
16use nika_engine::binding::ResolvedBindings;
17use nika_engine::error::NikaError;
18use nika_engine::event::EventLog;
19use nika_engine::runtime::TaskExecutor;
20use nika_engine::store::RunContext;
21
22// ═══════════════════════════════════════════════════════════════════════════
23// PROVIDER AUTO-DETECTION
24// ═══════════════════════════════════════════════════════════════════════════
25
26/// Auto-detect provider from environment variables (priority order).
27///
28/// Uses the canonical provider catalog from `nika_core::catalogs::providers`
29/// instead of hardcoding env var mappings.
30pub fn detect_provider() -> Option<String> {
31    use nika_engine::core::providers::{providers_by_category, ProviderCategory};
32    for provider in providers_by_category(ProviderCategory::Llm) {
33        if provider.has_env_key() {
34            return Some(provider.id.to_string());
35        }
36    }
37    None
38}
39
40/// Resolve the default model name for a provider (for display purposes).
41///
42/// Returns the well-known default model for each LLM provider so the CLI
43/// header shows the actual model name instead of "(default)".
44fn default_model_for_provider(provider: &str) -> &'static str {
45    match provider {
46        "anthropic" => "claude-sonnet-4-6",
47        "openai" => "gpt-4o",
48        "mistral" => "mistral-large-latest",
49        "groq" => "llama-3.3-70b-versatile",
50        "deepseek" => "deepseek-chat",
51        "gemini" => "gemini-2.0-flash",
52        "xai" => "grok-3-fast",
53        _ => "default",
54    }
55}
56
57/// Parse composite model identifier: "anthropic/claude-sonnet" → (Some("anthropic"), "claude-sonnet")
58pub fn parse_composite_model(model: &str) -> Result<(Option<&str>, &str), NikaError> {
59    match model.split_once('/') {
60        Some((provider, model_name)) => {
61            if provider.is_empty() || model_name.is_empty() || model_name.contains('/') {
62                Err(NikaError::ValidationError {
63                    reason: format!(
64                        "Invalid model format '{}'. Expected 'provider/model' or 'model'",
65                        model
66                    ),
67                })
68            } else {
69                Ok((Some(provider), model_name))
70            }
71        }
72        None => Ok((None, model)),
73    }
74}
75
76/// Resolve provider + model from flags with auto-detection.
77pub(crate) fn resolve_provider_model(
78    provider_flag: Option<&str>,
79    model_flag: Option<&str>,
80) -> Result<(String, Option<String>), NikaError> {
81    // If model has composite syntax (provider/model), extract both
82    if let Some(model) = model_flag {
83        let (composite_provider, model_name) = parse_composite_model(model)?;
84        if let Some(cp) = composite_provider {
85            return Ok((cp.to_string(), Some(model_name.to_string())));
86        }
87        // Model specified without provider — use provider flag or auto-detect
88        let provider = provider_flag
89            .map(|s| s.to_string())
90            .or_else(detect_provider)
91            .ok_or_else(|| NikaError::ValidationError {
92                reason: "No provider configured. Set an API key env var or use -p <provider>\n\
93                         Fix: export ANTHROPIC_API_KEY=sk-ant-..."
94                    .to_string(),
95            })?;
96        return Ok((provider, Some(model_name.to_string())));
97    }
98
99    // No model specified — use provider flag or auto-detect
100    let provider = provider_flag
101        .map(|s| s.to_string())
102        .or_else(detect_provider)
103        .ok_or_else(|| NikaError::ValidationError {
104            reason: "No provider configured. Set an API key env var or use -p <provider>\n\
105                     Fix: export ANTHROPIC_API_KEY=sk-ant-..."
106                .to_string(),
107        })?;
108    Ok((provider, None))
109}
110
111/// Create a one-shot TaskExecutor.
112async fn one_shot_executor(
113    provider: &str,
114    model: Option<&str>,
115) -> Result<(TaskExecutor, EventLog), NikaError> {
116    let event_log = EventLog::new();
117    let executor = TaskExecutor::new(provider, model, None, event_log.clone())?;
118    Ok((executor, event_log))
119}
120
121/// Read stdin content (spawn_blocking + 10MB limit to prevent OOM).
122async fn read_stdin_content() -> Result<String, NikaError> {
123    const MAX_STDIN_SIZE: u64 = 10 * 1024 * 1024; // 10 MB
124    tokio::task::spawn_blocking(|| {
125        let mut buf = String::new();
126        std::io::stdin()
127            .take(MAX_STDIN_SIZE)
128            .read_to_string(&mut buf)
129            .map_err(|e| NikaError::ParseError {
130                details: format!("Failed to read stdin: {}", e),
131            })?;
132        Ok(buf)
133    })
134    .await
135    .map_err(|e| NikaError::ParseError {
136        details: format!("stdin reader panicked: {}", e),
137    })?
138}
139
140// ═══════════════════════════════════════════════════════════════════════════
141// VISUAL OUTPUT
142// ═══════════════════════════════════════════════════════════════════════════
143
144fn print_header(label: &str, is_tty: bool) {
145    if is_tty {
146        eprintln!("\n  {} {}", "┌─".dimmed(), label.cyan());
147    }
148}
149
150fn print_footer(elapsed: std::time::Duration, extra: &str, is_tty: bool) {
151    if is_tty {
152        eprintln!(
153            "  {} {}",
154            "└─".dimmed(),
155            format!("{}ms{}", elapsed.as_millis(), extra).dimmed()
156        );
157        eprintln!();
158    }
159}
160
161// ═══════════════════════════════════════════════════════════════════════════
162// VERB HANDLERS
163// ═══════════════════════════════════════════════════════════════════════════
164
165/// Handle `nika infer "prompt"` — one-shot LLM call.
166#[allow(clippy::too_many_arguments)]
167pub async fn handle_infer(
168    prompt: String,
169    provider: Option<String>,
170    model: Option<String>,
171    system: Option<String>,
172    temperature: Option<f64>,
173    max_tokens: Option<u32>,
174    json_mode: bool,
175    from_example: Option<String>,
176    read_stdin: bool,
177    quiet: bool,
178) -> Result<(), NikaError> {
179    let is_tty = std::io::stdout().is_terminal();
180
181    // Read stdin if requested
182    let full_prompt = if read_stdin || prompt == "-" {
183        let stdin_content = read_stdin_content().await?;
184        if prompt == "-" {
185            stdin_content
186        } else {
187            format!("{}\n\n{}", stdin_content.trim(), prompt)
188        }
189    } else {
190        prompt
191    };
192
193    // Resolve provider + model
194    let (provider_name, model_name) =
195        resolve_provider_model(provider.as_deref(), model.as_deref())?;
196
197    // Build InferParams
198    let infer = InferParams {
199        prompt: full_prompt,
200        system,
201        temperature,
202        max_tokens,
203        response_format: if json_mode {
204            Some(ResponseFormat::Json)
205        } else {
206            None
207        },
208        ..Default::default()
209    };
210
211    // Build output policy for structured output (from_example)
212    let output_policy = if let Some(ref example) = from_example {
213        let spec = if example.starts_with('{') || example.starts_with('[') {
214            let value: serde_json::Value =
215                serde_json::from_str(example).map_err(|e| NikaError::ParseError {
216                    details: format!("Invalid JSON in --from-example: {}", e),
217                })?;
218            nika_engine::ast::StructuredOutputSpec::with_example_inline(value)
219        } else {
220            nika_engine::ast::StructuredOutputSpec::with_example_file(example)
221        };
222        Some(spec.to_output_policy())
223    } else if json_mode {
224        Some(OutputPolicy {
225            format: OutputFormat::Json,
226            schema: None,
227            from_example: None,
228            max_retries: None,
229            source_structured_spec: None,
230        })
231    } else {
232        None
233    };
234
235    let action = TaskAction::Infer { infer };
236    let task_id: Arc<str> = Arc::from("cli");
237
238    // Show header — resolve actual default model name instead of "(default)"
239    let display_model = model_name
240        .as_deref()
241        .unwrap_or_else(|| default_model_for_provider(&provider_name));
242    if !quiet {
243        print_header(&format!("{} via {}", display_model, provider_name), is_tty);
244    }
245
246    // Execute
247    let (executor, event_log) = one_shot_executor(&provider_name, model_name.as_deref()).await?;
248    let bindings = ResolvedBindings::new();
249    let datastore = RunContext::new();
250    let start = Instant::now();
251    let output = executor
252        .execute(
253            &task_id,
254            &action,
255            &bindings,
256            &datastore,
257            output_policy.as_ref(),
258        )
259        .await?;
260    let elapsed = start.elapsed();
261
262    // Print output
263    println!("{output}");
264
265    // Cost footer (TTY only)
266    if !quiet {
267        let events = event_log.events();
268        let mut tokens = 0u64;
269        let mut cost = 0.0f64;
270        for event in events.iter().rev() {
271            if let nika_engine::event::EventKind::ProviderResponded {
272                input_tokens,
273                output_tokens,
274                cost_usd,
275                ..
276            } = &event.kind
277            {
278                tokens = input_tokens + output_tokens;
279                cost = *cost_usd;
280                break;
281            }
282        }
283        let extra = if tokens > 0 {
284            format!(" · {} tokens · ${:.4}", tokens, cost)
285        } else {
286            String::new()
287        };
288        print_footer(elapsed, &extra, is_tty);
289    }
290
291    Ok(())
292}
293
294/// Handle `nika fetch URL` — HTTP request with extraction.
295#[allow(clippy::too_many_arguments)]
296pub async fn handle_fetch(
297    url: String,
298    extract: Option<String>,
299    selector: Option<String>,
300    method: Option<String>,
301    headers: Vec<String>,
302    body: Option<String>,
303    json_body: Option<String>,
304    response: Option<String>,
305    timeout: Option<u64>,
306    quiet: bool,
307) -> Result<(), NikaError> {
308    let is_tty = std::io::stdout().is_terminal();
309
310    // Parse headers
311    let mut header_map = rustc_hash::FxHashMap::default();
312    for h in &headers {
313        let (key, value) = h
314            .split_once(':')
315            .ok_or_else(|| NikaError::ValidationError {
316                reason: format!("Invalid header '{}', expected KEY:VALUE", h),
317            })?;
318        header_map.insert(key.trim().to_string(), value.trim().to_string());
319    }
320
321    // Parse JSON body
322    let json_value = json_body
323        .map(|j| serde_json::from_str(&j))
324        .transpose()
325        .map_err(|e| NikaError::ParseError {
326            details: format!("Invalid --json-body: {}", e),
327        })?;
328
329    let fetch = FetchParams {
330        url: url.clone(),
331        method: method.unwrap_or_else(|| "GET".to_string()),
332        headers: header_map,
333        body,
334        json: json_value,
335        timeout,
336        extract: extract.clone(),
337        selector,
338        response,
339        retry: None,
340        follow_redirects: None,
341    };
342
343    let action = TaskAction::Fetch { fetch };
344    let task_id: Arc<str> = Arc::from("cli");
345
346    let extract_label = extract.as_deref().unwrap_or("raw");
347    if !quiet {
348        print_header(&format!("{} · {}", url, extract_label), is_tty);
349    }
350
351    // Fetch doesn't need a real LLM provider — use "mock"
352    let (executor, _) = one_shot_executor("mock", None).await?;
353    let bindings = ResolvedBindings::new();
354    let datastore = RunContext::new();
355    let start = Instant::now();
356    let output = executor
357        .execute(&task_id, &action, &bindings, &datastore, None)
358        .await?;
359    let elapsed = start.elapsed();
360
361    println!("{output}");
362
363    if !quiet {
364        let extra = format!(" · {} bytes", output.len());
365        print_footer(elapsed, &extra, is_tty);
366    }
367
368    Ok(())
369}
370
371/// Handle `nika invoke tool` — call builtin nika:* or MCP tool.
372pub async fn handle_invoke(
373    tool: String,
374    file: Option<String>,
375    params: Option<String>,
376    mcp: Option<String>,
377    timeout: Option<u64>,
378    list_tools: bool,
379    quiet: bool,
380) -> Result<(), NikaError> {
381    // --list: show available tools and exit
382    if list_tools {
383        println!("{}", "Builtin Tools (nika:*)".bold());
384        println!("{}", "─".repeat(50));
385        println!("  Tier 1 (always on):");
386        for t in [
387            "import",
388            "dimensions",
389            "thumbhash",
390            "dominant_color",
391            "pipeline",
392        ] {
393            println!("    nika:{}", t);
394        }
395        println!("  Tier 2 (media-core):");
396        for t in [
397            "thumbnail",
398            "convert",
399            "strip",
400            "metadata",
401            "optimize",
402            "svg_render",
403        ] {
404            println!("    nika:{}", t);
405        }
406        println!("  Tier 3 (opt-in):");
407        for t in [
408            "phash",
409            "compare",
410            "pdf_extract",
411            "chart",
412            "provenance",
413            "verify",
414            "qr_validate",
415            "quality",
416            "html_to_md",
417            "css_select",
418            "extract_metadata",
419            "extract_links",
420            "readability",
421        ] {
422            println!("    nika:{}", t);
423        }
424        println!();
425        println!("  Use: nika invoke nika:<tool> [file] [--params JSON]");
426        println!("  Full details: nika media tools");
427        return Ok(());
428    }
429
430    let is_tty = std::io::stdout().is_terminal();
431
432    // Build params: merge positional file arg as "source"
433    let mut tool_params = if let Some(ref p) = params {
434        serde_json::from_str(p).map_err(|e| NikaError::ParseError {
435            details: format!("Invalid --params JSON: {}", e),
436        })?
437    } else {
438        serde_json::json!({})
439    };
440
441    if let Some(ref f) = file {
442        if let Some(obj) = tool_params.as_object_mut() {
443            if !obj.contains_key("source") {
444                obj.insert("source".to_string(), serde_json::json!(f));
445            }
446        }
447    }
448
449    // Parse tool name: "server::tool" or "nika:tool"
450    let (mcp_name, tool_name) = if tool.contains("::") {
451        let (s, t) = tool.split_once("::").unwrap();
452        (Some(s.to_string()), t.to_string())
453    } else {
454        (mcp, tool.clone())
455    };
456
457    let invoke = InvokeParams {
458        tool: Some(tool_name),
459        params: Some(tool_params),
460        mcp: mcp_name,
461        resource: None,
462        timeout,
463    };
464
465    let action = TaskAction::Invoke { invoke };
466    let task_id: Arc<str> = Arc::from("cli");
467
468    if !quiet {
469        print_header(&tool, is_tty);
470    }
471
472    let (executor, _) = one_shot_executor("mock", None).await?;
473    let bindings = ResolvedBindings::new();
474    let datastore = RunContext::new();
475    let start = Instant::now();
476    let output = executor
477        .execute(&task_id, &action, &bindings, &datastore, None)
478        .await?;
479    let elapsed = start.elapsed();
480
481    println!("{output}");
482
483    if !quiet {
484        print_footer(elapsed, "", is_tty);
485    }
486
487    Ok(())
488}
489
490/// Handle `nika agent "prompt"` — multi-turn AI agent.
491#[allow(clippy::too_many_arguments)]
492pub async fn handle_agent(
493    prompt: String,
494    provider: Option<String>,
495    model: Option<String>,
496    system: Option<String>,
497    tools: Vec<String>,
498    mcp_servers: Vec<String>,
499    turns: u32,
500    max_tokens: Option<u32>,
501    temperature: Option<f64>,
502    read_stdin: bool,
503    quiet: bool,
504) -> Result<(), NikaError> {
505    let is_tty = std::io::stdout().is_terminal();
506
507    let full_prompt = if read_stdin || prompt == "-" {
508        let stdin_content = read_stdin_content().await?;
509        if prompt == "-" {
510            stdin_content
511        } else {
512            format!("{}\n\n{}", stdin_content.trim(), prompt)
513        }
514    } else {
515        prompt
516    };
517
518    let (provider_name, model_name) =
519        resolve_provider_model(provider.as_deref(), model.as_deref())?;
520
521    let agent = AgentParams {
522        prompt: full_prompt,
523        system,
524        provider: Some(provider_name.clone()),
525        model: model_name.clone(),
526        tools,
527        mcp: mcp_servers,
528        max_turns: Some(turns),
529        max_tokens,
530        temperature: temperature.map(|t| t as f32),
531        ..Default::default()
532    };
533
534    let action = TaskAction::Agent { agent };
535    let task_id: Arc<str> = Arc::from("cli");
536
537    // Resolve actual default model name instead of "(default)"
538    let display_model = model_name
539        .as_deref()
540        .unwrap_or_else(|| default_model_for_provider(&provider_name));
541    if !quiet {
542        print_header(
543            &format!(
544                "agent · {} via {} · {} turns",
545                display_model, provider_name, turns
546            ),
547            is_tty,
548        );
549    }
550
551    let (executor, event_log) = one_shot_executor(&provider_name, model_name.as_deref()).await?;
552    let bindings = ResolvedBindings::new();
553    let datastore = RunContext::new();
554    let start = Instant::now();
555    let output = executor
556        .execute(&task_id, &action, &bindings, &datastore, None)
557        .await?;
558    let elapsed = start.elapsed();
559
560    println!("{output}");
561
562    if !quiet {
563        let events = event_log.events();
564        let mut total_tokens = 0u64;
565        let mut total_cost = 0.0f64;
566        for event in &events {
567            if let nika_engine::event::EventKind::ProviderResponded {
568                input_tokens,
569                output_tokens,
570                cost_usd,
571                ..
572            } = &event.kind
573            {
574                total_tokens += input_tokens + output_tokens;
575                total_cost += cost_usd;
576            }
577        }
578        let extra = if total_tokens > 0 {
579            format!(" · {} tokens · ${:.4}", total_tokens, total_cost)
580        } else {
581            String::new()
582        };
583        print_footer(elapsed, &extra, is_tty);
584    }
585
586    Ok(())
587}
588
589// ═══════════════════════════════════════════════════════════════════════════
590// TESTS
591// ═══════════════════════════════════════════════════════════════════════════
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    use std::sync::Mutex;
597
598    /// All env-var-mutating tests must hold this lock so they don't race.
599    static ENV_LOCK: Mutex<()> = Mutex::new(());
600
601    /// Provider env var names in priority order (must match `detect_provider`).
602    const PROVIDER_VARS: [&str; 7] = [
603        "ANTHROPIC_API_KEY",
604        "OPENAI_API_KEY",
605        "MISTRAL_API_KEY",
606        "GROQ_API_KEY",
607        "DEEPSEEK_API_KEY",
608        "GEMINI_API_KEY",
609        "XAI_API_KEY",
610    ];
611
612    /// Clear all provider env vars, returning their previous values for restore.
613    fn clear_provider_env() -> Vec<(String, Option<String>)> {
614        PROVIDER_VARS
615            .iter()
616            .map(|var| {
617                let prev = std::env::var(var).ok();
618                // SAFETY: single-threaded access guaranteed by ENV_LOCK mutex
619                unsafe { std::env::remove_var(var) };
620                (var.to_string(), prev)
621            })
622            .collect()
623    }
624
625    /// Restore previously saved env vars.
626    fn restore_provider_env(saved: Vec<(String, Option<String>)>) {
627        for (var, val) in saved {
628            match val {
629                // SAFETY: single-threaded access guaranteed by ENV_LOCK mutex
630                Some(v) => unsafe { std::env::set_var(&var, &v) },
631                None => unsafe { std::env::remove_var(&var) },
632            }
633        }
634    }
635
636    // ───────────────────────────────────────────────────────────────────────
637    // parse_composite_model
638    // ───────────────────────────────────────────────────────────────────────
639
640    #[test]
641    fn parse_composite_provider_and_model() {
642        let (provider, model) = parse_composite_model("anthropic/claude-sonnet").unwrap();
643        assert_eq!(provider, Some("anthropic"));
644        assert_eq!(model, "claude-sonnet");
645    }
646
647    #[test]
648    fn parse_composite_plain_model() {
649        let (provider, model) = parse_composite_model("gpt-4o").unwrap();
650        assert_eq!(provider, None);
651        assert_eq!(model, "gpt-4o");
652    }
653
654    #[test]
655    fn parse_composite_empty_string() {
656        let (provider, model) = parse_composite_model("").unwrap();
657        assert_eq!(provider, None);
658        assert_eq!(model, "");
659    }
660
661    #[test]
662    fn parse_composite_slash_only_is_error() {
663        let err = parse_composite_model("/").unwrap_err();
664        assert!(
665            err.to_string().contains("Invalid model format"),
666            "expected validation error, got: {}",
667            err
668        );
669    }
670
671    #[test]
672    fn parse_composite_empty_model_is_error() {
673        let err = parse_composite_model("anthropic/").unwrap_err();
674        assert!(
675            err.to_string().contains("Invalid model format"),
676            "expected validation error, got: {}",
677            err
678        );
679    }
680
681    #[test]
682    fn parse_composite_empty_provider_is_error() {
683        let err = parse_composite_model("/claude").unwrap_err();
684        assert!(
685            err.to_string().contains("Invalid model format"),
686            "expected validation error, got: {}",
687            err
688        );
689    }
690
691    #[test]
692    fn parse_composite_multiple_slashes_is_error() {
693        let err = parse_composite_model("a/b/c").unwrap_err();
694        assert!(
695            err.to_string().contains("Invalid model format"),
696            "expected validation error, got: {}",
697            err
698        );
699    }
700
701    #[test]
702    fn parse_composite_openai_with_model() {
703        let (provider, model) = parse_composite_model("openai/gpt-4.1").unwrap();
704        assert_eq!(provider, Some("openai"));
705        assert_eq!(model, "gpt-4.1");
706    }
707
708    // ───────────────────────────────────────────────────────────────────────
709    // detect_provider
710    // ───────────────────────────────────────────────────────────────────────
711
712    #[test]
713    fn detect_provider_returns_none_when_no_keys() {
714        let _guard = ENV_LOCK.lock().unwrap();
715        let saved = clear_provider_env();
716
717        let result = detect_provider();
718        assert_eq!(result, None);
719
720        restore_provider_env(saved);
721    }
722
723    #[test]
724    fn detect_provider_returns_anthropic_when_set() {
725        let _guard = ENV_LOCK.lock().unwrap();
726        let saved = clear_provider_env();
727
728        // SAFETY: single-threaded access guaranteed by ENV_LOCK
729        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "sk-ant-test") };
730        let result = detect_provider();
731        assert_eq!(result.as_deref(), Some("anthropic"));
732
733        restore_provider_env(saved);
734    }
735
736    #[test]
737    fn detect_provider_priority_anthropic_over_openai() {
738        let _guard = ENV_LOCK.lock().unwrap();
739        let saved = clear_provider_env();
740
741        // SAFETY: single-threaded access guaranteed by ENV_LOCK
742        unsafe {
743            std::env::set_var("OPENAI_API_KEY", "sk-openai-test");
744            std::env::set_var("ANTHROPIC_API_KEY", "sk-ant-test");
745        }
746        let result = detect_provider();
747        assert_eq!(
748            result.as_deref(),
749            Some("anthropic"),
750            "anthropic should win over openai in priority order"
751        );
752
753        restore_provider_env(saved);
754    }
755
756    #[test]
757    fn detect_provider_falls_through_to_openai() {
758        let _guard = ENV_LOCK.lock().unwrap();
759        let saved = clear_provider_env();
760
761        // SAFETY: single-threaded access guaranteed by ENV_LOCK
762        unsafe { std::env::set_var("OPENAI_API_KEY", "sk-openai-test") };
763        let result = detect_provider();
764        assert_eq!(result.as_deref(), Some("openai"));
765
766        restore_provider_env(saved);
767    }
768
769    #[test]
770    fn detect_provider_ignores_empty_and_whitespace_keys() {
771        let _guard = ENV_LOCK.lock().unwrap();
772        let saved = clear_provider_env();
773
774        // SAFETY: single-threaded access guaranteed by ENV_LOCK
775        unsafe {
776            std::env::set_var("ANTHROPIC_API_KEY", "");
777            std::env::set_var("OPENAI_API_KEY", "   ");
778            std::env::set_var("MISTRAL_API_KEY", "sk-mis-test");
779        }
780        let result = detect_provider();
781        assert_eq!(
782            result.as_deref(),
783            Some("mistral"),
784            "empty/whitespace keys should be skipped"
785        );
786
787        restore_provider_env(saved);
788    }
789
790    #[test]
791    fn detect_provider_xai_last_resort() {
792        let _guard = ENV_LOCK.lock().unwrap();
793        let saved = clear_provider_env();
794
795        // SAFETY: single-threaded access guaranteed by ENV_LOCK
796        unsafe { std::env::set_var("XAI_API_KEY", "xai-test") };
797        let result = detect_provider();
798        assert_eq!(result.as_deref(), Some("xai"));
799
800        restore_provider_env(saved);
801    }
802
803    // ───────────────────────────────────────────────────────────────────────
804    // resolve_provider_model
805    // ───────────────────────────────────────────────────────────────────────
806
807    #[test]
808    fn resolve_explicit_provider_flag_wins() {
809        let _guard = ENV_LOCK.lock().unwrap();
810        let saved = clear_provider_env();
811
812        // Even with ANTHROPIC_API_KEY set, explicit flag should win
813        // SAFETY: single-threaded access guaranteed by ENV_LOCK
814        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "sk-ant-test") };
815
816        let (provider, model) = resolve_provider_model(Some("openai"), Some("gpt-4o")).unwrap();
817        assert_eq!(provider, "openai");
818        assert_eq!(model.as_deref(), Some("gpt-4o"));
819
820        restore_provider_env(saved);
821    }
822
823    #[test]
824    fn resolve_composite_model_extracts_both() {
825        let _guard = ENV_LOCK.lock().unwrap();
826        let saved = clear_provider_env();
827
828        let (provider, model) =
829            resolve_provider_model(None, Some("anthropic/claude-sonnet")).unwrap();
830        assert_eq!(provider, "anthropic");
831        assert_eq!(model.as_deref(), Some("claude-sonnet"));
832
833        restore_provider_env(saved);
834    }
835
836    #[test]
837    fn resolve_composite_model_ignores_provider_flag() {
838        // Composite syntax should take precedence over the -p flag
839        let _guard = ENV_LOCK.lock().unwrap();
840        let saved = clear_provider_env();
841
842        let (provider, model) =
843            resolve_provider_model(Some("openai"), Some("anthropic/claude-sonnet")).unwrap();
844        assert_eq!(
845            provider, "anthropic",
846            "composite model provider should override -p flag"
847        );
848        assert_eq!(model.as_deref(), Some("claude-sonnet"));
849
850        restore_provider_env(saved);
851    }
852
853    #[test]
854    fn resolve_model_only_uses_env_auto_detect() {
855        let _guard = ENV_LOCK.lock().unwrap();
856        let saved = clear_provider_env();
857
858        // SAFETY: single-threaded access guaranteed by ENV_LOCK
859        unsafe { std::env::set_var("GROQ_API_KEY", "gsk-test") };
860
861        let (provider, model) = resolve_provider_model(None, Some("llama-4-maverick")).unwrap();
862        assert_eq!(provider, "groq");
863        assert_eq!(model.as_deref(), Some("llama-4-maverick"));
864
865        restore_provider_env(saved);
866    }
867
868    #[test]
869    fn resolve_no_provider_no_env_is_error() {
870        let _guard = ENV_LOCK.lock().unwrap();
871        let saved = clear_provider_env();
872
873        let err = resolve_provider_model(None, Some("gpt-4o")).unwrap_err();
874        assert!(
875            err.to_string().contains("No provider configured"),
876            "expected help message, got: {}",
877            err
878        );
879
880        restore_provider_env(saved);
881    }
882
883    #[test]
884    fn resolve_no_provider_no_model_no_env_is_error() {
885        let _guard = ENV_LOCK.lock().unwrap();
886        let saved = clear_provider_env();
887
888        let err = resolve_provider_model(None, None).unwrap_err();
889        assert!(
890            err.to_string().contains("No provider configured"),
891            "expected help message, got: {}",
892            err
893        );
894
895        restore_provider_env(saved);
896    }
897
898    #[test]
899    fn resolve_provider_flag_only_no_model() {
900        let _guard = ENV_LOCK.lock().unwrap();
901        let saved = clear_provider_env();
902
903        let (provider, model) = resolve_provider_model(Some("anthropic"), None).unwrap();
904        assert_eq!(provider, "anthropic");
905        assert_eq!(model, None, "no model specified means None");
906
907        restore_provider_env(saved);
908    }
909
910    #[test]
911    fn resolve_invalid_composite_model_propagates_error() {
912        let _guard = ENV_LOCK.lock().unwrap();
913        let saved = clear_provider_env();
914
915        let err = resolve_provider_model(None, Some("a/b/c")).unwrap_err();
916        assert!(
917            err.to_string().contains("Invalid model format"),
918            "should propagate parse_composite_model error, got: {}",
919            err
920        );
921
922        restore_provider_env(saved);
923    }
924}