offline_intelligence/worker_threads/
llm_worker.rs1use futures_util::StreamExt;
7use tracing::{info, debug, warn};
8use serde::{Deserialize, Serialize};
9
10use crate::memory::Message;
11
12#[derive(Debug, Serialize)]
14struct ChatCompletionRequest {
15 model: String,
16 messages: Vec<ChatMessage>,
17 max_tokens: u32,
18 temperature: f32,
19 stream: bool,
20}
21
22#[derive(Debug, Serialize)]
24struct EmbeddingRequest {
25 model: String,
26 input: Vec<String>,
27}
28
29#[derive(Debug, Deserialize)]
31struct EmbeddingResponse {
32 data: Vec<EmbeddingData>,
33}
34
35#[derive(Debug, Deserialize)]
36struct EmbeddingData {
37 embedding: Vec<f32>,
38}
39
40#[derive(Debug, Serialize, Deserialize, Clone)]
41struct ChatMessage {
42 role: String,
43 content: String,
44}
45
46#[derive(Debug, Deserialize)]
48struct ChatCompletionResponse {
49 choices: Vec<ChatChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct ChatChoice {
54 message: Option<ChatMessage>,
55}
56
57#[derive(Debug, Deserialize)]
59struct StreamChunk {
60 choices: Vec<StreamChoice>,
61}
62
63#[derive(Debug, Deserialize)]
64struct StreamChoice {
65 delta: Option<ChatDelta>,
66 finish_reason: Option<String>,
67}
68
69#[derive(Debug, Deserialize, Clone)]
70struct ChatDelta {
71 content: Option<String>,
72}
73
74pub struct LLMWorker {
75 backend_url: String,
76 http_client: reqwest::Client,
77}
78
79impl LLMWorker {
80 pub fn new(shared_state: std::sync::Arc<crate::shared_state::SharedState>) -> Self {
82 let backend_url = shared_state.config.backend_url.clone();
83 Self {
84 backend_url,
85 http_client: reqwest::Client::builder()
86 .timeout(std::time::Duration::from_secs(600))
87 .build()
88 .unwrap_or_default(),
89 }
90 }
91
92 pub fn new_with_backend(backend_url: String) -> Self {
94 info!("LLM worker initialized with backend: {}", backend_url);
95 Self {
96 backend_url,
97 http_client: reqwest::Client::builder()
98 .timeout(std::time::Duration::from_secs(600))
99 .build()
100 .unwrap_or_default(),
101 }
102 }
103
104 fn completions_url(&self) -> String {
106 format!("{}/v1/chat/completions", self.backend_url)
107 }
108
109 fn embeddings_url(&self) -> String {
111 format!("{}/v1/embeddings", self.backend_url)
112 }
113
114 fn to_chat_messages(messages: &[Message]) -> Vec<ChatMessage> {
116 messages.iter().map(|m| ChatMessage {
117 role: m.role.clone(),
118 content: m.content.clone(),
119 }).collect()
120 }
121
122 pub async fn generate_response(
124 &self,
125 _session_id: String,
126 context: Vec<Message>,
127 ) -> anyhow::Result<String> {
128 debug!("LLM worker generating response (non-streaming)");
129
130 let request = ChatCompletionRequest {
131 model: "local-llm".to_string(),
132 messages: Self::to_chat_messages(&context),
133 max_tokens: 2000,
134 temperature: 0.7,
135 stream: false,
136 };
137
138 let response = self.http_client
139 .post(&self.completions_url())
140 .json(&request)
141 .send()
142 .await
143 .map_err(|e| anyhow::anyhow!("LLM backend request failed: {}", e))?;
144
145 if !response.status().is_success() {
146 let status = response.status();
147 let body = response.text().await.unwrap_or_default();
148 return Err(anyhow::anyhow!("LLM backend returned {}: {}", status, body));
149 }
150
151 let completion: ChatCompletionResponse = response.json().await
152 .map_err(|e| anyhow::anyhow!("Failed to parse LLM response: {}", e))?;
153
154 let content = completion.choices
155 .first()
156 .and_then(|c| c.message.as_ref())
157 .map(|m| m.content.clone())
158 .unwrap_or_default();
159
160 Ok(content)
161 }
162
163 pub async fn stream_response(
166 &self,
167 messages: Vec<Message>,
168 max_tokens: u32,
169 temperature: f32,
170 ) -> anyhow::Result<impl futures_util::Stream<Item = Result<String, anyhow::Error>>> {
171 debug!("LLM worker starting streaming response");
172
173 let request = ChatCompletionRequest {
174 model: "local-llm".to_string(),
175 messages: Self::to_chat_messages(&messages),
176 max_tokens,
177 temperature,
178 stream: true,
179 };
180
181 let response = self.http_client
182 .post(&self.completions_url())
183 .json(&request)
184 .send()
185 .await
186 .map_err(|e| anyhow::anyhow!("LLM backend request failed: {}", e))?;
187
188 if !response.status().is_success() {
189 let status = response.status();
190 let body = response.text().await.unwrap_or_default();
191 return Err(anyhow::anyhow!("LLM backend returned {}: {}", status, body));
192 }
193
194 let byte_stream = response.bytes_stream();
195
196 let sse_stream = async_stream::try_stream! {
197 let mut buffer = String::new();
198
199 futures_util::pin_mut!(byte_stream);
200
201 while let Some(chunk_result) = byte_stream.next().await {
202 let chunk = chunk_result
203 .map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
204
205 buffer.push_str(&String::from_utf8_lossy(&chunk));
206
207 while let Some(newline_pos) = buffer.find('\n') {
208 let line = buffer[..newline_pos].trim().to_string();
209 buffer = buffer[newline_pos + 1..].to_string();
210
211 if line.is_empty() {
212 continue;
213 }
214
215 if line.starts_with("data: ") {
216 let data = &line[6..];
217
218 if data == "[DONE]" {
219 yield "data: [DONE]\n\n".to_string();
220 return;
221 }
222
223 match serde_json::from_str::<StreamChunk>(data) {
224 Ok(chunk) => {
225 let finished = chunk.choices.iter()
226 .any(|c| c.finish_reason.is_some());
227
228 yield format!("data: {}\n\n", data);
229
230 if finished {
231 yield "data: [DONE]\n\n".to_string();
232 return;
233 }
234 }
235 Err(_) => {
236 yield format!("data: {}\n\n", data);
237 }
238 }
239 }
240 }
241 }
242
243 yield "data: [DONE]\n\n".to_string();
244 };
245
246 Ok(sse_stream)
247 }
248
249 pub async fn batch_process(
251 &self,
252 prompts: Vec<(String, Vec<Message>)>,
253 ) -> anyhow::Result<Vec<String>> {
254 debug!("LLM worker batch processing {} prompts", prompts.len());
255
256 let mut responses = Vec::new();
257 for (session_id, messages) in prompts {
258 match self.generate_response(session_id.clone(), messages).await {
259 Ok(response) => responses.push(response),
260 Err(e) => {
261 warn!("Batch item {} failed: {}", session_id, e);
262 responses.push(format!("Error: {}", e));
263 }
264 }
265 }
266
267 info!("Batch processed {} prompts", responses.len());
268 Ok(responses)
269 }
270
271 pub async fn initialize_model(&self, model_path: &str) -> anyhow::Result<()> {
273 debug!("LLM worker model init (HTTP proxy mode): {}", model_path);
274 Ok(())
275 }
276
277 pub async fn generate_embeddings(
281 &self,
282 texts: Vec<String>,
283 ) -> anyhow::Result<Vec<Vec<f32>>> {
284 if texts.is_empty() {
285 return Ok(Vec::new());
286 }
287
288 debug!("Generating embeddings for {} text(s) via llama-server", texts.len());
289
290 let request = EmbeddingRequest {
291 model: "local-llm".to_string(),
292 input: texts,
293 };
294
295 let response = self.http_client
296 .post(&self.embeddings_url())
297 .json(&request)
298 .send()
299 .await
300 .map_err(|e| anyhow::anyhow!("Embedding request failed: {}", e))?;
301
302 if !response.status().is_success() {
303 let status = response.status();
304 let body = response.text().await.unwrap_or_default();
305 return Err(anyhow::anyhow!("Embedding endpoint returned {}: {}", status, body));
306 }
307
308 let embedding_response: EmbeddingResponse = response.json().await
309 .map_err(|e| anyhow::anyhow!("Failed to parse embedding response: {}", e))?;
310
311 let embeddings: Vec<Vec<f32>> = embedding_response.data
312 .into_iter()
313 .map(|d| d.embedding)
314 .collect();
315
316 debug!("Generated {} embeddings (dim={})",
317 embeddings.len(),
318 embeddings.first().map(|e| e.len()).unwrap_or(0));
319
320 Ok(embeddings)
321 }
322
323 pub async fn generate_title(
325 &self,
326 prompt: &str,
327 max_tokens: u32,
328 ) -> anyhow::Result<String> {
329 debug!("LLM worker generating title for prompt ({} chars)", prompt.len());
330
331 let messages = vec![Message {
332 role: "user".to_string(),
333 content: prompt.to_string(),
334 }];
335
336 let request = ChatCompletionRequest {
337 model: "local-llm".to_string(),
338 messages: Self::to_chat_messages(&messages),
339 max_tokens: max_tokens.min(20),
340 temperature: 0.3,
341 stream: false,
342 };
343
344 let response = self.http_client
345 .post(&self.completions_url())
346 .json(&request)
347 .send()
348 .await
349 .map_err(|e| anyhow::anyhow!("Title generation request failed: {}", e))?;
350
351 if !response.status().is_success() {
352 let status = response.status();
353 let body = response.text().await.unwrap_or_default();
354 return Err(anyhow::anyhow!("Title generation failed ({}): {}", status, body));
355 }
356
357 let completion: ChatCompletionResponse = response.json().await
358 .map_err(|e| anyhow::anyhow!("Failed to parse title response: {}", e))?;
359
360 let title = completion.choices
361 .first()
362 .and_then(|c| c.message.as_ref())
363 .map(|m| m.content.trim().to_string())
364 .unwrap_or_else(|| "New Chat".to_string());
365
366 let title = title.trim_matches('"').trim_matches('\'').to_string();
367
368 info!("Generated title: '{}'", title);
369 Ok(title)
370 }
371}