Skip to main content

roboticus_cli/cli/admin/
models.rs

1use super::*;
2
3// ── Models ───────────────────────────────────────────────────
4
5pub async fn cmd_models_list(base_url: &str, json: bool) -> Result<(), Box<dyn std::error::Error>> {
6    let (DIM, BOLD, ACCENT, GREEN, YELLOW, RED, CYAN, RESET, MONO) = colors();
7    let (OK, ACTION, WARN, DETAIL, ERR) = icons();
8    let resp = super::http_client()?
9        .get(format!("{base_url}/api/config"))
10        .send()
11        .await?;
12    let config: serde_json::Value = resp.json().await?;
13    if json {
14        println!("{}", serde_json::to_string_pretty(&config)?);
15        return Ok(());
16    }
17
18    println!("\n  {BOLD}Configured Models{RESET}\n");
19
20    let primary = config
21        .pointer("/models/primary")
22        .and_then(|v| v.as_str())
23        .unwrap_or("not set");
24    println!("  {:<12} {}", format!("{GREEN}primary{RESET}"), primary);
25
26    if let Some(fallbacks) = config
27        .pointer("/models/fallbacks")
28        .and_then(|v| v.as_array())
29    {
30        for (i, fb) in fallbacks.iter().enumerate() {
31            let name = fb.as_str().unwrap_or("?");
32            println!(
33                "  {:<12} {}",
34                format!("{YELLOW}fallback {}{RESET}", i + 1),
35                name
36            );
37        }
38    }
39
40    let mode = config
41        .pointer("/models/routing/mode")
42        .and_then(|v| v.as_str())
43        .unwrap_or("rule");
44    let threshold = config
45        .pointer("/models/routing/confidence_threshold")
46        .and_then(|v| v.as_f64())
47        .unwrap_or(0.9);
48    let local_first = config
49        .pointer("/models/routing/local_first")
50        .and_then(|v| v.as_bool())
51        .unwrap_or(true);
52
53    println!();
54    println!(
55        "  {DIM}Routing: mode={mode}, threshold={threshold}, local_first={local_first}{RESET}"
56    );
57    println!();
58    Ok(())
59}
60
61pub async fn cmd_models_scan(
62    base_url: &str,
63    provider: Option<&str>,
64) -> Result<(), Box<dyn std::error::Error>> {
65    let (DIM, BOLD, ACCENT, GREEN, YELLOW, RED, CYAN, RESET, MONO) = colors();
66    let (OK, ACTION, WARN, DETAIL, ERR) = icons();
67    println!("\n  {BOLD}Scanning for available models...{RESET}\n");
68
69    let resp = super::http_client()?
70        .get(format!("{base_url}/api/config"))
71        .send()
72        .await?;
73    let config: serde_json::Value = resp.json().await?;
74
75    let providers = config
76        .get("providers")
77        .and_then(|v| v.as_object())
78        .cloned()
79        .unwrap_or_default();
80
81    if providers.is_empty() {
82        println!("  No providers configured.");
83        println!();
84        return Ok(());
85    }
86
87    let client = reqwest::Client::builder()
88        .timeout(std::time::Duration::from_secs(10))
89        .build()?;
90
91    for (name, prov_config) in &providers {
92        if let Some(filter) = provider
93            && name != filter
94        {
95            continue;
96        }
97
98        let url = prov_config
99            .get("url")
100            .and_then(|v| v.as_str())
101            .unwrap_or("");
102
103        if url.is_empty() {
104            println!("  {YELLOW}{name}{RESET}: no URL configured");
105            continue;
106        }
107
108        let name_l = name.to_lowercase();
109        let url_l = url.to_lowercase();
110        let ollama_like = name_l.contains("ollama") || url_l.contains("11434");
111        let models_url = if ollama_like {
112            format!("{url}/api/tags")
113        } else {
114            format!("{url}/v1/models")
115        };
116
117        let scan_result =
118            super::spin_while(&format!("Probing {name}"), client.get(&models_url).send()).await;
119
120        print!("  {CYAN}{name}{RESET} ({url}): ");
121        match scan_result {
122            Ok(resp) if resp.status().is_success() => {
123                let body: serde_json::Value = resp.json().await.unwrap_or_default();
124                let models: Vec<String> =
125                    if let Some(arr) = body.get("models").and_then(|v| v.as_array()) {
126                        arr.iter()
127                            .filter_map(|m| {
128                                m.get("name")
129                                    .or_else(|| m.get("model"))
130                                    .and_then(|v| v.as_str())
131                            })
132                            .map(String::from)
133                            .collect()
134                    } else if let Some(arr) = body.get("data").and_then(|v| v.as_array()) {
135                        arr.iter()
136                            .filter_map(|m| m.get("id").and_then(|v| v.as_str()))
137                            .map(String::from)
138                            .collect()
139                    } else {
140                        vec![]
141                    };
142
143                if models.is_empty() {
144                    println!("no models found");
145                } else {
146                    println!("{} model(s)", models.len());
147                    for model in &models {
148                        println!("    - {model}");
149                    }
150                }
151            }
152            Ok(resp) => {
153                println!("{RED}error: {}{RESET}", resp.status());
154            }
155            Err(e) => {
156                println!("{RED}unreachable: {e}{RESET}");
157            }
158        }
159    }
160
161    println!();
162    Ok(())
163}
164
165/// Exercise a model across the task class matrix (5 complexity x 4 intent)
166/// to populate per-(model, intent_class) quality observations.
167///
168/// `iterations` controls how many full matrix passes to run. Each pass is
169/// 20 prompts. Use iterations=20 for 100 observations per intent class.
170pub async fn cmd_models_exercise(
171    base_url: &str,
172    model: &str,
173    iterations: usize,
174) -> Result<(), Box<dyn std::error::Error>> {
175    let (_dim, bold, _accent, green, yellow, red, cyan, reset, _mono) = colors();
176    let (ok, _action, warn, _detail, err) = icons();
177    let total_prompts = roboticus_llm::exercise::EXERCISE_MATRIX.len() * iterations;
178    println!(
179        "\n  {bold}Exercising model: {cyan}{model}{reset} ({iterations} iteration(s), {total_prompts} prompts)\n"
180    );
181
182    let (pass, fail) = exercise_single_model_iterations(base_url, model, iterations).await;
183
184    println!();
185    let fail_color = if fail > 0 { red } else { _dim };
186    println!(
187        "  {bold}Results:{reset} {green}{pass} passed{reset}, {fail_color}{fail} failed{reset}",
188    );
189    let obs_per_cell = iterations * 5; // 5 prompts per intent class
190    println!("  Observations per intent class: {obs_per_cell}");
191    if fail == 0 {
192        println!("  {ok} Quality observations recorded for all {pass} prompts.");
193    } else {
194        println!("  {warn} Some prompts failed — partial observations recorded.");
195    }
196    println!();
197    Ok(())
198}
199
200/// Suggest a fallback chain based on available providers and discovered models.
201pub async fn cmd_models_suggest(base_url: &str) -> Result<(), Box<dyn std::error::Error>> {
202    let (_dim, bold, _accent, green, _yellow, _red, cyan, reset, _mono) = colors();
203    let (_ok, _action, warn, _detail, _err) = icons();
204    println!("\n  {bold}Scanning for available models...{reset}\n");
205
206    let resp = super::http_client()?
207        .get(format!("{base_url}/api/config"))
208        .send()
209        .await?;
210    let config: serde_json::Value = resp.json().await?;
211
212    let providers = config
213        .get("providers")
214        .and_then(|v| v.as_object())
215        .cloned()
216        .unwrap_or_default();
217
218    if providers.is_empty() {
219        println!("  {warn} No providers configured. Nothing to suggest.");
220        println!();
221        return Ok(());
222    }
223
224    let client = reqwest::Client::builder()
225        .timeout(std::time::Duration::from_secs(10))
226        .build()?;
227
228    let mut available: Vec<(String, bool, f64)> = Vec::new();
229
230    for (name, prov_config) in &providers {
231        let url = prov_config
232            .get("url")
233            .and_then(|v| v.as_str())
234            .unwrap_or("");
235        if url.is_empty() {
236            continue;
237        }
238        let is_local = prov_config
239            .get("is_local")
240            .and_then(|v| v.as_bool())
241            .unwrap_or_else(|| {
242                let nl = name.to_lowercase();
243                nl.contains("ollama") || nl.contains("local") || nl.contains("lmstudio")
244            });
245        let cost = prov_config
246            .get("cost_per_input_token")
247            .and_then(|v| v.as_f64())
248            .unwrap_or(0.0)
249            + prov_config
250                .get("cost_per_output_token")
251                .and_then(|v| v.as_f64())
252                .unwrap_or(0.0);
253
254        let name_l = name.to_lowercase();
255        let url_l = url.to_lowercase();
256        let ollama_like = name_l.contains("ollama") || url_l.contains("11434");
257        let models_url = if ollama_like {
258            format!("{url}/api/tags")
259        } else {
260            format!("{url}/v1/models")
261        };
262
263        if let Ok(resp) = client.get(&models_url).send().await
264            && resp.status().is_success()
265        {
266            let body: serde_json::Value = resp.json().await.unwrap_or_default();
267            let models: Vec<String> =
268                if let Some(arr) = body.get("models").and_then(|v| v.as_array()) {
269                    arr.iter()
270                        .filter_map(|m| {
271                            m.get("name")
272                                .or_else(|| m.get("model"))
273                                .and_then(|v| v.as_str())
274                        })
275                        .map(|m| format!("{name}/{m}"))
276                        .collect()
277                } else if let Some(arr) = body.get("data").and_then(|v| v.as_array()) {
278                    arr.iter()
279                        .filter_map(|m| m.get("id").and_then(|v| v.as_str()))
280                        .map(|m| format!("{name}/{m}"))
281                        .collect()
282                } else {
283                    vec![]
284                };
285
286            for model in models {
287                available.push((model, is_local, cost));
288            }
289        }
290    }
291
292    if available.is_empty() {
293        println!("  {warn} No models discovered from any provider.");
294        println!();
295        return Ok(());
296    }
297
298    // Rank: local models first, then cloud by cost ascending.
299    available.sort_by(|a, b| {
300        b.1.cmp(&a.1)
301            .then(a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal))
302    });
303
304    println!("  {bold}Suggested fallback chain:{reset}\n");
305    for (i, (model, is_local, _cost)) in available.iter().take(6).enumerate() {
306        let role = if i == 0 {
307            "primary  ".to_string()
308        } else {
309            format!("fallback{i}")
310        };
311        let locality = if *is_local {
312            format!("{green}local{reset}")
313        } else {
314            format!("{cyan}cloud{reset}")
315        };
316        println!("  {role:<10} {model}  ({locality})");
317    }
318
319    println!("\n  {_dim}TOML:{reset}\n");
320    if let Some((primary, _, _)) = available.first() {
321        println!("  [models]");
322        println!("  primary = \"{primary}\"");
323        let fallbacks: Vec<&str> = available
324            .iter()
325            .skip(1)
326            .take(5)
327            .map(|(m, _, _)| m.as_str())
328            .collect();
329        if !fallbacks.is_empty() {
330            println!("  fallbacks = {fallbacks:?}");
331        }
332    }
333
334    println!();
335    Ok(())
336}
337
338/// Reset quality observations for a model (or all) to allow re-benchmarking.
339pub async fn cmd_models_reset(
340    base_url: &str,
341    model: Option<&str>,
342) -> Result<(), Box<dyn std::error::Error>> {
343    let (_dim, bold, _accent, green, _yellow, _red, _cyan, reset, _mono) = colors();
344    let (ok, _action, _warn, _detail, _err) = icons();
345    let client = super::http_client()?;
346    let mut req = client.post(format!("{base_url}/api/models/reset"));
347    if let Some(m) = model {
348        req = req.query(&[("model", m)]);
349    }
350    let resp = req.send().await?;
351    let data: serde_json::Value = resp.json().await?;
352    let msg = data["message"].as_str().unwrap_or("done");
353    println!("\n  {bold}{ok}{reset} {green}{msg}{reset}\n");
354    if model.is_some() {
355        println!(
356            "  Run {bold}roboticus models exercise {}{reset} to re-benchmark.",
357            model.unwrap_or("?")
358        );
359    } else {
360        println!("  Run {bold}roboticus models exercise <model>{reset} per model to re-benchmark.");
361    }
362    println!();
363    Ok(())
364}
365
366/// Full baseline: flush all scores, scan providers, exercise every model.
367pub async fn cmd_models_baseline(base_url: &str) -> Result<(), Box<dyn std::error::Error>> {
368    let (_dim, bold, _accent, green, yellow, red, cyan, reset, _mono) = colors();
369    let (ok, _action, warn, _detail, err) = icons();
370
371    // Step 1: Discover available models
372    println!("\n  {bold}Step 1: Discovering available models...{reset}\n");
373    let resp = super::http_client()?
374        .get(format!("{base_url}/api/config"))
375        .send()
376        .await?;
377    let config: serde_json::Value = resp.json().await?;
378    let mut configured: Vec<String> = Vec::new();
379    if let Some(primary) = config.pointer("/models/primary").and_then(|v| v.as_str()) {
380        configured.push(primary.to_string());
381    }
382    if let Some(fbs) = config
383        .pointer("/models/fallbacks")
384        .and_then(|v| v.as_array())
385    {
386        for fb in fbs {
387            if let Some(name) = fb.as_str()
388                && !name.is_empty()
389                && !configured.contains(&name.to_string())
390            {
391                configured.push(name.to_string());
392            }
393        }
394    }
395
396    if configured.is_empty() {
397        println!("  {warn} No models configured. Nothing to baseline.");
398        return Ok(());
399    }
400
401    println!(
402        "  Found {bold}{}{reset} configured model(s):\n",
403        configured.len()
404    );
405    for (i, model) in configured.iter().enumerate() {
406        let role = if i == 0 { "primary" } else { "fallback" };
407        println!("    {cyan}{role:<10}{reset} {model}");
408    }
409
410    // Step 2: Confirm with user
411    println!();
412    print!(
413        "  This will flush all quality scores and exercise each model \
414         across 20 prompts.\n  Proceed? [Y/n] "
415    );
416    use std::io::Write;
417    std::io::stdout().flush().ok();
418    let mut input = String::new();
419    std::io::stdin().read_line(&mut input).ok();
420    let answer = input.trim().to_lowercase();
421    if !answer.is_empty() && !matches!(answer.as_str(), "y" | "yes") {
422        println!("  Cancelled.");
423        return Ok(());
424    }
425
426    // Step 3: Flush all scores
427    println!("\n  {bold}Step 2: Flushing all quality scores...{reset}");
428    let resp = super::http_client()?
429        .post(format!("{base_url}/api/models/reset"))
430        .send()
431        .await?;
432    let data: serde_json::Value = resp.json().await?;
433    let cleared = data["cleared"].as_u64().unwrap_or(0);
434    println!("  {ok} Cleared {cleared} observation entries.\n");
435
436    // Step 4: Exercise each model
437    println!("  {bold}Step 3: Exercising models...{reset}\n");
438    let mut results: Vec<(String, usize, usize)> = Vec::new();
439    for model in &configured {
440        println!("  {cyan}--- {model} ---{reset}");
441        let (pass, fail) = exercise_single_model_iterations(base_url, model, 20).await;
442        results.push((model.clone(), pass, fail));
443        println!();
444    }
445
446    // Step 5: Summary
447    println!("  {bold}Baseline Results:{reset}\n");
448    for (model, pass, fail) in &results {
449        let status = if *fail == 0 {
450            format!("{green}{ok}{reset}")
451        } else {
452            format!("{yellow}{warn}{reset}")
453        };
454        println!(
455            "    {status} {model}: {green}{pass} passed{reset}, {}{fail} failed{reset}",
456            if *fail > 0 { red } else { _dim }
457        );
458    }
459    println!();
460    Ok(())
461}
462
463async fn exercise_single_model_iterations(
464    base_url: &str,
465    model: &str,
466    iterations: usize,
467) -> (usize, usize) {
468    let (_dim, bold, _accent, green, _yellow, red, _cyan, reset, _mono) = colors();
469    let (ok, _action, _warn, _detail, err) = icons();
470    let matrix = roboticus_llm::exercise::EXERCISE_MATRIX;
471    // Use a longer timeout than the default 10s — local models can take 60-120s
472    let client = match reqwest::Client::builder()
473        .timeout(std::time::Duration::from_secs(180))
474        .build()
475    {
476        Ok(c) => c,
477        Err(_) => return (0, matrix.len() * iterations),
478    };
479    let mut pass = 0usize;
480    let mut fail = 0usize;
481    let total = matrix.len() * iterations;
482
483    // Create a dedicated session for exercising (avoids scope_mode errors)
484    let session_id: String = match client
485        .post(format!("{base_url}/api/sessions"))
486        .json(&serde_json::json!({}))
487        .send()
488        .await
489    {
490        Ok(resp) => resp
491            .json::<serde_json::Value>()
492            .await
493            .ok()
494            .and_then(|v| {
495                v.get("session_id")
496                    .or_else(|| v.get("id"))
497                    .and_then(|s| s.as_str())
498                    .map(String::from)
499            })
500            .unwrap_or_default(),
501        Err(_) => String::new(),
502    };
503
504    // Per-intent latency tracking: (intent_class, Vec<latency_ms>)
505    let mut latencies: std::collections::HashMap<String, Vec<u64>> =
506        std::collections::HashMap::new();
507
508    for iter in 0..iterations {
509        for (i, prompt) in matrix.iter().enumerate() {
510            let n = iter * matrix.len() + i + 1;
511            let label = format!(
512                "[{n}/{total}] {}:{}",
513                prompt.complexity, prompt.intent_class
514            );
515            eprint!("    {_dim}{label}{reset} ... ");
516
517            let mut body = serde_json::json!({
518                "content": prompt.prompt,
519                "model_override": model,
520            });
521            if !session_id.is_empty() {
522                body["session_id"] = serde_json::Value::String(session_id.clone());
523            }
524
525            let started = std::time::Instant::now();
526            let result = tokio::time::timeout(
527                std::time::Duration::from_secs(600),
528                client
529                    .post(format!("{base_url}/api/agent/message"))
530                    .json(&body)
531                    .send(),
532            )
533            .await;
534            let elapsed_ms = started.elapsed().as_millis() as u64;
535
536            match result {
537                Ok(Ok(resp)) if resp.status().is_success() => {
538                    pass += 1;
539                    latencies
540                        .entry(prompt.intent_class.to_string())
541                        .or_default()
542                        .push(elapsed_ms);
543                    let secs = elapsed_ms as f64 / 1000.0;
544                    eprintln!("{green}{ok}{reset} {_dim}{secs:.1}s{reset}");
545                }
546                Ok(Ok(resp)) => {
547                    fail += 1;
548                    let status = resp.status();
549                    eprintln!("{red}{err} {status}{reset}");
550                }
551                Ok(Err(e)) => {
552                    fail += 1;
553                    eprintln!("{red}{err} {e}{reset}");
554                }
555                Err(_) => {
556                    fail += 1;
557                    eprintln!("{red}{err} timeout (>600s){reset}");
558                }
559            }
560        }
561    }
562
563    // Print latency scorecard
564    if !latencies.is_empty() {
565        eprintln!();
566        eprintln!("    {_dim}┌──────────────────┬────────┬────────┬────────┐{reset}");
567        eprintln!("    {_dim}│ Intent Class     │  Avg   │  P50   │  P95   │{reset}");
568        eprintln!("    {_dim}├──────────────────┼────────┼────────┼────────┤{reset}");
569        let mut all_latencies: Vec<u64> = Vec::new();
570        let mut intents: Vec<_> = latencies.iter().collect();
571        intents.sort_by_key(|(k, _)| (*k).clone());
572        for (intent, times) in &intents {
573            all_latencies.extend(times.iter().copied());
574            let mut sorted = (*times).clone();
575            sorted.sort();
576            let avg = sorted.iter().sum::<u64>() as f64 / sorted.len() as f64 / 1000.0;
577            let p50 = sorted[sorted.len() / 2] as f64 / 1000.0;
578            let p95_idx = (sorted.len() as f64 * 0.95) as usize;
579            let p95 = sorted[p95_idx.min(sorted.len() - 1)] as f64 / 1000.0;
580            eprintln!(
581                "    {_dim}│{reset} {:<16} {_dim}│{reset} {avg:5.1}s {_dim}│{reset} {p50:5.1}s {_dim}│{reset} {p95:5.1}s {_dim}│{reset}",
582                intent
583            );
584        }
585        all_latencies.sort();
586        if !all_latencies.is_empty() {
587            let avg_all =
588                all_latencies.iter().sum::<u64>() as f64 / all_latencies.len() as f64 / 1000.0;
589            let p50_all = all_latencies[all_latencies.len() / 2] as f64 / 1000.0;
590            let p95_idx = (all_latencies.len() as f64 * 0.95) as usize;
591            let p95_all = all_latencies[p95_idx.min(all_latencies.len() - 1)] as f64 / 1000.0;
592            eprintln!("    {_dim}├──────────────────┼────────┼────────┼────────┤{reset}");
593            eprintln!(
594                "    {_dim}│{reset} {bold}ALL{reset}              {_dim}│{reset} {avg_all:5.1}s {_dim}│{reset} {p50_all:5.1}s {_dim}│{reset} {p95_all:5.1}s {_dim}│{reset}"
595            );
596        }
597        eprintln!("    {_dim}└──────────────────┴────────┴────────┴────────┘{reset}");
598    }
599
600    (pass, fail)
601}