offline_intelligence/model_runtime/
gguf_runtime.rs1use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct GGUFRuntime {
14 config: Option<RuntimeConfig>,
15 server_process: Option<Child>,
16 http_client: reqwest::Client,
17 base_url: String,
18}
19
20impl GGUFRuntime {
21 pub fn new() -> Self {
22 Self {
23 config: None,
24 server_process: None,
25 http_client: reqwest::Client::builder()
26 .timeout(Duration::from_secs(600))
27 .build()
28 .unwrap_or_default(),
29 base_url: String::new(),
30 }
31 }
32
33 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
35 let binary_path = config.runtime_binary.as_ref()
36 .ok_or_else(|| anyhow::anyhow!("GGUF runtime requires runtime_binary path"))?;
37
38 if !binary_path.exists() {
39 return Err(anyhow::anyhow!(
40 "llama-server binary not found at: {}",
41 binary_path.display()
42 ));
43 }
44
45 info!("Starting llama-server for GGUF model: {}", config.model_path.display());
46 info!(" Binary: {}", binary_path.display());
47 info!(" Port: {}", config.port);
48 info!(" Context Size: {}", config.context_size);
49 info!(" GPU Layers: {}", config.gpu_layers);
50
51 if !config.model_path.exists() {
53 return Err(anyhow::anyhow!(
54 "Model file not found at: {}",
55 config.model_path.display()
56 ));
57 }
58
59 let mut cmd = Command::new(binary_path);
61 cmd.arg("--model").arg(&config.model_path)
62 .arg("--host").arg(&config.host)
63 .arg("--port").arg(config.port.to_string())
64 .arg("--ctx-size").arg(config.context_size.to_string())
65 .arg("--batch-size").arg(config.batch_size.to_string())
66 .arg("--ubatch-size").arg(config.ubatch_size.to_string())
68 .arg("--threads").arg(config.threads.to_string())
69 .arg("--n-gpu-layers").arg(config.gpu_layers.to_string())
70 .arg("--parallel").arg(config.parallel_slots.to_string())
73 .arg("--cont-batching")
77 .arg("--flash-attn").arg("on")
82 .arg("--cache-type-k").arg("q8_0")
85 .arg("--cache-type-v").arg("q8_0")
86 .arg("--defrag-thold").arg("0.1")
90 .arg("--prio").arg("2")
94 .arg("--mlock");
99
100 if let Some(ref draft_path) = config.draft_model_path {
104 if draft_path.exists() {
105 cmd.arg("--model-draft").arg(draft_path)
106 .arg("--draft-max").arg(config.speculative_draft_max.to_string())
107 .arg("--draft-min").arg("1")
108 .arg("--draft-p-min").arg(config.speculative_draft_p_min.to_string());
109 info!("Speculative decoding enabled: draft_model={}", draft_path.display());
110 } else {
111 info!("Speculative decoding disabled: draft model not found at {}", draft_path.display());
112 }
113 }
114
115 info!("Full llama-server command: {:?} --model {} --host {} --port {} --ctx-size {} --batch-size {} --ubatch-size {} --threads {} --n-gpu-layers {} --parallel {} --cont-batching --flash-attn on --cache-type-k q8_0 --cache-type-v q8_0 --defrag-thold 0.1 --prio 2 --mlock",
117 binary_path,
118 config.model_path.display(), config.host, config.port,
119 config.context_size, config.batch_size, config.ubatch_size,
120 config.threads, config.gpu_layers, config.parallel_slots);
121
122 #[cfg(target_os = "macos")]
128 {
129 if let Some(binary_dir) = binary_path.parent() {
130 let lib_path = binary_dir.to_string_lossy().to_string();
131 info!("macOS: setting DYLD_LIBRARY_PATH={}", lib_path);
132 let existing = std::env::var("DYLD_LIBRARY_PATH").unwrap_or_default();
134 let new_val = if existing.is_empty() {
135 lib_path
136 } else {
137 format!("{}:{}", lib_path, existing)
138 };
139 cmd.env("DYLD_LIBRARY_PATH", new_val);
140 }
141 }
142
143 #[cfg(target_os = "windows")]
145 {
146 use std::os::windows::process::CommandExt;
147 const CREATE_NO_WINDOW: u32 = 0x08000000;
148 cmd.creation_flags(CREATE_NO_WINDOW);
149 }
150
151 cmd.stdout(Stdio::piped())
152 .stderr(Stdio::piped());
153
154 let child = cmd.spawn()
156 .map_err(|e| anyhow::anyhow!("Failed to spawn llama-server: {}", e))?;
157
158 self.server_process = Some(child);
159 self.base_url = format!("http://{}:{}", config.host, config.port);
160
161 info!("llama-server process started, waiting for health check...");
162
163 let _start = std::time::Instant::now();
167 let mut delay_ms: u64 = 100;
168 let mut last_log_secs: u64 = 0;
169 loop {
170 sleep(Duration::from_millis(delay_ms)).await;
171
172 if self.is_ready().await {
173 info!("✅ GGUF runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
174
175 let warmup_url = format!("{}/v1/chat/completions", self.base_url);
181 let warmup_payload = serde_json::json!({
182 "model": "local-llm",
183 "messages": [{"role": "user", "content": "hi"}],
184 "max_tokens": 1,
185 "temperature": 0.0,
186 "stream": false,
187 "cache_prompt": true,
188 });
189 info!("Pre-warming CUDA kernels (max_tokens=1 dummy request)...");
190 match self.http_client
191 .post(&warmup_url)
192 .json(&warmup_payload)
193 .timeout(Duration::from_secs(30))
194 .send()
195 .await
196 {
197 Ok(_) => info!("CUDA pre-warm complete — first user request will get warm TTFT"),
198 Err(e) => warn!("CUDA pre-warm failed (non-fatal, first request may be slow): {}", e),
199 }
200
201 return Ok(());
202 }
203 let elapsed_secs = _start.elapsed().as_secs();
204 if elapsed_secs >= 120 {
205 break;
206 }
207 if elapsed_secs >= last_log_secs + 10 {
208 info!("Still waiting for llama-server... ({}/120s)", elapsed_secs);
209 last_log_secs = elapsed_secs;
210 }
211 delay_ms = (delay_ms * 2).min(2_000);
212 }
213
214 Err(anyhow::anyhow!("llama-server failed to become ready within 120 seconds"))
215 }
216
217 #[cfg(unix)]
221 fn send_sigterm_and_wait(child: &mut Child, grace_secs: u64) -> bool {
222 if let Some(pid) = child.id() {
223 let _ = std::process::Command::new("kill")
225 .args(["-TERM", &pid.to_string()])
226 .output();
227
228 let deadline = std::time::Instant::now() + Duration::from_secs(grace_secs);
229 while std::time::Instant::now() < deadline {
230 if let Ok(Some(_)) = child.try_wait() {
231 return true; }
233 std::thread::sleep(Duration::from_millis(100));
234 }
235 }
236 false }
238}
239
240impl Default for GGUFRuntime {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246#[async_trait]
247impl ModelRuntime for GGUFRuntime {
248 fn supported_format(&self) -> ModelFormat {
249 ModelFormat::GGUF
250 }
251
252 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
253 info!("Initializing GGUF runtime");
254
255 if config.format != ModelFormat::GGUF {
257 return Err(anyhow::anyhow!(
258 "GGUF runtime received wrong format: {:?}",
259 config.format
260 ));
261 }
262
263 if !config.model_path.exists() {
264 return Err(anyhow::anyhow!(
265 "Model file not found: {}",
266 config.model_path.display()
267 ));
268 }
269
270 self.config = Some(config.clone());
271 self.start_server(&config).await?;
272
273 Ok(())
274 }
275
276 async fn is_ready(&self) -> bool {
277 if self.base_url.is_empty() {
278 return false;
279 }
280
281 let health_url = format!("{}/health", self.base_url);
282 match self.http_client
286 .get(&health_url)
287 .timeout(Duration::from_secs(3))
288 .send()
289 .await
290 {
291 Ok(resp) => resp.status().is_success(),
292 Err(_) => false,
293 }
294 }
295
296 async fn health_check(&self) -> anyhow::Result<String> {
297 if self.base_url.is_empty() {
298 return Err(anyhow::anyhow!("Runtime not initialized"));
299 }
300
301 let health_url = format!("{}/health", self.base_url);
302 let resp = self.http_client.get(&health_url)
303 .send()
304 .await
305 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
306
307 if resp.status().is_success() {
308 Ok("healthy".to_string())
309 } else {
310 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
311 }
312 }
313
314 fn base_url(&self) -> String {
315 self.base_url.clone()
316 }
317
318 async fn generate(
319 &self,
320 request: InferenceRequest,
321 ) -> anyhow::Result<InferenceResponse> {
322 let url = self.completions_url();
323
324 let payload = serde_json::json!({
325 "model": "local-llm",
326 "messages": request.messages,
327 "max_tokens": request.max_tokens,
328 "temperature": request.temperature,
329 "stream": false,
330 });
331
332 let resp = self.http_client.post(&url)
333 .json(&payload)
334 .send()
335 .await
336 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
337
338 if !resp.status().is_success() {
339 let status = resp.status();
340 let body = resp.text().await.unwrap_or_default();
341 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
342 }
343
344 let response: serde_json::Value = resp.json().await
345 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
346
347 let content = response["choices"][0]["message"]["content"]
348 .as_str()
349 .unwrap_or("")
350 .to_string();
351
352 let finish_reason = response["choices"][0]["finish_reason"]
353 .as_str()
354 .map(|s| s.to_string());
355
356 Ok(InferenceResponse {
357 content,
358 finish_reason,
359 })
360 }
361
362 async fn generate_stream(
363 &self,
364 request: InferenceRequest,
365 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
366 use futures_util::StreamExt;
367
368 let url = self.completions_url();
369
370 let payload = serde_json::json!({
371 "model": "local-llm",
372 "messages": request.messages,
373 "max_tokens": request.max_tokens,
374 "temperature": request.temperature,
375 "stream": true,
376 });
377
378 let resp = self.http_client.post(&url)
379 .json(&payload)
380 .send()
381 .await
382 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
383
384 if !resp.status().is_success() {
385 let status = resp.status();
386 let body = resp.text().await.unwrap_or_default();
387 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
388 }
389
390 let byte_stream = resp.bytes_stream();
391
392 let sse_stream = async_stream::try_stream! {
393 let mut buffer = String::new();
394 futures_util::pin_mut!(byte_stream);
395
396 while let Some(chunk_result) = byte_stream.next().await {
397 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
398 buffer.push_str(&String::from_utf8_lossy(&chunk));
399
400 while let Some(newline_pos) = buffer.find('\n') {
401 let line = buffer[..newline_pos].trim().to_string();
402 buffer = buffer[newline_pos + 1..].to_string();
403
404 if line.is_empty() || !line.starts_with("data: ") {
405 continue;
406 }
407
408 let data = &line[6..];
409 if data == "[DONE]" {
410 return;
411 }
412
413 yield format!("data: {}\n\n", data);
414 }
415 }
416 };
417
418 Ok(Box::new(Box::pin(sse_stream)))
419 }
420
421 async fn shutdown(&mut self) -> anyhow::Result<()> {
422 info!("Shutting down GGUF runtime");
423
424 if let Some(mut child) = self.server_process.take() {
425 #[cfg(unix)]
429 {
430 let exited_gracefully = Self::send_sigterm_and_wait(&mut child, 1);
433 if exited_gracefully {
434 info!("llama-server shut down gracefully after SIGTERM");
435 return Ok(());
436 }
437 info!("llama-server did not exit after SIGTERM — sending SIGKILL");
438 }
439
440 match child.kill() {
442 Ok(_) => {
443 info!("llama-server process killed");
444 let _ = child.wait();
448 }
449 Err(e) => {
450 warn!("Failed to kill llama-server (may have already exited): {}", e);
452 }
453 }
454 }
455
456 self.config = None;
457 self.base_url.clear();
458 Ok(())
459 }
460
461 fn metadata(&self) -> RuntimeMetadata {
462 RuntimeMetadata {
463 format: ModelFormat::GGUF,
464 runtime_name: "llama.cpp (llama-server)".to_string(),
465 version: "latest".to_string(),
466 supports_gpu: true,
467 supports_streaming: true,
468 }
469 }
470}
471
472impl Drop for GGUFRuntime {
473 fn drop(&mut self) {
474 if let Some(mut child) = self.server_process.take() {
475 let _ = child.kill();
480 }
481 }
482}