memvid_cli/enrich/
openai.rs

1//! OpenAI-based enrichment engine using GPT-4o-mini.
2//!
3//! This engine uses the OpenAI API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tracing::{debug, info, warn};
14
15/// The extraction prompt for GPT-4o-mini
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
17
18For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
19MEMORY_START
20kind: <Fact|Preference|Event|Profile|Relationship|Other>
21entity: <the main entity this memory is about, use "user" for the human in the conversation>
22slot: <a short key describing what aspect of the entity>
23value: <the actual information>
24polarity: <Positive|Negative|Neutral>
25MEMORY_END
26
27Only extract information that is explicitly stated. Do not infer or guess.
28If there are no clear facts to extract, output MEMORY_NONE.
29
30Extract memories from this text:
31"#;
32
33/// OpenAI API request message
34#[derive(Debug, Serialize, Clone)]
35struct ChatMessage {
36    role: String,
37    content: String,
38}
39
40/// OpenAI API request
41#[derive(Debug, Serialize)]
42struct ChatRequest {
43    model: String,
44    messages: Vec<ChatMessage>,
45    max_tokens: u32,
46    temperature: f32,
47}
48
49/// OpenAI API response
50#[derive(Debug, Deserialize)]
51struct ChatResponse {
52    choices: Vec<ChatChoice>,
53}
54
55#[derive(Debug, Deserialize)]
56struct ChatChoice {
57    message: ChatMessageResponse,
58}
59
60#[derive(Debug, Deserialize)]
61struct ChatMessageResponse {
62    content: String,
63}
64
65/// OpenAI enrichment engine using GPT-4o-mini with parallel processing.
66pub struct OpenAiEngine {
67    /// API key
68    api_key: String,
69    /// Model to use
70    model: String,
71    /// Whether the engine is initialized
72    ready: bool,
73    /// Number of parallel workers (default: 20)
74    parallelism: usize,
75}
76
77impl OpenAiEngine {
78    /// Create a new OpenAI engine.
79    pub fn new() -> Self {
80        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
81        Self {
82            api_key,
83            model: "gpt-4o-mini".to_string(),
84            ready: false,
85            parallelism: 20, // 20 concurrent requests
86        }
87    }
88
89    /// Create with a specific model.
90    pub fn with_model(model: &str) -> Self {
91        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
92        Self {
93            api_key,
94            model: model.to_string(),
95            ready: false,
96            parallelism: 20,
97        }
98    }
99
100    /// Set parallelism level.
101    pub fn with_parallelism(mut self, n: usize) -> Self {
102        self.parallelism = n;
103        self
104    }
105
106    /// Run inference via OpenAI API (blocking, thread-safe).
107    fn run_inference_blocking(api_key: &str, model: &str, text: &str) -> Result<String> {
108        let client = Client::new();
109        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
110
111        let request = ChatRequest {
112            model: model.to_string(),
113            messages: vec![ChatMessage {
114                role: "user".to_string(),
115                content: prompt,
116            }],
117            max_tokens: 1024,
118            temperature: 0.0,
119        };
120
121        let response = client
122            .post("https://api.openai.com/v1/chat/completions")
123            .header("Authorization", format!("Bearer {}", api_key))
124            .header("Content-Type", "application/json")
125            .json(&request)
126            .send()
127            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
128
129        if !response.status().is_success() {
130            let status = response.status();
131            let body = response.text().unwrap_or_default();
132            return Err(anyhow!("OpenAI API error {}: {}", status, body));
133        }
134
135        let chat_response: ChatResponse = response
136            .json()
137            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
138
139        chat_response
140            .choices
141            .first()
142            .map(|c| c.message.content.clone())
143            .ok_or_else(|| anyhow!("No response from OpenAI"))
144    }
145
146    /// Parse the LLM output into memory cards.
147    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
148        let mut cards = Vec::new();
149
150        // Check for "no memories" signal
151        if output.contains("MEMORY_NONE") {
152            return cards;
153        }
154
155        // Parse MEMORY_START...MEMORY_END blocks
156        for block in output.split("MEMORY_START") {
157            let block = block.trim();
158            if block.is_empty() || !block.contains("MEMORY_END") {
159                continue;
160            }
161
162            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
163
164            // Parse fields
165            let mut kind = None;
166            let mut entity = None;
167            let mut slot = None;
168            let mut value = None;
169            let mut polarity = Polarity::Neutral;
170
171            for line in block.lines() {
172                let line = line.trim();
173                if let Some(rest) = line.strip_prefix("kind:") {
174                    kind = parse_memory_kind(rest.trim());
175                } else if let Some(rest) = line.strip_prefix("entity:") {
176                    entity = Some(rest.trim().to_string());
177                } else if let Some(rest) = line.strip_prefix("slot:") {
178                    slot = Some(rest.trim().to_string());
179                } else if let Some(rest) = line.strip_prefix("value:") {
180                    value = Some(rest.trim().to_string());
181                } else if let Some(rest) = line.strip_prefix("polarity:") {
182                    polarity = parse_polarity(rest.trim());
183                }
184            }
185
186            // Build memory card if we have required fields
187            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
188                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
189                    match MemoryCardBuilder::new()
190                        .kind(k)
191                        .entity(&e)
192                        .slot(&s)
193                        .value(&v)
194                        .polarity(polarity)
195                        .source(frame_id, Some(uri.to_string()))
196                        .document_date(timestamp)
197                        .engine("openai:gpt-4o-mini", "1.0.0")
198                        .build(0)
199                    {
200                        Ok(card) => cards.push(card),
201                        Err(err) => {
202                            warn!("Failed to build memory card: {}", err);
203                        }
204                    }
205                }
206            }
207        }
208
209        cards
210    }
211
212    /// Process multiple frames in parallel and return all cards.
213    /// This is the key method for fast enrichment.
214    pub fn enrich_batch(&self, contexts: Vec<EnrichmentContext>) -> Vec<(u64, Vec<MemoryCard>)> {
215        let api_key = Arc::new(self.api_key.clone());
216        let model = Arc::new(self.model.clone());
217        let total = contexts.len();
218
219        info!(
220            "Starting parallel enrichment of {} frames with {} workers",
221            total, self.parallelism
222        );
223
224        // Use rayon for parallel processing
225        let pool = rayon::ThreadPoolBuilder::new()
226            .num_threads(self.parallelism)
227            .build()
228            .unwrap();
229
230        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
231            contexts
232                .into_par_iter()
233                .enumerate()
234                .map(|(i, ctx)| {
235                    if ctx.text.is_empty() {
236                        return (ctx.frame_id, vec![]);
237                    }
238
239                    // Progress logging every 50 frames
240                    if i > 0 && i % 50 == 0 {
241                        info!("Enrichment progress: {}/{} frames", i, total);
242                    }
243
244                    match Self::run_inference_blocking(&api_key, &model, &ctx.text) {
245                        Ok(output) => {
246                            debug!(
247                                "OpenAI output for frame {}: {}",
248                                ctx.frame_id,
249                                &output[..output.len().min(100)]
250                            );
251                            let cards =
252                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
253                            (ctx.frame_id, cards)
254                        }
255                        Err(err) => {
256                            warn!(
257                                "OpenAI inference failed for frame {}: {}",
258                                ctx.frame_id, err
259                            );
260                            (ctx.frame_id, vec![])
261                        }
262                    }
263                })
264                .collect()
265        });
266
267        info!(
268            "Parallel enrichment complete: {} frames processed",
269            results.len()
270        );
271        results
272    }
273}
274
275/// Parse a memory kind string into the enum.
276fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
277    match s.to_lowercase().as_str() {
278        "fact" => Some(MemoryKind::Fact),
279        "preference" => Some(MemoryKind::Preference),
280        "event" => Some(MemoryKind::Event),
281        "profile" => Some(MemoryKind::Profile),
282        "relationship" => Some(MemoryKind::Relationship),
283        "other" => Some(MemoryKind::Other),
284        _ => None,
285    }
286}
287
288/// Parse a polarity string into the enum.
289fn parse_polarity(s: &str) -> Polarity {
290    match s.to_lowercase().as_str() {
291        "positive" => Polarity::Positive,
292        "negative" => Polarity::Negative,
293        _ => Polarity::Neutral,
294    }
295}
296
297impl EnrichmentEngine for OpenAiEngine {
298    fn kind(&self) -> &str {
299        "openai:gpt-4o-mini"
300    }
301
302    fn version(&self) -> &str {
303        "1.0.0"
304    }
305
306    fn init(&mut self) -> memvid_core::Result<()> {
307        if self.api_key.is_empty() {
308            return Err(memvid_core::MemvidError::EmbeddingFailed {
309                reason: "OPENAI_API_KEY environment variable not set".into(),
310            });
311        }
312        self.ready = true;
313        Ok(())
314    }
315
316    fn is_ready(&self) -> bool {
317        self.ready
318    }
319
320    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
321        if ctx.text.is_empty() {
322            return EnrichmentResult::empty();
323        }
324
325        match Self::run_inference_blocking(&self.api_key, &self.model, &ctx.text) {
326            Ok(output) => {
327                debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
328                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
329                EnrichmentResult::success(cards)
330            }
331            Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
332        }
333    }
334}
335
336impl Default for OpenAiEngine {
337    fn default() -> Self {
338        Self::new()
339    }
340}