Skip to main content

offline_intelligence/worker_threads/
llm_worker.rs

1//! LLM worker thread implementation
2//!
3//! Handles LLM inference by proxying requests to the local llama-server process.
4//! This is the 1-hop architecture: shared memory state → HTTP to localhost llama-server.
5
6use futures_util::StreamExt;
7use tracing::{info, debug, warn};
8use serde::{Deserialize, Serialize};
9
10use crate::memory::Message;
11
12/// Chat completion request sent to llama-server (OpenAI-compatible format)
13#[derive(Debug, Serialize)]
14struct ChatCompletionRequest {
15    model: String,
16    messages: Vec<ChatMessage>,
17    max_tokens: u32,
18    temperature: f32,
19    stream: bool,
20}
21
22/// Embedding request sent to llama-server (OpenAI-compatible format)
23#[derive(Debug, Serialize)]
24struct EmbeddingRequest {
25    model: String,
26    input: Vec<String>,
27}
28
29/// Embedding response from llama-server
30#[derive(Debug, Deserialize)]
31struct EmbeddingResponse {
32    data: Vec<EmbeddingData>,
33}
34
35#[derive(Debug, Deserialize)]
36struct EmbeddingData {
37    embedding: Vec<f32>,
38}
39
40#[derive(Debug, Serialize, Deserialize, Clone)]
41struct ChatMessage {
42    role: String,
43    content: String,
44}
45
46/// Non-streaming response from llama-server
47#[derive(Debug, Deserialize)]
48struct ChatCompletionResponse {
49    choices: Vec<ChatChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct ChatChoice {
54    message: Option<ChatMessage>,
55}
56
57/// Streaming delta chunk from llama-server
58#[derive(Debug, Deserialize)]
59struct StreamChunk {
60    choices: Vec<StreamChoice>,
61}
62
63#[derive(Debug, Deserialize)]
64struct StreamChoice {
65    delta: Option<ChatDelta>,
66    finish_reason: Option<String>,
67}
68
69#[derive(Debug, Deserialize, Clone)]
70struct ChatDelta {
71    content: Option<String>,
72}
73
74pub struct LLMWorker {
75    backend_url: String,
76    http_client: reqwest::Client,
77}
78
79impl LLMWorker {
80    /// Create with shared state (legacy constructor)
81    pub fn new(shared_state: std::sync::Arc<crate::shared_state::SharedState>) -> Self {
82        let backend_url = shared_state.config.backend_url.clone();
83        Self {
84            backend_url,
85            http_client: reqwest::Client::builder()
86                .timeout(std::time::Duration::from_secs(600))
87                .build()
88                .unwrap_or_default(),
89        }
90    }
91
92    /// Create with explicit backend URL
93    pub fn new_with_backend(backend_url: String) -> Self {
94        info!("LLM worker initialized with backend: {}", backend_url);
95        Self {
96            backend_url,
97            http_client: reqwest::Client::builder()
98                .timeout(std::time::Duration::from_secs(600))
99                .build()
100                .unwrap_or_default(),
101        }
102    }
103
104    /// Get the chat completions endpoint URL
105    fn completions_url(&self) -> String {
106        format!("{}/v1/chat/completions", self.backend_url)
107    }
108
109    /// Get the embeddings endpoint URL
110    fn embeddings_url(&self) -> String {
111        format!("{}/v1/embeddings", self.backend_url)
112    }
113
114    /// Convert internal Message format to OpenAI-compatible ChatMessage
115    fn to_chat_messages(messages: &[Message]) -> Vec<ChatMessage> {
116        messages.iter().map(|m| ChatMessage {
117            role: m.role.clone(),
118            content: m.content.clone(),
119        }).collect()
120    }
121
122    /// Generate a complete (non-streaming) response from the LLM.
123    pub async fn generate_response(
124        &self,
125        _session_id: String,
126        context: Vec<Message>,
127    ) -> anyhow::Result<String> {
128        debug!("LLM worker generating response (non-streaming)");
129
130        let request = ChatCompletionRequest {
131            model: "local-llm".to_string(),
132            messages: Self::to_chat_messages(&context),
133            max_tokens: 2000,
134            temperature: 0.7,
135            stream: false,
136        };
137
138        let response = self.http_client
139            .post(&self.completions_url())
140            .json(&request)
141            .send()
142            .await
143            .map_err(|e| anyhow::anyhow!("LLM backend request failed: {}", e))?;
144
145        if !response.status().is_success() {
146            let status = response.status();
147            let body = response.text().await.unwrap_or_default();
148            return Err(anyhow::anyhow!("LLM backend returned {}: {}", status, body));
149        }
150
151        let completion: ChatCompletionResponse = response.json().await
152            .map_err(|e| anyhow::anyhow!("Failed to parse LLM response: {}", e))?;
153
154        let content = completion.choices
155            .first()
156            .and_then(|c| c.message.as_ref())
157            .map(|m| m.content.clone())
158            .unwrap_or_default();
159
160        Ok(content)
161    }
162
163    /// Stream response tokens from the LLM as Server-Sent Events.
164    /// Returns a stream of SSE-formatted strings ready to send to the client.
165    pub async fn stream_response(
166        &self,
167        messages: Vec<Message>,
168        max_tokens: u32,
169        temperature: f32,
170    ) -> anyhow::Result<impl futures_util::Stream<Item = Result<String, anyhow::Error>>> {
171        debug!("LLM worker starting streaming response");
172
173        let request = ChatCompletionRequest {
174            model: "local-llm".to_string(),
175            messages: Self::to_chat_messages(&messages),
176            max_tokens,
177            temperature,
178            stream: true,
179        };
180
181        let response = self.http_client
182            .post(&self.completions_url())
183            .json(&request)
184            .send()
185            .await
186            .map_err(|e| anyhow::anyhow!("LLM backend request failed: {}", e))?;
187
188        if !response.status().is_success() {
189            let status = response.status();
190            let body = response.text().await.unwrap_or_default();
191            return Err(anyhow::anyhow!("LLM backend returned {}: {}", status, body));
192        }
193
194        let byte_stream = response.bytes_stream();
195
196        let sse_stream = async_stream::try_stream! {
197            let mut buffer = String::new();
198
199            futures_util::pin_mut!(byte_stream);
200
201            while let Some(chunk_result) = byte_stream.next().await {
202                let chunk = chunk_result
203                    .map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
204
205                buffer.push_str(&String::from_utf8_lossy(&chunk));
206
207                while let Some(newline_pos) = buffer.find('\n') {
208                    let line = buffer[..newline_pos].trim().to_string();
209                    buffer = buffer[newline_pos + 1..].to_string();
210
211                    if line.is_empty() {
212                        continue;
213                    }
214
215                    if line.starts_with("data: ") {
216                        let data = &line[6..];
217
218                        if data == "[DONE]" {
219                            yield "data: [DONE]\n\n".to_string();
220                            return;
221                        }
222
223                        match serde_json::from_str::<StreamChunk>(data) {
224                            Ok(chunk) => {
225                                let finished = chunk.choices.iter()
226                                    .any(|c| c.finish_reason.is_some());
227
228                                yield format!("data: {}\n\n", data);
229
230                                if finished {
231                                    yield "data: [DONE]\n\n".to_string();
232                                    return;
233                                }
234                            }
235                            Err(_) => {
236                                yield format!("data: {}\n\n", data);
237                            }
238                        }
239                    }
240                }
241            }
242
243            yield "data: [DONE]\n\n".to_string();
244        };
245
246        Ok(sse_stream)
247    }
248
249    /// Batch process multiple prompts (non-streaming)
250    pub async fn batch_process(
251        &self,
252        prompts: Vec<(String, Vec<Message>)>,
253    ) -> anyhow::Result<Vec<String>> {
254        debug!("LLM worker batch processing {} prompts", prompts.len());
255
256        let mut responses = Vec::new();
257        for (session_id, messages) in prompts {
258            match self.generate_response(session_id.clone(), messages).await {
259                Ok(response) => responses.push(response),
260                Err(e) => {
261                    warn!("Batch item {} failed: {}", session_id, e);
262                    responses.push(format!("Error: {}", e));
263                }
264            }
265        }
266
267        info!("Batch processed {} prompts", responses.len());
268        Ok(responses)
269    }
270
271    /// Initialize LLM model (no-op for HTTP proxy mode)
272    pub async fn initialize_model(&self, model_path: &str) -> anyhow::Result<()> {
273        debug!("LLM worker model init (HTTP proxy mode): {}", model_path);
274        Ok(())
275    }
276
277    /// Generate embeddings for one or more text inputs via llama-server's /v1/embeddings endpoint.
278    /// This reuses the vectors llama.cpp already computes during inference — no separate model needed.
279    /// Returns a Vec of embedding vectors (one per input string).
280    pub async fn generate_embeddings(
281        &self,
282        texts: Vec<String>,
283    ) -> anyhow::Result<Vec<Vec<f32>>> {
284        if texts.is_empty() {
285            return Ok(Vec::new());
286        }
287
288        debug!("Generating embeddings for {} text(s) via llama-server", texts.len());
289
290        let request = EmbeddingRequest {
291            model: "local-llm".to_string(),
292            input: texts,
293        };
294
295        let response = self.http_client
296            .post(&self.embeddings_url())
297            .json(&request)
298            .send()
299            .await
300            .map_err(|e| anyhow::anyhow!("Embedding request failed: {}", e))?;
301
302        if !response.status().is_success() {
303            let status = response.status();
304            let body = response.text().await.unwrap_or_default();
305            return Err(anyhow::anyhow!("Embedding endpoint returned {}: {}", status, body));
306        }
307
308        let embedding_response: EmbeddingResponse = response.json().await
309            .map_err(|e| anyhow::anyhow!("Failed to parse embedding response: {}", e))?;
310
311        let embeddings: Vec<Vec<f32>> = embedding_response.data
312            .into_iter()
313            .map(|d| d.embedding)
314            .collect();
315
316        debug!("Generated {} embeddings (dim={})",
317            embeddings.len(),
318            embeddings.first().map(|e| e.len()).unwrap_or(0));
319
320        Ok(embeddings)
321    }
322
323    /// Generate title for a chat using the LLM
324    pub async fn generate_title(
325        &self,
326        prompt: &str,
327        max_tokens: u32,
328    ) -> anyhow::Result<String> {
329        debug!("LLM worker generating title for prompt ({} chars)", prompt.len());
330
331        let messages = vec![Message {
332            role: "user".to_string(),
333            content: prompt.to_string(),
334        }];
335
336        let request = ChatCompletionRequest {
337            model: "local-llm".to_string(),
338            messages: Self::to_chat_messages(&messages),
339            max_tokens: max_tokens.min(20),
340            temperature: 0.3,
341            stream: false,
342        };
343
344        let response = self.http_client
345            .post(&self.completions_url())
346            .json(&request)
347            .send()
348            .await
349            .map_err(|e| anyhow::anyhow!("Title generation request failed: {}", e))?;
350
351        if !response.status().is_success() {
352            let status = response.status();
353            let body = response.text().await.unwrap_or_default();
354            return Err(anyhow::anyhow!("Title generation failed ({}): {}", status, body));
355        }
356
357        let completion: ChatCompletionResponse = response.json().await
358            .map_err(|e| anyhow::anyhow!("Failed to parse title response: {}", e))?;
359
360        let title = completion.choices
361            .first()
362            .and_then(|c| c.message.as_ref())
363            .map(|m| m.content.trim().to_string())
364            .unwrap_or_else(|| "New Chat".to_string());
365
366        let title = title.trim_matches('"').trim_matches('\'').to_string();
367
368        info!("Generated title: '{}'", title);
369        Ok(title)
370    }
371}