Skip to main content

mars_agents/cli/
models.rs

1//! CLI handlers for `mars models` subcommands.
2#![allow(clippy::print_literal)]
3
4use clap::{Parser, Subcommand};
5use indexmap::IndexMap;
6
7use crate::error::MarsError;
8use crate::models::{self, HarnessSource, ModelAlias, ModelSpec};
9use crate::types::MarsContext;
10
11/// Manage model aliases and the models cache.
12#[derive(Debug, Parser)]
13pub struct ModelsArgs {
14    #[command(subcommand)]
15    pub command: ModelsCommand,
16}
17
18#[derive(Debug, Subcommand)]
19pub enum ModelsCommand {
20    /// Fetch models from API and update the local cache.
21    Refresh,
22    /// List all model aliases (consumer + deps) with resolved IDs.
23    List(ListArgs),
24    /// Show resolution chain for a specific alias.
25    Resolve(ResolveAliasArgs),
26    /// Quick-add a pinned alias to mars.toml [models].
27    Alias(AddAliasArgs),
28}
29
30#[derive(Debug, Parser)]
31pub struct ListArgs {
32    /// Show all aliases including those without an available harness.
33    #[arg(long)]
34    all: bool,
35    /// Skip automatic models-cache refresh; use whatever's on disk (equivalent to MARS_OFFLINE=1).
36    #[arg(long)]
37    no_refresh_models: bool,
38    /// Only show aliases matching these patterns (overrides config).
39    #[arg(long, value_delimiter = ',', conflicts_with = "exclude")]
40    include: Option<Vec<String>>,
41    /// Hide aliases matching these patterns (overrides config).
42    #[arg(long, value_delimiter = ',', conflicts_with = "include")]
43    exclude: Option<Vec<String>>,
44}
45
46#[derive(Debug, Parser)]
47pub struct ResolveAliasArgs {
48    /// Alias name to resolve.
49    pub name: String,
50    /// Skip automatic models-cache refresh; use whatever's on disk (equivalent to MARS_OFFLINE=1).
51    #[arg(long)]
52    no_refresh_models: bool,
53}
54
55#[derive(Debug, Parser)]
56pub struct AddAliasArgs {
57    /// Alias name.
58    pub name: String,
59    /// Model ID to pin.
60    pub model_id: String,
61    /// Harness for this alias (default: claude).
62    #[arg(long, default_value = "claude")]
63    pub harness: String,
64    /// Optional description.
65    #[arg(long)]
66    pub description: Option<String>,
67}
68
69pub fn run(args: &ModelsArgs, ctx: &MarsContext, json: bool) -> Result<i32, MarsError> {
70    match &args.command {
71        ModelsCommand::Refresh => run_refresh(ctx, json),
72        ModelsCommand::List(args) => run_list(args, ctx, json),
73        ModelsCommand::Resolve(a) => run_resolve(a, ctx, json),
74        ModelsCommand::Alias(a) => run_alias(a, ctx, json),
75    }
76}
77
78fn mars_dir(ctx: &MarsContext) -> std::path::PathBuf {
79    ctx.project_root.join(".mars")
80}
81
82fn run_refresh(ctx: &MarsContext, json: bool) -> Result<i32, MarsError> {
83    let mars = mars_dir(ctx);
84    let ttl = models::load_models_cache_ttl(ctx);
85    eprint!("Fetching models catalog... ");
86
87    let (cache, outcome) = models::ensure_fresh(&mars, ttl, models::RefreshMode::Force)?;
88    let count = cache.models.len();
89    let cache_warning = cache_warning(&outcome);
90
91    if let Some(warning) = cache_warning.as_deref() {
92        eprintln!("warning: {warning}");
93    } else if !json {
94        eprintln!("done.");
95    }
96
97    if json {
98        let out = serde_json::json!({
99            "status": "ok",
100            "models_count": count,
101            "fetched_at": cache.fetched_at,
102        });
103        let mut out = out;
104        if let Some(warning) = cache_warning.as_deref() {
105            out["cache_warning"] = serde_json::json!(warning);
106        }
107        println!("{}", serde_json::to_string_pretty(&out).unwrap());
108    } else {
109        if cache_warning.is_some() {
110            println!(
111                "Using stale models cache with {} models in .mars/models-cache.json",
112                count
113            );
114        } else {
115            println!("Cached {} models in .mars/models-cache.json", count);
116        }
117    }
118
119    Ok(0)
120}
121
122fn run_list(args: &ListArgs, ctx: &MarsContext, json: bool) -> Result<i32, MarsError> {
123    let mars = mars_dir(ctx);
124    let ttl = models::load_models_cache_ttl(ctx);
125    let mode = models::resolve_refresh_mode(args.no_refresh_models);
126    let (cache, outcome) = match models::ensure_fresh(&mars, ttl, mode) {
127        Ok(ok) => ok,
128        Err(err @ MarsError::ModelCacheUnavailable { .. }) if json => {
129            println!(
130                "{}",
131                serde_json::to_string_pretty(&serde_json::json!({
132                    "error": format!("{err}"),
133                }))
134                .unwrap()
135            );
136            return Ok(1);
137        }
138        Err(err) => return Err(err),
139    };
140    let cache_warning = cache_warning(&outcome);
141
142    // Load config to get consumer models + trigger merge
143    let merged = load_merged_aliases(ctx)?;
144    let resolved = models::resolve_all(&merged, &cache);
145
146    // Build effective visibility: CLI overrides config entirely.
147    let config_visibility = crate::config::load(&ctx.project_root)
148        .map(|c| c.settings.model_visibility)
149        .unwrap_or_default();
150
151    let visibility = if args.include.is_some() || args.exclude.is_some() {
152        crate::config::ModelVisibility {
153            include: args.include.clone(),
154            exclude: args.exclude.clone(),
155        }
156    } else {
157        config_visibility
158    };
159
160    let resolved = models::filter_by_visibility(resolved, &visibility);
161
162    if json {
163        let entries: Vec<serde_json::Value> = resolved
164            .values()
165            .map(|r| {
166                let mode = mode_for_alias(merged.get(&r.name).map(|a| &a.spec));
167                let mut obj = serde_json::json!({
168                    "name": r.name,
169                    "harness": r.harness,
170                    "harness_source": r.harness_source,
171                    "harness_candidates": r.harness_candidates,
172                    "provider": r.provider,
173                    "mode": mode,
174                    "model_id": r.model_id,
175                    "resolved_model": r.model_id,
176                    "description": r.description,
177                });
178                if let Some(error) = unavailable_harness_error(r) {
179                    obj["error"] = serde_json::json!(error);
180                }
181                obj
182            })
183            .collect();
184        let mut out = serde_json::json!({
185            "aliases": entries,
186            "cache_available": cache.fetched_at.is_some(),
187        });
188        if let Some(warning) = cache_warning.as_deref() {
189            out["cache_warning"] = serde_json::json!(warning);
190        }
191        println!("{}", serde_json::to_string_pretty(&out).unwrap());
192    } else {
193        if let Some(warning) = cache_warning.as_deref() {
194            eprintln!("warning: {warning}");
195        }
196        // Table output
197        println!(
198            "{:<12} {:<10} {:<14} {:<30} {}",
199            "ALIAS", "HARNESS", "MODE", "RESOLVED", "DESCRIPTION"
200        );
201        for r in resolved.values() {
202            if !args.all && r.harness_source == HarnessSource::Unavailable {
203                continue;
204            }
205            let harness = r.harness.as_deref().unwrap_or("—");
206            let mode = mode_for_alias(merged.get(&r.name).map(|a| &a.spec));
207            let desc = if r.harness_source == HarnessSource::Unavailable {
208                format!("(install: {})", r.harness_candidates.join(", "))
209            } else {
210                r.description.clone().unwrap_or_default()
211            };
212            println!(
213                "{:<12} {:<10} {:<14} {:<30} {}",
214                r.name, harness, mode, r.model_id, desc
215            );
216        }
217    }
218
219    Ok(0)
220}
221
222fn run_resolve(args: &ResolveAliasArgs, ctx: &MarsContext, json: bool) -> Result<i32, MarsError> {
223    let merged = load_merged_aliases(ctx)?;
224    let Some(alias) = merged.get(&args.name) else {
225        if json {
226            println!(
227                "{}",
228                serde_json::to_string_pretty(&serde_json::json!({
229                    "error": format!("unknown alias: {}", args.name),
230                }))
231                .unwrap()
232            );
233        } else {
234            eprintln!("error: unknown alias `{}`", args.name);
235        }
236        return Ok(1);
237    };
238
239    let mars = mars_dir(ctx);
240    let ttl = models::load_models_cache_ttl(ctx);
241    let mode = models::resolve_refresh_mode(args.no_refresh_models);
242    let (cache, outcome) = match models::ensure_fresh(&mars, ttl, mode) {
243        Ok(ok) => ok,
244        Err(err @ MarsError::ModelCacheUnavailable { .. }) if json => {
245            println!(
246                "{}",
247                serde_json::to_string_pretty(&serde_json::json!({
248                    "error": format!("{err}"),
249                }))
250                .unwrap()
251            );
252            return Ok(1);
253        }
254        Err(err) => return Err(err),
255    };
256    let cache_warning = cache_warning(&outcome);
257
258    if let Some(warning) = cache_warning.as_deref()
259        && !json
260    {
261        eprintln!("warning: {warning}");
262    }
263
264    // Determine source layer
265    let source = determine_source(&args.name, ctx)?;
266    let resolved_map = models::resolve_all(&merged, &cache);
267    let resolved_entry = resolved_map.get(&args.name);
268
269    if json {
270        if let Some(r) = resolved_entry {
271            let mut out = serde_json::json!({
272                "name": r.name,
273                "source": source,
274                "provider": r.provider,
275                "harness": r.harness,
276                "harness_source": r.harness_source,
277                "harness_candidates": r.harness_candidates,
278                "model_id": r.model_id,
279                "resolved_model": r.model_id,
280                "spec": format_spec(&alias.spec),
281                "description": r.description,
282            });
283            if let Some(error) = unavailable_harness_error(r) {
284                out["error"] = serde_json::json!(error);
285            }
286            if let Some(warning) = cache_warning.as_deref() {
287                out["cache_warning"] = serde_json::json!(warning);
288            }
289            println!("{}", serde_json::to_string_pretty(&out).unwrap());
290        } else {
291            let mut out = serde_json::json!({
292                "error": format!("alias `{}` did not resolve to a model ID", args.name),
293            });
294            if let Some(warning) = cache_warning.as_deref() {
295                out["cache_warning"] = serde_json::json!(warning);
296            }
297            println!("{}", serde_json::to_string_pretty(&out).unwrap());
298            return Ok(1);
299        }
300    } else {
301        let Some(r) = resolved_entry else {
302            eprintln!("error: alias `{}` did not resolve to a model ID", args.name);
303            return Ok(1);
304        };
305        let harness = r.harness.as_deref().unwrap_or("—");
306        println!("Alias:    {}", args.name);
307        println!("Source:   {}", source);
308        println!(
309            "Harness:  {} ({})",
310            harness,
311            harness_source_label(&r.harness_source)
312        );
313        println!("Provider: {}", r.provider);
314        match &alias.spec {
315            ModelSpec::Pinned { model, provider: _ } => {
316                println!("Mode:     pinned");
317                println!("Model:    {}", model);
318            }
319            ModelSpec::AutoResolve {
320                provider: _,
321                match_patterns,
322                exclude_patterns,
323            } => {
324                println!("Mode:     auto-resolve");
325                println!("Match:    {}", match_patterns.join(", "));
326                if !exclude_patterns.is_empty() {
327                    println!("Exclude:  {}", exclude_patterns.join(", "));
328                }
329                println!("Resolved: {}", r.model_id);
330            }
331        }
332        if let Some(error) = unavailable_harness_error(r) {
333            println!("Error:    {}", error);
334        }
335        if let Some(desc) = &r.description {
336            println!("Desc:     {}", desc);
337        }
338    }
339
340    Ok(0)
341}
342
343fn run_alias(args: &AddAliasArgs, ctx: &MarsContext, json: bool) -> Result<i32, MarsError> {
344    let config_path = ctx.project_root.join("mars.toml");
345
346    // Read existing config
347    let content = std::fs::read_to_string(&config_path).unwrap_or_default();
348
349    let harness = Some(args.harness.clone());
350
351    // Build the TOML entry
352    let mut entry = format!(
353        "\n[models.{}]\nharness = {:?}\nmodel = {:?}\n",
354        args.name,
355        harness.as_deref().unwrap_or("claude"),
356        args.model_id
357    );
358    if let Some(desc) = &args.description {
359        entry.push_str(&format!("description = {:?}\n", desc));
360    }
361
362    // Append to mars.toml
363    let new_content = if content.is_empty() {
364        entry
365    } else {
366        format!("{}{}", content.trim_end(), entry)
367    };
368    std::fs::write(&config_path, new_content)?;
369
370    if json {
371        println!(
372            "{}",
373            serde_json::to_string_pretty(&serde_json::json!({
374                "status": "ok",
375                "alias": args.name,
376                "model": args.model_id,
377                "harness": args.harness,
378            }))
379            .unwrap()
380        );
381    } else {
382        println!(
383            "Added alias `{}` → {} (harness: {})",
384            args.name, args.model_id, args.harness
385        );
386    }
387
388    Ok(0)
389}
390
391// ---------------------------------------------------------------------------
392// Helpers
393// ---------------------------------------------------------------------------
394
395/// Load model aliases by combining cached dependency aliases with consumer config.
396fn load_merged_aliases(
397    ctx: &MarsContext,
398) -> Result<indexmap::IndexMap<String, ModelAlias>, MarsError> {
399    // Start with builtins (lowest precedence)
400    let mut merged = models::builtin_aliases();
401
402    // Layer dep aliases from cached merge file (overrides builtins)
403    let mars_dir = ctx.project_root.join(".mars");
404    let merged_path = mars_dir.join("models-merged.json");
405    if let Ok(content) = std::fs::read_to_string(&merged_path)
406        && let Ok(cached) = serde_json::from_str::<IndexMap<String, ModelAlias>>(&content)
407    {
408        for (name, alias) in cached {
409            merged.insert(name, alias);
410        }
411    }
412
413    // Layer consumer config on top (highest precedence)
414    if let Ok(config) = crate::config::load(&ctx.project_root) {
415        for (name, alias) in &config.models {
416            merged.insert(name.clone(), alias.clone());
417        }
418    }
419
420    Ok(merged)
421}
422
423/// Determine which layer provides an alias (consumer or dependency).
424fn determine_source(name: &str, ctx: &MarsContext) -> Result<String, MarsError> {
425    let config = match crate::config::load(&ctx.project_root) {
426        Ok(c) => c,
427        Err(_) => return Ok("unknown".to_string()),
428    };
429
430    if config.models.contains_key(name) {
431        return Ok("consumer (mars.toml)".to_string());
432    }
433
434    Ok("dependency".to_string())
435}
436
437fn format_spec(spec: &ModelSpec) -> serde_json::Value {
438    match spec {
439        ModelSpec::Pinned { model, provider } => {
440            let mut out = serde_json::json!({ "mode": "pinned", "model": model });
441            if let Some(provider) = provider {
442                out["provider"] = serde_json::json!(provider);
443            }
444            out
445        }
446        ModelSpec::AutoResolve {
447            provider,
448            match_patterns,
449            exclude_patterns,
450        } => serde_json::json!({
451            "mode": "auto-resolve",
452            "provider": provider,
453            "match": match_patterns,
454            "exclude": exclude_patterns,
455        }),
456    }
457}
458
459fn mode_for_alias(spec: Option<&ModelSpec>) -> &'static str {
460    match spec {
461        Some(ModelSpec::Pinned { .. }) => "pinned",
462        Some(ModelSpec::AutoResolve { .. }) => "auto-resolve",
463        None => "unknown",
464    }
465}
466
467fn harness_source_label(source: &HarnessSource) -> &'static str {
468    match source {
469        HarnessSource::Explicit => "explicit",
470        HarnessSource::AutoDetected => "auto-detected",
471        HarnessSource::Unavailable => "unavailable",
472    }
473}
474
475fn unavailable_harness_error(resolved: &models::ResolvedAlias) -> Option<String> {
476    if resolved.harness_source != HarnessSource::Unavailable {
477        return None;
478    }
479    if let Some(h) = &resolved.harness {
480        Some(format!("Harness '{}' is not installed", h))
481    } else {
482        Some(format!(
483            "No installed harness for provider '{}'. Install one of: {}",
484            resolved.provider,
485            resolved.harness_candidates.join(", ")
486        ))
487    }
488}
489
490fn stale_warning(reason: &str) -> String {
491    format!("models cache refresh failed: {reason}; using stale cache")
492}
493
494fn cache_warning(outcome: &models::RefreshOutcome) -> Option<String> {
495    match outcome {
496        models::RefreshOutcome::StaleFallback { reason } => Some(stale_warning(reason)),
497        _ => None,
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504    use clap::Parser;
505    use tempfile::TempDir;
506
507    fn write_mars_toml(temp: &TempDir, contents: &str) {
508        std::fs::write(temp.path().join("mars.toml"), contents).unwrap();
509    }
510
511    fn normalized_exit_code(result: Result<i32, MarsError>) -> i32 {
512        match result {
513            Ok(code) => code,
514            Err(err) => err.exit_code(),
515        }
516    }
517
518    #[test]
519    fn list_args_parses_no_refresh_models() {
520        let args = ListArgs::try_parse_from(["mars", "--no-refresh-models"]).unwrap();
521        assert!(args.no_refresh_models);
522    }
523
524    #[test]
525    fn resolve_alias_args_parses_no_refresh_models() {
526        let args =
527            ResolveAliasArgs::try_parse_from(["mars", "opus", "--no-refresh-models"]).unwrap();
528        assert!(args.no_refresh_models);
529    }
530
531    #[test]
532    fn list_no_refresh_without_cache_is_non_zero() {
533        let temp = TempDir::new().unwrap();
534        write_mars_toml(&temp, "[settings]\n");
535        let ctx = MarsContext::new(temp.path().to_path_buf()).unwrap();
536        let args = ModelsArgs::try_parse_from(["mars", "list", "--no-refresh-models"]).unwrap();
537
538        let exit = normalized_exit_code(run(&args, &ctx, false));
539        assert_ne!(exit, 0);
540    }
541
542    #[test]
543    fn resolve_no_refresh_without_cache_is_non_zero() {
544        let temp = TempDir::new().unwrap();
545        write_mars_toml(
546            &temp,
547            r#"[settings]
548
549[models.opus]
550harness = "claude"
551model = "claude-opus-4-6"
552"#,
553        );
554        let ctx = MarsContext::new(temp.path().to_path_buf()).unwrap();
555        let args =
556            ModelsArgs::try_parse_from(["mars", "resolve", "opus", "--no-refresh-models"]).unwrap();
557
558        let exit = normalized_exit_code(run(&args, &ctx, false));
559        assert_ne!(exit, 0);
560    }
561}