Skip to main content

llm_manager/backend/
server.rs

1use std::fmt::Display;
2use std::process::Stdio;
3use tokio::io::{AsyncBufReadExt, BufReader};
4use tokio::process::Command;
5use tokio::sync::mpsc;
6use tracing::info;
7
8use crate::config::Config;
9use crate::models::{
10    DiscoveredModel, ModelSettings, RopeScaling, ServerMetrics, clean_host, strip_gguf,
11};
12
13/// Manages a single llama.cpp server process.
14#[derive(Clone)]
15pub struct ServerHandle {
16    pub port: u16,
17    pub host: String,
18    pub pid: u32,
19    pub kill_tx: mpsc::Sender<()>,
20}
21
22/// Helper: add an argument to both the Command and the display parts list.
23fn push_arg(cmd: &mut Command, parts: &mut Vec<String>, name: &str, value: impl Display) {
24    let val_str = value.to_string();
25    cmd.arg(name).arg(&val_str);
26    parts.push(name.to_string());
27    parts.push(val_str);
28}
29
30/// Helper: add a flag (argument without value) to both the Command and display parts.
31fn push_flag(cmd: &mut Command, parts: &mut Vec<String>, name: &str) {
32    cmd.arg(name);
33    parts.push(name.to_string());
34}
35
36fn push_gpu_layers(cmd: &mut Command, parts: &mut Vec<String>, settings: &ModelSettings) {
37    match settings.gpu_layers_mode {
38        crate::models::GpuLayersMode::Specific(n) => push_arg(cmd, parts, "-ngl", n),
39        crate::models::GpuLayersMode::All => push_arg(cmd, parts, "-ngl", "999"),
40        crate::models::GpuLayersMode::Auto => {}
41    }
42}
43
44fn push_spec_decoding(cmd: &mut Command, parts: &mut Vec<String>, settings: &ModelSettings) {
45    if !settings.spec_type.is_empty() {
46        push_arg(cmd, parts, "--spec-type", &settings.spec_type);
47        if settings.draft_tokens > 0 {
48            push_arg(cmd, parts, "--spec-draft-n-max", settings.draft_tokens);
49        }
50    }
51}
52
53/// Build the full llama-server command line from settings.
54/// Returns (Command, display_string) where the string is suitable for logging.
55pub fn build_server_cmd(
56    binary: &std::path::Path,
57    model: Option<&DiscoveredModel>,
58    settings: &ModelSettings,
59    config: &Config,
60    server_mode: crate::models::ServerMode,
61    router_max_models: u32,
62) -> (Command, String) {
63    let mut cmd = Command::new(binary);
64    let mut parts: Vec<String> = vec![binary.display().to_string()];
65
66    // ── Model ───────────────────────────────────────────────
67    match server_mode {
68        crate::models::ServerMode::Normal => {
69            if let Some(model) = model {
70                push_arg(&mut cmd, &mut parts, "-m", model.path.display());
71                // Add alias for router mode identification (uses the unique relative path)
72                push_arg(&mut cmd, &mut parts, "--alias", &model.display_name);
73            }
74        }
75        crate::models::ServerMode::Router => {
76            // Router mode: no model in CLI, use /load API to load models
77            if router_max_models > 0 {
78                push_arg(&mut cmd, &mut parts, "--models-max", router_max_models);
79            }
80            // Always pass --models-dir in router mode (global config setting)
81            if let Some(dir) = config.models_dirs.first() {
82                push_arg(&mut cmd, &mut parts, "--models-dir", dir.display());
83            }
84        }
85        crate::models::ServerMode::Bench => {
86            // Should not be reached as Bench uses build_bench_cmd
87        }
88        crate::models::ServerMode::BenchTune => {
89            // Should not be reached as BenchTune uses benchmark tuning function
90        }
91    }
92
93    // ── Loading ──────────────────────────────────────────────
94    push_arg(&mut cmd, &mut parts, "--threads", settings.threads);
95    push_arg(
96        &mut cmd,
97        &mut parts,
98        "--threads-batch",
99        settings.threads_batch,
100    );
101    let effective_ctx = (settings.context_length as f32 * settings.rope_scale) as u32;
102    push_arg(&mut cmd, &mut parts, "--ctx-size", effective_ctx);
103    push_arg(&mut cmd, &mut parts, "--ubatch-size", settings.ubatch_size);
104    if let Some(n) = settings.max_concurrent_predictions {
105        push_arg(&mut cmd, &mut parts, "--parallel", n);
106    }
107
108    push_flag(&mut cmd, &mut parts, "--no-warmup");
109
110    push_spec_decoding(&mut cmd, &mut parts, settings);
111
112    if let Some(cache_k) = settings.cache_type_k {
113        push_arg(&mut cmd, &mut parts, "--cache-type-k", cache_k);
114    }
115    if let Some(cache_v) = settings.cache_type_v {
116        push_arg(&mut cmd, &mut parts, "--cache-type-v", cache_v);
117    }
118
119    if settings.keep != 0 {
120        push_arg(&mut cmd, &mut parts, "--keep", settings.keep);
121    }
122    if settings.swa_full {
123        push_flag(&mut cmd, &mut parts, "--swa-full");
124    }
125    if settings.mlock {
126        push_flag(&mut cmd, &mut parts, "--mlock");
127    }
128    if !settings.mmap {
129        push_flag(&mut cmd, &mut parts, "--no-mmap");
130    }
131    if settings.numa != Default::default() {
132        push_arg(&mut cmd, &mut parts, "--numa", settings.numa.to_string());
133    }
134    if settings.kv_cache_offload {
135        push_flag(&mut cmd, &mut parts, "--kv-offload");
136    }
137
138    // ── GPU ──────────────────────────────────────────────────
139    push_gpu_layers(&mut cmd, &mut parts, settings);
140
141    if settings.split_mode != Default::default() {
142        push_arg(
143            &mut cmd,
144            &mut parts,
145            "--split-mode",
146            settings.split_mode.to_string(),
147        );
148    }
149    if !settings.tensor_split.is_empty() {
150        push_arg(
151            &mut cmd,
152            &mut parts,
153            "--tensor-split",
154            &settings.tensor_split,
155        );
156    }
157    if settings.main_gpu != 0 {
158        push_arg(&mut cmd, &mut parts, "--main-gpu", settings.main_gpu);
159    }
160    if settings.fit {
161        push_arg(&mut cmd, &mut parts, "--fit", "on");
162    } else {
163        push_arg(&mut cmd, &mut parts, "--fit", "off");
164    }
165
166    if let Some(ref lora) = settings.lora {
167        push_arg(&mut cmd, &mut parts, "--lora", lora.display());
168    }
169    if let Some((ref lora, scale)) = settings.lora_scaled {
170        push_arg(
171            &mut cmd,
172            &mut parts,
173            "--lora-scaled",
174            format!("{}:{}", lora.display(), scale),
175        );
176    }
177
178    let mut rpc_list = Vec::new();
179    if !settings.rpc.is_empty() {
180        rpc_list.push(settings.rpc.clone());
181    }
182    for worker in &config.rpc_workers {
183        if worker.selected {
184            rpc_list.push(format!("{}:{}", worker.ip, worker.port));
185        }
186    }
187
188    if !rpc_list.is_empty() {
189        let joined_rpc = rpc_list.join(",");
190        push_arg(&mut cmd, &mut parts, "--rpc", joined_rpc);
191    }
192
193    if settings.embedding {
194        push_flag(&mut cmd, &mut parts, "--embedding");
195    }
196
197    if settings.expert_count > 0 {
198        push_arg(
199            &mut cmd,
200            &mut parts,
201            "--override-kv",
202            format!("llama.expert_used_count=int:int:{}", settings.expert_count),
203        );
204    }
205
206    push_arg(
207        &mut cmd,
208        &mut parts,
209        "-fa",
210        if settings.flash_attn { "on" } else { "off" },
211    );
212
213    if settings.jinja {
214        push_flag(&mut cmd, &mut parts, "--jinja");
215    }
216
217    if let Some(ref template) = settings.chat_template {
218        push_arg(&mut cmd, &mut parts, "--chat-template", template);
219    }
220
221    // Inject system prompt via chat template kwargs when it differs from default
222    if settings.system_prompt != "You are a helpful assistant." {
223        let escaped = settings
224            .system_prompt
225            .replace('\\', "\\\\")
226            .replace('"', "\\\"");
227        let mut merged = serde_json::Map::new();
228        if let Some(ref kwargs) = settings.chat_template_kwargs
229            && let Ok(obj) = serde_json::from_str::<serde_json::Value>(kwargs)
230                && let serde_json::Value::Object(map) = obj {
231                    for (k, v) in map {
232                        merged.insert(k, v);
233                    }
234                }
235        merged.insert(
236            "system_prompt".to_string(),
237            serde_json::Value::String(escaped),
238        );
239        push_arg(
240            &mut cmd,
241            &mut parts,
242            "--chat-template-kwargs",
243            serde_json::to_string(&merged).unwrap(),
244        );
245    } else if let Some(ref kwargs) = settings.chat_template_kwargs {
246        push_arg(&mut cmd, &mut parts, "--chat-template-kwargs", kwargs);
247    }
248
249    // ── Sampling ─────────────────────────────────────────────
250    if settings.seed != -1 {
251        push_arg(&mut cmd, &mut parts, "--seed", settings.seed);
252    }
253    if let Some(max_tokens) = settings.max_tokens {
254        push_arg(&mut cmd, &mut parts, "--n-predict", max_tokens);
255    }
256    push_arg(
257        &mut cmd,
258        &mut parts,
259        "--temp",
260        format!("{:.2}", settings.temperature),
261    );
262
263    push_arg(&mut cmd, &mut parts, "--top-k", settings.top_k);
264
265    push_arg(
266        &mut cmd,
267        &mut parts,
268        "--top-p",
269        format!("{:.2}", settings.top_p),
270    );
271
272    push_arg(
273        &mut cmd,
274        &mut parts,
275        "--min-p",
276        format!("{:.2}", settings.min_p),
277    );
278
279    push_arg(
280        &mut cmd,
281        &mut parts,
282        "--typical",
283        format!("{:.2}", settings.typical_p),
284    );
285
286    if settings.mirostat != Default::default() {
287        push_arg(
288            &mut cmd,
289            &mut parts,
290            "--mirostat",
291            settings.mirostat.to_string(),
292        );
293        push_arg(
294            &mut cmd,
295            &mut parts,
296            "--mirostat-lr",
297            format!("{:.2}", settings.mirostat_lr),
298        );
299        push_arg(
300            &mut cmd,
301            &mut parts,
302            "--mirostat-ent",
303            format!("{:.2}", settings.mirostat_ent),
304        );
305    }
306
307    if settings.ignore_eos {
308        push_flag(&mut cmd, &mut parts, "--ignore-eos");
309    }
310
311    if !settings.samplers.0.is_empty() {
312        push_arg(
313            &mut cmd,
314            &mut parts,
315            "--samplers",
316            settings.samplers.to_string(),
317        );
318    }
319
320    if let Some(frequency) = settings.frequency_penalty {
321        push_arg(
322            &mut cmd,
323            &mut parts,
324            "--frequency-penalty",
325            format!("{:.2}", frequency),
326        );
327    }
328
329    if settings.dry_multiplier != 0.0 {
330        push_arg(
331            &mut cmd,
332            &mut parts,
333            "--dry-multiplier",
334            format!("{:.2}", settings.dry_multiplier),
335        );
336        push_arg(
337            &mut cmd,
338            &mut parts,
339            "--dry-base",
340            format!("{:.2}", settings.dry_base),
341        );
342        push_arg(
343            &mut cmd,
344            &mut parts,
345            "--dry-allowed-length",
346            settings.dry_allowed_length,
347        );
348        push_arg(
349            &mut cmd,
350            &mut parts,
351            "--dry-penalty-last-n",
352            settings.dry_penalty_last_n,
353        );
354    }
355
356    // ── RoPE ─────────────────────────────────────────────────
357    let rope_scaling = if settings.rope_yarn_enabled {
358        RopeScaling::Yarn
359    } else {
360        settings.rope_scaling
361    };
362    if rope_scaling != Default::default() {
363        push_arg(
364            &mut cmd,
365            &mut parts,
366            "--rope-scaling",
367            rope_scaling.to_string(),
368        );
369    }
370    if settings.rope_scale != 1.0 {
371        push_arg(
372            &mut cmd,
373            &mut parts,
374            "--rope-scale",
375            format!("{:.2}", settings.rope_scale),
376        );
377    }
378    if settings.rope_freq_base != 0.0 {
379        push_arg(
380            &mut cmd,
381            &mut parts,
382            "--rope-freq-base",
383            format!("{:.2}", settings.rope_freq_base),
384        );
385    }
386    if settings.rope_freq_scale != 1.0 {
387        push_arg(
388            &mut cmd,
389            &mut parts,
390            "--rope-freq-scale",
391            format!("{:.2}", settings.rope_freq_scale),
392        );
393    }
394
395    let resolved_host = clean_host(&settings.host);
396    push_arg(&mut cmd, &mut parts, "--host", resolved_host);
397    push_arg(&mut cmd, &mut parts, "--port", settings.port);
398    push_arg(&mut cmd, &mut parts, "--timeout", settings.timeout);
399
400    push_flag(&mut cmd, &mut parts, "--metrics");
401    if !settings.cache_prompt {
402        push_flag(&mut cmd, &mut parts, "--no-cache-prompt");
403    }
404    if settings.cache_reuse != 0 {
405        push_arg(&mut cmd, &mut parts, "--cache-reuse", settings.cache_reuse);
406    }
407    if !settings.webui {
408        push_flag(&mut cmd, &mut parts, "--no-webui");
409    }
410
411    // ── General ──────────────────────────────────────────────
412
413    let display = parts.join(" ");
414    (cmd, display)
415}
416
417/// Build the full llama-bench command line.
418pub fn build_bench_cmd(
419    binary: &std::path::Path,
420    model: &DiscoveredModel,
421    settings: &ModelSettings,
422) -> (Command, String) {
423    let mut cmd = Command::new(binary);
424    let mut parts: Vec<String> = vec![binary.display().to_string()];
425
426    push_arg(&mut cmd, &mut parts, "-m", model.path.display());
427    push_arg(&mut cmd, &mut parts, "-t", settings.threads);
428    push_arg(&mut cmd, &mut parts, "-b", settings.batch_size);
429
430    push_gpu_layers(&mut cmd, &mut parts, settings);
431
432    if settings.flash_attn {
433        push_arg(&mut cmd, &mut parts, "-fa", "1");
434    }
435
436    push_spec_decoding(&mut cmd, &mut parts, settings);
437
438    push_flag(&mut cmd, &mut parts, "--progress");
439
440    let display = parts.join(" ");
441    (cmd, display)
442}
443
444/// Spawn a llama.cpp server process (single model or router).
445/// Returns (ServerHandle, command_string) where command_string is the full CLI.
446pub struct SpawnServerRequest<'a> {
447    pub config: &'a Config,
448    pub model: Option<&'a DiscoveredModel>,
449    pub settings: &'a ModelSettings,
450    pub log_tx: mpsc::Sender<String>,
451    pub progress_tx: Option<tokio::sync::broadcast::Sender<crate::models::DownloadState>>,
452    pub server_mode: crate::models::ServerMode,
453    pub router_max_models: u32,
454    pub exit_tx: mpsc::Sender<()>,
455}
456
457pub async fn spawn_server(
458    req: SpawnServerRequest<'_>,
459) -> Result<(ServerHandle, String), String> {
460    let SpawnServerRequest {
461        config,
462        model,
463        settings,
464        log_tx,
465        progress_tx,
466        server_mode,
467        router_max_models,
468        exit_tx,
469    } = req;
470    if server_mode != crate::models::ServerMode::Bench
471        && server_mode != crate::models::ServerMode::BenchTune
472    {
473        let port = settings.port;
474        // Check if port is already in use
475        if std::net::TcpListener::bind(("127.0.0.1", port)).is_err() {
476            return Err(format!("Port {} is already in use", port));
477        }
478    }
479
480    // BenchTune mode is handled separately in app.start_pending_spawn()
481    // and should never reach this function.
482    if server_mode == crate::models::ServerMode::BenchTune {
483        unreachable!("BenchTune mode must be handled before calling spawn_server")
484    }
485
486    // Resolve the backend binary (downloads if needed)
487    let backend_name = if server_mode == crate::models::ServerMode::Bench {
488        "llama-bench"
489    } else {
490        "llama-server"
491    };
492    let version_display = settings.get_active_backend_version_display();
493    info!(
494        "spawn_server: backend={}, requested_version={:?}, version_display={}",
495        settings.backend,
496        settings.get_active_backend_version(),
497        version_display
498    );
499    log_tx
500        .send(format!(
501            "Resolving {} (v{}) binary...",
502            backend_name, version_display
503        ))
504        .await
505        .ok();
506    let version_param = settings.get_active_backend_version().map(|s| s.as_str());
507
508    let server_binary = match crate::backend::hub::resolve_backend_binary(
509        settings.backend,
510        version_param,
511        Some(log_tx.clone()),
512        progress_tx,
513    )
514    .await
515    {
516        Ok(path) => {
517            info!("spawn_server: resolved binary path={}", path.display());
518            path
519        }
520        Err(e) => {
521            return Err(format!("Failed to resolve backend binary: {}", e));
522        }
523    };
524
525    let binary = if server_mode == crate::models::ServerMode::Bench {
526        server_binary.parent().unwrap().join("llama-bench")
527    } else {
528        server_binary
529    };
530
531    if !binary.exists() {
532        return Err(format!("Binary not found at: {}", binary.display()));
533    }
534    #[cfg(unix)]
535    {
536        use std::os::unix::fs::PermissionsExt;
537        if let Ok(metadata) = binary.metadata()
538            && metadata.permissions().mode() & 0o111 == 0
539        {
540            return Err(format!("Binary is not executable: {}", binary.display()));
541        }
542    }
543
544    let (mut cmd, cmd_string) = if server_mode == crate::models::ServerMode::Bench {
545        if let Some(m) = model {
546            build_bench_cmd(&binary, m, settings)
547        } else {
548            return Err("Model required for benchmark".to_string());
549        }
550    } else {
551        build_server_cmd(
552            &binary,
553            model,
554            settings,
555            config,
556            server_mode,
557            router_max_models,
558        )
559    };
560
561    cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
562
563    // Set platform-specific env vars so the binary can find its shared libraries
564    let bin_dir = binary.parent().unwrap();
565    match std::env::consts::OS {
566        "windows" => {
567            // On Windows, add bin_dir to PATH so llama-server.exe finds libllama.dll
568            if let Ok(current) = std::env::var("PATH") {
569                cmd.env("PATH", format!("{};{}", bin_dir.display(), current));
570            } else {
571                cmd.env("PATH", bin_dir);
572            }
573        }
574        "macos" => {
575            // On macOS, set DYLD_LIBRARY_PATH for dylib loading
576            if let Ok(current) = std::env::var("DYLD_LIBRARY_PATH") {
577                cmd.env(
578                    "DYLD_LIBRARY_PATH",
579                    format!("{}:{}", bin_dir.display(), current),
580                );
581            } else {
582                cmd.env("DYLD_LIBRARY_PATH", bin_dir);
583            }
584        }
585        _ => {
586            // On Linux, set LD_LIBRARY_PATH for so loading
587            if let Ok(current) = std::env::var("LD_LIBRARY_PATH") {
588                cmd.env(
589                    "LD_LIBRARY_PATH",
590                    format!("{}:{}", bin_dir.display(), current),
591                );
592            } else {
593                cmd.env("LD_LIBRARY_PATH", bin_dir);
594            }
595        }
596    }
597
598    info!("Spawning: {}", cmd_string);
599    let _ = log_tx
600        .send(format!("{}: {}", backend_name, cmd_string))
601        .await;
602    let mut child = cmd
603        .spawn()
604        .map_err(|e| format!("Failed to spawn process: {}", e))?;
605    let pid = child.id().unwrap_or(0);
606
607    let (kill_tx, mut kill_rx) = mpsc::channel(1);
608
609    // Background task: read stdout and stderr concurrently via separate tasks.
610    // Each stream gets its own task + mpsc channel so neither can block the other.
611    let log_tx_inner = log_tx.clone();
612    let exit_tx_inner = exit_tx.clone();
613    let backend_name_upper = backend_name.to_uppercase();
614    tokio::spawn(async move {
615        let stdout = child.stdout.take().unwrap();
616        let stderr = child.stderr.take().unwrap();
617
618        let (stdout_tx, mut stdout_rx) = mpsc::channel::<String>(64);
619        let (stderr_tx, mut stderr_rx) = mpsc::channel::<String>(64);
620
621        // Spawn a reader task for each stream
622        let mut std_out = Some(tokio::spawn(async move {
623            let reader = BufReader::new(stdout).lines();
624            tokio::pin!(reader);
625            while let Ok(Some(line)) = reader.next_line().await {
626                if stdout_tx.send(line).await.is_err() {
627                    break;
628                }
629            }
630        }));
631
632        let mut std_err = Some(tokio::spawn(async move {
633            let reader = BufReader::new(stderr).lines();
634            tokio::pin!(reader);
635            while let Ok(Some(line)) = reader.next_line().await {
636                if stderr_tx.send(line).await.is_err() {
637                    break;
638                }
639            }
640        }));
641
642        // Merge loop: block on whichever channel has data.
643        // When both are empty, select! sleeps with zero CPU cost.
644        loop {
645            tokio::select! {
646                _ = kill_rx.recv() => {
647                    let _ = child.kill().await;
648                    if let Some(h) = std_out.take() { let _ = h.await; }
649                    if let Some(h) = std_err.take() { let _ = h.await; }
650                    break;
651                }
652                line = stdout_rx.recv() => {
653                    if let Some(line) = line { let _ = log_tx_inner.send(line).await; } else { break; }
654                }
655                line = stderr_rx.recv() => {
656                    if let Some(line) = line { let _ = log_tx_inner.send(line).await; } else { break; }
657                }
658                else => break,
659            }
660        }
661
662        // Wait for reader tasks to finish
663        if let Some(h) = std_out.take() {
664            let _ = h.await;
665        }
666        if let Some(h) = std_err.take() {
667            let _ = h.await;
668        }
669
670        let exit_code = child.wait().await.ok().and_then(|s| s.code());
671        let _ = exit_tx_inner.send(()).await;
672        let _ = log_tx_inner
673            .send(format!(
674                "{} exited with code {:?}",
675                backend_name_upper, exit_code
676            ))
677            .await;
678    });
679
680    Ok((
681        ServerHandle {
682            port: if server_mode == crate::models::ServerMode::Bench {
683                0
684            } else {
685                settings.port
686            },
687            host: settings.host.clone(),
688            pid,
689            kill_tx,
690        },
691        cmd_string,
692    ))
693}
694
695/// Check if the server is healthy and responsive.
696pub async fn check_health(host: &str, port: u16) -> bool {
697    let host = clean_host(host);
698    let url = format!("http://{}:{}/health", host, port);
699    let client = reqwest::Client::builder()
700        .timeout(std::time::Duration::from_secs(1))
701        .build()
702        .unwrap_or_default();
703
704    match client.get(&url).send().await {
705        Ok(resp) => resp.status().is_success(),
706        Err(_) => false,
707    }
708}
709
710/// Kill a running server.
711pub async fn kill_server(handle: ServerHandle) -> Result<(), String> {
712    handle
713        .kill_tx
714        .send(())
715        .await
716        .map_err(|_| "Server already stopped".to_string())
717}
718
719/// Poll metrics from the server.
720pub async fn get_metrics(
721    host: &str,
722    port: u16,
723    model_name: Option<&str>,
724    pid: Option<u32>,
725) -> Result<ServerMetrics, String> {
726    let host = clean_host(host);
727    // We prefer the /metrics endpoint as it's more stable for system info.
728    // In router mode, we can specify the model via query parameter.
729    let mut url = if let Some(model) = model_name {
730        let name = strip_gguf(model);
731        format!("http://{}:{}/metrics?model={}", host, port, name)
732    } else {
733        format!("http://{}:{}/metrics", host, port)
734    };
735
736    let mut resp = reqwest::get(&url)
737        .await
738        .map_err(|e| format!("Failed to get metrics: {}", e))?;
739
740    // If model-specific metrics fail with 404 or 400, try plain /metrics
741    if (resp.status() == reqwest::StatusCode::NOT_FOUND
742        || resp.status() == reqwest::StatusCode::BAD_REQUEST)
743        && model_name.is_some()
744    {
745        url = format!("http://{}:{}/metrics", host, port);
746        resp = reqwest::get(&url)
747            .await
748            .map_err(|e| format!("Failed to get metrics: {}", e))?;
749    }
750
751    if !resp.status().is_success() {
752        return Err(format!("Server returned {}", resp.status()));
753    }
754
755    let text = resp
756        .text()
757        .await
758        .map_err(|e| format!("Failed to read metrics: {}", e))?;
759
760    let mut m = ServerMetrics { loaded: true, ..Default::default() };
761
762    let mut ctx_max_slots = 0u32;
763    let mut ctx_used_slots = 0u32;
764    let mut ctx_used_global = 0u32;
765    let mut ctx_max_global = 0u32;
766
767    let mut vram_used_slots = 0u64;
768    let mut vram_total_slots = 0u64;
769    let mut vram_used_global = 0u64;
770    let mut vram_total_global = 0u64;
771
772    for line in text.lines() {
773        if line.starts_with('#') || line.is_empty() {
774            continue;
775        }
776
777        let parts: Vec<&str> = line.split_whitespace().collect();
778        if parts.len() < 2 {
779            continue;
780        }
781
782        let name_with_labels = parts[0];
783        let mut val = 0.0;
784        for part in parts.iter().skip(1) {
785            if let Ok(v) = part.parse::<f64>() {
786                val = v;
787                break;
788            }
789        }
790
791        let is_slot = name_with_labels.contains("slot=\"") || name_with_labels.contains("pool=\"");
792        let name = name_with_labels
793            .split('{')
794            .next()
795            .unwrap_or(name_with_labels);
796
797        match name {
798            "llama_kv_cache_usage_bytes"
799            | "kv_cache_usage_bytes"
800            | "llama_server_kv_cache_usage_bytes"
801            | "llama_server_kv_cache_used_bytes"
802            | "llama_server_vram_used_bytes" => {
803                if is_slot {
804                    vram_used_slots += val as u64;
805                } else {
806                    vram_used_global = vram_used_global.max(val as u64);
807                }
808            }
809            "llama_kv_cache_total_bytes"
810            | "kv_cache_total_bytes"
811            | "llama_server_kv_cache_total_bytes"
812            | "llama_server_vram_total_bytes" => {
813                if is_slot {
814                    vram_total_slots += val as u64;
815                } else {
816                    vram_total_global = vram_total_global.max(val as u64);
817                }
818            }
819            "llama_model_memory_usage_bytes"
820            | "model_memory_usage_bytes"
821            | "llama_server_model_memory_usage_bytes"
822            | "llama_server_memory_usage_bytes"
823            | "llama_server_ram_usage_bytes"
824            | "llama_server_mem_used_bytes" => {
825                m.ram_used = m.ram_used.max(val as u64);
826            }
827            "llama_kv_cache_tokens_used"
828            | "kv_cache_usage_tokens"
829            | "kv_cache_tokens_used"
830            | "llama_server_kv_cache_tokens_used"
831            | "llamacpp:n_tokens_used"
832            | "llama_server_n_tokens_used"
833            | "llama_server_n_past"
834            | "llamacpp:n_past" => {
835                if is_slot {
836                    ctx_used_slots += val as u32;
837                } else {
838                    ctx_used_global = ctx_used_global.max(val as u32);
839                }
840            }
841            "llama_kv_cache_tokens_total"
842            | "kv_cache_total_tokens"
843            | "kv_cache_tokens_total"
844            | "llama_server_kv_cache_tokens_total"
845            | "llamacpp:n_ctx"
846            | "llamacpp:n_tokens_max"
847            | "llama_server_n_ctx"
848            | "llama_server_n_tokens_max" => {
849                if is_slot {
850                    ctx_max_slots += val as u32;
851                } else {
852                    ctx_max_global = ctx_max_global.max(val as u32);
853                }
854            }
855            "llama_server_cpu_usage_percentage"
856            | "cpu_usage_percentage"
857            | "llama_server_cpu_usage"
858            | "llama_server_cpu_percent" => {
859                m.cpu_usage = m.cpu_usage.max(val);
860            }
861            "llamacpp:predicted_tokens_seconds"
862            | "llama_server_predicted_tokens_seconds"
863            | "llama_server_tps" => {
864                m.tps += val;
865            }
866            "llamacpp:prompt_tokens_seconds"
867            | "llama_server_prompt_tokens_seconds"
868            | "llama_server_prompt_tps" => {
869                m.prompt_tps += val;
870            }
871            "llamacpp:kv_cache_usage_ratio" | "llama_server_kv_cache_usage_ratio" => {
872                if !is_slot && ctx_max_global > 0 {
873                    ctx_used_global = ctx_used_global.max((val * ctx_max_global as f64) as u32);
874                }
875            }
876            _ => {}
877        }
878    }
879
880    m.gpu_mem_used = if vram_used_slots > 0 {
881        vram_used_slots
882    } else {
883        vram_used_global
884    };
885    m.gpu_mem_total = if vram_total_slots > 0 {
886        vram_total_slots
887    } else {
888        vram_total_global
889    };
890
891    // ctx_used = tokens currently in the KV cache.
892    // ctx_max = the total context window size allocated by the server.
893    m.ctx_used = if ctx_used_slots > 0 {
894        ctx_used_slots
895    } else {
896        ctx_used_global
897    };
898    m.ctx_max = if ctx_max_slots > 0 {
899        ctx_max_slots
900    } else {
901        ctx_max_global
902    };
903    // ctx_max may be overridden in poll_metrics() by the user-configured value.
904
905    // Prefer actual GPU memory usage from nvidia-smi or amdgpu_top.
906    // llama-server's kv_cache_usage_bytes only reports KV cache (typically 10%
907    // of total VRAM); model weights are loaded into GPU memory but not tracked
908    // by the server, so we use system-level tools to report what users see on GPUs.
909    if model_name.is_none() {
910        // Prefer system-level VRAM over llama-server's KV-only value.
911        // System tools report actual GPU memory including model weights,
912        // which is what users see on their GPUs and expect to read.
913        let set_if_better = |out: &mut ServerMetrics, used: u64, total: u64| {
914            if out.gpu_mem_used == 0 || used > out.gpu_mem_used {
915                out.gpu_mem_used = used;
916                out.gpu_mem_total = total;
917            }
918        };
919
920        let (nv_used, nv_total) = get_nvidia_vram_metrics().unwrap_or((0, 0));
921        set_if_better(&mut m, nv_used, nv_total);
922
923        if m.gpu_mem_total == 0 {
924            // AMD fallback when nvidia-smi is not available.
925            let (amd_used, amd_total) = get_amdgpu_vram_metrics().unwrap_or((0, 0));
926            set_if_better(&mut m, amd_used, amd_total);
927        }
928    } else if m.gpu_mem_used == 0 {
929        // KV-only queries: use system tools as a last resort.
930        if let Ok((used, total)) = get_nvidia_vram_metrics() {
931            m.gpu_mem_used = used;
932            m.gpu_mem_total = total;
933        } else if let Ok((used, total)) = get_amdgpu_vram_metrics() {
934            m.gpu_mem_used = used;
935            m.gpu_mem_total = total;
936        }
937    }
938
939    // Fallback for RAM and CPU using sysinfo (cross-platform)
940    if let Some(p) = pid {
941        if let Ok((ram, cpu)) = get_process_metrics(p) {
942            if m.ram_used == 0 {
943                m.ram_used = ram;
944            }
945            if m.cpu_usage == 0.0 {
946                m.cpu_usage = cpu;
947            }
948        }
949    }
950
951    Ok(m)
952}
953
954/// Get VRAM usage using nvidia-smi
955fn get_nvidia_vram_metrics() -> Result<(u64, u64), String> {
956    let output = std::process::Command::new("nvidia-smi")
957        .args([
958            "--query-gpu=memory.used,memory.total",
959            "--format=csv,noheader,nounits",
960        ])
961        .output()
962        .map_err(|e| e.to_string())?;
963
964    if !output.status.success() {
965        return Err("nvidia-smi failed".to_string());
966    }
967
968    let stdout = String::from_utf8_lossy(&output.stdout);
969    let line = stdout.lines().next().ok_or("No output from nvidia-smi")?;
970    let parts: Vec<&str> = line.split(',').collect();
971    if parts.len() >= 2 {
972        let used = parts[0].trim().parse::<u64>().unwrap_or(0) * 1024 * 1024;
973        let total = parts[1].trim().parse::<u64>().unwrap_or(0) * 1024 * 1024;
974        return Ok((used, total));
975    }
976
977    Err("Invalid output from nvidia-smi".to_string())
978}
979
980/// Get VRAM usage using amdgpu_top
981fn get_amdgpu_vram_metrics() -> Result<(u64, u64), String> {
982    let output = std::process::Command::new("amdgpu_top")
983        .args(["-d", "--json"])
984        .output()
985        .map_err(|e| e.to_string())?;
986
987    if !output.status.success() {
988        return Err("amdgpu_top failed".to_string());
989    }
990
991    let json: serde_json::Value =
992        serde_json::from_slice(&output.stdout).map_err(|e| e.to_string())?;
993
994    // amdgpu_top --json output has a "devices" array (or sometimes just a list of objects depending on version)
995    let devices = if json.is_array() {
996        json.as_array()
997    } else {
998        json.get("devices").and_then(|d| d.as_array())
999    };
1000
1001    if let Some(devices) = devices
1002        && let Some(device) = devices.first()
1003    {
1004        // Priority 1: Check root keys (newer amdgpu_top format as provided by user)
1005        // "VRAM Usage Size": 3070128128, "VRAM Size": 8589934592
1006        let root_used = device.get("VRAM Usage Size").and_then(|v| v.as_u64());
1007        let root_total = device.get("VRAM Size").and_then(|v| v.as_u64());
1008
1009        if let (Some(used), Some(total)) = (root_used, root_total)
1010            && total > 0
1011        {
1012            return Ok((used, total));
1013        }
1014
1015        // Priority 2: Check nested VRAM object (alternative format)
1016        let vram_obj = device.get("VRAM");
1017        if let Some(vram) = vram_obj {
1018            // Check if it's the "Total VRAM Usage" format (usually MiB)
1019            let nested_used = vram
1020                .get("Total VRAM Usage")
1021                .and_then(|v| v.get("value").or(Some(v)))
1022                .and_then(|v| v.as_u64());
1023            let nested_total = vram
1024                .get("Total VRAM")
1025                .and_then(|v| v.get("value").or(Some(v)))
1026                .and_then(|v| v.as_u64());
1027
1028            if let (Some(used), Some(total)) = (nested_used, nested_total) {
1029                // These are usually in MiB if they have a "unit" field
1030                let multiplier = if vram.get("Total VRAM").and_then(|v| v.get("unit")).is_some() {
1031                    1024 * 1024
1032                } else {
1033                    1
1034                };
1035                if total > 0 {
1036                    return Ok((used * multiplier, total * multiplier));
1037                }
1038            }
1039        }
1040
1041        // Priority 3: Check vram_usage key (older format)
1042        let vram_usage = device.get("vram_usage");
1043        if let Some(vram) = vram_usage {
1044            let used = vram
1045                .get("VRAM")
1046                .or_else(|| vram.get("usage"))
1047                .and_then(|v| v.get("value").or(Some(v)))
1048                .and_then(|v| v.as_u64())
1049                .unwrap_or(0);
1050            let total = vram
1051                .get("TotalVRAM")
1052                .or_else(|| vram.get("total"))
1053                .and_then(|v| v.get("value").or(Some(v)))
1054                .and_then(|v| v.as_u64())
1055                .unwrap_or(0);
1056
1057            if total > 0 {
1058                return Ok((used * 1024 * 1024, total * 1024 * 1024));
1059            }
1060        }
1061    }
1062
1063    Err("Could not find VRAM info in amdgpu_top output".to_string())
1064}
1065
1066/// Cross-platform: Get RAM (RSS) and CPU usage for a PID.
1067/// Uses a persistent System instance so sysinfo can compute accurate
1068/// CPU deltas across calls (first call on a fresh instance is always 0).
1069fn get_process_metrics(
1070    pid: u32,
1071) -> Result<(u64, f64), String> {
1072    use std::sync::{LazyLock, Mutex};
1073    use sysinfo::{Pid, ProcessRefreshKind, ProcessesToUpdate, RefreshKind, System};
1074
1075    static SYS: LazyLock<Mutex<System>> = LazyLock::new(|| {
1076        Mutex::new(System::new_with_specifics(
1077            RefreshKind::everything().with_processes(ProcessRefreshKind::nothing().with_cpu().with_memory()),
1078        ))
1079    });
1080
1081    let mut sys = SYS.lock().unwrap();
1082    let pids = [Pid::from(pid as usize)];
1083    sys.refresh_processes_specifics(
1084        ProcessesToUpdate::Some(&pids),
1085        true,
1086        ProcessRefreshKind::nothing().with_cpu().with_memory(),
1087    );
1088
1089    let sys_pid = Pid::from(pid as usize);
1090
1091    if let Some(process) = sys.process(sys_pid) {
1092        let ram = process.memory(); // bytes
1093        let cpu = process.cpu_usage() as f64; // percentage
1094        return Ok((ram, cpu));
1095    }
1096
1097    Err(format!("Process not found: pid={}", pid))
1098}
1099
1100/// Load a model via the llama-server Router API.
1101pub async fn load_model(
1102    host: &str,
1103    port: u16,
1104    model_id: &str,
1105    model_path: Option<&str>,
1106) -> Result<(), String> {
1107    let client = reqwest::Client::new();
1108    let host = clean_host(host);
1109
1110    // Try multiple endpoints
1111    let endpoints = ["/models/load", "/v1/models/load"];
1112
1113    // Construct all possible identification variants
1114    let mut variants = Vec::new();
1115
1116    // 1. Original ID (display_name / relative path)
1117    variants.push(model_id.to_string());
1118    variants.push(strip_gguf(model_id).to_string());
1119
1120    // 2. Just the filename
1121    if let Some(filename) = std::path::Path::new(model_id)
1122        .file_name()
1123        .and_then(|f| f.to_str())
1124    {
1125        variants.push(filename.to_string());
1126        variants.push(strip_gguf(filename).to_string());
1127    }
1128
1129    // 3. Absolute path
1130    if let Some(path) = model_path {
1131        variants.push(path.to_string());
1132    }
1133
1134    let mut last_status = reqwest::StatusCode::OK;
1135    let mut last_error = String::new();
1136
1137    for endpoint in endpoints {
1138        let url = format!("http://{}:{}{}", host, port, endpoint);
1139        for variant in &variants {
1140            // Try both "model" and "alias" fields
1141            let bodies = vec![
1142                serde_json::json!({ "model": variant }),
1143                serde_json::json!({ "alias": variant }),
1144            ];
1145
1146            for body in bodies {
1147                match client.post(&url).json(&body).send().await {
1148                    Ok(res) => {
1149                        if res.status().is_success() {
1150                            return Ok(());
1151                        }
1152                        last_status = res.status();
1153                        last_error = res
1154                            .text()
1155                            .await
1156                            .unwrap_or_else(|_| "Unknown error".to_string());
1157                    }
1158                    Err(e) => {
1159                        last_error = e.to_string();
1160                    }
1161                }
1162            }
1163        }
1164    }
1165
1166    Err(format!(
1167        "Failed to load model (tried {} variants). Last status {}: {}",
1168        variants.len() * 2,
1169        last_status,
1170        last_error
1171    ))
1172}
1173
1174/// List all models and their status from the llama-server Router API.
1175pub async fn list_models(
1176    host: &str,
1177    port: u16,
1178) -> Result<Vec<(String, String, Option<String>)>, String> {
1179    let client = reqwest::Client::new();
1180    let host = clean_host(host);
1181    let url = format!("http://{}:{}/models", host, port);
1182
1183    let res = client
1184        .get(&url)
1185        .send()
1186        .await
1187        .map_err(|e| format!("Failed to list models: {}", e))?;
1188
1189    if !res.status().is_success() {
1190        return Err(format!("Server returned error {}", res.status()));
1191    }
1192
1193    let json: serde_json::Value = res
1194        .json()
1195        .await
1196        .map_err(|e| format!("Invalid JSON: {}", e))?;
1197
1198    let mut results = Vec::new();
1199    if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
1200        for model in data {
1201            let id = model
1202                .get("id")
1203                .and_then(|v| v.as_str())
1204                .unwrap_or_default()
1205                .to_string();
1206            // Status can be a string or an object with a "value" field
1207            let status = model
1208                .get("status")
1209                .and_then(|s| s.get("value").or(Some(s)))
1210                .and_then(|v| v.as_str())
1211                .unwrap_or("unloaded")
1212                .to_string();
1213            let path = model
1214                .get("path")
1215                .or_else(|| model.get("filename"))
1216                .and_then(|v| v.as_str())
1217                .map(|s| s.to_string());
1218
1219            results.push((id, status, path));
1220        }
1221    }
1222
1223    Ok(results)
1224}
1225
1226/// Unload a model via the llama-server Router API.
1227pub async fn unload_model(
1228    host: &str,
1229    port: u16,
1230    model_id: &str,
1231    model_path: Option<&str>,
1232) -> Result<(), String> {
1233    let client = reqwest::Client::new();
1234    let host = clean_host(host);
1235
1236    let endpoints = ["/models/unload", "/v1/models/unload"];
1237    let stripped = strip_gguf(model_id);
1238    let mut variants = vec![model_id.to_string(), stripped.to_string()];
1239    if let Some(path) = model_path {
1240        variants.push(path.to_string());
1241    }
1242
1243    for endpoint in endpoints {
1244        let url = format!("http://{}:{}{}", host, port, endpoint);
1245        for variant in &variants {
1246            let body = serde_json::json!({
1247                "model": variant
1248            });
1249
1250            if let Ok(res) = client.post(&url).json(&body).send().await
1251                && res.status().is_success()
1252            {
1253                return Ok(());
1254            }
1255        }
1256    }
1257
1258    Ok(()) // Silently ignore unload errors as it's often just a cleanup
1259}