offline_intelligence/
llm_integration.rs

1// Server/src/llm_integration.rs
2// Direct LLM integration module for unified server architecture
3
4use crate::config::Config;
5use crate::memory::Message;
6use anyhow::{Context, Result};
7use futures::{Stream, StreamExt};
8use serde_json::{json, Value};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tracing::info;
14use bytes::Bytes;
15
16pub struct LLMEngine {
17    config: Config,
18    backend_process: Arc<Mutex<Option<BackendProcess>>>,
19}
20
21struct BackendProcess {
22    child: Child,
23    port: u16,
24    model_path: String,
25}
26
27impl LLMEngine {
28    pub fn new(config: Config) -> Self {
29        Self {
30            config,
31            backend_process: Arc::new(Mutex::new(None)),
32        }
33    }
34
35    pub async fn initialize(&self) -> Result<()> {
36        let model_path = self.config.model_path.clone();
37        self.load_model(model_path).await
38    }
39
40    pub async fn load_model(&self, model_path: String) -> Result<()> {
41        let mut process_guard = self.backend_process.lock().await;
42        
43        // Stop existing process if running
44        if let Some(existing) = process_guard.as_mut() {
45            info!("Stopping existing backend process on port {}", existing.port);
46            let _ = existing.child.kill().await;
47            let _ = existing.child.wait().await;
48        }
49
50        // Start new backend process
51        let port = self.config.llama_port;
52        let child = self.spawn_backend_process(&model_path, port).await?;
53        
54        let model_path_clone = model_path.clone();
55        *process_guard = Some(BackendProcess {
56            child,
57            port,
58            model_path: model_path_clone,
59        });
60
61        // Wait for backend to be ready
62        self.wait_for_backend_ready(port).await?;
63
64        info!("LLM engine initialized with model: {}", model_path);
65        Ok(())
66    }
67
68    async fn spawn_backend_process(&self, model_path: &str, port: u16) -> Result<Child> {
69        let mut args = vec![
70            "--host".to_string(),
71            self.config.llama_host.clone(),
72            "--port".to_string(),
73            port.to_string(),
74            "-m".to_string(),
75            model_path.to_string(),
76            "-n".to_string(),
77            "-1".to_string(),
78            "--keep".to_string(),
79            "32".to_string(),
80            "-c".to_string(),
81            self.config.ctx_size.to_string(),
82            "-b".to_string(),
83            self.config.batch_size.to_string(),
84            "-t".to_string(),
85            self.config.threads.to_string(),
86        ];
87
88        if self.config.gpu_layers > 0 {
89            args.push("-ngl".to_string());
90            args.push(self.config.gpu_layers.to_string());
91        }
92
93        let mut cmd = Command::new(&self.config.llama_bin);
94        cmd.args(&args);
95        cmd.stdin(std::process::Stdio::null());
96        cmd.stdout(std::process::Stdio::null());
97        cmd.stderr(std::process::Stdio::null());
98
99        let child = cmd.spawn().context("Failed to spawn llama backend process")?;
100        info!("Spawned LLM backend (pid={:?}) on port {}", child.id(), port);
101        
102        Ok(child)
103    }
104
105    async fn wait_for_backend_ready(&self, port: u16) -> Result<()> {
106        let start = std::time::Instant::now();
107        let timeout = Duration::from_secs(self.config.health_timeout_seconds);
108        
109        loop {
110            if self.probe_backend_health(port).await {
111                info!("Backend is healthy on port {}", port);
112                return Ok(());
113            }
114            
115            if start.elapsed() > timeout {
116                return Err(anyhow::anyhow!("Backend did not start within timeout"));
117            }
118            
119            tokio::time::sleep(Duration::from_millis(100)).await;
120        }
121    }
122
123    async fn probe_backend_health(&self, port: u16) -> bool {
124        let url = format!("http://{}:{}/health", self.config.llama_host, port);
125        match reqwest::Client::new()
126            .get(&url)
127            .timeout(Duration::from_secs(5))
128            .send()
129            .await
130        {
131            Ok(response) => response.status().is_success(),
132            Err(_) => false,
133        }
134    }
135
136    pub async fn generate_stream(
137        &self,
138        messages: Vec<Message>,
139        _session_id: String,
140    ) -> Result<impl Stream<Item = Result<Bytes, std::io::Error>>> {
141        let process_guard = self.backend_process.lock().await;
142        
143        let backend = process_guard.as_ref()
144            .ok_or_else(|| anyhow::anyhow!("LLM backend not initialized"))?;
145
146        let openai_payload = json!({
147            "messages": messages,
148            "stream": true,
149            "temperature": 0.7,
150            "max_tokens": 2048
151        });
152
153        let url = format!("http://{}:{}/v1/chat/completions", self.config.llama_host, backend.port);
154        
155        let response = reqwest::Client::new()
156            .post(&url)
157            .timeout(Duration::from_secs(self.config.generate_timeout_seconds))
158            .header("content-type", "application/json")
159            .header("connection", "keep-alive")
160            .header("accept", "text/event-stream")
161            .json(&openai_payload)
162            .send()
163            .await
164            .context("Failed to send request to LLM backend")?;
165
166        if !response.status().is_success() {
167            let status = response.status();
168            let body = response.text().await.unwrap_or_default();
169            return Err(anyhow::anyhow!("Backend status {}: {}", status, body));
170        }
171
172        // Convert response to stream of bytes
173        let byte_stream = response
174            .bytes_stream()
175            .map(|result| result.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
176
177        Ok(byte_stream)
178    }
179
180    pub async fn get_backend_info(&self) -> Option<(String, u16)> {
181        let process_guard = self.backend_process.lock().await;
182        process_guard.as_ref().map(|b| (b.model_path.clone(), b.port))
183    }
184
185    pub async fn stop(&self) -> Result<()> {
186        let mut process_guard = self.backend_process.lock().await;
187        
188        if let Some(mut backend) = process_guard.take() {
189            info!("Stopping LLM backend on port {}...", backend.port);
190            backend.child.kill().await.context("Failed to kill backend process")?;
191            backend.child.wait().await?;
192            info!("LLM backend process stopped.");
193        }
194        
195        Ok(())
196    }
197}
198
199// Helper function to extract content from OpenAI-style responses
200pub fn extract_openai_content(openai_response: &Value) -> String {
201    openai_response["choices"][0]["message"]["content"]
202        .as_str()
203        .map(|s| s.to_string())
204        .unwrap_or_else(|| {
205            openai_response["choices"][0]["text"]
206                .as_str()
207                .map(|s| s.to_string())
208                .unwrap_or_else(|| "No response content found".to_string())
209        })
210}
211
212// Helper function to extract thinking content
213pub fn extract_thinking(openai_response: &Value) -> Option<String> {
214    let content = extract_openai_content(openai_response);
215    if content.contains("<thoughts>") && content.contains("</thoughts>") {
216        if let Some(start) = content.find("<thoughts>") {
217            if let Some(end) = content.find("</thoughts>") {
218                let thinking = content[start + 10..end].trim().to_string();
219                if !thinking.is_empty() {
220                    return Some(thinking);
221                }
222            }
223        }
224    }
225    None
226}