memvid_cli/enrich/
openai.rs

1//! OpenAI-based enrichment engine using GPT-4o-mini.
2//!
3//! This engine uses the OpenAI API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// The extraction prompt for GPT-4o-mini
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34/// OpenAI API request message
35#[derive(Debug, Serialize, Clone)]
36struct ChatMessage {
37    role: String,
38    content: String,
39}
40
41/// OpenAI API request
42#[derive(Debug, Serialize)]
43struct ChatRequest {
44    model: String,
45    messages: Vec<ChatMessage>,
46    max_tokens: u32,
47    temperature: f32,
48}
49
50/// OpenAI API response
51#[derive(Debug, Deserialize)]
52struct ChatResponse {
53    choices: Vec<ChatChoice>,
54}
55
56#[derive(Debug, Deserialize)]
57struct ChatChoice {
58    message: ChatMessageResponse,
59}
60
61#[derive(Debug, Deserialize)]
62struct ChatMessageResponse {
63    content: String,
64}
65
66/// OpenAI enrichment engine using GPT-4o-mini with parallel processing.
67pub struct OpenAiEngine {
68    /// API key
69    api_key: String,
70    /// Model to use
71    model: String,
72    /// Whether the engine is initialized
73    ready: bool,
74    /// Number of parallel workers (default: 20)
75    parallelism: usize,
76    /// Shared HTTP client (built in `init`)
77    client: Option<Client>,
78}
79
80impl OpenAiEngine {
81    /// Create a new OpenAI engine.
82    pub fn new() -> Self {
83        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
84        Self {
85            api_key,
86            model: "gpt-4o-mini".to_string(),
87            ready: false,
88            parallelism: 20, // 20 concurrent requests
89            client: None,
90        }
91    }
92
93    /// Create with a specific model.
94    pub fn with_model(model: &str) -> Self {
95        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
96        Self {
97            api_key,
98            model: model.to_string(),
99            ready: false,
100            parallelism: 20,
101            client: None,
102        }
103    }
104
105    /// Set parallelism level.
106    pub fn with_parallelism(mut self, n: usize) -> Self {
107        self.parallelism = n;
108        self
109    }
110
111    /// Run inference via OpenAI API (blocking, thread-safe).
112    fn run_inference_blocking(client: &Client, api_key: &str, model: &str, text: &str) -> Result<String> {
113        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
114
115        let request = ChatRequest {
116            model: model.to_string(),
117            messages: vec![ChatMessage {
118                role: "user".to_string(),
119                content: prompt,
120            }],
121            max_tokens: 1024,
122            temperature: 0.0,
123        };
124
125        let response = client
126            .post("https://api.openai.com/v1/chat/completions")
127            .header("Authorization", format!("Bearer {}", api_key))
128            .header("Content-Type", "application/json")
129            .json(&request)
130            .send()
131            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
132
133        if !response.status().is_success() {
134            let status = response.status();
135            let body = response.text().unwrap_or_default();
136            return Err(anyhow!("OpenAI API error {}: {}", status, body));
137        }
138
139        let chat_response: ChatResponse = response
140            .json()
141            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
142
143        chat_response
144            .choices
145            .first()
146            .map(|c| c.message.content.clone())
147            .ok_or_else(|| anyhow!("No response from OpenAI"))
148    }
149
150    /// Parse the LLM output into memory cards.
151    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
152        let mut cards = Vec::new();
153
154        // Check for "no memories" signal
155        if output.contains("MEMORY_NONE") {
156            return cards;
157        }
158
159        // Parse MEMORY_START...MEMORY_END blocks
160        for block in output.split("MEMORY_START") {
161            let block = block.trim();
162            if block.is_empty() || !block.contains("MEMORY_END") {
163                continue;
164            }
165
166            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
167
168            // Parse fields
169            let mut kind = None;
170            let mut entity = None;
171            let mut slot = None;
172            let mut value = None;
173            let mut polarity = Polarity::Neutral;
174
175            for line in block.lines() {
176                let line = line.trim();
177                if let Some(rest) = line.strip_prefix("kind:") {
178                    kind = parse_memory_kind(rest.trim());
179                } else if let Some(rest) = line.strip_prefix("entity:") {
180                    entity = Some(rest.trim().to_string());
181                } else if let Some(rest) = line.strip_prefix("slot:") {
182                    slot = Some(rest.trim().to_string());
183                } else if let Some(rest) = line.strip_prefix("value:") {
184                    value = Some(rest.trim().to_string());
185                } else if let Some(rest) = line.strip_prefix("polarity:") {
186                    polarity = parse_polarity(rest.trim());
187                }
188            }
189
190            // Build memory card if we have required fields
191            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
192                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
193                    match MemoryCardBuilder::new()
194                        .kind(k)
195                        .entity(&e)
196                        .slot(&s)
197                        .value(&v)
198                        .polarity(polarity)
199                        .source(frame_id, Some(uri.to_string()))
200                        .document_date(timestamp)
201                        .engine("openai:gpt-4o-mini", "1.0.0")
202                        .build(0)
203                    {
204                        Ok(card) => cards.push(card),
205                        Err(err) => {
206                            warn!("Failed to build memory card: {}", err);
207                        }
208                    }
209                }
210            }
211        }
212
213        cards
214    }
215
216    /// Process multiple frames in parallel and return all cards.
217    /// This is the key method for fast enrichment.
218    pub fn enrich_batch(
219        &self,
220        contexts: Vec<EnrichmentContext>,
221    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
222        let client = self
223            .client
224            .as_ref()
225            .ok_or_else(|| anyhow!("OpenAI engine not initialized (init() not called)"))?
226            .clone();
227        let client = Arc::new(client);
228        let api_key = Arc::new(self.api_key.clone());
229        let model = Arc::new(self.model.clone());
230        let total = contexts.len();
231
232        info!(
233            "Starting parallel enrichment of {} frames with {} workers",
234            total, self.parallelism
235        );
236
237        // Use rayon for parallel processing
238        let pool = rayon::ThreadPoolBuilder::new()
239            .num_threads(self.parallelism)
240            .build()
241            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
242
243        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
244            contexts
245                .into_par_iter()
246                .enumerate()
247                .map(|(i, ctx)| {
248                    if ctx.text.is_empty() {
249                        return (ctx.frame_id, vec![]);
250                    }
251
252                    // Progress logging every 50 frames
253                    if i > 0 && i % 50 == 0 {
254                        info!("Enrichment progress: {}/{} frames", i, total);
255                    }
256
257                    match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
258                        Ok(output) => {
259                            debug!(
260                                "OpenAI output for frame {}: {}",
261                                ctx.frame_id,
262                                &output[..output.len().min(100)]
263                            );
264                            let cards =
265                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
266                            (ctx.frame_id, cards)
267                        }
268                        Err(err) => {
269                            warn!(
270                                "OpenAI inference failed for frame {}: {}",
271                                ctx.frame_id, err
272                            );
273                            (ctx.frame_id, vec![])
274                        }
275                    }
276                })
277                .collect()
278        });
279
280        info!(
281            "Parallel enrichment complete: {} frames processed",
282            results.len()
283        );
284        Ok(results)
285    }
286}
287
288/// Parse a memory kind string into the enum.
289fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
290    match s.to_lowercase().as_str() {
291        "fact" => Some(MemoryKind::Fact),
292        "preference" => Some(MemoryKind::Preference),
293        "event" => Some(MemoryKind::Event),
294        "profile" => Some(MemoryKind::Profile),
295        "relationship" => Some(MemoryKind::Relationship),
296        "other" => Some(MemoryKind::Other),
297        _ => None,
298    }
299}
300
301/// Parse a polarity string into the enum.
302fn parse_polarity(s: &str) -> Polarity {
303    match s.to_lowercase().as_str() {
304        "positive" => Polarity::Positive,
305        "negative" => Polarity::Negative,
306        _ => Polarity::Neutral,
307    }
308}
309
310impl EnrichmentEngine for OpenAiEngine {
311    fn kind(&self) -> &str {
312        "openai:gpt-4o-mini"
313    }
314
315    fn version(&self) -> &str {
316        "1.0.0"
317    }
318
319    fn init(&mut self) -> memvid_core::Result<()> {
320        if self.api_key.is_empty() {
321            return Err(memvid_core::MemvidError::EmbeddingFailed {
322                reason: "OPENAI_API_KEY environment variable not set".into(),
323            });
324        }
325        let client =
326            crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
327                memvid_core::MemvidError::EmbeddingFailed {
328                    reason: format!("Failed to create OpenAI HTTP client: {err}").into(),
329                }
330            })?;
331        self.client = Some(client);
332        self.ready = true;
333        Ok(())
334    }
335
336    fn is_ready(&self) -> bool {
337        self.ready
338    }
339
340    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
341        if ctx.text.is_empty() {
342            return EnrichmentResult::empty();
343        }
344
345        let client = match self.client.as_ref() {
346            Some(client) => client,
347            None => {
348                return EnrichmentResult::failed(
349                    "OpenAI engine not initialized (init() not called)".to_string(),
350                )
351            }
352        };
353
354        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
355            Ok(output) => {
356                debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
357                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
358                EnrichmentResult::success(cards)
359            }
360            Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
361        }
362    }
363}
364
365impl Default for OpenAiEngine {
366    fn default() -> Self {
367        Self::new()
368    }
369}