1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use anyhow::{Context, Result};
5use tokio::select;
6use tokio::signal;
7use tracing::info;
8
9use crate::backend::server;
10use crate::backend::tls;
11use crate::config::Config;
12use crate::models::{DiscoveredModel, WsMetrics};
13
14#[derive(Default)]
15pub struct ServeOptions {
16 pub model_path: String,
17 pub profile_name: Option<String>,
18 pub config_path: Option<String>,
19 pub api_port: Option<u16>,
20 pub api_key: Option<String>,
21 pub ws_enable: bool,
22 pub ws_port: Option<u16>,
23 pub ws_auth: Option<String>,
24 pub backend_binary: Option<String>,
25 pub host: Option<String>,
26 pub tls_enable: bool,
27 pub tls_cert: Option<String>,
28 pub tls_key: Option<String>,
29}
30
31async fn start_metrics_polling_task(
32 host: String,
33 port: u16,
34 pid: u32,
35 model_name: String,
36 settings: crate::models::ModelSettings,
37 cmd_display: String,
38 tx: tokio::sync::broadcast::Sender<WsMetrics>,
39 shutdown_rx: tokio::sync::watch::Receiver<bool>,
40) {
41 let mut consecutive_failures: u32 = 0;
42 let max_failures: u32 = 15;
43
44 loop {
45 if *shutdown_rx.borrow() {
47 break;
48 }
49
50 let m = match tokio::time::timeout(
51 std::time::Duration::from_secs(3),
52 server::get_metrics(&host, port, None, Some(pid)),
53 )
54 .await
55 {
56 Ok(Ok(metrics)) => {
57 consecutive_failures = 0;
58 metrics
59 }
60 Ok(Err(_)) | Err(_) => {
61 consecutive_failures += 1;
62 if consecutive_failures >= max_failures {
63 tracing::warn!(
64 "Metrics polling aborted after {} consecutive failures (server likely dead)",
65 max_failures
66 );
67 break;
68 }
69 if consecutive_failures % 5 == 1 {
70 tracing::warn!(
71 "Metrics polling: server unreachable (attempt {}/{})",
72 consecutive_failures,
73 max_failures
74 );
75 }
76 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
77 continue;
78 }
79 };
80
81 let state = "loaded";
82 let ws_metrics =
83 WsMetrics::from_metrics(&m, &model_name, state, &settings, Some(&cmd_display));
84
85 if let Err(e) = tx.send(ws_metrics) {
86 tracing::debug!("Failed to send metrics to broadcast channel: {e}");
87 }
88
89 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
90 }
91}
92
93pub async fn serve_model(opts: ServeOptions) -> Result<()> {
107 let config = match opts.config_path.as_deref() {
109 Some(p) => {
110 let path = PathBuf::from(p);
111 Config::load_from(path).map_err(|e| anyhow::anyhow!("Failed to load config: {}", e))?
112 }
113 None => Config::load().map_err(|e| anyhow::anyhow!("Failed to load config: {}", e))?,
114 };
115
116 let model_path = PathBuf::from(&opts.model_path);
118
119 if let Ok(metadata) = model_path.symlink_metadata()
121 && metadata.file_type().is_symlink()
122 && !model_path.exists()
123 {
124 let target = std::fs::read_link(&model_path).unwrap_or_default();
125 let msg = format!(
126 "Model file is a broken symlink: {}\n Symlink points to: {}\n The target does not exist. Fix the symlink or use the actual file.",
127 model_path.display(),
128 target.display()
129 );
130 return Err(anyhow::Error::msg(msg));
131 }
132
133 if !model_path.exists() {
134 if let Some(parent) = model_path.parent()
136 && !parent.exists()
137 {
138 let msg = format!(
139 "Model file not found: {}\n Parent directory does not exist: {}",
140 model_path.display(),
141 parent.display()
142 );
143 return Err(anyhow::Error::msg(msg));
144 }
145 let msg = format!("Model file not found: {}", model_path.display());
146 return Err(anyhow::Error::msg(msg));
147 }
148
149 if !model_path.extension().map(|e| e == "gguf").unwrap_or(false) {
150 let msg = format!("Model file must be a .gguf file: {}", model_path.display());
151 return Err(anyhow::Error::msg(msg));
152 }
153
154 let name = model_path
155 .file_name()
156 .map(|n| n.to_string_lossy().to_string())
157 .unwrap_or_default();
158 let display_name = model_path
159 .strip_prefix(config.models_dirs.first().unwrap_or(&PathBuf::new()))
160 .ok()
161 .and_then(|p| p.to_str())
162 .map(|s| s.to_string())
163 .unwrap_or_else(|| name.clone());
164
165 let model = DiscoveredModel {
166 path: model_path.clone(),
167 name: name.clone(),
168 file_size: std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0),
169 display_name: display_name.clone(),
170 };
171
172 tracing::info!("Model name for config lookup: {}", name);
174 tracing::info!(
175 "Available model config keys: {:?}",
176 config.model_overrides.keys()
177 );
178 let mut settings = config.resolve_settings(Some(&name), opts.profile_name.as_deref());
179
180 if settings.spec_type.is_empty()
182 && let Ok(meta) = crate::models::GgufMetadata::from_path(&model_path)
183 && meta.arch == "mtp" {
184 tracing::info!("Auto-enabling MTP (Multi-Token Prediction) for model");
185 settings.spec_type = "draft-mtp".to_string();
186 if settings.draft_tokens == 0 {
187 settings.draft_tokens = meta.draft_tokens;
188 }
189 }
190
191 let ws_enable = opts.ws_enable || config.default.ws_server_enabled;
193 let ws_port = opts.ws_port.unwrap_or(config.default.ws_server_port);
194 let ws_auth: Option<String> = opts.ws_auth.or(config.default.ws_server_auth_key.clone());
195
196 let tls_config = if opts.tls_enable || (opts.tls_cert.is_some() && opts.tls_key.is_some()) {
198 let (cert_path, key_path) = if let Some(cert) = opts.tls_cert {
199 let key = opts.tls_key.unwrap();
200 tls::validate_tls_path(&cert).map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
201 tls::validate_tls_path(&key).map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
202 (cert.to_string(), key.to_string())
203 } else {
204 let (cert, key) = tls::ensure_tls_certs().map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
205 (
206 cert.to_string_lossy().to_string(),
207 key.to_string_lossy().to_string(),
208 )
209 };
210 let tls_cfg = tls::load_tls_config(&cert_path, &key_path)
211 .await
212 .map_err(|e| anyhow::anyhow!("TLS: {}", e))?;
213 Some(tls_cfg)
214 } else {
215 None
216 };
217
218 if tls_config.is_some() {
219 info!("TLS enabled for WebSocket dashboard and API server");
220 }
221
222 if let Some(h) = &opts.host {
224 settings.host = h.to_string();
225 }
226
227 info!("Serving model: {}", model.display_name);
228 let layers_str = match settings.gpu_layers_mode {
229 crate::models::GpuLayersMode::Auto => "auto".to_string(),
230 crate::models::GpuLayersMode::Specific(n) => n.to_string(),
231 crate::models::GpuLayersMode::All => "all".to_string(),
232 };
233 info!(
234 "Settings: {} threads, {} layers, {} context",
235 settings.threads, layers_str, settings.context_length
236 );
237
238 let active_version = settings.get_active_backend_version();
240 let version_display = settings.get_active_backend_version_display();
241 info!(
242 "Backend: {}, version config: {:?} (display: {})",
243 settings.backend, active_version, version_display
244 );
245 if let Some(ref cpu_ver) = settings.llama_cpp_version_cpu {
246 info!(" llama_cpp_version_cpu = {}", cpu_ver);
247 }
248 if let Some(ref cuda_ver) = settings.llama_cpp_version_cuda {
249 info!(" llama_cpp_version_cuda = {}", cuda_ver);
250 }
251
252 if ws_enable {
253 let auth_info = if let Some(ref auth) = ws_auth {
254 format!(" (auth: {})", &auth[..auth.len().min(8)])
255 } else {
256 String::new()
257 };
258 info!(
259 "WebSocket dashboard enabled on port {}{}",
260 ws_port, auth_info
261 );
262 }
263
264 let binary = if let Some(path) = &opts.backend_binary {
266 let binary_path = PathBuf::from(path);
267 if !binary_path.exists() {
268 anyhow::bail!("Backend binary not found: {}", binary_path.display());
269 }
270 info!("Using custom backend binary: {}", binary_path.display());
271 binary_path
272 } else {
273 let version_param = settings.get_active_backend_version().map(|s| s.as_str());
274 info!(
275 "Resolving backend binary: backend={}, version_param={:?}",
276 settings.backend, version_param
277 );
278 match crate::backend::hub::resolve_backend_binary(
279 settings.backend,
280 version_param,
281 None,
282 None,
283 )
284 .await
285 {
286 Ok(path) => {
287 info!("Resolved binary path: {}", path.display());
288 if !path.exists() {
289 anyhow::bail!("llama-server binary not found at: {}", path.display());
290 }
291 path
292 }
293 Err(e) => anyhow::bail!("Failed to resolve backend binary: {}", e),
294 }
295 };
296 info!(
297 "Using llama-server: {} (backend: {})",
298 binary.display(),
299 settings.backend
300 );
301
302 let (mut cmd, cmd_display) = server::build_server_cmd(
304 &binary,
305 Some(&model),
306 &settings,
307 &config,
308 config.default.server_mode,
309 config.default.router_max_models,
310 );
311
312 let bin_dir = binary.parent()
314 .context("Backend binary path has no parent directory. Use a full path for --backend-binary.")?;
315 if let Ok(current) = std::env::var("LD_LIBRARY_PATH") {
316 cmd.env(
317 "LD_LIBRARY_PATH",
318 format!("{}:{}", bin_dir.display(), current),
319 );
320 } else {
321 cmd.env("LD_LIBRARY_PATH", bin_dir);
322 }
323
324 info!("Command: {}", cmd_display);
326 let mut child = cmd
327 .stdout(std::process::Stdio::inherit())
328 .stderr(std::process::Stdio::inherit())
329 .spawn()
330 .context(format!("Failed to spawn llama-server.\n\n Command that was attempted:\n {}\n\n Check that the binary exists and is executable.", cmd_display))?;
331
332 info!("llama-server started (pid={})", child.id().unwrap_or(0));
333 info!("Press Ctrl+C to stop the server");
334
335 let server_pid = child.id().unwrap_or(0);
336
337 let (api_done_tx, api_done_rx) = tokio::sync::oneshot::channel();
339 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
340 let mut api_server_handle = if let Some(port) = opts.api_port {
341 let host_str = &settings.host;
342 let addr: SocketAddr = format!("{}:{}", host_str, port).parse()?;
343 let model_name = model.display_name.clone();
344 let server_port = settings.port;
345 let api_key_clone = opts.api_key.clone();
346 let shutdown_rx_for_api = shutdown_rx.clone();
347 let host_clone = host_str.clone();
348 let tls_for_api = tls_config.clone();
349 let handle = tokio::spawn(async move {
350 let _ = crate::serve_api::start_api_server(
351 addr,
352 api_key_clone,
353 server_port,
354 model_name,
355 server_pid,
356 shutdown_rx_for_api,
357 host_clone,
358 tls_for_api,
359 )
360 .await;
361 let _ = api_done_tx.send(());
362 });
363 let api_protocol = if tls_config.is_some() {
364 "https"
365 } else {
366 "http"
367 };
368 info!(
369 "API proxy started on {api_protocol}://{}:{}",
370 host_str, port
371 );
372 Some((handle, api_done_rx, shutdown_tx))
373 } else {
374 None
375 };
376
377 let ws_server_handle = if ws_enable {
379 let (tx, rx) = tokio::sync::broadcast::channel(64);
380 let ws_rx = std::sync::Arc::new(rx);
381 let host_str = &settings.host;
382 let handle = crate::backend::ws_server::start_ws_server(
383 ws_port,
384 ws_rx,
385 ws_auth.clone(),
386 tls_config.clone(),
387 host_str.clone(),
388 )
389 .await?;
390
391 let auth_param = if let Some(ref auth) = ws_auth {
392 format!("?auth={}", urlencoding::encode(auth))
393 } else {
394 "".to_string()
395 };
396 let protocol = if tls_config.is_some() {
397 "https"
398 } else {
399 "http"
400 };
401 info!(
402 "Dashboard enabled: {protocol}://{}:{}/dashboard{}",
403 host_str, ws_port, auth_param
404 );
405
406 let settings_clone = settings.clone();
408 let model_name_clone = model.display_name.clone();
409 let host_clone = settings.host.clone();
410 let server_port_clone = settings.port;
411 let pid_clone = server_pid;
412 let cmd_display_clone = cmd_display.clone();
413 let shutdown_rx_clone = shutdown_rx.clone();
414 tokio::spawn(async move {
415 start_metrics_polling_task(
416 host_clone,
417 server_port_clone,
418 pid_clone,
419 model_name_clone,
420 settings_clone,
421 cmd_display_clone,
422 tx,
423 shutdown_rx_clone,
424 )
425 .await;
426 });
427
428 Some(handle)
429 } else {
430 None
431 };
432
433 let status = loop {
435 select! {
436 status = child.wait() => {
437 if let Some((_, _, tx)) = &mut api_server_handle {
439 let _ = tx.send(true);
440 }
441 break status.context("Failed to wait for llama-server")?;
442 }
443 _ = async {
444 let (_, rx, _) = api_server_handle.as_mut().unwrap();
445 let _ = rx.await;
446 }, if api_server_handle.is_some() => {
447 if let Some((_, _, tx)) = &mut api_server_handle {
449 let _ = tx.send(true);
450 }
451 break child.wait().await.context("Failed to wait for llama-server")?;
452 }
453 _ = signal::ctrl_c() => {
454 info!("Received SIGINT, shutting down llama-server...");
455 let _ = child.kill().await;
456 if let Some((_, _, tx)) = &mut api_server_handle {
457 let _ = tx.send(true);
458 }
459 }
460 }
461 };
462
463 if let Some((handle, _, _)) = api_server_handle {
465 let _ = handle.await;
466 }
467
468 if let Some(handle) = ws_server_handle {
470 handle.abort();
471 }
472
473 if status.success() {
474 info!("llama-server exited normally");
475 } else {
476 info!("llama-server exited with status: {}", status);
477 }
478
479 if status.success() {
480 Ok(())
481 } else {
482 anyhow::bail!("llama-server exited with status: {}", status)
483 }
484}