memvid_cli/enrich/
groq.rs

1//! Groq enrichment engine using Llama 3.3 70B.
2//!
3//! This engine uses the Groq API (OpenAI-compatible) to extract structured memory cards
4//! from text content. Groq offers ultra-fast inference on LPU hardware.
5//! Supports parallel batch processing for speed.
6
7use anyhow::{anyhow, Result};
8use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
9use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
10use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
11use reqwest::blocking::Client;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::{debug, info, warn};
16
17/// The extraction prompt for Groq
18const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
19
20For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
21MEMORY_START
22kind: <Fact|Preference|Event|Profile|Relationship|Other>
23entity: <the main entity this memory is about, use "user" for the human in the conversation>
24slot: <a short key describing what aspect of the entity>
25value: <the actual information>
26polarity: <Positive|Negative|Neutral>
27MEMORY_END
28
29Only extract information that is explicitly stated. Do not infer or guess.
30If there are no clear facts to extract, output MEMORY_NONE.
31
32Extract memories from this text:
33"#;
34
35/// OpenAI-compatible API request message
36#[derive(Debug, Serialize, Clone)]
37struct ChatMessage {
38    role: String,
39    content: String,
40}
41
42/// OpenAI-compatible API request
43#[derive(Debug, Serialize)]
44struct ChatRequest {
45    model: String,
46    messages: Vec<ChatMessage>,
47    max_tokens: u32,
48    temperature: f32,
49}
50
51/// OpenAI-compatible API response
52#[derive(Debug, Deserialize)]
53struct ChatResponse {
54    choices: Vec<ChatChoice>,
55}
56
57#[derive(Debug, Deserialize)]
58struct ChatChoice {
59    message: ChatMessageResponse,
60}
61
62#[derive(Debug, Deserialize)]
63struct ChatMessageResponse {
64    content: String,
65}
66
67/// Groq enrichment engine using Llama 3.3 70B with parallel processing.
68pub struct GroqEngine {
69    /// API key
70    api_key: String,
71    /// Model to use
72    model: String,
73    /// Whether the engine is initialized
74    ready: bool,
75    /// Number of parallel workers (default: 20)
76    parallelism: usize,
77    /// Shared HTTP client (built in `init`)
78    client: Option<Client>,
79}
80
81impl GroqEngine {
82    /// Create a new Groq engine.
83    pub fn new() -> Self {
84        let api_key = std::env::var("GROQ_API_KEY").unwrap_or_default();
85        Self {
86            api_key,
87            model: "llama-3.3-70b-versatile".to_string(),
88            ready: false,
89            parallelism: 20,
90            client: None,
91        }
92    }
93
94    /// Create with a specific model.
95    pub fn with_model(model: &str) -> Self {
96        let api_key = std::env::var("GROQ_API_KEY").unwrap_or_default();
97        Self {
98            api_key,
99            model: model.to_string(),
100            ready: false,
101            parallelism: 20,
102            client: None,
103        }
104    }
105
106    /// Set parallelism level.
107    pub fn with_parallelism(mut self, n: usize) -> Self {
108        self.parallelism = n;
109        self
110    }
111
112
113    /// Run inference via Groq API (blocking, thread-safe).
114    fn run_inference_blocking(
115        client: &Client,
116        api_key: &str,
117        model: &str,
118        text: &str,
119    ) -> Result<String> {
120        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
121
122        let request = ChatRequest {
123            model: model.to_string(),
124            messages: vec![ChatMessage {
125                role: "user".to_string(),
126                content: prompt,
127            }],
128            max_tokens: 1024,
129            temperature: 0.0,
130        };
131
132        let response = client
133            .post("https://api.groq.com/openai/v1/chat/completions")
134            .header("Authorization", format!("Bearer {}", api_key))
135            .header("Content-Type", "application/json")
136            .json(&request)
137            .send()
138            .map_err(|e| anyhow!("Groq API request failed: {}", e))?;
139
140        if !response.status().is_success() {
141            let status = response.status();
142            let body = response.text().unwrap_or_default();
143            return Err(anyhow!("Groq API error {}: {}", status, body));
144        }
145
146        let chat_response: ChatResponse = response
147            .json()
148            .map_err(|e| anyhow!("Failed to parse Groq response: {}", e))?;
149
150        chat_response
151            .choices
152            .first()
153            .map(|c| c.message.content.clone())
154            .ok_or_else(|| anyhow!("No response from Groq"))
155    }
156
157    /// Parse the LLM output into memory cards.
158    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
159        let mut cards = Vec::new();
160
161        if output.contains("MEMORY_NONE") {
162            return cards;
163        }
164
165        for block in output.split("MEMORY_START") {
166            let block = block.trim();
167            if block.is_empty() || !block.contains("MEMORY_END") {
168                continue;
169            }
170
171            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
172
173            let mut kind = None;
174            let mut entity = None;
175            let mut slot = None;
176            let mut value = None;
177            let mut polarity = Polarity::Neutral;
178
179            for line in block.lines() {
180                let line = line.trim();
181                if let Some(rest) = line.strip_prefix("kind:") {
182                    kind = parse_memory_kind(rest.trim());
183                } else if let Some(rest) = line.strip_prefix("entity:") {
184                    entity = Some(rest.trim().to_string());
185                } else if let Some(rest) = line.strip_prefix("slot:") {
186                    slot = Some(rest.trim().to_string());
187                } else if let Some(rest) = line.strip_prefix("value:") {
188                    value = Some(rest.trim().to_string());
189                } else if let Some(rest) = line.strip_prefix("polarity:") {
190                    polarity = parse_polarity(rest.trim());
191                }
192            }
193
194            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
195                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
196                    match MemoryCardBuilder::new()
197                        .kind(k)
198                        .entity(&e)
199                        .slot(&s)
200                        .value(&v)
201                        .polarity(polarity)
202                        .source(frame_id, Some(uri.to_string()))
203                        .document_date(timestamp)
204                        .engine("groq:llama-3.3-70b", "1.0.0")
205                        .build(0)
206                    {
207                        Ok(card) => cards.push(card),
208                        Err(err) => {
209                            warn!("Failed to build memory card: {}", err);
210                        }
211                    }
212                }
213            }
214        }
215
216        cards
217    }
218
219    /// Process multiple frames in parallel and return all cards.
220    pub fn enrich_batch(
221        &self,
222        contexts: Vec<EnrichmentContext>,
223    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
224        let client = self
225            .client
226            .as_ref()
227            .ok_or_else(|| anyhow!("Groq engine not initialized (init() not called)"))?
228            .clone();
229        let client = Arc::new(client);
230        let api_key = Arc::new(self.api_key.clone());
231        let model = Arc::new(self.model.clone());
232        let total = contexts.len();
233
234        info!(
235            "Starting parallel enrichment of {} frames with {} workers",
236            total, self.parallelism
237        );
238
239        let pool = rayon::ThreadPoolBuilder::new()
240            .num_threads(self.parallelism)
241            .build()
242            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
243
244        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
245            contexts
246                .into_par_iter()
247                .enumerate()
248                .map(|(i, ctx)| {
249                    if ctx.text.is_empty() {
250                        return (ctx.frame_id, vec![]);
251                    }
252
253                    if i > 0 && i % 50 == 0 {
254                        info!("Enrichment progress: {}/{} frames", i, total);
255                    }
256
257                    match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
258                        Ok(output) => {
259                            debug!(
260                                "Groq output for frame {}: {}",
261                                ctx.frame_id,
262                                &output[..output.len().min(100)]
263                            );
264                            let cards =
265                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
266                            (ctx.frame_id, cards)
267                        }
268                        Err(err) => {
269                            warn!(
270                                "Groq inference failed for frame {}: {}",
271                                ctx.frame_id, err
272                            );
273                            (ctx.frame_id, vec![])
274                        }
275                    }
276                })
277                .collect()
278        });
279
280        info!(
281            "Parallel enrichment complete: {} frames processed",
282            results.len()
283        );
284        Ok(results)
285    }
286}
287
288fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
289    match s.to_lowercase().as_str() {
290        "fact" => Some(MemoryKind::Fact),
291        "preference" => Some(MemoryKind::Preference),
292        "event" => Some(MemoryKind::Event),
293        "profile" => Some(MemoryKind::Profile),
294        "relationship" => Some(MemoryKind::Relationship),
295        "other" => Some(MemoryKind::Other),
296        _ => None,
297    }
298}
299
300fn parse_polarity(s: &str) -> Polarity {
301    match s.to_lowercase().as_str() {
302        "positive" => Polarity::Positive,
303        "negative" => Polarity::Negative,
304        _ => Polarity::Neutral,
305    }
306}
307
308impl EnrichmentEngine for GroqEngine {
309    fn kind(&self) -> &str {
310        "groq:llama-3.3-70b"
311    }
312
313    fn version(&self) -> &str {
314        "1.0.0"
315    }
316
317    fn init(&mut self) -> memvid_core::Result<()> {
318        if self.api_key.is_empty() {
319            return Err(memvid_core::MemvidError::EmbeddingFailed {
320                reason: "GROQ_API_KEY environment variable not set".into(),
321            });
322        }
323        let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
324            memvid_core::MemvidError::EmbeddingFailed {
325                reason: format!("Failed to create Groq HTTP client: {err}").into(),
326            }
327        })?;
328        self.client = Some(client);
329        self.ready = true;
330        Ok(())
331    }
332
333    fn is_ready(&self) -> bool {
334        self.ready
335    }
336
337    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
338        if ctx.text.is_empty() {
339            return EnrichmentResult::empty();
340        }
341
342        let client = match self.client.as_ref() {
343            Some(client) => client,
344            None => {
345                return EnrichmentResult::failed(
346                    "Groq engine not initialized (init() not called)".to_string(),
347                )
348            }
349        };
350
351        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
352            Ok(output) => {
353                debug!("Groq output for frame {}: {}", ctx.frame_id, output);
354                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
355                EnrichmentResult::success(cards)
356            }
357            Err(err) => EnrichmentResult::failed(format!("Groq inference failed: {}", err)),
358        }
359    }
360}
361
362impl Default for GroqEngine {
363    fn default() -> Self {
364        Self::new()
365    }
366}