Skip to main content

memvid_cli/enrich/
openai.rs

1//! OpenAI-based enrichment engine using GPT-4o-mini.
2//!
3//! This engine uses the OpenAI API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// The extraction prompt for GPT-4o-mini (single frame)
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34/// The extraction prompt for batched frames (multiple frames per API call)
35const BATCH_EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from multiple text blocks.
36
37Each text block is labeled with a FRAME_ID. For each distinct fact in each block, output a memory card with the frame_id field:
38
39MEMORY_START
40frame_id: <the FRAME_ID of the source text>
41kind: <Fact|Preference|Event|Profile|Relationship|Other>
42entity: <the main entity this memory is about, use "user" for the human in the conversation>
43slot: <a short key describing what aspect of the entity>
44value: <the actual information>
45polarity: <Positive|Negative|Neutral>
46MEMORY_END
47
48Only extract information that is explicitly stated. Do not infer or guess.
49If a text block has no facts, output MEMORY_NONE with its frame_id.
50
51Process these text blocks:
52"#;
53
54/// OpenAI API request message
55#[derive(Debug, Serialize, Clone)]
56struct ChatMessage {
57    role: String,
58    content: String,
59}
60
61/// OpenAI API request
62#[derive(Debug, Serialize)]
63struct ChatRequest {
64    model: String,
65    messages: Vec<ChatMessage>,
66    max_tokens: u32,
67    temperature: f32,
68}
69
70/// OpenAI API response
71#[derive(Debug, Deserialize)]
72struct ChatResponse {
73    choices: Vec<ChatChoice>,
74}
75
76#[derive(Debug, Deserialize)]
77struct ChatChoice {
78    message: ChatMessageResponse,
79}
80
81#[derive(Debug, Deserialize)]
82struct ChatMessageResponse {
83    content: String,
84}
85
86/// OpenAI enrichment engine using GPT-4o-mini with parallel processing.
87pub struct OpenAiEngine {
88    /// API key
89    api_key: String,
90    /// Model to use
91    model: String,
92    /// Whether the engine is initialized
93    ready: bool,
94    /// Number of parallel workers (default: 100)
95    parallelism: usize,
96    /// Number of frames to batch per API call (default: 10)
97    batch_size: usize,
98    /// Shared HTTP client (built in `init`)
99    client: Option<Client>,
100}
101
102impl OpenAiEngine {
103    /// Create a new OpenAI engine.
104    pub fn new() -> Self {
105        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
106        Self {
107            api_key,
108            model: "gpt-4o-mini".to_string(),
109            ready: false,
110            parallelism: 20, // 20 concurrent requests (balanced for API rate limits)
111            batch_size: 10,  // 10 frames per API call
112            client: None,
113        }
114    }
115
116    /// Create with a specific model.
117    pub fn with_model(model: &str) -> Self {
118        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
119        Self {
120            api_key,
121            model: model.to_string(),
122            ready: false,
123            parallelism: 20,
124            batch_size: 10,
125            client: None,
126        }
127    }
128
129    /// Set parallelism level.
130    pub fn with_parallelism(mut self, n: usize) -> Self {
131        self.parallelism = n;
132        self
133    }
134
135    /// Set batch size (number of frames per API call).
136    pub fn with_batch_size(mut self, n: usize) -> Self {
137        self.batch_size = n.max(1); // At least 1
138        self
139    }
140
141    /// Run inference via OpenAI API (blocking, thread-safe).
142    fn run_inference_blocking(
143        client: &Client,
144        api_key: &str,
145        model: &str,
146        text: &str,
147    ) -> Result<String> {
148        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
149
150        let request = ChatRequest {
151            model: model.to_string(),
152            messages: vec![ChatMessage {
153                role: "user".to_string(),
154                content: prompt,
155            }],
156            max_tokens: 1024,
157            temperature: 0.0,
158        };
159
160        let response = client
161            .post("https://api.openai.com/v1/chat/completions")
162            .header("Authorization", format!("Bearer {}", api_key))
163            .header("Content-Type", "application/json")
164            .json(&request)
165            .send()
166            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
167
168        if !response.status().is_success() {
169            let status = response.status();
170            let body = response.text().unwrap_or_default();
171            return Err(anyhow!("OpenAI API error {}: {}", status, body));
172        }
173
174        let chat_response: ChatResponse = response
175            .json()
176            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
177
178        chat_response
179            .choices
180            .first()
181            .map(|c| c.message.content.clone())
182            .ok_or_else(|| anyhow!("No response from OpenAI"))
183    }
184
185    /// Parse the LLM output into memory cards.
186    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
187        let mut cards = Vec::new();
188
189        // Check for "no memories" signal
190        if output.contains("MEMORY_NONE") {
191            return cards;
192        }
193
194        // Parse MEMORY_START...MEMORY_END blocks
195        for block in output.split("MEMORY_START") {
196            let block = block.trim();
197            if block.is_empty() || !block.contains("MEMORY_END") {
198                continue;
199            }
200
201            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
202
203            // Parse fields
204            let mut kind = None;
205            let mut entity = None;
206            let mut slot = None;
207            let mut value = None;
208            let mut polarity = Polarity::Neutral;
209
210            for line in block.lines() {
211                let line = line.trim();
212                if let Some(rest) = line.strip_prefix("kind:") {
213                    kind = parse_memory_kind(rest.trim());
214                } else if let Some(rest) = line.strip_prefix("entity:") {
215                    entity = Some(rest.trim().to_string());
216                } else if let Some(rest) = line.strip_prefix("slot:") {
217                    slot = Some(rest.trim().to_string());
218                } else if let Some(rest) = line.strip_prefix("value:") {
219                    value = Some(rest.trim().to_string());
220                } else if let Some(rest) = line.strip_prefix("polarity:") {
221                    polarity = parse_polarity(rest.trim());
222                }
223            }
224
225            // Build memory card if we have required fields
226            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
227                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
228                    match MemoryCardBuilder::new()
229                        .kind(k)
230                        .entity(&e)
231                        .slot(&s)
232                        .value(&v)
233                        .polarity(polarity)
234                        .source(frame_id, Some(uri.to_string()))
235                        .document_date(timestamp)
236                        .engine("openai:gpt-4o-mini", "1.0.0")
237                        .build(0)
238                    {
239                        Ok(card) => cards.push(card),
240                        Err(err) => {
241                            warn!("Failed to build memory card: {}", err);
242                        }
243                    }
244                }
245            }
246        }
247
248        cards
249    }
250
251    /// Run batched inference for multiple frames in a single API call.
252    fn run_batched_inference_blocking(
253        client: &Client,
254        api_key: &str,
255        model: &str,
256        contexts: &[&EnrichmentContext],
257    ) -> Result<String> {
258        // Build the batched prompt with frame markers
259        let mut prompt = BATCH_EXTRACTION_PROMPT.to_string();
260        for ctx in contexts {
261            prompt.push_str(&format!(
262                "\n\n=== FRAME_ID: {} ===\n{}",
263                ctx.frame_id, ctx.text
264            ));
265        }
266
267        // Use larger max_tokens for batched requests
268        let max_tokens = 1024 + (contexts.len() as u32 * 512);
269
270        let request = ChatRequest {
271            model: model.to_string(),
272            messages: vec![ChatMessage {
273                role: "user".to_string(),
274                content: prompt,
275            }],
276            max_tokens: max_tokens.min(4096), // Cap at 4096
277            temperature: 0.0,
278        };
279
280        let response = client
281            .post("https://api.openai.com/v1/chat/completions")
282            .header("Authorization", format!("Bearer {}", api_key))
283            .header("Content-Type", "application/json")
284            .json(&request)
285            .send()
286            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
287
288        if !response.status().is_success() {
289            let status = response.status();
290            let body = response.text().unwrap_or_default();
291            return Err(anyhow!("OpenAI API error {}: {}", status, body));
292        }
293
294        let chat_response: ChatResponse = response
295            .json()
296            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
297
298        chat_response
299            .choices
300            .first()
301            .map(|c| c.message.content.clone())
302            .ok_or_else(|| anyhow!("No response from OpenAI"))
303    }
304
305    /// Parse batched LLM output into memory cards grouped by frame_id.
306    fn parse_batched_output(
307        output: &str,
308        contexts: &[&EnrichmentContext],
309    ) -> std::collections::HashMap<u64, Vec<MemoryCard>> {
310        let mut results: std::collections::HashMap<u64, Vec<MemoryCard>> =
311            std::collections::HashMap::new();
312
313        // Initialize empty results for all frames
314        for ctx in contexts {
315            results.insert(ctx.frame_id, Vec::new());
316        }
317
318        // Build a lookup for context metadata
319        let ctx_lookup: std::collections::HashMap<u64, &EnrichmentContext> =
320            contexts.iter().map(|c| (c.frame_id, *c)).collect();
321
322        // Parse MEMORY_START...MEMORY_END blocks
323        for block in output.split("MEMORY_START") {
324            let block = block.trim();
325            if block.is_empty() || !block.contains("MEMORY_END") {
326                continue;
327            }
328
329            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
330
331            // Parse fields including frame_id
332            let mut frame_id: Option<u64> = None;
333            let mut kind = None;
334            let mut entity = None;
335            let mut slot = None;
336            let mut value = None;
337            let mut polarity = Polarity::Neutral;
338
339            for line in block.lines() {
340                let line = line.trim();
341                if let Some(rest) = line.strip_prefix("frame_id:") {
342                    frame_id = rest.trim().parse().ok();
343                } else if let Some(rest) = line.strip_prefix("kind:") {
344                    kind = parse_memory_kind(rest.trim());
345                } else if let Some(rest) = line.strip_prefix("entity:") {
346                    entity = Some(rest.trim().to_string());
347                } else if let Some(rest) = line.strip_prefix("slot:") {
348                    slot = Some(rest.trim().to_string());
349                } else if let Some(rest) = line.strip_prefix("value:") {
350                    value = Some(rest.trim().to_string());
351                } else if let Some(rest) = line.strip_prefix("polarity:") {
352                    polarity = parse_polarity(rest.trim());
353                }
354            }
355
356            // Build memory card if we have required fields
357            if let (Some(fid), Some(k), Some(e), Some(s), Some(v)) =
358                (frame_id, kind, entity, slot, value)
359            {
360                if let Some(ctx) = ctx_lookup.get(&fid) {
361                    let uri = &ctx.uri;
362                    let timestamp = ctx.timestamp;
363
364                    if !e.is_empty() && !s.is_empty() && !v.is_empty() {
365                        match MemoryCardBuilder::new()
366                            .kind(k)
367                            .entity(&e)
368                            .slot(&s)
369                            .value(&v)
370                            .polarity(polarity)
371                            .source(fid, Some(uri.to_string()))
372                            .document_date(timestamp)
373                            .engine("openai:gpt-4o-mini", "1.0.0")
374                            .build(0)
375                        {
376                            Ok(card) => {
377                                results.entry(fid).or_default().push(card);
378                            }
379                            Err(err) => {
380                                warn!("Failed to build memory card: {}", err);
381                            }
382                        }
383                    }
384                }
385            }
386        }
387
388        results
389    }
390
391    /// Process multiple frames in parallel and return all cards.
392    /// This is the key method for fast enrichment.
393    /// Uses batching to reduce API calls: batch_size frames per API call.
394    pub fn enrich_batch(
395        &self,
396        contexts: Vec<EnrichmentContext>,
397    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
398        let client = self
399            .client
400            .as_ref()
401            .ok_or_else(|| anyhow!("OpenAI engine not initialized (init() not called)"))?
402            .clone();
403        let client = Arc::new(client);
404        let api_key = Arc::new(self.api_key.clone());
405        let model = Arc::new(self.model.clone());
406        let total = contexts.len();
407        let batch_size = self.batch_size;
408
409        // Calculate number of batches
410        let num_batches = (total + batch_size - 1) / batch_size;
411
412        info!(
413            "Starting parallel enrichment of {} frames with {} workers, {} frames per batch ({} batches)",
414            total, self.parallelism, batch_size, num_batches
415        );
416
417        // Create batches of contexts
418        let batches: Vec<Vec<EnrichmentContext>> = contexts
419            .into_iter()
420            .collect::<Vec<_>>()
421            .chunks(batch_size)
422            .map(|chunk| chunk.to_vec())
423            .collect();
424
425        // Use rayon for parallel processing of batches
426        let pool = rayon::ThreadPoolBuilder::new()
427            .num_threads(self.parallelism)
428            .build()
429            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
430
431        let batch_results: Vec<std::collections::HashMap<u64, Vec<MemoryCard>>> =
432            pool.install(|| {
433                batches
434                    .into_par_iter()
435                    .enumerate()
436                    .map(|(batch_idx, batch)| {
437                        // Filter out empty texts
438                        let non_empty: Vec<&EnrichmentContext> =
439                            batch.iter().filter(|ctx| !ctx.text.is_empty()).collect();
440
441                        if non_empty.is_empty() {
442                            // Return empty results for all frames in this batch
443                            return batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect();
444                        }
445
446                        // Progress logging every 10 batches
447                        if batch_idx > 0 && batch_idx % 10 == 0 {
448                            info!("Enrichment progress: {} batches processed", batch_idx);
449                        }
450
451                        match Self::run_batched_inference_blocking(
452                            &client, &api_key, &model, &non_empty,
453                        ) {
454                            Ok(output) => {
455                                debug!(
456                                    "OpenAI batch output (batch {}): {}...",
457                                    batch_idx,
458                                    &output[..output.len().min(100)]
459                                );
460                                Self::parse_batched_output(&output, &non_empty)
461                            }
462                            Err(err) => {
463                                warn!(
464                                    "OpenAI batch inference failed (batch {}): {}",
465                                    batch_idx, err
466                                );
467                                // Return empty results for all frames in this batch
468                                batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect()
469                            }
470                        }
471                    })
472                    .collect()
473            });
474
475        // Flatten batch results into a single vec
476        let mut results: Vec<(u64, Vec<MemoryCard>)> = Vec::with_capacity(total);
477        for batch_map in batch_results {
478            for (frame_id, cards) in batch_map {
479                results.push((frame_id, cards));
480            }
481        }
482
483        info!(
484            "Parallel enrichment complete: {} frames processed in {} batches",
485            results.len(),
486            num_batches
487        );
488        Ok(results)
489    }
490}
491
492/// Parse a memory kind string into the enum.
493fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
494    match s.to_lowercase().as_str() {
495        "fact" => Some(MemoryKind::Fact),
496        "preference" => Some(MemoryKind::Preference),
497        "event" => Some(MemoryKind::Event),
498        "profile" => Some(MemoryKind::Profile),
499        "relationship" => Some(MemoryKind::Relationship),
500        "other" => Some(MemoryKind::Other),
501        _ => None,
502    }
503}
504
505/// Parse a polarity string into the enum.
506fn parse_polarity(s: &str) -> Polarity {
507    match s.to_lowercase().as_str() {
508        "positive" => Polarity::Positive,
509        "negative" => Polarity::Negative,
510        _ => Polarity::Neutral,
511    }
512}
513
514impl EnrichmentEngine for OpenAiEngine {
515    fn kind(&self) -> &str {
516        "openai:gpt-4o-mini"
517    }
518
519    fn version(&self) -> &str {
520        "1.0.0"
521    }
522
523    fn init(&mut self) -> memvid_core::Result<()> {
524        if self.api_key.is_empty() {
525            return Err(memvid_core::MemvidError::EmbeddingFailed {
526                reason: "OPENAI_API_KEY environment variable not set".into(),
527            });
528        }
529        // Use longer timeout for batched requests (multiple frames per call)
530        let client = crate::http::blocking_client(Duration::from_secs(120)).map_err(|err| {
531            memvid_core::MemvidError::EmbeddingFailed {
532                reason: format!("Failed to create OpenAI HTTP client: {err}").into(),
533            }
534        })?;
535        self.client = Some(client);
536        self.ready = true;
537        Ok(())
538    }
539
540    fn is_ready(&self) -> bool {
541        self.ready
542    }
543
544    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
545        if ctx.text.is_empty() {
546            return EnrichmentResult::empty();
547        }
548
549        let client = match self.client.as_ref() {
550            Some(client) => client,
551            None => {
552                return EnrichmentResult::failed(
553                    "OpenAI engine not initialized (init() not called)".to_string(),
554                )
555            }
556        };
557
558        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
559            Ok(output) => {
560                debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
561                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
562                EnrichmentResult::success(cards)
563            }
564            Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
565        }
566    }
567}
568
569impl Default for OpenAiEngine {
570    fn default() -> Self {
571        Self::new()
572    }
573}