Skip to main content

offline_intelligence/model_runtime/
gguf_runtime.rs

1//! GGUF Runtime Adapter
2//!
3//! Wraps the existing llama-server.exe (llama.cpp) for GGUF models.
4//! This adapter spawns the llama-server process and proxies requests via HTTP.
5
6use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn, error};
11use tokio::time::sleep;
12
13pub struct GGUFRuntime {
14    config: Option<RuntimeConfig>,
15    server_process: Option<Child>,
16    http_client: reqwest::Client,
17    base_url: String,
18}
19
20impl GGUFRuntime {
21    pub fn new() -> Self {
22        Self {
23            config: None,
24            server_process: None,
25            http_client: reqwest::Client::builder()
26                .timeout(Duration::from_secs(600))
27                .build()
28                .unwrap_or_default(),
29            base_url: String::new(),
30        }
31    }
32
33    /// Start llama-server process
34    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
35        let binary_path = config.runtime_binary.as_ref()
36            .ok_or_else(|| anyhow::anyhow!("GGUF runtime requires runtime_binary path"))?;
37
38        if !binary_path.exists() {
39            return Err(anyhow::anyhow!(
40                "llama-server binary not found at: {}",
41                binary_path.display()
42            ));
43        }
44
45        info!("Starting llama-server for GGUF model: {}", config.model_path.display());
46        info!("  Binary: {}", binary_path.display());
47        info!("  Port: {}", config.port);
48        info!("  Context Size: {}", config.context_size);
49        info!("  GPU Layers: {}", config.gpu_layers);
50
51        // Build command arguments
52        let mut cmd = Command::new(binary_path);
53        cmd.arg("--model").arg(&config.model_path)
54            .arg("--host").arg(&config.host)
55            .arg("--port").arg(config.port.to_string())
56            .arg("--ctx-size").arg(config.context_size.to_string())
57            .arg("--batch-size").arg(config.batch_size.to_string())
58            .arg("--threads").arg(config.threads.to_string())
59            .arg("--n-gpu-layers").arg(config.gpu_layers.to_string())
60            .stdout(Stdio::piped())
61            .stderr(Stdio::piped());
62
63        // Spawn the process
64        let child = cmd.spawn()
65            .map_err(|e| anyhow::anyhow!("Failed to spawn llama-server: {}", e))?;
66
67        self.server_process = Some(child);
68        self.base_url = format!("http://{}:{}", config.host, config.port);
69
70        info!("llama-server process started, waiting for health check...");
71
72        // Wait for server to be ready (up to 60 seconds)
73        for attempt in 1..=30 {
74            sleep(Duration::from_secs(2)).await;
75            
76            if self.is_ready().await {
77                info!("✅ GGUF runtime ready after {} seconds", attempt * 2);
78                return Ok(());
79            }
80            
81            if attempt % 5 == 0 {
82                info!("Still waiting for llama-server... ({}/60s)", attempt * 2);
83            }
84        }
85
86        Err(anyhow::anyhow!("llama-server failed to start within 60 seconds"))
87    }
88}
89
90impl Default for GGUFRuntime {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96#[async_trait]
97impl ModelRuntime for GGUFRuntime {
98    fn supported_format(&self) -> ModelFormat {
99        ModelFormat::GGUF
100    }
101
102    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
103        info!("Initializing GGUF runtime");
104        
105        // Validate config
106        if config.format != ModelFormat::GGUF {
107            return Err(anyhow::anyhow!(
108                "GGUF runtime received wrong format: {:?}",
109                config.format
110            ));
111        }
112
113        if !config.model_path.exists() {
114            return Err(anyhow::anyhow!(
115                "Model file not found: {}",
116                config.model_path.display()
117            ));
118        }
119
120        self.config = Some(config.clone());
121        self.start_server(&config).await?;
122        
123        Ok(())
124    }
125
126    async fn is_ready(&self) -> bool {
127        if self.base_url.is_empty() {
128            return false;
129        }
130
131        let health_url = format!("{}/health", self.base_url);
132        match self.http_client.get(&health_url).send().await {
133            Ok(resp) => resp.status().is_success(),
134            Err(_) => false,
135        }
136    }
137
138    async fn health_check(&self) -> anyhow::Result<String> {
139        if self.base_url.is_empty() {
140            return Err(anyhow::anyhow!("Runtime not initialized"));
141        }
142
143        let health_url = format!("{}/health", self.base_url);
144        let resp = self.http_client.get(&health_url)
145            .send()
146            .await
147            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
148
149        if resp.status().is_success() {
150            Ok("healthy".to_string())
151        } else {
152            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
153        }
154    }
155
156    fn base_url(&self) -> String {
157        self.base_url.clone()
158    }
159
160    async fn generate(
161        &self,
162        request: InferenceRequest,
163    ) -> anyhow::Result<InferenceResponse> {
164        let url = self.completions_url();
165        
166        let payload = serde_json::json!({
167            "model": "local-llm",
168            "messages": request.messages,
169            "max_tokens": request.max_tokens,
170            "temperature": request.temperature,
171            "stream": false,
172        });
173
174        let resp = self.http_client.post(&url)
175            .json(&payload)
176            .send()
177            .await
178            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
179
180        if !resp.status().is_success() {
181            let status = resp.status();
182            let body = resp.text().await.unwrap_or_default();
183            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
184        }
185
186        let response: serde_json::Value = resp.json().await
187            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
188
189        let content = response["choices"][0]["message"]["content"]
190            .as_str()
191            .unwrap_or("")
192            .to_string();
193
194        let finish_reason = response["choices"][0]["finish_reason"]
195            .as_str()
196            .map(|s| s.to_string());
197
198        Ok(InferenceResponse {
199            content,
200            finish_reason,
201        })
202    }
203
204    async fn generate_stream(
205        &self,
206        request: InferenceRequest,
207    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
208        use futures_util::StreamExt;
209        
210        let url = self.completions_url();
211        
212        let payload = serde_json::json!({
213            "model": "local-llm",
214            "messages": request.messages,
215            "max_tokens": request.max_tokens,
216            "temperature": request.temperature,
217            "stream": true,
218        });
219
220        let resp = self.http_client.post(&url)
221            .json(&payload)
222            .send()
223            .await
224            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
225
226        if !resp.status().is_success() {
227            let status = resp.status();
228            let body = resp.text().await.unwrap_or_default();
229            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
230        }
231
232        let byte_stream = resp.bytes_stream();
233        
234        let sse_stream = async_stream::try_stream! {
235            let mut buffer = String::new();
236            futures_util::pin_mut!(byte_stream);
237
238            while let Some(chunk_result) = byte_stream.next().await {
239                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
240                buffer.push_str(&String::from_utf8_lossy(&chunk));
241
242                while let Some(newline_pos) = buffer.find('\n') {
243                    let line = buffer[..newline_pos].trim().to_string();
244                    buffer = buffer[newline_pos + 1..].to_string();
245
246                    if line.is_empty() || !line.starts_with("data: ") {
247                        continue;
248                    }
249
250                    let data = &line[6..];
251                    if data == "[DONE]" {
252                        return;
253                    }
254
255                    yield format!("data: {}\n\n", data);
256                }
257            }
258        };
259
260        Ok(Box::new(Box::pin(sse_stream)))
261    }
262
263    async fn shutdown(&mut self) -> anyhow::Result<()> {
264        info!("Shutting down GGUF runtime");
265        
266        if let Some(mut child) = self.server_process.take() {
267            match child.kill() {
268                Ok(_) => {
269                    info!("llama-server process killed successfully");
270                    let _ = child.wait();
271                }
272                Err(e) => {
273                    warn!("Failed to kill llama-server process: {}", e);
274                }
275            }
276        }
277
278        self.config = None;
279        self.base_url.clear();
280        Ok(())
281    }
282
283    fn metadata(&self) -> RuntimeMetadata {
284        RuntimeMetadata {
285            format: ModelFormat::GGUF,
286            runtime_name: "llama.cpp (llama-server)".to_string(),
287            version: "latest".to_string(),
288            supports_gpu: true,
289            supports_streaming: true,
290        }
291    }
292}
293
294impl Drop for GGUFRuntime {
295    fn drop(&mut self) {
296        if let Some(mut child) = self.server_process.take() {
297            let _ = child.kill();
298            let _ = child.wait();
299        }
300    }
301}