Skip to main content

offline_intelligence/model_runtime/
gguf_runtime.rs

1//! GGUF Runtime Adapter
2//!
3//! Wraps the existing llama-server (llama.cpp) for GGUF models.
4//! This adapter spawns the llama-server process and proxies requests via HTTP.
5
6use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct GGUFRuntime {
14    config: Option<RuntimeConfig>,
15    server_process: Option<Child>,
16    http_client: reqwest::Client,
17    base_url: String,
18}
19
20impl GGUFRuntime {
21    pub fn new() -> Self {
22        Self {
23            config: None,
24            server_process: None,
25            http_client: reqwest::Client::builder()
26                .timeout(Duration::from_secs(600))
27                .build()
28                .unwrap_or_default(),
29            base_url: String::new(),
30        }
31    }
32
33    /// Start llama-server process
34    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
35        let binary_path = config.runtime_binary.as_ref()
36            .ok_or_else(|| anyhow::anyhow!("GGUF runtime requires runtime_binary path"))?;
37
38        if !binary_path.exists() {
39            return Err(anyhow::anyhow!(
40                "llama-server binary not found at: {}",
41                binary_path.display()
42            ));
43        }
44
45        info!("Starting llama-server for GGUF model: {}", config.model_path.display());
46        info!("  Binary: {}", binary_path.display());
47        info!("  Port: {}", config.port);
48        info!("  Context Size: {}", config.context_size);
49        info!("  GPU Layers: {}", config.gpu_layers);
50
51        // Verify model file exists before starting
52        if !config.model_path.exists() {
53            return Err(anyhow::anyhow!(
54                "Model file not found at: {}",
55                config.model_path.display()
56            ));
57        }
58
59        // Build command arguments
60        let mut cmd = Command::new(binary_path);
61        cmd.arg("--model").arg(&config.model_path)
62            .arg("--host").arg(&config.host)
63            .arg("--port").arg(config.port.to_string())
64            .arg("--ctx-size").arg(config.context_size.to_string())
65            .arg("--batch-size").arg(config.batch_size.to_string())
66            // Micro-batch size: larger value keeps GPU tensor cores busy.
67            .arg("--ubatch-size").arg(config.ubatch_size.to_string())
68            .arg("--threads").arg(config.threads.to_string())
69            .arg("--n-gpu-layers").arg(config.gpu_layers.to_string())
70            // Parallel KV-cache slots — each slot handles one concurrent request.
71            // Enables continuous batching so multiple users share a single GPU pass.
72            .arg("--parallel").arg(config.parallel_slots.to_string())
73            // Continuous batching: interleave prefill and decode across all active
74            // slots every generation step. Without this flag, --parallel has no
75            // throughput effect.
76            .arg("--cont-batching")
77            // Flash Attention 2: replaces O(n²) attention with fused CUDA kernels.
78            // +15–30% throughput at 8k context, halves KV VRAM consumption.
79            // Must pass "on" explicitly — bare flag causes the parser in b8037 to
80            // consume the next argument as the value, silently breaking the command.
81            .arg("--flash-attn").arg("on")
82            // KV cache quantisation: store K/V matrices in Q8_0 instead of F16.
83            // Halves KV VRAM (~256 MB → ~128 MB for 8192 ctx, 28 layers, 1 slot).
84            .arg("--cache-type-k").arg("q8_0")
85            .arg("--cache-type-v").arg("q8_0")
86            // Defragmentation threshold: when KV-cache fragmentation exceeds 10%
87            // of total slots, compact the cache in-place.  Prevents the gradual
88            // throughput degradation visible in long-running sessions.
89            .arg("--defrag-thold").arg("0.1")
90            // Process priority: HIGH (2) reduces OS scheduler jitter so llama-server
91            // is not preempted mid-decode. Measurable effect on TTFT P90/P99 and
92            // consistent throughput. Values: 0=normal 1=medium 2=high 3=realtime.
93            .arg("--prio").arg("2")
94            // Lock all model pages in RAM. Prevents the OS from paging weight
95            // tensors to disk under memory pressure. Eliminates rare 100-500ms
96            // TTFT spikes caused by page-fault stalls during decode. The model
97            // (~1.9 GB) comfortably fits in the 15.7 GB system RAM.
98            .arg("--mlock");
99
100        // Speculative decoding: if a draft model path is set and the file exists,
101        // enable speculative decoding. The draft model generates candidate tokens
102        // which the main model verifies in one forward pass — 2–3× throughput boost.
103        if let Some(ref draft_path) = config.draft_model_path {
104            if draft_path.exists() {
105                cmd.arg("--model-draft").arg(draft_path)
106                    .arg("--draft-max").arg(config.speculative_draft_max.to_string())
107                    .arg("--draft-min").arg("1")
108                    .arg("--draft-p-min").arg(config.speculative_draft_p_min.to_string());
109                info!("Speculative decoding enabled: draft_model={}", draft_path.display());
110            } else {
111                info!("Speculative decoding disabled: draft model not found at {}", draft_path.display());
112            }
113        }
114
115        // Log the full command for debugging
116        info!("Full llama-server command: {:?} --model {} --host {} --port {} --ctx-size {} --batch-size {} --ubatch-size {} --threads {} --n-gpu-layers {} --parallel {} --cont-batching --flash-attn on --cache-type-k q8_0 --cache-type-v q8_0 --defrag-thold 0.1 --prio 2 --mlock",
117            binary_path,
118            config.model_path.display(), config.host, config.port,
119            config.context_size, config.batch_size, config.ubatch_size,
120            config.threads, config.gpu_layers, config.parallel_slots);
121
122        // On macOS: set DYLD_LIBRARY_PATH to the directory that contains
123        // llama-server so that co-located dylibs (libllama.dylib,
124        // libggml.dylib, libggml-metal.dylib, libggml-cpu.dylib …) are
125        // found by dyld at process start.  Without this, the child process
126        // will exit immediately with a "dyld: Library not loaded" error.
127        #[cfg(target_os = "macos")]
128        {
129            if let Some(binary_dir) = binary_path.parent() {
130                let lib_path = binary_dir.to_string_lossy().to_string();
131                info!("macOS: setting DYLD_LIBRARY_PATH={}", lib_path);
132                // Prepend to any existing value so system dylibs are still found.
133                let existing = std::env::var("DYLD_LIBRARY_PATH").unwrap_or_default();
134                let new_val = if existing.is_empty() {
135                    lib_path
136                } else {
137                    format!("{}:{}", lib_path, existing)
138                };
139                cmd.env("DYLD_LIBRARY_PATH", new_val);
140            }
141        }
142
143        // On Windows, hide the console window
144        #[cfg(target_os = "windows")]
145        {
146            use std::os::windows::process::CommandExt;
147            const CREATE_NO_WINDOW: u32 = 0x08000000;
148            cmd.creation_flags(CREATE_NO_WINDOW);
149        }
150
151        cmd.stdout(Stdio::piped())
152            .stderr(Stdio::piped());
153
154        // Spawn the process
155        let child = cmd.spawn()
156            .map_err(|e| anyhow::anyhow!("Failed to spawn llama-server: {}", e))?;
157
158        self.server_process = Some(child);
159        self.base_url = format!("http://{}:{}", config.host, config.port);
160
161        info!("llama-server process started, waiting for health check...");
162
163        // Wait for server to be ready (up to 120 seconds) with exponential backoff.
164        // Checks at 100 ms → 200 ms → 400 ms → … → 2 s (cap) so a fast start is
165        // detected in < 200 ms instead of the old fixed 2 s minimum.
166        let _start = std::time::Instant::now();
167        let mut delay_ms: u64 = 100;
168        let mut last_log_secs: u64 = 0;
169        loop {
170            sleep(Duration::from_millis(delay_ms)).await;
171
172            if self.is_ready().await {
173                info!("✅ GGUF runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
174
175                // Pre-warm: fire one minimal completion request so that CUDA kernels
176                // are JIT-compiled and GPU caches are hot before the first real user
177                // request arrives.  Without this the very first request pays a
178                // 400–600 ms CUDA cold-start penalty even though the model is loaded.
179                // max_tokens=1 keeps this fast (~100 ms total).
180                let warmup_url = format!("{}/v1/chat/completions", self.base_url);
181                let warmup_payload = serde_json::json!({
182                    "model": "local-llm",
183                    "messages": [{"role": "user", "content": "hi"}],
184                    "max_tokens": 1,
185                    "temperature": 0.0,
186                    "stream": false,
187                    "cache_prompt": true,
188                });
189                info!("Pre-warming CUDA kernels (max_tokens=1 dummy request)...");
190                match self.http_client
191                    .post(&warmup_url)
192                    .json(&warmup_payload)
193                    .timeout(Duration::from_secs(30))
194                    .send()
195                    .await
196                {
197                    Ok(_) => info!("CUDA pre-warm complete — first user request will get warm TTFT"),
198                    Err(e) => warn!("CUDA pre-warm failed (non-fatal, first request may be slow): {}", e),
199                }
200
201                return Ok(());
202            }
203            let elapsed_secs = _start.elapsed().as_secs();
204            if elapsed_secs >= 120 {
205                break;
206            }
207            if elapsed_secs >= last_log_secs + 10 {
208                info!("Still waiting for llama-server... ({}/120s)", elapsed_secs);
209                last_log_secs = elapsed_secs;
210            }
211            delay_ms = (delay_ms * 2).min(2_000);
212        }
213
214        Err(anyhow::anyhow!("llama-server failed to become ready within 120 seconds"))
215    }
216
217    /// Send SIGTERM to the child process (Unix only) and wait up to
218    /// `grace_secs` seconds for it to exit before returning.
219    /// Returns true if the process exited gracefully, false on timeout.
220    #[cfg(unix)]
221    fn send_sigterm_and_wait(child: &mut Child, grace_secs: u64) -> bool {
222        if let Some(pid) = child.id() {
223            // `kill -TERM <pid>` — portable across macOS and Linux
224            let _ = std::process::Command::new("kill")
225                .args(["-TERM", &pid.to_string()])
226                .output();
227
228            let deadline = std::time::Instant::now() + Duration::from_secs(grace_secs);
229            while std::time::Instant::now() < deadline {
230                if let Ok(Some(_)) = child.try_wait() {
231                    return true; // exited gracefully
232                }
233                std::thread::sleep(Duration::from_millis(100));
234            }
235        }
236        false // timed out
237    }
238}
239
240impl Default for GGUFRuntime {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246#[async_trait]
247impl ModelRuntime for GGUFRuntime {
248    fn supported_format(&self) -> ModelFormat {
249        ModelFormat::GGUF
250    }
251
252    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
253        info!("Initializing GGUF runtime");
254
255        // Validate config
256        if config.format != ModelFormat::GGUF {
257            return Err(anyhow::anyhow!(
258                "GGUF runtime received wrong format: {:?}",
259                config.format
260            ));
261        }
262
263        if !config.model_path.exists() {
264            return Err(anyhow::anyhow!(
265                "Model file not found: {}",
266                config.model_path.display()
267            ));
268        }
269
270        self.config = Some(config.clone());
271        self.start_server(&config).await?;
272
273        Ok(())
274    }
275
276    async fn is_ready(&self) -> bool {
277        if self.base_url.is_empty() {
278            return false;
279        }
280
281        let health_url = format!("{}/health", self.base_url);
282        // Use a short per-request timeout for health probes so that the
283        // /healthz handler never blocks longer than 3 s even if llama-server
284        // is in a degraded/hung state (e.g. orphan process from a previous run).
285        match self.http_client
286            .get(&health_url)
287            .timeout(Duration::from_secs(3))
288            .send()
289            .await
290        {
291            Ok(resp) => resp.status().is_success(),
292            Err(_) => false,
293        }
294    }
295
296    async fn health_check(&self) -> anyhow::Result<String> {
297        if self.base_url.is_empty() {
298            return Err(anyhow::anyhow!("Runtime not initialized"));
299        }
300
301        let health_url = format!("{}/health", self.base_url);
302        let resp = self.http_client.get(&health_url)
303            .send()
304            .await
305            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
306
307        if resp.status().is_success() {
308            Ok("healthy".to_string())
309        } else {
310            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
311        }
312    }
313
314    fn base_url(&self) -> String {
315        self.base_url.clone()
316    }
317
318    async fn generate(
319        &self,
320        request: InferenceRequest,
321    ) -> anyhow::Result<InferenceResponse> {
322        let url = self.completions_url();
323
324        let payload = serde_json::json!({
325            "model": "local-llm",
326            "messages": request.messages,
327            "max_tokens": request.max_tokens,
328            "temperature": request.temperature,
329            "stream": false,
330        });
331
332        let resp = self.http_client.post(&url)
333            .json(&payload)
334            .send()
335            .await
336            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
337
338        if !resp.status().is_success() {
339            let status = resp.status();
340            let body = resp.text().await.unwrap_or_default();
341            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
342        }
343
344        let response: serde_json::Value = resp.json().await
345            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
346
347        let content = response["choices"][0]["message"]["content"]
348            .as_str()
349            .unwrap_or("")
350            .to_string();
351
352        let finish_reason = response["choices"][0]["finish_reason"]
353            .as_str()
354            .map(|s| s.to_string());
355
356        Ok(InferenceResponse {
357            content,
358            finish_reason,
359        })
360    }
361
362    async fn generate_stream(
363        &self,
364        request: InferenceRequest,
365    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
366        use futures_util::StreamExt;
367
368        let url = self.completions_url();
369
370        let payload = serde_json::json!({
371            "model": "local-llm",
372            "messages": request.messages,
373            "max_tokens": request.max_tokens,
374            "temperature": request.temperature,
375            "stream": true,
376        });
377
378        let resp = self.http_client.post(&url)
379            .json(&payload)
380            .send()
381            .await
382            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
383
384        if !resp.status().is_success() {
385            let status = resp.status();
386            let body = resp.text().await.unwrap_or_default();
387            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
388        }
389
390        let byte_stream = resp.bytes_stream();
391
392        let sse_stream = async_stream::try_stream! {
393            let mut buffer = String::new();
394            futures_util::pin_mut!(byte_stream);
395
396            while let Some(chunk_result) = byte_stream.next().await {
397                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
398                buffer.push_str(&String::from_utf8_lossy(&chunk));
399
400                while let Some(newline_pos) = buffer.find('\n') {
401                    let line = buffer[..newline_pos].trim().to_string();
402                    buffer = buffer[newline_pos + 1..].to_string();
403
404                    if line.is_empty() || !line.starts_with("data: ") {
405                        continue;
406                    }
407
408                    let data = &line[6..];
409                    if data == "[DONE]" {
410                        return;
411                    }
412
413                    yield format!("data: {}\n\n", data);
414                }
415            }
416        };
417
418        Ok(Box::new(Box::pin(sse_stream)))
419    }
420
421    async fn shutdown(&mut self) -> anyhow::Result<()> {
422        info!("Shutting down GGUF runtime");
423
424        if let Some(mut child) = self.server_process.take() {
425            // On Unix (macOS + Linux): send SIGTERM first so llama-server can
426            // release Metal command queues / CUDA contexts gracefully.
427            // Give it up to 3 s before escalating to SIGKILL.
428            #[cfg(unix)]
429            {
430                // 1 s grace (was 3 s) — enough for llama-server to flush its Metal/CUDA
431                // contexts; any longer only adds latency to model switching.
432                let exited_gracefully = Self::send_sigterm_and_wait(&mut child, 1);
433                if exited_gracefully {
434                    info!("llama-server shut down gracefully after SIGTERM");
435                    return Ok(());
436                }
437                info!("llama-server did not exit after SIGTERM — sending SIGKILL");
438            }
439
440            // SIGKILL (or TerminateProcess on Windows)
441            match child.kill() {
442                Ok(_) => {
443                    info!("llama-server process killed");
444                    // wait() is safe here: we are in an async fn but this is
445                    // a blocking call on an already-dead process, so it returns
446                    // immediately.
447                    let _ = child.wait();
448                }
449                Err(e) => {
450                    // Process may have already exited on its own
451                    warn!("Failed to kill llama-server (may have already exited): {}", e);
452                }
453            }
454        }
455
456        self.config = None;
457        self.base_url.clear();
458        Ok(())
459    }
460
461    fn metadata(&self) -> RuntimeMetadata {
462        RuntimeMetadata {
463            format: ModelFormat::GGUF,
464            runtime_name: "llama.cpp (llama-server)".to_string(),
465            version: "latest".to_string(),
466            supports_gpu: true,
467            supports_streaming: true,
468        }
469    }
470}
471
472impl Drop for GGUFRuntime {
473    fn drop(&mut self) {
474        if let Some(mut child) = self.server_process.take() {
475            // Best-effort kill — we intentionally do NOT call child.wait() here
476            // because Drop can be invoked from an async Tokio context and a
477            // blocking wait would stall the thread-pool worker.
478            // The OS reclaims the zombie when the Tokio runtime itself exits.
479            let _ = child.kill();
480        }
481    }
482}