Skip to main content

llm_manager/backend/
server.rs

1use std::fmt::Display;
2use std::process::Stdio;
3use tokio::io::{AsyncBufReadExt, BufReader};
4use tokio::process::Command;
5use tokio::sync::mpsc;
6use tracing::{info, warn};
7
8use crate::config::{Config, DEFAULT_SYSTEM_PROMPT};
9use crate::models::{
10    DiscoveredModel, ModelSettings, RopeScaling, ServerMetrics, clean_host, strip_gguf,
11};
12
13/// Manages a single llama.cpp server process.
14#[derive(Clone)]
15pub struct ServerHandle {
16    pub port: u16,
17    pub host: String,
18    pub pid: u32,
19    pub kill_tx: mpsc::Sender<()>,
20}
21
22/// Helper: add an argument to both the Command and the display parts list.
23fn push_arg(cmd: &mut Command, parts: &mut Vec<String>, name: &str, value: impl Display) {
24    let val_str = value.to_string();
25    cmd.arg(name).arg(&val_str);
26    parts.push(name.to_string());
27    parts.push(val_str);
28}
29
30/// Helper: add a flag (argument without value) to both the Command and display parts.
31fn push_flag(cmd: &mut Command, parts: &mut Vec<String>, name: &str) {
32    cmd.arg(name);
33    parts.push(name.to_string());
34}
35
36fn push_gpu_layers(cmd: &mut Command, parts: &mut Vec<String>, settings: &ModelSettings) {
37    match settings.gpu_layers_mode {
38        crate::models::GpuLayersMode::Specific(n) => push_arg(cmd, parts, "-ngl", n),
39        crate::models::GpuLayersMode::All => push_arg(cmd, parts, "-ngl", "999"),
40        crate::models::GpuLayersMode::Auto => {}
41    }
42}
43
44fn push_spec_decoding(cmd: &mut Command, parts: &mut Vec<String>, settings: &ModelSettings) {
45    if !settings.spec_type.is_empty() {
46        push_arg(cmd, parts, "--spec-type", &settings.spec_type);
47        if settings.draft_tokens > 0 {
48            push_arg(cmd, parts, "--spec-draft-n-max", settings.draft_tokens);
49        }
50    }
51}
52
53/// Build the full llama-server command line from settings.
54/// Returns (Command, display_string) where the string is suitable for logging.
55pub fn build_server_cmd(
56    binary: &std::path::Path,
57    model: Option<&DiscoveredModel>,
58    settings: &ModelSettings,
59    config: &Config,
60    server_mode: crate::models::ServerMode,
61    router_max_models: u32,
62) -> (Command, String) {
63    let mut cmd = Command::new(binary);
64    let mut parts: Vec<String> = vec![binary.display().to_string()];
65
66    // ── Model ───────────────────────────────────────────────
67    match server_mode {
68        crate::models::ServerMode::Normal => {
69            if let Some(model) = model {
70                push_arg(&mut cmd, &mut parts, "-m", model.path.display());
71                // Add alias for router mode identification (uses the unique relative path)
72                push_arg(&mut cmd, &mut parts, "--alias", &model.display_name);
73            }
74        }
75        crate::models::ServerMode::Router => {
76            // Router mode: no model in CLI, use /load API to load models
77            if router_max_models > 0 {
78                push_arg(&mut cmd, &mut parts, "--models-max", router_max_models);
79            }
80            // Always pass --models-dir in router mode (global config setting)
81            if let Some(dir) = config.models_dirs.first() {
82                push_arg(&mut cmd, &mut parts, "--models-dir", dir.display());
83            }
84        }
85        crate::models::ServerMode::Bench => {
86            // Should not be reached as Bench uses build_bench_cmd
87        }
88        crate::models::ServerMode::BenchTune => {
89            // Should not be reached as BenchTune uses benchmark tuning function
90        }
91    }
92
93    // Parse GGUF metadata for arch-specific override-kv keys
94    let gguf_meta = model
95        .map(|m| crate::models::GgufMetadata::from_path(&m.path))
96        .transpose();
97
98    // ── Loading ──────────────────────────────────────────────
99    push_arg(&mut cmd, &mut parts, "--threads", settings.threads);
100    push_arg(
101        &mut cmd,
102        &mut parts,
103        "--threads-batch",
104        settings.threads_batch,
105    );
106    let effective_ctx = (settings.context_length as f32 * settings.rope_scale) as u32;
107    push_arg(&mut cmd, &mut parts, "--ctx-size", effective_ctx);
108    push_arg(&mut cmd, &mut parts, "--ubatch-size", settings.ubatch_size);
109    if let Some(n) = settings.max_concurrent_predictions {
110        push_arg(&mut cmd, &mut parts, "--parallel", n);
111    }
112
113    push_flag(&mut cmd, &mut parts, "--no-warmup");
114
115    push_spec_decoding(&mut cmd, &mut parts, settings);
116
117    if let Some(cache_k) = settings.cache_type_k {
118        push_arg(&mut cmd, &mut parts, "--cache-type-k", cache_k);
119    }
120    if let Some(cache_v) = settings.cache_type_v {
121        push_arg(&mut cmd, &mut parts, "--cache-type-v", cache_v);
122    }
123
124    if settings.keep != 0 {
125        push_arg(&mut cmd, &mut parts, "--keep", settings.keep);
126    }
127    if settings.swa_full {
128        push_flag(&mut cmd, &mut parts, "--swa-full");
129    }
130    if settings.mlock {
131        push_flag(&mut cmd, &mut parts, "--mlock");
132    }
133    if !settings.mmap {
134        push_flag(&mut cmd, &mut parts, "--no-mmap");
135    }
136    if settings.numa != Default::default() {
137        push_arg(&mut cmd, &mut parts, "--numa", settings.numa.to_string());
138    }
139    if settings.kv_cache_offload {
140        push_flag(&mut cmd, &mut parts, "--kv-offload");
141    }
142
143    // ── GPU ──────────────────────────────────────────────────
144    push_gpu_layers(&mut cmd, &mut parts, settings);
145
146    if settings.split_mode != Default::default() {
147        push_arg(
148            &mut cmd,
149            &mut parts,
150            "--split-mode",
151            settings.split_mode.to_string(),
152        );
153    }
154    if !settings.tensor_split.is_empty() {
155        push_arg(
156            &mut cmd,
157            &mut parts,
158            "--tensor-split",
159            &settings.tensor_split,
160        );
161    }
162    if settings.main_gpu != 0 {
163        push_arg(&mut cmd, &mut parts, "--main-gpu", settings.main_gpu);
164    }
165    if settings.fit {
166        push_arg(&mut cmd, &mut parts, "--fit", "on");
167    }
168
169    if let Some(ref lora) = settings.lora {
170        push_arg(&mut cmd, &mut parts, "--lora", lora.display());
171    }
172    if let Some((ref lora, scale)) = settings.lora_scaled {
173        push_arg(
174            &mut cmd,
175            &mut parts,
176            "--lora-scaled",
177            format!("{}:{}", lora.display(), scale),
178        );
179    }
180
181    let mut rpc_list = Vec::new();
182    if !settings.rpc.is_empty() {
183        rpc_list.push(settings.rpc.clone());
184    }
185    for worker in &config.rpc_workers {
186        if worker.selected {
187            rpc_list.push(format!("{}:{}", worker.ip, worker.port));
188        }
189    }
190
191    if !rpc_list.is_empty() {
192        let joined_rpc = rpc_list.join(",");
193        push_arg(&mut cmd, &mut parts, "--rpc", joined_rpc);
194    }
195
196    if settings.embedding {
197        push_flag(&mut cmd, &mut parts, "--embedding");
198    }
199
200    if settings.expert_count > 0 {
201        let arch = gguf_meta
202            .as_ref()
203            .ok()
204            .and_then(|opt| opt.as_ref())
205            .map(|m| m.arch.as_str())
206            .unwrap_or("llama");
207        push_arg(
208            &mut cmd,
209            &mut parts,
210            "--override-kv",
211            format!("{}.expert_used_count=int:int:{}", arch, settings.expert_count),
212        );
213    }
214
215    push_arg(
216        &mut cmd,
217        &mut parts,
218        "-fa",
219        if settings.flash_attn { "on" } else { "off" },
220    );
221
222    if settings.jinja {
223        push_flag(&mut cmd, &mut parts, "--jinja");
224    }
225
226    if let Some(ref template) = settings.chat_template {
227        push_arg(&mut cmd, &mut parts, "--chat-template", template);
228    }
229
230    // Inject system prompt via chat template kwargs when it differs from default
231    if settings.system_prompt != DEFAULT_SYSTEM_PROMPT {
232        let mut merged = serde_json::Map::new();
233        if let Some(ref kwargs) = settings.chat_template_kwargs
234            && let Ok(obj) = serde_json::from_str::<serde_json::Value>(kwargs)
235                && let serde_json::Value::Object(map) = obj {
236                    for (k, v) in map {
237                        merged.insert(k, v);
238                    }
239                }
240        merged.insert(
241            "system_prompt".to_string(),
242            serde_json::Value::String(settings.system_prompt.clone()),
243        );
244        push_arg(
245            &mut cmd,
246            &mut parts,
247            "--chat-template-kwargs",
248            serde_json::to_string(&merged).unwrap(),
249        );
250    } else if let Some(ref kwargs) = settings.chat_template_kwargs {
251        push_arg(&mut cmd, &mut parts, "--chat-template-kwargs", kwargs);
252    }
253
254    // ── Sampling ─────────────────────────────────────────────
255    if settings.seed != -1 {
256        push_arg(&mut cmd, &mut parts, "--seed", settings.seed);
257    }
258    if let Some(max_tokens) = settings.max_tokens {
259        push_arg(&mut cmd, &mut parts, "--n-predict", max_tokens);
260    }
261    push_arg(
262        &mut cmd,
263        &mut parts,
264        "--temp",
265        format!("{:.2}", settings.temperature),
266    );
267
268    push_arg(&mut cmd, &mut parts, "--top-k", settings.top_k);
269
270    push_arg(
271        &mut cmd,
272        &mut parts,
273        "--top-p",
274        format!("{:.2}", settings.top_p),
275    );
276
277    push_arg(
278        &mut cmd,
279        &mut parts,
280        "--min-p",
281        format!("{:.2}", settings.min_p),
282    );
283
284    push_arg(
285        &mut cmd,
286        &mut parts,
287        "--typical",
288        format!("{:.2}", settings.typical_p),
289    );
290
291    if settings.mirostat != Default::default() {
292        push_arg(
293            &mut cmd,
294            &mut parts,
295            "--mirostat",
296            settings.mirostat.to_string(),
297        );
298        push_arg(
299            &mut cmd,
300            &mut parts,
301            "--mirostat-lr",
302            format!("{:.2}", settings.mirostat_lr),
303        );
304        push_arg(
305            &mut cmd,
306            &mut parts,
307            "--mirostat-ent",
308            format!("{:.2}", settings.mirostat_ent),
309        );
310    }
311
312    if settings.ignore_eos {
313        push_flag(&mut cmd, &mut parts, "--ignore-eos");
314    }
315
316    if !settings.samplers.0.is_empty() {
317        push_arg(
318            &mut cmd,
319            &mut parts,
320            "--samplers",
321            settings.samplers.to_string(),
322        );
323    }
324
325    if let Some(frequency) = settings.frequency_penalty {
326        push_arg(
327            &mut cmd,
328            &mut parts,
329            "--frequency-penalty",
330            format!("{:.2}", frequency),
331        );
332    }
333
334    if settings.dry_multiplier != 0.0 {
335        push_arg(
336            &mut cmd,
337            &mut parts,
338            "--dry-multiplier",
339            format!("{:.2}", settings.dry_multiplier),
340        );
341        push_arg(
342            &mut cmd,
343            &mut parts,
344            "--dry-base",
345            format!("{:.2}", settings.dry_base),
346        );
347        push_arg(
348            &mut cmd,
349            &mut parts,
350            "--dry-allowed-length",
351            settings.dry_allowed_length,
352        );
353        push_arg(
354            &mut cmd,
355            &mut parts,
356            "--dry-penalty-last-n",
357            settings.dry_penalty_last_n,
358        );
359    }
360
361    // ── RoPE ─────────────────────────────────────────────────
362    let rope_scaling = if settings.rope_yarn_enabled {
363        RopeScaling::Yarn
364    } else {
365        settings.rope_scaling
366    };
367    if rope_scaling != Default::default() {
368        push_arg(
369            &mut cmd,
370            &mut parts,
371            "--rope-scaling",
372            rope_scaling.to_string(),
373        );
374    }
375    if settings.rope_scale != 1.0 {
376        push_arg(
377            &mut cmd,
378            &mut parts,
379            "--rope-scale",
380            format!("{:.2}", settings.rope_scale),
381        );
382    }
383    if settings.rope_freq_base != 0.0 {
384        push_arg(
385            &mut cmd,
386            &mut parts,
387            "--rope-freq-base",
388            format!("{:.2}", settings.rope_freq_base),
389        );
390    }
391    if settings.rope_freq_scale != 1.0 {
392        push_arg(
393            &mut cmd,
394            &mut parts,
395            "--rope-freq-scale",
396            format!("{:.2}", settings.rope_freq_scale),
397        );
398    }
399
400    if settings.rope_yarn_enabled && settings.rope_scale > 1.0 {
401        if let Some(ref meta) = gguf_meta.as_ref().ok().and_then(|x| x.as_ref()) {
402            push_arg(
403                &mut cmd,
404                &mut parts,
405                "--override-kv",
406                format!("{}.context_length=int:{}", meta.arch, effective_ctx),
407            );
408            let orig_ctx = meta.n_ctx_train;
409            push_arg(
410                &mut cmd,
411                &mut parts,
412                "--yarn-orig-ctx",
413                orig_ctx,
414            );
415        }
416    }
417
418    let resolved_host = clean_host(&settings.host);
419    push_arg(&mut cmd, &mut parts, "--host", resolved_host);
420    push_arg(&mut cmd, &mut parts, "--port", settings.port);
421    push_arg(&mut cmd, &mut parts, "--timeout", settings.timeout);
422
423    push_flag(&mut cmd, &mut parts, "--metrics");
424    if !settings.cache_prompt {
425        push_flag(&mut cmd, &mut parts, "--no-cache-prompt");
426    }
427    if settings.cache_reuse != 0 {
428        push_arg(&mut cmd, &mut parts, "--cache-reuse", settings.cache_reuse);
429    }
430    if !settings.webui {
431        push_flag(&mut cmd, &mut parts, "--no-webui");
432    }
433
434    // ── General ──────────────────────────────────────────────
435
436    let display = parts.join(" ");
437    (cmd, display)
438}
439
440/// Build the full llama-bench command line.
441pub fn build_bench_cmd(
442    binary: &std::path::Path,
443    model: &DiscoveredModel,
444    settings: &ModelSettings,
445) -> (Command, String) {
446    let mut cmd = Command::new(binary);
447    let mut parts: Vec<String> = vec![binary.display().to_string()];
448
449    push_arg(&mut cmd, &mut parts, "-m", model.path.display());
450    push_arg(&mut cmd, &mut parts, "-t", settings.threads);
451    push_arg(&mut cmd, &mut parts, "-b", settings.batch_size);
452
453    push_gpu_layers(&mut cmd, &mut parts, settings);
454
455    if settings.flash_attn {
456        push_arg(&mut cmd, &mut parts, "-fa", "1");
457    }
458
459    push_spec_decoding(&mut cmd, &mut parts, settings);
460
461    push_flag(&mut cmd, &mut parts, "--progress");
462
463    let display = parts.join(" ");
464    (cmd, display)
465}
466
467/// Spawn a llama.cpp server process (single model or router).
468/// Returns (ServerHandle, command_string) where command_string is the full CLI.
469pub struct SpawnServerRequest<'a> {
470    pub config: &'a Config,
471    pub model: Option<&'a DiscoveredModel>,
472    pub settings: &'a ModelSettings,
473    pub log_tx: mpsc::Sender<String>,
474    pub progress_tx: Option<tokio::sync::broadcast::Sender<crate::models::DownloadState>>,
475    pub server_mode: crate::models::ServerMode,
476    pub router_max_models: u32,
477    pub exit_tx: mpsc::Sender<()>,
478}
479
480pub async fn spawn_server(
481    req: SpawnServerRequest<'_>,
482) -> Result<(ServerHandle, String), String> {
483    let SpawnServerRequest {
484        config,
485        model,
486        settings,
487        log_tx,
488        progress_tx,
489        server_mode,
490        router_max_models,
491        exit_tx,
492    } = req;
493    if server_mode != crate::models::ServerMode::Bench
494        && server_mode != crate::models::ServerMode::BenchTune
495    {
496        let port = settings.port;
497        // Check if port is already in use (bind to the same host the server will use)
498        let resolved_host = clean_host(&settings.host);
499        if std::net::TcpListener::bind(format!("{}:{}", resolved_host, port)).is_err() {
500            return Err(format!("Port {} is already in use", port));
501        }
502    }
503
504    // BenchTune mode is handled separately in app.start_pending_spawn()
505    // and should never reach this function.
506    if server_mode == crate::models::ServerMode::BenchTune {
507        unreachable!("BenchTune mode must be handled before calling spawn_server")
508    }
509
510    // Resolve the backend binary (downloads if needed)
511    let backend_name = if server_mode == crate::models::ServerMode::Bench {
512        "llama-bench"
513    } else {
514        "llama-server"
515    };
516    let version_display = settings.get_active_backend_version_display();
517    info!(
518        "spawn_server: backend={}, requested_version={:?}, version_display={}",
519        settings.backend,
520        settings.get_active_backend_version(),
521        version_display
522    );
523    log_tx
524        .send(format!(
525            "Resolving {} (v{}) binary...",
526            backend_name, version_display
527        ))
528        .await
529        .ok();
530    let version_param = settings.get_active_backend_version().map(|s| s.as_str());
531
532    let server_binary = match crate::backend::hub::resolve_backend_binary(
533        settings.backend,
534        version_param,
535        Some(log_tx.clone()),
536        progress_tx,
537    )
538    .await
539    {
540        Ok(path) => {
541            info!("spawn_server: resolved binary path={}", path.display());
542            path
543        }
544        Err(e) => {
545            return Err(format!("Failed to resolve backend binary: {}", e));
546        }
547    };
548
549    let binary = if server_mode == crate::models::ServerMode::Bench {
550        server_binary.parent().unwrap().join("llama-bench")
551    } else {
552        server_binary
553    };
554
555    if !binary.exists() {
556        return Err(format!("Binary not found at: {}", binary.display()));
557    }
558    #[cfg(unix)]
559    {
560        use std::os::unix::fs::PermissionsExt;
561        if let Ok(metadata) = binary.metadata()
562            && metadata.permissions().mode() & 0o111 == 0
563        {
564            return Err(format!("Binary is not executable: {}", binary.display()));
565        }
566    }
567
568    let (mut cmd, cmd_string) = if server_mode == crate::models::ServerMode::Bench {
569        if let Some(m) = model {
570            build_bench_cmd(&binary, m, settings)
571        } else {
572            return Err("Model required for benchmark".to_string());
573        }
574    } else {
575        build_server_cmd(
576            &binary,
577            model,
578            settings,
579            config,
580            server_mode,
581            router_max_models,
582        )
583    };
584
585    cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
586
587    // Set platform-specific env vars so the binary can find its shared libraries
588    let bin_dir = binary.parent().unwrap();
589    match std::env::consts::OS {
590        "windows" => {
591            // On Windows, add bin_dir to PATH so llama-server.exe finds libllama.dll
592            if let Ok(current) = std::env::var("PATH") {
593                cmd.env("PATH", format!("{};{}", bin_dir.display(), current));
594            } else {
595                cmd.env("PATH", bin_dir);
596            }
597        }
598        "macos" => {
599            // On macOS, set DYLD_LIBRARY_PATH for dylib loading
600            if let Ok(current) = std::env::var("DYLD_LIBRARY_PATH") {
601                cmd.env(
602                    "DYLD_LIBRARY_PATH",
603                    format!("{}:{}", bin_dir.display(), current),
604                );
605            } else {
606                cmd.env("DYLD_LIBRARY_PATH", bin_dir);
607            }
608        }
609        _ => {
610            // On Linux, set LD_LIBRARY_PATH for so loading
611            if let Ok(current) = std::env::var("LD_LIBRARY_PATH") {
612                cmd.env(
613                    "LD_LIBRARY_PATH",
614                    format!("{}:{}", bin_dir.display(), current),
615                );
616            } else {
617                cmd.env("LD_LIBRARY_PATH", bin_dir);
618            }
619        }
620    }
621
622    info!("Spawning: {}", cmd_string);
623    let _ = log_tx
624        .send(format!("{}: {}", backend_name, cmd_string))
625        .await;
626    let mut child = cmd
627        .spawn()
628        .map_err(|e| format!("Failed to spawn process: {}", e))?;
629    let pid = child.id().unwrap_or(0);
630
631    let (kill_tx, mut kill_rx) = mpsc::channel(1);
632
633    // Background task: read stdout and stderr concurrently via separate tasks.
634    // Each stream gets its own task + mpsc channel so neither can block the other.
635    let log_tx_inner = log_tx.clone();
636    let exit_tx_inner = exit_tx.clone();
637    let backend_name_upper = backend_name.to_uppercase();
638    tokio::spawn(async move {
639        let stdout = child.stdout.take().unwrap();
640        let stderr = child.stderr.take().unwrap();
641
642        let (stdout_tx, mut stdout_rx) = mpsc::channel::<String>(64);
643        let (stderr_tx, mut stderr_rx) = mpsc::channel::<String>(64);
644
645        // Spawn a reader task for each stream
646        let mut std_out = Some(tokio::spawn(async move {
647            let reader = BufReader::new(stdout).lines();
648            tokio::pin!(reader);
649            while let Ok(Some(line)) = reader.next_line().await {
650                if stdout_tx.send(line).await.is_err() {
651                    break;
652                }
653            }
654        }));
655
656        let mut std_err = Some(tokio::spawn(async move {
657            let reader = BufReader::new(stderr).lines();
658            tokio::pin!(reader);
659            while let Ok(Some(line)) = reader.next_line().await {
660                if stderr_tx.send(line).await.is_err() {
661                    break;
662                }
663            }
664        }));
665
666        // Merge loop: block on whichever channel has data.
667        // When both are empty, select! sleeps with zero CPU cost.
668        loop {
669            tokio::select! {
670                _ = kill_rx.recv() => {
671                    let _ = child.kill().await;
672                    if let Some(h) = std_out.take() { let _ = h.await; }
673                    if let Some(h) = std_err.take() { let _ = h.await; }
674                    break;
675                }
676                line = stdout_rx.recv() => {
677                    if let Some(line) = line { let _ = log_tx_inner.send(line).await; } else { break; }
678                }
679                line = stderr_rx.recv() => {
680                    if let Some(line) = line { let _ = log_tx_inner.send(line).await; } else { break; }
681                }
682                else => break,
683            }
684        }
685
686        // Wait for reader tasks to finish
687        if let Some(h) = std_out.take() {
688            let _ = h.await;
689        }
690        if let Some(h) = std_err.take() {
691            let _ = h.await;
692        }
693
694        let exit_code = child.wait().await.ok().and_then(|s| s.code());
695        let _ = exit_tx_inner.send(()).await;
696        let _ = log_tx_inner
697            .send(format!(
698                "{} exited with code {:?}",
699                backend_name_upper, exit_code
700            ))
701            .await;
702    });
703
704    Ok((
705        ServerHandle {
706            port: if server_mode == crate::models::ServerMode::Bench {
707                0
708            } else {
709                settings.port
710            },
711            host: settings.host.clone(),
712            pid,
713            kill_tx,
714        },
715        cmd_string,
716    ))
717}
718
719/// Check if the server is healthy and responsive.
720pub async fn check_health(host: &str, port: u16) -> bool {
721    let host = clean_host(host);
722    let url = format!("http://{}:{}/health", host, port);
723    let client = reqwest::Client::builder()
724        .timeout(std::time::Duration::from_secs(1))
725        .build()
726        .unwrap_or_default();
727
728    match client.get(&url).send().await {
729        Ok(resp) => resp.status().is_success(),
730        Err(_) => false,
731    }
732}
733
734/// Kill a running server.
735pub async fn kill_server(handle: ServerHandle) -> Result<(), String> {
736    handle
737        .kill_tx
738        .send(())
739        .await
740        .map_err(|_| "Server already stopped".to_string())
741}
742
743/// Poll metrics from the server.
744pub async fn get_metrics(
745    host: &str,
746    port: u16,
747    model_name: Option<&str>,
748    pid: Option<u32>,
749) -> Result<ServerMetrics, String> {
750    let host = clean_host(host);
751    // We prefer the /metrics endpoint as it's more stable for system info.
752    // In router mode, we can specify the model via query parameter.
753    let mut url = if let Some(model) = model_name {
754        let name = strip_gguf(model);
755        format!("http://{}:{}/metrics?model={}", host, port, name)
756    } else {
757        format!("http://{}:{}/metrics", host, port)
758    };
759
760    let mut resp = reqwest::get(&url)
761        .await
762        .map_err(|e| format!("Failed to get metrics: {}", e))?;
763
764    // If model-specific metrics fail with 404 or 400, try plain /metrics
765    if (resp.status() == reqwest::StatusCode::NOT_FOUND
766        || resp.status() == reqwest::StatusCode::BAD_REQUEST)
767        && model_name.is_some()
768    {
769        url = format!("http://{}:{}/metrics", host, port);
770        resp = reqwest::get(&url)
771            .await
772            .map_err(|e| format!("Failed to get metrics: {}", e))?;
773    }
774
775    if !resp.status().is_success() {
776        return Err(format!("Server returned {}", resp.status()));
777    }
778
779    let text = resp
780        .text()
781        .await
782        .map_err(|e| format!("Failed to read metrics: {}", e))?;
783
784    let mut m = ServerMetrics { loaded: true, ..Default::default() };
785
786    let mut ctx_max_slots = 0u32;
787    let mut ctx_used_slots = 0u32;
788    let mut ctx_used_global = 0u32;
789    let mut ctx_max_global = 0u32;
790
791    let mut vram_used_slots = 0u64;
792    let mut vram_total_slots = 0u64;
793    let mut vram_used_global = 0u64;
794    let mut vram_total_global = 0u64;
795
796    for line in text.lines() {
797        if line.starts_with('#') || line.is_empty() {
798            continue;
799        }
800
801        let parts: Vec<&str> = line.split_whitespace().collect();
802        if parts.len() < 2 {
803            continue;
804        }
805
806        let name_with_labels = parts[0];
807        let mut val = 0.0;
808        for part in parts.iter().skip(1) {
809            if let Ok(v) = part.parse::<f64>() {
810                val = v;
811                break;
812            }
813        }
814
815        let is_slot = name_with_labels.contains("slot=\"") || name_with_labels.contains("pool=\"");
816        let name = name_with_labels
817            .split('{')
818            .next()
819            .unwrap_or(name_with_labels);
820
821        match name {
822            "llama_kv_cache_usage_bytes"
823            | "kv_cache_usage_bytes"
824            | "llama_server_kv_cache_usage_bytes"
825            | "llama_server_kv_cache_used_bytes"
826            | "llama_server_vram_used_bytes" => {
827                if is_slot {
828                    vram_used_slots += val as u64;
829                } else {
830                    vram_used_global = vram_used_global.max(val as u64);
831                }
832            }
833            "llama_kv_cache_total_bytes"
834            | "kv_cache_total_bytes"
835            | "llama_server_kv_cache_total_bytes"
836            | "llama_server_vram_total_bytes" => {
837                if is_slot {
838                    vram_total_slots += val as u64;
839                } else {
840                    vram_total_global = vram_total_global.max(val as u64);
841                }
842            }
843            "llama_model_memory_usage_bytes"
844            | "model_memory_usage_bytes"
845            | "llama_server_model_memory_usage_bytes"
846            | "llama_server_memory_usage_bytes"
847            | "llama_server_ram_usage_bytes"
848            | "llama_server_mem_used_bytes" => {
849                m.ram_used = m.ram_used.max(val as u64);
850            }
851            "llama_kv_cache_tokens_used"
852            | "kv_cache_usage_tokens"
853            | "kv_cache_tokens_used"
854            | "llama_server_kv_cache_tokens_used"
855            | "llamacpp:n_tokens_used"
856            | "llama_server_n_tokens_used"
857            | "llama_server_n_past"
858            | "llamacpp:n_past" => {
859                if is_slot {
860                    ctx_used_slots += val as u32;
861                } else {
862                    ctx_used_global = ctx_used_global.max(val as u32);
863                }
864            }
865            "llama_kv_cache_tokens_total"
866            | "kv_cache_total_tokens"
867            | "kv_cache_tokens_total"
868            | "llama_server_kv_cache_tokens_total"
869            | "llamacpp:n_ctx"
870            | "llamacpp:n_tokens_max"
871            | "llama_server_n_ctx"
872            | "llama_server_n_tokens_max" => {
873                if is_slot {
874                    ctx_max_slots += val as u32;
875                } else {
876                    ctx_max_global = ctx_max_global.max(val as u32);
877                }
878            }
879            "llama_server_cpu_usage_percentage"
880            | "cpu_usage_percentage"
881            | "llama_server_cpu_usage"
882            | "llama_server_cpu_percent" => {
883                m.cpu_usage = m.cpu_usage.max(val);
884            }
885            "llamacpp:predicted_tokens_seconds"
886            | "llama_server_predicted_tokens_seconds"
887            | "llama_server_tps" => {
888                m.tps += val;
889            }
890            "llamacpp:prompt_tokens_seconds"
891            | "llama_server_prompt_tokens_seconds"
892            | "llama_server_prompt_tps" => {
893                m.prompt_tps += val;
894            }
895            "llamacpp:kv_cache_usage_ratio" | "llama_server_kv_cache_usage_ratio" => {
896                if !is_slot && ctx_max_global > 0 {
897                    ctx_used_global = ctx_used_global.max((val * ctx_max_global as f64) as u32);
898                }
899            }
900            _ => {}
901        }
902    }
903
904    // Prefer global metrics (includes model weights + KV cache) over slot-only (KV cache subset).
905    m.gpu_mem_used = if vram_used_global > 0 {
906        vram_used_global
907    } else if vram_used_slots > 0 {
908        vram_used_slots
909    } else {
910        0
911    };
912    m.gpu_mem_total = if vram_total_global > 0 {
913        vram_total_global
914    } else if vram_total_slots > 0 {
915        vram_total_slots
916    } else {
917        0
918    };
919
920    // ctx_used = tokens currently in the KV cache.
921    // ctx_max = the total context window size allocated by the server.
922    m.ctx_used = if ctx_used_slots > 0 {
923        ctx_used_slots
924    } else {
925        ctx_used_global
926    };
927    m.ctx_max = if ctx_max_slots > 0 {
928        ctx_max_slots
929    } else {
930        ctx_max_global
931    };
932    // ctx_max may be overridden in poll_metrics() by the user-configured value.
933
934    // Prefer actual GPU memory usage from nvidia-smi or amdgpu_top.
935    // llama-server's kv_cache_usage_bytes only reports KV cache (typically 10%
936    // of total VRAM); model weights are loaded into GPU memory but not tracked
937    // by the server, so we use system-level tools to report what users see on GPUs.
938    if model_name.is_none() {
939        // Prefer system-level VRAM over llama-server's KV-only value.
940        // System tools report actual GPU memory including model weights,
941        // which is what users see on their GPUs and expect to read.
942        let set_if_better = |out: &mut ServerMetrics, used: u64, total: u64| {
943            if out.gpu_mem_used == 0 || used > out.gpu_mem_used {
944                out.gpu_mem_used = used;
945                out.gpu_mem_total = total;
946            }
947        };
948
949        let (nv_used, nv_total) = get_nvidia_vram_metrics().unwrap_or((0, 0));
950        set_if_better(&mut m, nv_used, nv_total);
951
952        if m.gpu_mem_total == 0 {
953            // AMD fallback when nvidia-smi is not available.
954            let (amd_used, amd_total) = get_amdgpu_vram_metrics().unwrap_or((0, 0));
955            set_if_better(&mut m, amd_used, amd_total);
956        }
957    } else if m.gpu_mem_used == 0 {
958        // KV-only queries: use system tools as a last resort.
959        if let Ok((used, total)) = get_nvidia_vram_metrics() {
960            m.gpu_mem_used = used;
961            m.gpu_mem_total = total;
962        } else if let Ok((used, total)) = get_amdgpu_vram_metrics() {
963            m.gpu_mem_used = used;
964            m.gpu_mem_total = total;
965        }
966    }
967
968    // Fallback for RAM and CPU using sysinfo (cross-platform)
969    if let Some(p) = pid {
970        if let Ok((ram, cpu)) = get_process_metrics(p) {
971            if m.ram_used == 0 {
972                m.ram_used = ram;
973            }
974            if m.cpu_usage == 0.0 {
975                m.cpu_usage = cpu;
976            }
977        }
978    }
979
980    Ok(m)
981}
982
983/// Get VRAM usage using nvidia-smi
984fn get_nvidia_vram_metrics() -> Result<(u64, u64), String> {
985    let output = std::process::Command::new("nvidia-smi")
986        .args([
987            "--query-gpu=memory.used,memory.total",
988            "--format=csv,noheader,nounits",
989        ])
990        .output()
991        .map_err(|e| e.to_string())?;
992
993    if !output.status.success() {
994        return Err("nvidia-smi failed".to_string());
995    }
996
997    let stdout = String::from_utf8_lossy(&output.stdout);
998    let mut total_used: u64 = 0;
999    let mut total_total: u64 = 0;
1000    for line in stdout.lines() {
1001        let parts: Vec<&str> = line.split(',').collect();
1002        if parts.len() >= 2 {
1003            let used = match parts[0].trim().parse::<u64>() {
1004                Ok(v) => v,
1005                Err(_) => {
1006                    warn!("nvidia-smi: failed to parse used memory from '{}'", parts[0]);
1007                    continue;
1008                }
1009            } * 1024 * 1024;
1010            let total = match parts[1].trim().parse::<u64>() {
1011                Ok(v) => v,
1012                Err(_) => {
1013                    warn!("nvidia-smi: failed to parse total memory from '{}'", parts[1]);
1014                    continue;
1015                }
1016            } * 1024 * 1024;
1017            total_used += used;
1018            total_total += total;
1019        }
1020    }
1021    if total_total > 0 {
1022        return Ok((total_used, total_total));
1023    }
1024
1025    Err("Invalid output from nvidia-smi".to_string())
1026}
1027
1028/// Get VRAM usage using amdgpu_top
1029fn get_amdgpu_vram_metrics() -> Result<(u64, u64), String> {
1030    let output = std::process::Command::new("amdgpu_top")
1031        .args(["-d", "--json"])
1032        .output()
1033        .map_err(|e| e.to_string())?;
1034
1035    if !output.status.success() {
1036        return Err("amdgpu_top failed".to_string());
1037    }
1038
1039    let json: serde_json::Value =
1040        serde_json::from_slice(&output.stdout).map_err(|e| e.to_string())?;
1041
1042    // amdgpu_top --json output has a "devices" array (or sometimes just a list of objects depending on version)
1043    let devices = if json.is_array() {
1044        json.as_array()
1045    } else {
1046        json.get("devices").and_then(|d| d.as_array())
1047    };
1048
1049    if let Some(devices) = devices
1050        && let Some(device) = devices.first()
1051    {
1052        // Priority 1: Check root keys (newer amdgpu_top format as provided by user)
1053        // "VRAM Usage Size": 3070128128, "VRAM Size": 8589934592
1054        let root_used = device.get("VRAM Usage Size").and_then(|v| v.as_u64());
1055        let root_total = device.get("VRAM Size").and_then(|v| v.as_u64());
1056
1057        if let (Some(used), Some(total)) = (root_used, root_total)
1058            && total > 0
1059        {
1060            return Ok((used, total));
1061        }
1062
1063        // Priority 2: Check nested VRAM object (alternative format)
1064        let vram_obj = device.get("VRAM");
1065        if let Some(vram) = vram_obj {
1066            // Check if it's the "Total VRAM Usage" format (usually MiB)
1067            let nested_used = vram
1068                .get("Total VRAM Usage")
1069                .and_then(|v| v.get("value").or(Some(v)))
1070                .and_then(|v| v.as_u64());
1071            let nested_total = vram
1072                .get("Total VRAM")
1073                .and_then(|v| v.get("value").or(Some(v)))
1074                .and_then(|v| v.as_u64());
1075
1076            if let (Some(used), Some(total)) = (nested_used, nested_total) {
1077                // These are usually in MiB if they have a "unit" field
1078                let multiplier = if vram.get("Total VRAM").and_then(|v| v.get("unit")).is_some() {
1079                    1024 * 1024
1080                } else {
1081                    1
1082                };
1083                if total > 0 {
1084                    return Ok((used * multiplier, total * multiplier));
1085                }
1086            }
1087        }
1088
1089        // Priority 3: Check vram_usage key (older format)
1090        let vram_usage = device.get("vram_usage");
1091        if let Some(vram) = vram_usage {
1092            let used = vram
1093                .get("VRAM")
1094                .or_else(|| vram.get("usage"))
1095                .and_then(|v| v.get("value").or(Some(v)))
1096                .and_then(|v| v.as_u64())
1097                .unwrap_or(0);
1098            let total = vram
1099                .get("TotalVRAM")
1100                .or_else(|| vram.get("total"))
1101                .and_then(|v| v.get("value").or(Some(v)))
1102                .and_then(|v| v.as_u64())
1103                .unwrap_or(0);
1104
1105            if total > 0 {
1106                return Ok((used * 1024 * 1024, total * 1024 * 1024));
1107            }
1108        }
1109    }
1110
1111    Err("Could not find VRAM info in amdgpu_top output".to_string())
1112}
1113
1114/// Cross-platform: Get RAM (RSS) and CPU usage for a PID.
1115/// Uses a persistent System instance so sysinfo can compute accurate
1116/// CPU deltas across calls (first call on a fresh instance is always 0).
1117fn get_process_metrics(
1118    pid: u32,
1119) -> Result<(u64, f64), String> {
1120    use std::sync::{LazyLock, Mutex};
1121    use sysinfo::{Pid, ProcessRefreshKind, ProcessesToUpdate, RefreshKind, System};
1122
1123    static SYS: LazyLock<Mutex<System>> = LazyLock::new(|| {
1124        Mutex::new(System::new_with_specifics(
1125            RefreshKind::everything().with_processes(ProcessRefreshKind::nothing().with_cpu().with_memory()),
1126        ))
1127    });
1128
1129    let mut sys = SYS.lock().unwrap();
1130    let pids = [Pid::from(pid as usize)];
1131    sys.refresh_processes_specifics(
1132        ProcessesToUpdate::Some(&pids),
1133        true,
1134        ProcessRefreshKind::nothing().with_cpu().with_memory(),
1135    );
1136
1137    let sys_pid = Pid::from(pid as usize);
1138
1139    if let Some(process) = sys.process(sys_pid) {
1140        let ram = process.memory(); // bytes
1141        let cpu = process.cpu_usage() as f64; // percentage
1142        return Ok((ram, cpu));
1143    }
1144
1145    Err(format!("Process not found: pid={}", pid))
1146}
1147
1148/// Load a model via the llama-server Router API.
1149pub async fn load_model(
1150    host: &str,
1151    port: u16,
1152    model_id: &str,
1153    model_path: Option<&str>,
1154) -> Result<(), String> {
1155    let client = reqwest::Client::new();
1156    let host = clean_host(host);
1157
1158    // Try multiple endpoints
1159    let endpoints = ["/models/load", "/v1/models/load"];
1160
1161    // Construct all possible identification variants
1162    let mut variants = Vec::new();
1163
1164    // 1. Original ID (display_name / relative path)
1165    variants.push(model_id.to_string());
1166    variants.push(strip_gguf(model_id).to_string());
1167
1168    // 2. Just the filename
1169    if let Some(filename) = std::path::Path::new(model_id)
1170        .file_name()
1171        .and_then(|f| f.to_str())
1172    {
1173        variants.push(filename.to_string());
1174        variants.push(strip_gguf(filename).to_string());
1175    }
1176
1177    // 3. Absolute path
1178    if let Some(path) = model_path {
1179        variants.push(path.to_string());
1180    }
1181
1182    let mut last_status = reqwest::StatusCode::OK;
1183    let mut last_error = String::new();
1184
1185    for endpoint in endpoints {
1186        let url = format!("http://{}:{}{}", host, port, endpoint);
1187        for variant in &variants {
1188            // Try both "model" and "alias" fields
1189            let bodies = vec![
1190                serde_json::json!({ "model": variant }),
1191                serde_json::json!({ "alias": variant }),
1192            ];
1193
1194            for body in bodies {
1195                match client.post(&url).json(&body).send().await {
1196                    Ok(res) => {
1197                        if res.status().is_success() {
1198                            return Ok(());
1199                        }
1200                        last_status = res.status();
1201                        last_error = res
1202                            .text()
1203                            .await
1204                            .unwrap_or_else(|_| "Unknown error".to_string());
1205                    }
1206                    Err(e) => {
1207                        last_error = e.to_string();
1208                    }
1209                }
1210            }
1211        }
1212    }
1213
1214    Err(format!(
1215        "Failed to load model (tried {} requests). Last status {}: {}",
1216        endpoints.len() * variants.len() * 2,
1217        last_status,
1218        last_error
1219    ))
1220}
1221
1222/// List all models and their status from the llama-server Router API.
1223pub async fn list_models(
1224    host: &str,
1225    port: u16,
1226) -> Result<Vec<(String, String, Option<String>)>, String> {
1227    let client = reqwest::Client::new();
1228    let host = clean_host(host);
1229    let url = format!("http://{}:{}/models", host, port);
1230
1231    let res = client
1232        .get(&url)
1233        .send()
1234        .await
1235        .map_err(|e| format!("Failed to list models: {}", e))?;
1236
1237    if !res.status().is_success() {
1238        return Err(format!("Server returned error {}", res.status()));
1239    }
1240
1241    let json: serde_json::Value = res
1242        .json()
1243        .await
1244        .map_err(|e| format!("Invalid JSON: {}", e))?;
1245
1246    let mut results = Vec::new();
1247    if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
1248        for model in data {
1249            let id = model
1250                .get("id")
1251                .and_then(|v| v.as_str())
1252                .unwrap_or_default()
1253                .to_string();
1254            // Status can be a string or an object with a "value" field
1255            let status = model
1256                .get("status")
1257                .and_then(|s| s.get("value").or(Some(s)))
1258                .and_then(|v| v.as_str())
1259                .unwrap_or("unloaded")
1260                .to_string();
1261            let path = model
1262                .get("path")
1263                .or_else(|| model.get("filename"))
1264                .and_then(|v| v.as_str())
1265                .map(|s| s.to_string());
1266
1267            results.push((id, status, path));
1268        }
1269    }
1270
1271    Ok(results)
1272}
1273
1274/// Unload a model via the llama-server Router API.
1275pub async fn unload_model(
1276    host: &str,
1277    port: u16,
1278    model_id: &str,
1279    model_path: Option<&str>,
1280) -> Result<(), String> {
1281    let client = reqwest::Client::new();
1282    let host = clean_host(host);
1283
1284    let endpoints = ["/models/unload", "/v1/models/unload"];
1285    let stripped = strip_gguf(model_id);
1286    let mut variants = vec![model_id.to_string(), stripped.to_string()];
1287    if let Some(path) = model_path {
1288        variants.push(path.to_string());
1289    }
1290
1291    let mut last_status = reqwest::StatusCode::OK;
1292    let mut last_error = String::new();
1293
1294    for endpoint in endpoints {
1295        let url = format!("http://{}:{}{}", host, port, endpoint);
1296        for variant in &variants {
1297            let body = serde_json::json!({
1298                "model": variant
1299            });
1300
1301            if let Ok(res) = client.post(&url).json(&body).send().await {
1302                if res.status().is_success() {
1303                    return Ok(());
1304                }
1305                last_status = res.status();
1306                last_error = res
1307                    .text()
1308                    .await
1309                    .unwrap_or_else(|_| "Unknown error".to_string());
1310            }
1311        }
1312    }
1313
1314    tracing::debug!(
1315        "Model unload failed (status {}, error: {}): this is expected if model was already unloaded",
1316        last_status,
1317        last_error
1318    );
1319    Ok(())
1320}