Skip to main content

memvid_cli/enrich/
groq.rs

1//! Groq enrichment engine using Llama 3.3 70B.
2//!
3//! This engine uses the Groq API (OpenAI-compatible) to extract structured memory cards
4//! from text content. Groq offers ultra-fast inference on LPU hardware.
5//! Supports parallel batch processing for speed.
6
7use anyhow::{anyhow, Result};
8use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
9use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
10use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
11use reqwest::blocking::Client;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::{debug, info, warn};
16
17/// The extraction prompt for Groq
18const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
19
20For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
21MEMORY_START
22kind: <Fact|Preference|Event|Profile|Relationship|Other>
23entity: <the main entity this memory is about, use "user" for the human in the conversation>
24slot: <a short key describing what aspect of the entity>
25value: <the actual information>
26polarity: <Positive|Negative|Neutral>
27MEMORY_END
28
29Only extract information that is explicitly stated. Do not infer or guess.
30If there are no clear facts to extract, output MEMORY_NONE.
31
32Extract memories from this text:
33"#;
34
35/// OpenAI-compatible API request message
36#[derive(Debug, Serialize, Clone)]
37struct ChatMessage {
38    role: String,
39    content: String,
40}
41
42/// OpenAI-compatible API request
43#[derive(Debug, Serialize)]
44struct ChatRequest {
45    model: String,
46    messages: Vec<ChatMessage>,
47    max_tokens: u32,
48    temperature: f32,
49}
50
51/// OpenAI-compatible API response
52#[derive(Debug, Deserialize)]
53struct ChatResponse {
54    choices: Vec<ChatChoice>,
55}
56
57#[derive(Debug, Deserialize)]
58struct ChatChoice {
59    message: ChatMessageResponse,
60}
61
62#[derive(Debug, Deserialize)]
63struct ChatMessageResponse {
64    content: String,
65}
66
67/// Groq enrichment engine using Llama 3.3 70B with parallel processing.
68pub struct GroqEngine {
69    /// API key
70    api_key: String,
71    /// Model to use
72    model: String,
73    /// Whether the engine is initialized
74    ready: bool,
75    /// Number of parallel workers (default: 20)
76    parallelism: usize,
77    /// Shared HTTP client (built in `init`)
78    client: Option<Client>,
79}
80
81impl GroqEngine {
82    /// Create a new Groq engine.
83    pub fn new() -> Self {
84        let api_key = std::env::var("GROQ_API_KEY").unwrap_or_default();
85        Self {
86            api_key,
87            model: "llama-3.3-70b-versatile".to_string(),
88            ready: false,
89            parallelism: 20,
90            client: None,
91        }
92    }
93
94    /// Create with a specific model.
95    pub fn with_model(model: &str) -> Self {
96        let api_key = std::env::var("GROQ_API_KEY").unwrap_or_default();
97        Self {
98            api_key,
99            model: model.to_string(),
100            ready: false,
101            parallelism: 20,
102            client: None,
103        }
104    }
105
106    /// Set parallelism level.
107    pub fn with_parallelism(mut self, n: usize) -> Self {
108        self.parallelism = n;
109        self
110    }
111
112    /// Run inference via Groq API (blocking, thread-safe).
113    fn run_inference_blocking(
114        client: &Client,
115        api_key: &str,
116        model: &str,
117        text: &str,
118    ) -> Result<String> {
119        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
120
121        let request = ChatRequest {
122            model: model.to_string(),
123            messages: vec![ChatMessage {
124                role: "user".to_string(),
125                content: prompt,
126            }],
127            max_tokens: 1024,
128            temperature: 0.0,
129        };
130
131        let response = client
132            .post("https://api.groq.com/openai/v1/chat/completions")
133            .header("Authorization", format!("Bearer {}", api_key))
134            .header("Content-Type", "application/json")
135            .json(&request)
136            .send()
137            .map_err(|e| anyhow!("Groq API request failed: {}", e))?;
138
139        if !response.status().is_success() {
140            let status = response.status();
141            let body = response.text().unwrap_or_default();
142            return Err(anyhow!("Groq API error {}: {}", status, body));
143        }
144
145        let chat_response: ChatResponse = response
146            .json()
147            .map_err(|e| anyhow!("Failed to parse Groq response: {}", e))?;
148
149        chat_response
150            .choices
151            .first()
152            .map(|c| c.message.content.clone())
153            .ok_or_else(|| anyhow!("No response from Groq"))
154    }
155
156    /// Parse the LLM output into memory cards.
157    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
158        let mut cards = Vec::new();
159
160        if output.contains("MEMORY_NONE") {
161            return cards;
162        }
163
164        for block in output.split("MEMORY_START") {
165            let block = block.trim();
166            if block.is_empty() || !block.contains("MEMORY_END") {
167                continue;
168            }
169
170            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
171
172            let mut kind = None;
173            let mut entity = None;
174            let mut slot = None;
175            let mut value = None;
176            let mut polarity = Polarity::Neutral;
177
178            for line in block.lines() {
179                let line = line.trim();
180                if let Some(rest) = line.strip_prefix("kind:") {
181                    kind = parse_memory_kind(rest.trim());
182                } else if let Some(rest) = line.strip_prefix("entity:") {
183                    entity = Some(rest.trim().to_string());
184                } else if let Some(rest) = line.strip_prefix("slot:") {
185                    slot = Some(rest.trim().to_string());
186                } else if let Some(rest) = line.strip_prefix("value:") {
187                    value = Some(rest.trim().to_string());
188                } else if let Some(rest) = line.strip_prefix("polarity:") {
189                    polarity = parse_polarity(rest.trim());
190                }
191            }
192
193            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
194                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
195                    match MemoryCardBuilder::new()
196                        .kind(k)
197                        .entity(&e)
198                        .slot(&s)
199                        .value(&v)
200                        .polarity(polarity)
201                        .source(frame_id, Some(uri.to_string()))
202                        .document_date(timestamp)
203                        .engine("groq:llama-3.3-70b", "1.0.0")
204                        .build(0)
205                    {
206                        Ok(card) => cards.push(card),
207                        Err(err) => {
208                            warn!("Failed to build memory card: {}", err);
209                        }
210                    }
211                }
212            }
213        }
214
215        cards
216    }
217
218    /// Process multiple frames in parallel and return all cards.
219    pub fn enrich_batch(
220        &self,
221        contexts: Vec<EnrichmentContext>,
222    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
223        let client = self
224            .client
225            .as_ref()
226            .ok_or_else(|| anyhow!("Groq engine not initialized (init() not called)"))?
227            .clone();
228        let client = Arc::new(client);
229        let api_key = Arc::new(self.api_key.clone());
230        let model = Arc::new(self.model.clone());
231        let total = contexts.len();
232
233        info!(
234            "Starting parallel enrichment of {} frames with {} workers",
235            total, self.parallelism
236        );
237
238        let pool = rayon::ThreadPoolBuilder::new()
239            .num_threads(self.parallelism)
240            .build()
241            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
242
243        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
244            contexts
245                .into_par_iter()
246                .enumerate()
247                .map(|(i, ctx)| {
248                    if ctx.text.is_empty() {
249                        return (ctx.frame_id, vec![]);
250                    }
251
252                    if i > 0 && i % 50 == 0 {
253                        info!("Enrichment progress: {}/{} frames", i, total);
254                    }
255
256                    match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
257                        Ok(output) => {
258                            debug!(
259                                "Groq output for frame {}: {}",
260                                ctx.frame_id,
261                                &output[..output.len().min(100)]
262                            );
263                            let cards =
264                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
265                            (ctx.frame_id, cards)
266                        }
267                        Err(err) => {
268                            warn!("Groq inference failed for frame {}: {}", ctx.frame_id, err);
269                            (ctx.frame_id, vec![])
270                        }
271                    }
272                })
273                .collect()
274        });
275
276        info!(
277            "Parallel enrichment complete: {} frames processed",
278            results.len()
279        );
280        Ok(results)
281    }
282}
283
284fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
285    match s.to_lowercase().as_str() {
286        "fact" => Some(MemoryKind::Fact),
287        "preference" => Some(MemoryKind::Preference),
288        "event" => Some(MemoryKind::Event),
289        "profile" => Some(MemoryKind::Profile),
290        "relationship" => Some(MemoryKind::Relationship),
291        "other" => Some(MemoryKind::Other),
292        _ => None,
293    }
294}
295
296fn parse_polarity(s: &str) -> Polarity {
297    match s.to_lowercase().as_str() {
298        "positive" => Polarity::Positive,
299        "negative" => Polarity::Negative,
300        _ => Polarity::Neutral,
301    }
302}
303
304impl EnrichmentEngine for GroqEngine {
305    fn kind(&self) -> &str {
306        "groq:llama-3.3-70b"
307    }
308
309    fn version(&self) -> &str {
310        "1.0.0"
311    }
312
313    fn init(&mut self) -> memvid_core::Result<()> {
314        if self.api_key.is_empty() {
315            return Err(memvid_core::MemvidError::EmbeddingFailed {
316                reason: "GROQ_API_KEY environment variable not set".into(),
317            });
318        }
319        let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
320            memvid_core::MemvidError::EmbeddingFailed {
321                reason: format!("Failed to create Groq HTTP client: {err}").into(),
322            }
323        })?;
324        self.client = Some(client);
325        self.ready = true;
326        Ok(())
327    }
328
329    fn is_ready(&self) -> bool {
330        self.ready
331    }
332
333    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
334        if ctx.text.is_empty() {
335            return EnrichmentResult::empty();
336        }
337
338        let client = match self.client.as_ref() {
339            Some(client) => client,
340            None => {
341                return EnrichmentResult::failed(
342                    "Groq engine not initialized (init() not called)".to_string(),
343                )
344            }
345        };
346
347        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
348            Ok(output) => {
349                debug!("Groq output for frame {}: {}", ctx.frame_id, output);
350                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
351                EnrichmentResult::success(cards)
352            }
353            Err(err) => EnrichmentResult::failed(format!("Groq inference failed: {}", err)),
354        }
355    }
356}
357
358impl Default for GroqEngine {
359    fn default() -> Self {
360        Self::new()
361    }
362}