1use anyhow::{anyhow, Result};
8use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
9use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
10use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
11use reqwest::blocking::Client;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::{debug, info, warn};
16
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
19
20For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
21MEMORY_START
22kind: <Fact|Preference|Event|Profile|Relationship|Other>
23entity: <the main entity this memory is about, use "user" for the human in the conversation>
24slot: <a short key describing what aspect of the entity>
25value: <the actual information>
26polarity: <Positive|Negative|Neutral>
27MEMORY_END
28
29Only extract information that is explicitly stated. Do not infer or guess.
30If there are no clear facts to extract, output MEMORY_NONE.
31
32Extract memories from this text:
33"#;
34
35#[derive(Debug, Serialize, Clone)]
37struct ChatMessage {
38 role: String,
39 content: String,
40}
41
42#[derive(Debug, Serialize)]
44struct ChatRequest {
45 model: String,
46 messages: Vec<ChatMessage>,
47 max_tokens: u32,
48 temperature: f32,
49}
50
51#[derive(Debug, Deserialize)]
53struct ChatResponse {
54 choices: Vec<ChatChoice>,
55}
56
57#[derive(Debug, Deserialize)]
58struct ChatChoice {
59 message: ChatMessageResponse,
60}
61
62#[derive(Debug, Deserialize)]
63struct ChatMessageResponse {
64 content: String,
65}
66
67pub struct GroqEngine {
69 api_key: String,
71 model: String,
73 ready: bool,
75 parallelism: usize,
77 client: Option<Client>,
79}
80
81impl GroqEngine {
82 pub fn new() -> Self {
84 let api_key = std::env::var("GROQ_API_KEY").unwrap_or_default();
85 Self {
86 api_key,
87 model: "llama-3.3-70b-versatile".to_string(),
88 ready: false,
89 parallelism: 20,
90 client: None,
91 }
92 }
93
94 pub fn with_model(model: &str) -> Self {
96 let api_key = std::env::var("GROQ_API_KEY").unwrap_or_default();
97 Self {
98 api_key,
99 model: model.to_string(),
100 ready: false,
101 parallelism: 20,
102 client: None,
103 }
104 }
105
106 pub fn with_parallelism(mut self, n: usize) -> Self {
108 self.parallelism = n;
109 self
110 }
111
112
113 fn run_inference_blocking(
115 client: &Client,
116 api_key: &str,
117 model: &str,
118 text: &str,
119 ) -> Result<String> {
120 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
121
122 let request = ChatRequest {
123 model: model.to_string(),
124 messages: vec![ChatMessage {
125 role: "user".to_string(),
126 content: prompt,
127 }],
128 max_tokens: 1024,
129 temperature: 0.0,
130 };
131
132 let response = client
133 .post("https://api.groq.com/openai/v1/chat/completions")
134 .header("Authorization", format!("Bearer {}", api_key))
135 .header("Content-Type", "application/json")
136 .json(&request)
137 .send()
138 .map_err(|e| anyhow!("Groq API request failed: {}", e))?;
139
140 if !response.status().is_success() {
141 let status = response.status();
142 let body = response.text().unwrap_or_default();
143 return Err(anyhow!("Groq API error {}: {}", status, body));
144 }
145
146 let chat_response: ChatResponse = response
147 .json()
148 .map_err(|e| anyhow!("Failed to parse Groq response: {}", e))?;
149
150 chat_response
151 .choices
152 .first()
153 .map(|c| c.message.content.clone())
154 .ok_or_else(|| anyhow!("No response from Groq"))
155 }
156
157 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
159 let mut cards = Vec::new();
160
161 if output.contains("MEMORY_NONE") {
162 return cards;
163 }
164
165 for block in output.split("MEMORY_START") {
166 let block = block.trim();
167 if block.is_empty() || !block.contains("MEMORY_END") {
168 continue;
169 }
170
171 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
172
173 let mut kind = None;
174 let mut entity = None;
175 let mut slot = None;
176 let mut value = None;
177 let mut polarity = Polarity::Neutral;
178
179 for line in block.lines() {
180 let line = line.trim();
181 if let Some(rest) = line.strip_prefix("kind:") {
182 kind = parse_memory_kind(rest.trim());
183 } else if let Some(rest) = line.strip_prefix("entity:") {
184 entity = Some(rest.trim().to_string());
185 } else if let Some(rest) = line.strip_prefix("slot:") {
186 slot = Some(rest.trim().to_string());
187 } else if let Some(rest) = line.strip_prefix("value:") {
188 value = Some(rest.trim().to_string());
189 } else if let Some(rest) = line.strip_prefix("polarity:") {
190 polarity = parse_polarity(rest.trim());
191 }
192 }
193
194 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
195 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
196 match MemoryCardBuilder::new()
197 .kind(k)
198 .entity(&e)
199 .slot(&s)
200 .value(&v)
201 .polarity(polarity)
202 .source(frame_id, Some(uri.to_string()))
203 .document_date(timestamp)
204 .engine("groq:llama-3.3-70b", "1.0.0")
205 .build(0)
206 {
207 Ok(card) => cards.push(card),
208 Err(err) => {
209 warn!("Failed to build memory card: {}", err);
210 }
211 }
212 }
213 }
214 }
215
216 cards
217 }
218
219 pub fn enrich_batch(
221 &self,
222 contexts: Vec<EnrichmentContext>,
223 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
224 let client = self
225 .client
226 .as_ref()
227 .ok_or_else(|| anyhow!("Groq engine not initialized (init() not called)"))?
228 .clone();
229 let client = Arc::new(client);
230 let api_key = Arc::new(self.api_key.clone());
231 let model = Arc::new(self.model.clone());
232 let total = contexts.len();
233
234 info!(
235 "Starting parallel enrichment of {} frames with {} workers",
236 total, self.parallelism
237 );
238
239 let pool = rayon::ThreadPoolBuilder::new()
240 .num_threads(self.parallelism)
241 .build()
242 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
243
244 let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
245 contexts
246 .into_par_iter()
247 .enumerate()
248 .map(|(i, ctx)| {
249 if ctx.text.is_empty() {
250 return (ctx.frame_id, vec![]);
251 }
252
253 if i > 0 && i % 50 == 0 {
254 info!("Enrichment progress: {}/{} frames", i, total);
255 }
256
257 match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
258 Ok(output) => {
259 debug!(
260 "Groq output for frame {}: {}",
261 ctx.frame_id,
262 &output[..output.len().min(100)]
263 );
264 let cards =
265 Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
266 (ctx.frame_id, cards)
267 }
268 Err(err) => {
269 warn!(
270 "Groq inference failed for frame {}: {}",
271 ctx.frame_id, err
272 );
273 (ctx.frame_id, vec![])
274 }
275 }
276 })
277 .collect()
278 });
279
280 info!(
281 "Parallel enrichment complete: {} frames processed",
282 results.len()
283 );
284 Ok(results)
285 }
286}
287
288fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
289 match s.to_lowercase().as_str() {
290 "fact" => Some(MemoryKind::Fact),
291 "preference" => Some(MemoryKind::Preference),
292 "event" => Some(MemoryKind::Event),
293 "profile" => Some(MemoryKind::Profile),
294 "relationship" => Some(MemoryKind::Relationship),
295 "other" => Some(MemoryKind::Other),
296 _ => None,
297 }
298}
299
300fn parse_polarity(s: &str) -> Polarity {
301 match s.to_lowercase().as_str() {
302 "positive" => Polarity::Positive,
303 "negative" => Polarity::Negative,
304 _ => Polarity::Neutral,
305 }
306}
307
308impl EnrichmentEngine for GroqEngine {
309 fn kind(&self) -> &str {
310 "groq:llama-3.3-70b"
311 }
312
313 fn version(&self) -> &str {
314 "1.0.0"
315 }
316
317 fn init(&mut self) -> memvid_core::Result<()> {
318 if self.api_key.is_empty() {
319 return Err(memvid_core::MemvidError::EmbeddingFailed {
320 reason: "GROQ_API_KEY environment variable not set".into(),
321 });
322 }
323 let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
324 memvid_core::MemvidError::EmbeddingFailed {
325 reason: format!("Failed to create Groq HTTP client: {err}").into(),
326 }
327 })?;
328 self.client = Some(client);
329 self.ready = true;
330 Ok(())
331 }
332
333 fn is_ready(&self) -> bool {
334 self.ready
335 }
336
337 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
338 if ctx.text.is_empty() {
339 return EnrichmentResult::empty();
340 }
341
342 let client = match self.client.as_ref() {
343 Some(client) => client,
344 None => {
345 return EnrichmentResult::failed(
346 "Groq engine not initialized (init() not called)".to_string(),
347 )
348 }
349 };
350
351 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
352 Ok(output) => {
353 debug!("Groq output for frame {}: {}", ctx.frame_id, output);
354 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
355 EnrichmentResult::success(cards)
356 }
357 Err(err) => EnrichmentResult::failed(format!("Groq inference failed: {}", err)),
358 }
359 }
360}
361
362impl Default for GroqEngine {
363 fn default() -> Self {
364 Self::new()
365 }
366}