1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34const BATCH_EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from multiple text blocks.
36
37Each text block is labeled with a FRAME_ID. For each distinct fact in each block, output a memory card with the frame_id field:
38
39MEMORY_START
40frame_id: <the FRAME_ID of the source text>
41kind: <Fact|Preference|Event|Profile|Relationship|Other>
42entity: <the main entity this memory is about, use "user" for the human in the conversation>
43slot: <a short key describing what aspect of the entity>
44value: <the actual information>
45polarity: <Positive|Negative|Neutral>
46MEMORY_END
47
48Only extract information that is explicitly stated. Do not infer or guess.
49If a text block has no facts, output MEMORY_NONE with its frame_id.
50
51Process these text blocks:
52"#;
53
54#[derive(Debug, Serialize, Clone)]
56struct ChatMessage {
57 role: String,
58 content: String,
59}
60
61#[derive(Debug, Serialize)]
63struct ChatRequest {
64 model: String,
65 messages: Vec<ChatMessage>,
66 max_tokens: u32,
67 temperature: f32,
68}
69
70#[derive(Debug, Deserialize)]
72struct ChatResponse {
73 choices: Vec<ChatChoice>,
74}
75
76#[derive(Debug, Deserialize)]
77struct ChatChoice {
78 message: ChatMessageResponse,
79}
80
81#[derive(Debug, Deserialize)]
82struct ChatMessageResponse {
83 content: String,
84}
85
86pub struct OpenAiEngine {
88 api_key: String,
90 model: String,
92 ready: bool,
94 parallelism: usize,
96 batch_size: usize,
98 client: Option<Client>,
100}
101
102impl OpenAiEngine {
103 pub fn new() -> Self {
105 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
106 Self {
107 api_key,
108 model: "gpt-4o-mini".to_string(),
109 ready: false,
110 parallelism: 20, batch_size: 10, client: None,
113 }
114 }
115
116 pub fn with_model(model: &str) -> Self {
118 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
119 Self {
120 api_key,
121 model: model.to_string(),
122 ready: false,
123 parallelism: 20,
124 batch_size: 10,
125 client: None,
126 }
127 }
128
129 pub fn with_parallelism(mut self, n: usize) -> Self {
131 self.parallelism = n;
132 self
133 }
134
135 pub fn with_batch_size(mut self, n: usize) -> Self {
137 self.batch_size = n.max(1); self
139 }
140
141 fn run_inference_blocking(
143 client: &Client,
144 api_key: &str,
145 model: &str,
146 text: &str,
147 ) -> Result<String> {
148 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
149
150 let request = ChatRequest {
151 model: model.to_string(),
152 messages: vec![ChatMessage {
153 role: "user".to_string(),
154 content: prompt,
155 }],
156 max_tokens: 1024,
157 temperature: 0.0,
158 };
159
160 let response = client
161 .post("https://api.openai.com/v1/chat/completions")
162 .header("Authorization", format!("Bearer {}", api_key))
163 .header("Content-Type", "application/json")
164 .json(&request)
165 .send()
166 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
167
168 if !response.status().is_success() {
169 let status = response.status();
170 let body = response.text().unwrap_or_default();
171 return Err(anyhow!("OpenAI API error {}: {}", status, body));
172 }
173
174 let chat_response: ChatResponse = response
175 .json()
176 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
177
178 chat_response
179 .choices
180 .first()
181 .map(|c| c.message.content.clone())
182 .ok_or_else(|| anyhow!("No response from OpenAI"))
183 }
184
185 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
187 let mut cards = Vec::new();
188
189 if output.contains("MEMORY_NONE") {
191 return cards;
192 }
193
194 for block in output.split("MEMORY_START") {
196 let block = block.trim();
197 if block.is_empty() || !block.contains("MEMORY_END") {
198 continue;
199 }
200
201 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
202
203 let mut kind = None;
205 let mut entity = None;
206 let mut slot = None;
207 let mut value = None;
208 let mut polarity = Polarity::Neutral;
209
210 for line in block.lines() {
211 let line = line.trim();
212 if let Some(rest) = line.strip_prefix("kind:") {
213 kind = parse_memory_kind(rest.trim());
214 } else if let Some(rest) = line.strip_prefix("entity:") {
215 entity = Some(rest.trim().to_string());
216 } else if let Some(rest) = line.strip_prefix("slot:") {
217 slot = Some(rest.trim().to_string());
218 } else if let Some(rest) = line.strip_prefix("value:") {
219 value = Some(rest.trim().to_string());
220 } else if let Some(rest) = line.strip_prefix("polarity:") {
221 polarity = parse_polarity(rest.trim());
222 }
223 }
224
225 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
227 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
228 match MemoryCardBuilder::new()
229 .kind(k)
230 .entity(&e)
231 .slot(&s)
232 .value(&v)
233 .polarity(polarity)
234 .source(frame_id, Some(uri.to_string()))
235 .document_date(timestamp)
236 .engine("openai:gpt-4o-mini", "1.0.0")
237 .build(0)
238 {
239 Ok(card) => cards.push(card),
240 Err(err) => {
241 warn!("Failed to build memory card: {}", err);
242 }
243 }
244 }
245 }
246 }
247
248 cards
249 }
250
251 fn run_batched_inference_blocking(
253 client: &Client,
254 api_key: &str,
255 model: &str,
256 contexts: &[&EnrichmentContext],
257 ) -> Result<String> {
258 let mut prompt = BATCH_EXTRACTION_PROMPT.to_string();
260 for ctx in contexts {
261 prompt.push_str(&format!(
262 "\n\n=== FRAME_ID: {} ===\n{}",
263 ctx.frame_id, ctx.text
264 ));
265 }
266
267 let max_tokens = 1024 + (contexts.len() as u32 * 512);
269
270 let request = ChatRequest {
271 model: model.to_string(),
272 messages: vec![ChatMessage {
273 role: "user".to_string(),
274 content: prompt,
275 }],
276 max_tokens: max_tokens.min(4096), temperature: 0.0,
278 };
279
280 let response = client
281 .post("https://api.openai.com/v1/chat/completions")
282 .header("Authorization", format!("Bearer {}", api_key))
283 .header("Content-Type", "application/json")
284 .json(&request)
285 .send()
286 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
287
288 if !response.status().is_success() {
289 let status = response.status();
290 let body = response.text().unwrap_or_default();
291 return Err(anyhow!("OpenAI API error {}: {}", status, body));
292 }
293
294 let chat_response: ChatResponse = response
295 .json()
296 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
297
298 chat_response
299 .choices
300 .first()
301 .map(|c| c.message.content.clone())
302 .ok_or_else(|| anyhow!("No response from OpenAI"))
303 }
304
305 fn parse_batched_output(
307 output: &str,
308 contexts: &[&EnrichmentContext],
309 ) -> std::collections::HashMap<u64, Vec<MemoryCard>> {
310 let mut results: std::collections::HashMap<u64, Vec<MemoryCard>> =
311 std::collections::HashMap::new();
312
313 for ctx in contexts {
315 results.insert(ctx.frame_id, Vec::new());
316 }
317
318 let ctx_lookup: std::collections::HashMap<u64, &EnrichmentContext> =
320 contexts.iter().map(|c| (c.frame_id, *c)).collect();
321
322 for block in output.split("MEMORY_START") {
324 let block = block.trim();
325 if block.is_empty() || !block.contains("MEMORY_END") {
326 continue;
327 }
328
329 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
330
331 let mut frame_id: Option<u64> = None;
333 let mut kind = None;
334 let mut entity = None;
335 let mut slot = None;
336 let mut value = None;
337 let mut polarity = Polarity::Neutral;
338
339 for line in block.lines() {
340 let line = line.trim();
341 if let Some(rest) = line.strip_prefix("frame_id:") {
342 frame_id = rest.trim().parse().ok();
343 } else if let Some(rest) = line.strip_prefix("kind:") {
344 kind = parse_memory_kind(rest.trim());
345 } else if let Some(rest) = line.strip_prefix("entity:") {
346 entity = Some(rest.trim().to_string());
347 } else if let Some(rest) = line.strip_prefix("slot:") {
348 slot = Some(rest.trim().to_string());
349 } else if let Some(rest) = line.strip_prefix("value:") {
350 value = Some(rest.trim().to_string());
351 } else if let Some(rest) = line.strip_prefix("polarity:") {
352 polarity = parse_polarity(rest.trim());
353 }
354 }
355
356 if let (Some(fid), Some(k), Some(e), Some(s), Some(v)) =
358 (frame_id, kind, entity, slot, value)
359 {
360 if let Some(ctx) = ctx_lookup.get(&fid) {
361 let uri = &ctx.uri;
362 let timestamp = ctx.timestamp;
363
364 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
365 match MemoryCardBuilder::new()
366 .kind(k)
367 .entity(&e)
368 .slot(&s)
369 .value(&v)
370 .polarity(polarity)
371 .source(fid, Some(uri.to_string()))
372 .document_date(timestamp)
373 .engine("openai:gpt-4o-mini", "1.0.0")
374 .build(0)
375 {
376 Ok(card) => {
377 results.entry(fid).or_default().push(card);
378 }
379 Err(err) => {
380 warn!("Failed to build memory card: {}", err);
381 }
382 }
383 }
384 }
385 }
386 }
387
388 results
389 }
390
391 pub fn enrich_batch(
395 &self,
396 contexts: Vec<EnrichmentContext>,
397 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
398 let client = self
399 .client
400 .as_ref()
401 .ok_or_else(|| anyhow!("OpenAI engine not initialized (init() not called)"))?
402 .clone();
403 let client = Arc::new(client);
404 let api_key = Arc::new(self.api_key.clone());
405 let model = Arc::new(self.model.clone());
406 let total = contexts.len();
407 let batch_size = self.batch_size;
408
409 let num_batches = (total + batch_size - 1) / batch_size;
411
412 info!(
413 "Starting parallel enrichment of {} frames with {} workers, {} frames per batch ({} batches)",
414 total, self.parallelism, batch_size, num_batches
415 );
416
417 let batches: Vec<Vec<EnrichmentContext>> = contexts
419 .into_iter()
420 .collect::<Vec<_>>()
421 .chunks(batch_size)
422 .map(|chunk| chunk.to_vec())
423 .collect();
424
425 let pool = rayon::ThreadPoolBuilder::new()
427 .num_threads(self.parallelism)
428 .build()
429 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
430
431 let batch_results: Vec<std::collections::HashMap<u64, Vec<MemoryCard>>> =
432 pool.install(|| {
433 batches
434 .into_par_iter()
435 .enumerate()
436 .map(|(batch_idx, batch)| {
437 let non_empty: Vec<&EnrichmentContext> =
439 batch.iter().filter(|ctx| !ctx.text.is_empty()).collect();
440
441 if non_empty.is_empty() {
442 return batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect();
444 }
445
446 if batch_idx > 0 && batch_idx % 10 == 0 {
448 info!("Enrichment progress: {} batches processed", batch_idx);
449 }
450
451 match Self::run_batched_inference_blocking(
452 &client, &api_key, &model, &non_empty,
453 ) {
454 Ok(output) => {
455 debug!(
456 "OpenAI batch output (batch {}): {}...",
457 batch_idx,
458 &output[..output.len().min(100)]
459 );
460 Self::parse_batched_output(&output, &non_empty)
461 }
462 Err(err) => {
463 warn!(
464 "OpenAI batch inference failed (batch {}): {}",
465 batch_idx, err
466 );
467 batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect()
469 }
470 }
471 })
472 .collect()
473 });
474
475 let mut results: Vec<(u64, Vec<MemoryCard>)> = Vec::with_capacity(total);
477 for batch_map in batch_results {
478 for (frame_id, cards) in batch_map {
479 results.push((frame_id, cards));
480 }
481 }
482
483 info!(
484 "Parallel enrichment complete: {} frames processed in {} batches",
485 results.len(),
486 num_batches
487 );
488 Ok(results)
489 }
490}
491
492fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
494 match s.to_lowercase().as_str() {
495 "fact" => Some(MemoryKind::Fact),
496 "preference" => Some(MemoryKind::Preference),
497 "event" => Some(MemoryKind::Event),
498 "profile" => Some(MemoryKind::Profile),
499 "relationship" => Some(MemoryKind::Relationship),
500 "other" => Some(MemoryKind::Other),
501 _ => None,
502 }
503}
504
505fn parse_polarity(s: &str) -> Polarity {
507 match s.to_lowercase().as_str() {
508 "positive" => Polarity::Positive,
509 "negative" => Polarity::Negative,
510 _ => Polarity::Neutral,
511 }
512}
513
514impl EnrichmentEngine for OpenAiEngine {
515 fn kind(&self) -> &str {
516 "openai:gpt-4o-mini"
517 }
518
519 fn version(&self) -> &str {
520 "1.0.0"
521 }
522
523 fn init(&mut self) -> memvid_core::Result<()> {
524 if self.api_key.is_empty() {
525 return Err(memvid_core::MemvidError::EmbeddingFailed {
526 reason: "OPENAI_API_KEY environment variable not set".into(),
527 });
528 }
529 let client = crate::http::blocking_client(Duration::from_secs(120)).map_err(|err| {
531 memvid_core::MemvidError::EmbeddingFailed {
532 reason: format!("Failed to create OpenAI HTTP client: {err}").into(),
533 }
534 })?;
535 self.client = Some(client);
536 self.ready = true;
537 Ok(())
538 }
539
540 fn is_ready(&self) -> bool {
541 self.ready
542 }
543
544 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
545 if ctx.text.is_empty() {
546 return EnrichmentResult::empty();
547 }
548
549 let client = match self.client.as_ref() {
550 Some(client) => client,
551 None => {
552 return EnrichmentResult::failed(
553 "OpenAI engine not initialized (init() not called)".to_string(),
554 )
555 }
556 };
557
558 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
559 Ok(output) => {
560 debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
561 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
562 EnrichmentResult::success(cards)
563 }
564 Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
565 }
566 }
567}
568
569impl Default for OpenAiEngine {
570 fn default() -> Self {
571 Self::new()
572 }
573}