Skip to main content

agent_trace/commands/
model.rs

1use crate::config::{
2    CredentialsStore, GlobalConfig, MergedConfig, SynthesisConfig, SynthesisMode, SynthesisProvider,
3};
4use crate::llm::Llm;
5use crate::observability::CliOutput;
6use anyhow::{bail, Context, Result};
7use clap::Subcommand;
8use std::io;
9use std::time::Instant;
10
11#[derive(Subcommand, Debug)]
12pub enum ModelCmd {
13    /// Show active provider, model, and health status.
14    Status,
15    /// Interactive setup wizard for synthesis provider.
16    Setup,
17    /// Switch active provider.
18    Use { provider: String },
19    /// Set provider and model non-interactively.
20    Set {
21        #[arg(long)]
22        provider: Option<String>,
23        #[arg(long)]
24        model: Option<String>,
25        #[arg(long)]
26        base_url: Option<String>,
27        #[arg(long)]
28        mode: Option<String>,
29    },
30    /// Store an API key for a provider (prompted, never echoed).
31    Credentials {
32        #[command(subcommand)]
33        sub: CredentialsCmd,
34    },
35    /// Run a sample synthesis prompt and report latency.
36    Test,
37    /// List known models per provider.
38    List,
39    /// Pull an Ollama model by name or short alias (e.g. 1.5b, qwen2.5:1.5b).
40    Pull {
41        /// Model tag or short alias (default: 1.5b → qwen2.5:1.5b).
42        #[arg(default_value = "1.5b")]
43        size: String,
44    },
45    /// Check whether Ollama is reachable and the configured model is pulled.
46    ServeCheck,
47    /// Ensure Ollama daemon is running and configured model is pulled.
48    /// Starts ollama serve if unreachable, pulls model if missing.
49    Ensure,
50}
51
52#[derive(Subcommand, Debug)]
53pub enum CredentialsCmd {
54    Set { provider: String },
55    Clear { provider: String },
56}
57
58pub fn run(
59    cmd: ModelCmd,
60    store_root: Option<&std::path::Path>,
61    output: &dyn CliOutput,
62) -> Result<()> {
63    match cmd {
64        ModelCmd::Status => cmd_status(store_root, output),
65        ModelCmd::Setup => cmd_setup(store_root, output),
66        ModelCmd::Use { provider } => cmd_use(&provider, output),
67        ModelCmd::Set {
68            provider,
69            model,
70            base_url,
71            mode,
72        } => cmd_set(provider, model, base_url, mode, output),
73        ModelCmd::Credentials { sub } => cmd_credentials(sub, output),
74        ModelCmd::Test => cmd_test(store_root, output),
75        ModelCmd::List => cmd_list(output),
76        ModelCmd::Pull { size } => cmd_pull(&size, output),
77        ModelCmd::ServeCheck => cmd_serve_check(output),
78        ModelCmd::Ensure => cmd_ensure(store_root, output),
79    }
80}
81
82fn parse_provider(s: &str) -> Result<SynthesisProvider> {
83    match s.to_lowercase().as_str() {
84        "openai" => Ok(SynthesisProvider::Openai),
85        "anthropic" => Ok(SynthesisProvider::Anthropic),
86        "openrouter" => Ok(SynthesisProvider::Openrouter),
87        "ollama" => Ok(SynthesisProvider::Ollama),
88        "custom" => Ok(SynthesisProvider::Custom),
89        "embedded" => {
90            tracing::warn!("'embedded' provider is deprecated and will use Ollama instead");
91            Ok(SynthesisProvider::Ollama)
92        }
93        other => {
94            bail!("Unknown provider '{other}'. Use: openai, anthropic, openrouter, ollama, custom")
95        }
96    }
97}
98
99fn parse_mode(s: &str) -> Result<SynthesisMode> {
100    match s.to_lowercase().as_str() {
101        "auto" => Ok(SynthesisMode::Auto),
102        "remote" => Ok(SynthesisMode::Remote),
103        "ollama" => Ok(SynthesisMode::Ollama),
104        "embedded" => {
105            tracing::warn!("'embedded' mode is deprecated; using 'auto' instead");
106            Ok(SynthesisMode::Auto)
107        }
108        other => bail!("Unknown mode '{other}'. Use: auto, remote, ollama"),
109    }
110}
111
112fn load_merged(store_root: Option<&std::path::Path>) -> Result<MergedConfig> {
113    if let Some(root) = store_root {
114        if root.join(".agent-trace").join("config.toml").exists() {
115            return MergedConfig::load(root);
116        }
117    }
118    let global = GlobalConfig::load()?;
119    Ok(MergedConfig::merge(
120        global,
121        crate::config::StoreConfig {
122            store: crate::config::StoreInfo::new("global".into()),
123            llm: None,
124            synthesis: None,
125            polling: crate::config::PollingConfig::default(),
126        },
127    ))
128}
129
130fn synthesis_status_line(merged: &MergedConfig) -> String {
131    let info = Llm::backend_info_from_config(merged);
132    if info.degraded {
133        "Synthesis: degraded (no backend) — run `agent-trace model ensure`".into()
134    } else {
135        format!("Synthesis: {} (ok)", info.label)
136    }
137}
138
139fn cmd_status(store_root: Option<&std::path::Path>, output: &dyn CliOutput) -> Result<()> {
140    let merged = load_merged(store_root)?;
141    let syn = &merged.synthesis;
142    let info = Llm::backend_info_from_config(&merged);
143    let creds = CredentialsStore::load().unwrap_or_default();
144
145    output.line(&format!("Mode:     {:?}", syn.mode))?;
146    output.line(&format!("Provider: {}", syn.provider.slug()))?;
147    output.line(&format!("Model:    {}", syn.effective_model()))?;
148    output.line(&format!("Base URL: {}", syn.effective_base_url()))?;
149    if let Some(key) = creds.redacted_key(syn.provider) {
150        output.line(&format!("API key:  {key}"))?;
151    }
152    if info.degraded {
153        output.line("Health:   degraded (no reachable backend)")?;
154    } else {
155        output.line(&format!("Health:   ok ({})", info.label))?;
156    }
157    Ok(())
158}
159
160fn cmd_setup(store_root: Option<&std::path::Path>, output: &dyn CliOutput) -> Result<()> {
161    output.line("Agent Trace — synthesis setup")?;
162    output.line("Providers: openai, anthropic, openrouter, ollama, custom")?;
163    output.line("Enter provider [ollama]: ")?;
164    let mut line = String::new();
165    io::stdin().read_line(&mut line)?;
166    let provider = if line.trim().is_empty() {
167        SynthesisProvider::Ollama
168    } else {
169        parse_provider(line.trim())?
170    };
171
172    let mut config = GlobalConfig::load()?;
173    config.synthesis.provider = provider;
174    config.synthesis.mode = SynthesisMode::Auto;
175
176    if SynthesisConfig::provider_needs_credentials(provider) {
177        output.line(&format!("Enter API key for {}: ", provider.slug()))?;
178        let key = read_secret()?;
179        let mut creds = CredentialsStore::load().unwrap_or_default();
180        creds.set_key(provider, key);
181        creds.save()?;
182    }
183
184    output.line(&format!("Model [{}]: ", provider.default_model()))?;
185    line.clear();
186    io::stdin().read_line(&mut line)?;
187    config.synthesis.model = if line.trim().is_empty() {
188        provider.default_model().into()
189    } else {
190        Llm::normalize_model_alias(line.trim())
191    };
192
193    if provider == SynthesisProvider::Custom || provider == SynthesisProvider::Ollama {
194        output.line(&format!("Base URL [{}]: ", provider.default_base_url()))?;
195        line.clear();
196        io::stdin().read_line(&mut line)?;
197        if !line.trim().is_empty() {
198            config.synthesis.base_url = Some(line.trim().into());
199        }
200    }
201
202    config.save()?;
203
204    // For Ollama providers, run ensure_ready to start daemon and pull model
205    if provider == SynthesisProvider::Ollama || provider == SynthesisProvider::Custom {
206        output.line("Ensuring Ollama daemon and model are ready…")?;
207        let merged = load_merged(store_root)?;
208        match Llm::ensure_ready(&merged) {
209            Ok(report) => {
210                for line in report.display().lines() {
211                    output.line(&format!("  {line}"))?;
212                }
213            }
214            Err(e) => output.warn(&format!("  Warning: {e} — run `agent-trace model ensure`"))?,
215        }
216    }
217
218    let merged = load_merged(store_root)?;
219    output.line(&synthesis_status_line(&merged))?;
220    output.line("Run `agent-trace model test` to verify.")?;
221    Ok(())
222}
223
224fn cmd_ensure(store_root: Option<&std::path::Path>, output: &dyn CliOutput) -> Result<()> {
225    let merged = load_merged(store_root)?;
226    output.line("Ensuring Ollama daemon and model are ready…")?;
227    let report = Llm::ensure_ready(&merged)?;
228    for line in report.display().lines() {
229        output.line(line)?;
230    }
231    Ok(())
232}
233
234fn read_secret() -> Result<String> {
235    let mut key = String::new();
236    io::stdin().read_line(&mut key)?;
237    Ok(key.trim().to_string())
238}
239
240fn cmd_use(provider: &str, output: &dyn CliOutput) -> Result<()> {
241    let p = parse_provider(provider)?;
242    let mut config = GlobalConfig::load()?;
243    config.synthesis.provider = p;
244    config.synthesis.model = p.default_model().into();
245    config.save()?;
246    output.line(&format!("Active provider set to {}", p.slug()))?;
247    Ok(())
248}
249
250fn cmd_set(
251    provider: Option<String>,
252    model: Option<String>,
253    base_url: Option<String>,
254    mode: Option<String>,
255    output: &dyn CliOutput,
256) -> Result<()> {
257    let mut config = GlobalConfig::load()?;
258    if let Some(p) = provider {
259        config.synthesis.provider = parse_provider(&p)?;
260    }
261    if let Some(m) = model {
262        config.synthesis.model = m;
263    }
264    if let Some(url) = base_url {
265        config.synthesis.base_url = Some(url);
266    }
267    if let Some(m) = mode {
268        config.synthesis.mode = parse_mode(&m)?;
269    }
270    config.save()?;
271    output.line("Synthesis config updated.")?;
272    Ok(())
273}
274
275fn cmd_credentials(sub: CredentialsCmd, output: &dyn CliOutput) -> Result<()> {
276    match sub {
277        CredentialsCmd::Set { provider } => {
278            let p = parse_provider(&provider)?;
279            if !SynthesisConfig::provider_needs_credentials(p) && p != SynthesisProvider::Custom {
280                bail!("Provider {} does not use stored credentials", p.slug());
281            }
282            output.line(&format!("API key for {}: ", p.slug()))?;
283            let key = read_secret()?;
284            let mut creds = CredentialsStore::load().unwrap_or_default();
285            creds.set_key(p, key);
286            creds.save()?;
287            output.line("Credentials saved.")?;
288        }
289        CredentialsCmd::Clear { provider } => {
290            let p = parse_provider(&provider)?;
291            let mut creds = CredentialsStore::load().unwrap_or_default();
292            creds.clear_key(p);
293            creds.save()?;
294            output.line(&format!("Cleared credentials for {}.", p.slug()))?;
295        }
296    }
297    Ok(())
298}
299
300fn cmd_test(store_root: Option<&std::path::Path>, output: &dyn CliOutput) -> Result<()> {
301    let merged = load_merged(store_root)?;
302    let api = Llm::from_merged_config(&merged)
303        .map_err(|e| anyhow::anyhow!("Cannot initialize LLM backend: {e}"))?;
304    let start = Instant::now();
305    let result = api.summarize_change(
306        std::path::Path::new("plan.md"),
307        &crate::types::DocType::Plan,
308        "+Added phase 2 checklist\n-Removed stale blocker\n",
309    );
310    let elapsed = start.elapsed();
311    match result {
312        Ok(text) => {
313            output.line(&format!("Backend: {}", api.backend_label))?;
314            output.line(&format!("Latency: {:.0}ms", elapsed.as_millis()))?;
315            output.line(&format!("Sample: {text}"))?;
316        }
317        Err(e) => bail!("Synthesis test failed: {e}"),
318    }
319    Ok(())
320}
321
322fn cmd_list(output: &dyn CliOutput) -> Result<()> {
323    output.line("Ollama (local):")?;
324    for m in [
325        "qwen2.5:0.5b",
326        "qwen2.5:1.5b",
327        "qwen2.5:3b",
328        "llama3.2:3b",
329        "phi4:latest",
330    ] {
331        output.line(&format!("  {m}"))?;
332    }
333    output.line("Aliases (short form for ollama pull):")?;
334    for (alias, full) in [
335        ("0.5b", "qwen2.5:0.5b"),
336        ("1.5b", "qwen2.5:1.5b"),
337        ("3b", "qwen2.5:3b"),
338    ] {
339        output.line(&format!("  {alias} → {full}"))?;
340    }
341    output.line("Remote examples:")?;
342    output.line("  openai/gpt-4o-mini")?;
343    output.line("  anthropic/claude-3-5-haiku-latest")?;
344    Ok(())
345}
346
347fn cmd_pull(size: &str, output: &dyn CliOutput) -> Result<()> {
348    let normalized = Llm::normalize_model_alias(size);
349    output.line(&format!("Pulling Ollama model {normalized}…"))?;
350    let config = GlobalConfig::load()?;
351    Llm::pull_model(&config.synthesis, &normalized)
352        .with_context(|| format!("ollama pull {normalized}"))?;
353    let mut config = GlobalConfig::load()?;
354    config.synthesis.provider = SynthesisProvider::Ollama;
355    config.synthesis.model = normalized;
356    config.save()?;
357    output.line("Ollama model pulled and config updated.")?;
358    Ok(())
359}
360
361fn cmd_serve_check(output: &dyn CliOutput) -> Result<()> {
362    // Diagnostic only — reports reachability and model presence without starting
363    // the daemon or pulling. Use `model ensure` for side-effectful readiness.
364    let config = GlobalConfig::load()?;
365    let syn = &config.synthesis;
366    let reachable = Llm::is_reachable(syn);
367    output.line(&format!(
368        "Ollama at {}: {}",
369        syn.effective_base_url(),
370        if reachable {
371            "reachable"
372        } else {
373            "unreachable"
374        }
375    ))?;
376    if reachable {
377        let pulled = Llm::is_model_pulled(syn).unwrap_or(false);
378        output.line(&format!(
379            "Model '{}': {}",
380            syn.effective_model(),
381            if pulled { "pulled" } else { "not pulled" }
382        ))?;
383        if !pulled {
384            output.line("  → Run `agent-trace model ensure` to pull the model automatically")?;
385        }
386    } else {
387        output.line("  → Run `agent-trace model ensure` to start the daemon and pull the model")?;
388    }
389    Ok(())
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn parse_provider_variants() {
398        assert_eq!(parse_provider("openai").unwrap(), SynthesisProvider::Openai);
399        assert_eq!(parse_provider("OLLAMA").unwrap(), SynthesisProvider::Ollama);
400        assert!(parse_provider("nope").is_err());
401    }
402
403    #[test]
404    fn parse_provider_embedded_deprecated_migrates_to_ollama() {
405        assert_eq!(
406            parse_provider("embedded").unwrap(),
407            SynthesisProvider::Ollama
408        );
409    }
410
411    #[test]
412    fn parse_mode_embedded_deprecated_migrates_to_auto() {
413        assert_eq!(parse_mode("embedded").unwrap(), SynthesisMode::Auto);
414    }
415
416    #[test]
417    fn pull_normalizes_short_alias() {
418        // Verify normalization logic works
419        assert_eq!(Llm::normalize_model_alias("1.5b"), "qwen2.5:1.5b");
420        assert_eq!(Llm::normalize_model_alias("0.5b"), "qwen2.5:0.5b");
421    }
422}