memvid_cli/enrich/
openai.rs

1//! OpenAI-based enrichment engine using GPT-4o-mini.
2//!
3//! This engine uses the OpenAI API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// The extraction prompt for GPT-4o-mini (single frame)
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34/// The extraction prompt for batched frames (multiple frames per API call)
35const BATCH_EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from multiple text blocks.
36
37Each text block is labeled with a FRAME_ID. For each distinct fact in each block, output a memory card with the frame_id field:
38
39MEMORY_START
40frame_id: <the FRAME_ID of the source text>
41kind: <Fact|Preference|Event|Profile|Relationship|Other>
42entity: <the main entity this memory is about, use "user" for the human in the conversation>
43slot: <a short key describing what aspect of the entity>
44value: <the actual information>
45polarity: <Positive|Negative|Neutral>
46MEMORY_END
47
48Only extract information that is explicitly stated. Do not infer or guess.
49If a text block has no facts, output MEMORY_NONE with its frame_id.
50
51Process these text blocks:
52"#;
53
54/// OpenAI API request message
55#[derive(Debug, Serialize, Clone)]
56struct ChatMessage {
57    role: String,
58    content: String,
59}
60
61/// OpenAI API request
62#[derive(Debug, Serialize)]
63struct ChatRequest {
64    model: String,
65    messages: Vec<ChatMessage>,
66    max_tokens: u32,
67    temperature: f32,
68}
69
70/// OpenAI API response
71#[derive(Debug, Deserialize)]
72struct ChatResponse {
73    choices: Vec<ChatChoice>,
74}
75
76#[derive(Debug, Deserialize)]
77struct ChatChoice {
78    message: ChatMessageResponse,
79}
80
81#[derive(Debug, Deserialize)]
82struct ChatMessageResponse {
83    content: String,
84}
85
86/// OpenAI enrichment engine using GPT-4o-mini with parallel processing.
87pub struct OpenAiEngine {
88    /// API key
89    api_key: String,
90    /// Model to use
91    model: String,
92    /// Whether the engine is initialized
93    ready: bool,
94    /// Number of parallel workers (default: 100)
95    parallelism: usize,
96    /// Number of frames to batch per API call (default: 10)
97    batch_size: usize,
98    /// Shared HTTP client (built in `init`)
99    client: Option<Client>,
100}
101
102impl OpenAiEngine {
103    /// Create a new OpenAI engine.
104    pub fn new() -> Self {
105        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
106        Self {
107            api_key,
108            model: "gpt-4o-mini".to_string(),
109            ready: false,
110            parallelism: 20,  // 20 concurrent requests (balanced for API rate limits)
111            batch_size: 10,   // 10 frames per API call
112            client: None,
113        }
114    }
115
116    /// Create with a specific model.
117    pub fn with_model(model: &str) -> Self {
118        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
119        Self {
120            api_key,
121            model: model.to_string(),
122            ready: false,
123            parallelism: 20,
124            batch_size: 10,
125            client: None,
126        }
127    }
128
129    /// Set parallelism level.
130    pub fn with_parallelism(mut self, n: usize) -> Self {
131        self.parallelism = n;
132        self
133    }
134
135    /// Set batch size (number of frames per API call).
136    pub fn with_batch_size(mut self, n: usize) -> Self {
137        self.batch_size = n.max(1); // At least 1
138        self
139    }
140
141    /// Run inference via OpenAI API (blocking, thread-safe).
142    fn run_inference_blocking(client: &Client, api_key: &str, model: &str, text: &str) -> Result<String> {
143        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
144
145        let request = ChatRequest {
146            model: model.to_string(),
147            messages: vec![ChatMessage {
148                role: "user".to_string(),
149                content: prompt,
150            }],
151            max_tokens: 1024,
152            temperature: 0.0,
153        };
154
155        let response = client
156            .post("https://api.openai.com/v1/chat/completions")
157            .header("Authorization", format!("Bearer {}", api_key))
158            .header("Content-Type", "application/json")
159            .json(&request)
160            .send()
161            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
162
163        if !response.status().is_success() {
164            let status = response.status();
165            let body = response.text().unwrap_or_default();
166            return Err(anyhow!("OpenAI API error {}: {}", status, body));
167        }
168
169        let chat_response: ChatResponse = response
170            .json()
171            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
172
173        chat_response
174            .choices
175            .first()
176            .map(|c| c.message.content.clone())
177            .ok_or_else(|| anyhow!("No response from OpenAI"))
178    }
179
180    /// Parse the LLM output into memory cards.
181    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
182        let mut cards = Vec::new();
183
184        // Check for "no memories" signal
185        if output.contains("MEMORY_NONE") {
186            return cards;
187        }
188
189        // Parse MEMORY_START...MEMORY_END blocks
190        for block in output.split("MEMORY_START") {
191            let block = block.trim();
192            if block.is_empty() || !block.contains("MEMORY_END") {
193                continue;
194            }
195
196            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
197
198            // Parse fields
199            let mut kind = None;
200            let mut entity = None;
201            let mut slot = None;
202            let mut value = None;
203            let mut polarity = Polarity::Neutral;
204
205            for line in block.lines() {
206                let line = line.trim();
207                if let Some(rest) = line.strip_prefix("kind:") {
208                    kind = parse_memory_kind(rest.trim());
209                } else if let Some(rest) = line.strip_prefix("entity:") {
210                    entity = Some(rest.trim().to_string());
211                } else if let Some(rest) = line.strip_prefix("slot:") {
212                    slot = Some(rest.trim().to_string());
213                } else if let Some(rest) = line.strip_prefix("value:") {
214                    value = Some(rest.trim().to_string());
215                } else if let Some(rest) = line.strip_prefix("polarity:") {
216                    polarity = parse_polarity(rest.trim());
217                }
218            }
219
220            // Build memory card if we have required fields
221            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
222                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
223                    match MemoryCardBuilder::new()
224                        .kind(k)
225                        .entity(&e)
226                        .slot(&s)
227                        .value(&v)
228                        .polarity(polarity)
229                        .source(frame_id, Some(uri.to_string()))
230                        .document_date(timestamp)
231                        .engine("openai:gpt-4o-mini", "1.0.0")
232                        .build(0)
233                    {
234                        Ok(card) => cards.push(card),
235                        Err(err) => {
236                            warn!("Failed to build memory card: {}", err);
237                        }
238                    }
239                }
240            }
241        }
242
243        cards
244    }
245
246    /// Run batched inference for multiple frames in a single API call.
247    fn run_batched_inference_blocking(
248        client: &Client,
249        api_key: &str,
250        model: &str,
251        contexts: &[&EnrichmentContext],
252    ) -> Result<String> {
253        // Build the batched prompt with frame markers
254        let mut prompt = BATCH_EXTRACTION_PROMPT.to_string();
255        for ctx in contexts {
256            prompt.push_str(&format!(
257                "\n\n=== FRAME_ID: {} ===\n{}",
258                ctx.frame_id, ctx.text
259            ));
260        }
261
262        // Use larger max_tokens for batched requests
263        let max_tokens = 1024 + (contexts.len() as u32 * 512);
264
265        let request = ChatRequest {
266            model: model.to_string(),
267            messages: vec![ChatMessage {
268                role: "user".to_string(),
269                content: prompt,
270            }],
271            max_tokens: max_tokens.min(4096), // Cap at 4096
272            temperature: 0.0,
273        };
274
275        let response = client
276            .post("https://api.openai.com/v1/chat/completions")
277            .header("Authorization", format!("Bearer {}", api_key))
278            .header("Content-Type", "application/json")
279            .json(&request)
280            .send()
281            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
282
283        if !response.status().is_success() {
284            let status = response.status();
285            let body = response.text().unwrap_or_default();
286            return Err(anyhow!("OpenAI API error {}: {}", status, body));
287        }
288
289        let chat_response: ChatResponse = response
290            .json()
291            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
292
293        chat_response
294            .choices
295            .first()
296            .map(|c| c.message.content.clone())
297            .ok_or_else(|| anyhow!("No response from OpenAI"))
298    }
299
300    /// Parse batched LLM output into memory cards grouped by frame_id.
301    fn parse_batched_output(
302        output: &str,
303        contexts: &[&EnrichmentContext],
304    ) -> std::collections::HashMap<u64, Vec<MemoryCard>> {
305        let mut results: std::collections::HashMap<u64, Vec<MemoryCard>> = std::collections::HashMap::new();
306
307        // Initialize empty results for all frames
308        for ctx in contexts {
309            results.insert(ctx.frame_id, Vec::new());
310        }
311
312        // Build a lookup for context metadata
313        let ctx_lookup: std::collections::HashMap<u64, &EnrichmentContext> =
314            contexts.iter().map(|c| (c.frame_id, *c)).collect();
315
316        // Parse MEMORY_START...MEMORY_END blocks
317        for block in output.split("MEMORY_START") {
318            let block = block.trim();
319            if block.is_empty() || !block.contains("MEMORY_END") {
320                continue;
321            }
322
323            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
324
325            // Parse fields including frame_id
326            let mut frame_id: Option<u64> = None;
327            let mut kind = None;
328            let mut entity = None;
329            let mut slot = None;
330            let mut value = None;
331            let mut polarity = Polarity::Neutral;
332
333            for line in block.lines() {
334                let line = line.trim();
335                if let Some(rest) = line.strip_prefix("frame_id:") {
336                    frame_id = rest.trim().parse().ok();
337                } else if let Some(rest) = line.strip_prefix("kind:") {
338                    kind = parse_memory_kind(rest.trim());
339                } else if let Some(rest) = line.strip_prefix("entity:") {
340                    entity = Some(rest.trim().to_string());
341                } else if let Some(rest) = line.strip_prefix("slot:") {
342                    slot = Some(rest.trim().to_string());
343                } else if let Some(rest) = line.strip_prefix("value:") {
344                    value = Some(rest.trim().to_string());
345                } else if let Some(rest) = line.strip_prefix("polarity:") {
346                    polarity = parse_polarity(rest.trim());
347                }
348            }
349
350            // Build memory card if we have required fields
351            if let (Some(fid), Some(k), Some(e), Some(s), Some(v)) = (frame_id, kind, entity, slot, value) {
352                if let Some(ctx) = ctx_lookup.get(&fid) {
353                    let uri = &ctx.uri;
354                    let timestamp = ctx.timestamp;
355
356                    if !e.is_empty() && !s.is_empty() && !v.is_empty() {
357                        match MemoryCardBuilder::new()
358                            .kind(k)
359                            .entity(&e)
360                            .slot(&s)
361                            .value(&v)
362                            .polarity(polarity)
363                            .source(fid, Some(uri.to_string()))
364                            .document_date(timestamp)
365                            .engine("openai:gpt-4o-mini", "1.0.0")
366                            .build(0)
367                        {
368                            Ok(card) => {
369                                results.entry(fid).or_default().push(card);
370                            }
371                            Err(err) => {
372                                warn!("Failed to build memory card: {}", err);
373                            }
374                        }
375                    }
376                }
377            }
378        }
379
380        results
381    }
382
383    /// Process multiple frames in parallel and return all cards.
384    /// This is the key method for fast enrichment.
385    /// Uses batching to reduce API calls: batch_size frames per API call.
386    pub fn enrich_batch(
387        &self,
388        contexts: Vec<EnrichmentContext>,
389    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
390        let client = self
391            .client
392            .as_ref()
393            .ok_or_else(|| anyhow!("OpenAI engine not initialized (init() not called)"))?
394            .clone();
395        let client = Arc::new(client);
396        let api_key = Arc::new(self.api_key.clone());
397        let model = Arc::new(self.model.clone());
398        let total = contexts.len();
399        let batch_size = self.batch_size;
400
401        // Calculate number of batches
402        let num_batches = (total + batch_size - 1) / batch_size;
403
404        info!(
405            "Starting parallel enrichment of {} frames with {} workers, {} frames per batch ({} batches)",
406            total, self.parallelism, batch_size, num_batches
407        );
408
409        // Create batches of contexts
410        let batches: Vec<Vec<EnrichmentContext>> = contexts
411            .into_iter()
412            .collect::<Vec<_>>()
413            .chunks(batch_size)
414            .map(|chunk| chunk.to_vec())
415            .collect();
416
417        // Use rayon for parallel processing of batches
418        let pool = rayon::ThreadPoolBuilder::new()
419            .num_threads(self.parallelism)
420            .build()
421            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
422
423        let batch_results: Vec<std::collections::HashMap<u64, Vec<MemoryCard>>> = pool.install(|| {
424            batches
425                .into_par_iter()
426                .enumerate()
427                .map(|(batch_idx, batch)| {
428                    // Filter out empty texts
429                    let non_empty: Vec<&EnrichmentContext> = batch
430                        .iter()
431                        .filter(|ctx| !ctx.text.is_empty())
432                        .collect();
433
434                    if non_empty.is_empty() {
435                        // Return empty results for all frames in this batch
436                        return batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect();
437                    }
438
439                    // Progress logging every 10 batches
440                    if batch_idx > 0 && batch_idx % 10 == 0 {
441                        info!("Enrichment progress: {} batches processed", batch_idx);
442                    }
443
444                    match Self::run_batched_inference_blocking(&client, &api_key, &model, &non_empty) {
445                        Ok(output) => {
446                            debug!(
447                                "OpenAI batch output (batch {}): {}...",
448                                batch_idx,
449                                &output[..output.len().min(100)]
450                            );
451                            Self::parse_batched_output(&output, &non_empty)
452                        }
453                        Err(err) => {
454                            warn!(
455                                "OpenAI batch inference failed (batch {}): {}",
456                                batch_idx, err
457                            );
458                            // Return empty results for all frames in this batch
459                            batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect()
460                        }
461                    }
462                })
463                .collect()
464        });
465
466        // Flatten batch results into a single vec
467        let mut results: Vec<(u64, Vec<MemoryCard>)> = Vec::with_capacity(total);
468        for batch_map in batch_results {
469            for (frame_id, cards) in batch_map {
470                results.push((frame_id, cards));
471            }
472        }
473
474        info!(
475            "Parallel enrichment complete: {} frames processed in {} batches",
476            results.len(),
477            num_batches
478        );
479        Ok(results)
480    }
481}
482
483/// Parse a memory kind string into the enum.
484fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
485    match s.to_lowercase().as_str() {
486        "fact" => Some(MemoryKind::Fact),
487        "preference" => Some(MemoryKind::Preference),
488        "event" => Some(MemoryKind::Event),
489        "profile" => Some(MemoryKind::Profile),
490        "relationship" => Some(MemoryKind::Relationship),
491        "other" => Some(MemoryKind::Other),
492        _ => None,
493    }
494}
495
496/// Parse a polarity string into the enum.
497fn parse_polarity(s: &str) -> Polarity {
498    match s.to_lowercase().as_str() {
499        "positive" => Polarity::Positive,
500        "negative" => Polarity::Negative,
501        _ => Polarity::Neutral,
502    }
503}
504
505impl EnrichmentEngine for OpenAiEngine {
506    fn kind(&self) -> &str {
507        "openai:gpt-4o-mini"
508    }
509
510    fn version(&self) -> &str {
511        "1.0.0"
512    }
513
514    fn init(&mut self) -> memvid_core::Result<()> {
515        if self.api_key.is_empty() {
516            return Err(memvid_core::MemvidError::EmbeddingFailed {
517                reason: "OPENAI_API_KEY environment variable not set".into(),
518            });
519        }
520        // Use longer timeout for batched requests (multiple frames per call)
521        let client =
522            crate::http::blocking_client(Duration::from_secs(120)).map_err(|err| {
523                memvid_core::MemvidError::EmbeddingFailed {
524                    reason: format!("Failed to create OpenAI HTTP client: {err}").into(),
525                }
526            })?;
527        self.client = Some(client);
528        self.ready = true;
529        Ok(())
530    }
531
532    fn is_ready(&self) -> bool {
533        self.ready
534    }
535
536    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
537        if ctx.text.is_empty() {
538            return EnrichmentResult::empty();
539        }
540
541        let client = match self.client.as_ref() {
542            Some(client) => client,
543            None => {
544                return EnrichmentResult::failed(
545                    "OpenAI engine not initialized (init() not called)".to_string(),
546                )
547            }
548        };
549
550        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
551            Ok(output) => {
552                debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
553                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
554                EnrichmentResult::success(cards)
555            }
556            Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
557        }
558    }
559}
560
561impl Default for OpenAiEngine {
562    fn default() -> Self {
563        Self::new()
564    }
565}