Skip to main content

llm_manager/
serve.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use anyhow::{Context, Result};
5use tokio::select;
6use tokio::signal;
7use tracing::info;
8
9use crate::backend::server;
10use crate::backend::tls;
11use crate::config::Config;
12use crate::models::{DiscoveredModel, WsMetrics};
13
14#[derive(Default)]
15pub struct ServeOptions {
16    pub model_path: String,
17    pub profile_name: Option<String>,
18    pub config_path: Option<String>,
19    pub api_port: Option<u16>,
20    pub api_key: Option<String>,
21    pub ws_enable: bool,
22    pub ws_port: Option<u16>,
23    pub ws_auth: Option<String>,
24    pub backend_binary: Option<String>,
25    pub host: Option<String>,
26    pub tls_enable: bool,
27    pub tls_cert: Option<String>,
28    pub tls_key: Option<String>,
29}
30
31async fn start_metrics_polling_task(
32    host: String,
33    port: u16,
34    pid: u32,
35    model_name: String,
36    settings: crate::models::ModelSettings,
37    cmd_display: String,
38    tx: tokio::sync::broadcast::Sender<WsMetrics>,
39    shutdown_rx: tokio::sync::watch::Receiver<bool>,
40) {
41    let mut consecutive_failures: u32 = 0;
42    let max_failures: u32 = 15;
43
44    loop {
45        // Check shutdown first
46        if *shutdown_rx.borrow() {
47            break;
48        }
49
50        let m = match tokio::time::timeout(
51            std::time::Duration::from_secs(3),
52            server::get_metrics(&host, port, None, Some(pid)),
53        )
54        .await
55        {
56            Ok(Ok(metrics)) => {
57                consecutive_failures = 0;
58                metrics
59            }
60            Ok(Err(_)) | Err(_) => {
61                consecutive_failures += 1;
62                if consecutive_failures >= max_failures {
63                    tracing::warn!(
64                        "Metrics polling aborted after {} consecutive failures (server likely dead)",
65                        max_failures
66                    );
67                    break;
68                }
69                if consecutive_failures % 5 == 1 {
70                    tracing::warn!(
71                        "Metrics polling: server unreachable (attempt {}/{})",
72                        consecutive_failures,
73                        max_failures
74                    );
75                }
76                tokio::time::sleep(std::time::Duration::from_secs(2)).await;
77                continue;
78            }
79        };
80
81        let state = "loaded";
82        let ws_metrics =
83            WsMetrics::from_metrics(&m, &model_name, state, &settings, Some(&cmd_display));
84
85        if let Err(e) = tx.send(ws_metrics) {
86            tracing::debug!("Failed to send metrics to broadcast channel: {e}");
87        }
88
89        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
90    }
91}
92
93/// Serve a model using the llama-server binary, applying all settings from config.yaml.
94///
95/// This is a standalone CLI command (llm-manager serve) that:
96/// 1. Loads config (same config.yaml as the TUI)
97/// 2. Resolves the model path
98/// 3. Fetches settings from config overrides, profiles, and defaults
99/// 4. Builds and spawns the llama-server command
100/// 5. Optionally starts an API proxy server on a separate port
101/// 6. Streams output to stdout/stderr until killed
102///
103/// Usage:
104///   llm-manager serve --model /path/to/model.gguf [--profile qwen] [--config /path/to/config.yaml]
105///   llm-manager serve --model model.gguf --api-port 49222 --api-key secret
106pub async fn serve_model(opts: ServeOptions) -> Result<()> {
107    // Load config from explicit path or default location
108    let config = match opts.config_path.as_deref() {
109        Some(p) => {
110            let path = PathBuf::from(p);
111            Config::load_from(path).map_err(|e| anyhow::anyhow!("Failed to load config: {}", e))?
112        }
113        None => Config::load().map_err(|e| anyhow::anyhow!("Failed to load config: {}", e))?,
114    };
115
116    // Resolve model path
117    let model_path = PathBuf::from(&opts.model_path);
118
119    // Check for broken symlinks first
120    if let Ok(metadata) = model_path.symlink_metadata()
121        && metadata.file_type().is_symlink()
122        && !model_path.exists()
123    {
124        let target = std::fs::read_link(&model_path).unwrap_or_default();
125        let msg = format!(
126            "Model file is a broken symlink: {}\n  Symlink points to: {}\n  The target does not exist. Fix the symlink or use the actual file.",
127            model_path.display(),
128            target.display()
129        );
130        return Err(anyhow::Error::msg(msg));
131    }
132
133    if !model_path.exists() {
134        // Check if parent directory exists
135        if let Some(parent) = model_path.parent()
136            && !parent.exists()
137        {
138            let msg = format!(
139                "Model file not found: {}\n  Parent directory does not exist: {}",
140                model_path.display(),
141                parent.display()
142            );
143            return Err(anyhow::Error::msg(msg));
144        }
145        let msg = format!("Model file not found: {}", model_path.display());
146        return Err(anyhow::Error::msg(msg));
147    }
148
149    if !model_path.extension().map(|e| e == "gguf").unwrap_or(false) {
150        let msg = format!("Model file must be a .gguf file: {}", model_path.display());
151        return Err(anyhow::Error::msg(msg));
152    }
153
154    let name = model_path
155        .file_name()
156        .map(|n| n.to_string_lossy().to_string())
157        .unwrap_or_default();
158    let display_name = model_path
159        .strip_prefix(config.models_dirs.first().unwrap_or(&PathBuf::new()))
160        .ok()
161        .and_then(|p| p.to_str())
162        .map(|s| s.to_string())
163        .unwrap_or_else(|| name.clone());
164
165    let model = DiscoveredModel {
166        path: model_path.clone(),
167        name: name.clone(),
168        file_size: std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0),
169        display_name: display_name.clone(),
170    };
171
172    // Build settings: start with defaults, apply model override, then profile override
173    tracing::info!("Model name for config lookup: {}", name);
174    tracing::info!(
175        "Available model config keys: {:?}",
176        config.model_overrides.keys()
177    );
178    let mut settings = config.resolve_settings(Some(&name), opts.profile_name.as_deref());
179
180    // Auto-enable MTP if supported by model and not explicitly enabled in config
181    if settings.spec_type.is_empty()
182        && let Ok(meta) = crate::models::GgufMetadata::from_path(&model_path)
183            && meta.arch == "mtp" {
184                tracing::info!("Auto-enabling MTP (Multi-Token Prediction) for model");
185                settings.spec_type = "draft-mtp".to_string();
186                if settings.draft_tokens == 0 {
187                    settings.draft_tokens = meta.draft_tokens;
188                }
189            }
190
191    // WebSocket settings: CLI flags take precedence, then config.yaml
192    let ws_enable = opts.ws_enable || config.default.ws_server_enabled;
193    let ws_port = opts.ws_port.unwrap_or(config.default.ws_server_port);
194    let ws_auth: Option<String> = opts.ws_auth.or(config.default.ws_server_auth_key.clone());
195
196    // TLS configuration
197    let tls_config = if opts.tls_enable || (opts.tls_cert.is_some() && opts.tls_key.is_some()) {
198        let (cert_path, key_path) = if let Some(cert) = opts.tls_cert {
199            let key = opts.tls_key.unwrap();
200            tls::validate_tls_path(&cert).map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
201            tls::validate_tls_path(&key).map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
202            (cert.to_string(), key.to_string())
203        } else {
204            let (cert, key) = tls::ensure_tls_certs().map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
205            (
206                cert.to_string_lossy().to_string(),
207                key.to_string_lossy().to_string(),
208            )
209        };
210        let tls_cfg = tls::load_tls_config(&cert_path, &key_path)
211            .await
212            .map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
213        Some(tls_cfg)
214    } else {
215        None
216    };
217
218    if tls_config.is_some() {
219        info!("TLS enabled for WebSocket dashboard and API server");
220    }
221
222    // CLI host override
223    if let Some(h) = &opts.host {
224        settings.host = h.to_string();
225    }
226
227    info!("Serving model: {}", model.display_name);
228    let layers_str = match settings.gpu_layers_mode {
229        crate::models::GpuLayersMode::Auto => "auto".to_string(),
230        crate::models::GpuLayersMode::Specific(n) => n.to_string(),
231        crate::models::GpuLayersMode::All => "all".to_string(),
232    };
233    info!(
234        "Settings: {} threads, {} layers, {} context",
235        settings.threads, layers_str, settings.context_length
236    );
237
238    // Trace backend binary selection
239    let active_version = settings.get_active_backend_version();
240    let version_display = settings.get_active_backend_version_display();
241    info!(
242        "Backend: {}, version config: {:?} (display: {})",
243        settings.backend, active_version, version_display
244    );
245    if let Some(ref cpu_ver) = settings.llama_cpp_version_cpu {
246        info!("  llama_cpp_version_cpu = {}", cpu_ver);
247    }
248    if let Some(ref cuda_ver) = settings.llama_cpp_version_cuda {
249        info!("  llama_cpp_version_cuda = {}", cuda_ver);
250    }
251
252    if ws_enable {
253        let auth_info = if let Some(ref auth) = ws_auth {
254            format!(" (auth: {})", &auth[..auth.len().min(8)])
255        } else {
256            String::new()
257        };
258        info!(
259            "WebSocket dashboard enabled on port {}{}",
260            ws_port, auth_info
261        );
262    }
263
264    // Resolve the backend binary (downloads if needed)
265    let binary = if let Some(path) = &opts.backend_binary {
266        let binary_path = PathBuf::from(path);
267        if !binary_path.exists() {
268            anyhow::bail!("Backend binary not found: {}", binary_path.display());
269        }
270        info!("Using custom backend binary: {}", binary_path.display());
271        binary_path
272    } else {
273        let version_param = settings.get_active_backend_version().map(|s| s.as_str());
274        info!(
275            "Resolving backend binary: backend={}, version_param={:?}",
276            settings.backend, version_param
277        );
278        match crate::backend::hub::resolve_backend_binary(
279            settings.backend,
280            version_param,
281            None,
282            None,
283        )
284        .await
285        {
286            Ok(path) => {
287                info!("Resolved binary path: {}", path.display());
288                if !path.exists() {
289                    anyhow::bail!("llama-server binary not found at: {}", path.display());
290                }
291                path
292            }
293            Err(e) => anyhow::bail!("Failed to resolve backend binary: {}", e),
294        }
295    };
296    info!(
297        "Using llama-server: {} (backend: {})",
298        binary.display(),
299        settings.backend
300    );
301
302    // Build the server command
303    let (mut cmd, cmd_display) = server::build_server_cmd(
304        &binary,
305        Some(&model),
306        &settings,
307        &config,
308        config.default.server_mode,
309        config.default.router_max_models,
310    );
311
312    // Set LD_LIBRARY_PATH so the binary can find its shared libraries
313    let bin_dir = binary.parent()
314        .context("Backend binary path has no parent directory. Use a full path for --backend-binary.")?;
315    if let Ok(current) = std::env::var("LD_LIBRARY_PATH") {
316        cmd.env(
317            "LD_LIBRARY_PATH",
318            format!("{}:{}", bin_dir.display(), current),
319        );
320    } else {
321        cmd.env("LD_LIBRARY_PATH", bin_dir);
322    }
323
324    // Spawn the process
325    info!("Command: {}", cmd_display);
326    let mut child = cmd
327        .stdout(std::process::Stdio::inherit())
328        .stderr(std::process::Stdio::inherit())
329        .spawn()
330        .context(format!("Failed to spawn llama-server.\n\n  Command that was attempted:\n    {}\n\n  Check that the binary exists and is executable.", cmd_display))?;
331
332    info!("llama-server started (pid={})", child.id().unwrap_or(0));
333    info!("Press Ctrl+C to stop the server");
334
335    let server_pid = child.id().unwrap_or(0);
336
337    // Optionally start the API proxy server
338    let (api_done_tx, api_done_rx) = tokio::sync::oneshot::channel();
339    let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
340    let mut api_server_handle = if let Some(port) = opts.api_port {
341        let host_str = &settings.host;
342        let addr: SocketAddr = format!("{}:{}", host_str, port).parse()?;
343        let model_name = model.display_name.clone();
344        let server_port = settings.port;
345        let api_key_clone = opts.api_key.clone();
346        let shutdown_rx_for_api = shutdown_rx.clone();
347        let host_clone = host_str.clone();
348        let tls_for_api = tls_config.clone();
349        let handle = tokio::spawn(async move {
350            let _ = crate::serve_api::start_api_server(
351                addr,
352                api_key_clone,
353                server_port,
354                model_name,
355                server_pid,
356                shutdown_rx_for_api,
357                host_clone,
358                tls_for_api,
359            )
360            .await;
361            let _ = api_done_tx.send(());
362        });
363        let api_protocol = if tls_config.is_some() {
364            "https"
365        } else {
366            "http"
367        };
368        info!(
369            "API proxy started on {api_protocol}://{}:{}",
370            host_str, port
371        );
372        Some((handle, api_done_rx, shutdown_tx))
373    } else {
374        None
375    };
376
377    // Start WebSocket dashboard server
378    let ws_server_handle = if ws_enable {
379        let (tx, rx) = tokio::sync::broadcast::channel(64);
380        let ws_rx = std::sync::Arc::new(rx);
381        let host_str = &settings.host;
382        let handle = crate::backend::ws_server::start_ws_server(
383            ws_port,
384            ws_rx,
385            ws_auth.clone(),
386            tls_config.clone(),
387            host_str.clone(),
388        )
389        .await?;
390
391        let auth_param = if let Some(ref auth) = ws_auth {
392            format!("?auth={}", urlencoding::encode(auth))
393        } else {
394            "".to_string()
395        };
396        let protocol = if tls_config.is_some() {
397            "https"
398        } else {
399            "http"
400        };
401        info!(
402            "Dashboard enabled: {protocol}://{}:{}/dashboard{}",
403            host_str, ws_port, auth_param
404        );
405
406        // Start metrics polling task
407        let settings_clone = settings.clone();
408        let model_name_clone = model.display_name.clone();
409        let host_clone = settings.host.clone();
410        let server_port_clone = settings.port;
411        let pid_clone = server_pid;
412        let cmd_display_clone = cmd_display.clone();
413        let shutdown_rx_clone = shutdown_rx.clone();
414        tokio::spawn(async move {
415            start_metrics_polling_task(
416                host_clone,
417                server_port_clone,
418                pid_clone,
419                model_name_clone,
420                settings_clone,
421                cmd_display_clone,
422                tx,
423                shutdown_rx_clone,
424            )
425            .await;
426        });
427
428        Some(handle)
429    } else {
430        None
431    };
432
433    // Wait for either llama-server, API server, or Ctrl+C
434    let status = loop {
435        select! {
436            status = child.wait() => {
437                // llama-server exited — gracefully shut down API server
438                if let Some((_, _, tx)) = &mut api_server_handle {
439                    let _ = tx.send(true);
440                }
441                break status.context("Failed to wait for llama-server")?;
442            }
443            _ = async {
444                let (_, rx, _) = api_server_handle.as_mut().unwrap();
445                let _ = rx.await;
446            }, if api_server_handle.is_some() => {
447                // API server exited — gracefully shut down, then wait for llama-server
448                if let Some((_, _, tx)) = &mut api_server_handle {
449                    let _ = tx.send(true);
450                }
451                break child.wait().await.context("Failed to wait for llama-server")?;
452            }
453            _ = signal::ctrl_c() => {
454                info!("Received SIGINT, shutting down llama-server...");
455                let _ = child.kill().await;
456                if let Some((_, _, tx)) = &mut api_server_handle {
457                    let _ = tx.send(true);
458                }
459            }
460        }
461    };
462
463    // Drop the API server handle so the spawned task can finish
464    if let Some((handle, _, _)) = api_server_handle {
465        let _ = handle.await;
466    }
467
468    // Abort the WebSocket dashboard server
469    if let Some(handle) = ws_server_handle {
470        handle.abort();
471    }
472
473    if status.success() {
474        info!("llama-server exited normally");
475    } else {
476        info!("llama-server exited with status: {}", status);
477    }
478
479    if status.success() {
480        Ok(())
481    } else {
482        anyhow::bail!("llama-server exited with status: {}", status)
483    }
484}