1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34const BATCH_EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from multiple text blocks.
36
37Each text block is labeled with a FRAME_ID. For each distinct fact in each block, output a memory card with the frame_id field:
38
39MEMORY_START
40frame_id: <the FRAME_ID of the source text>
41kind: <Fact|Preference|Event|Profile|Relationship|Other>
42entity: <the main entity this memory is about, use "user" for the human in the conversation>
43slot: <a short key describing what aspect of the entity>
44value: <the actual information>
45polarity: <Positive|Negative|Neutral>
46MEMORY_END
47
48Only extract information that is explicitly stated. Do not infer or guess.
49If a text block has no facts, output MEMORY_NONE with its frame_id.
50
51Process these text blocks:
52"#;
53
54#[derive(Debug, Serialize, Clone)]
56struct ChatMessage {
57 role: String,
58 content: String,
59}
60
61#[derive(Debug, Serialize)]
63struct ChatRequest {
64 model: String,
65 messages: Vec<ChatMessage>,
66 max_tokens: u32,
67 temperature: f32,
68}
69
70#[derive(Debug, Deserialize)]
72struct ChatResponse {
73 choices: Vec<ChatChoice>,
74}
75
76#[derive(Debug, Deserialize)]
77struct ChatChoice {
78 message: ChatMessageResponse,
79}
80
81#[derive(Debug, Deserialize)]
82struct ChatMessageResponse {
83 content: String,
84}
85
86pub struct OpenAiEngine {
88 api_key: String,
90 model: String,
92 ready: bool,
94 parallelism: usize,
96 batch_size: usize,
98 client: Option<Client>,
100}
101
102impl OpenAiEngine {
103 pub fn new() -> Self {
105 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
106 Self {
107 api_key,
108 model: "gpt-4o-mini".to_string(),
109 ready: false,
110 parallelism: 20, batch_size: 10, client: None,
113 }
114 }
115
116 pub fn with_model(model: &str) -> Self {
118 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
119 Self {
120 api_key,
121 model: model.to_string(),
122 ready: false,
123 parallelism: 20,
124 batch_size: 10,
125 client: None,
126 }
127 }
128
129 pub fn with_parallelism(mut self, n: usize) -> Self {
131 self.parallelism = n;
132 self
133 }
134
135 pub fn with_batch_size(mut self, n: usize) -> Self {
137 self.batch_size = n.max(1); self
139 }
140
141 fn run_inference_blocking(client: &Client, api_key: &str, model: &str, text: &str) -> Result<String> {
143 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
144
145 let request = ChatRequest {
146 model: model.to_string(),
147 messages: vec![ChatMessage {
148 role: "user".to_string(),
149 content: prompt,
150 }],
151 max_tokens: 1024,
152 temperature: 0.0,
153 };
154
155 let response = client
156 .post("https://api.openai.com/v1/chat/completions")
157 .header("Authorization", format!("Bearer {}", api_key))
158 .header("Content-Type", "application/json")
159 .json(&request)
160 .send()
161 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
162
163 if !response.status().is_success() {
164 let status = response.status();
165 let body = response.text().unwrap_or_default();
166 return Err(anyhow!("OpenAI API error {}: {}", status, body));
167 }
168
169 let chat_response: ChatResponse = response
170 .json()
171 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
172
173 chat_response
174 .choices
175 .first()
176 .map(|c| c.message.content.clone())
177 .ok_or_else(|| anyhow!("No response from OpenAI"))
178 }
179
180 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
182 let mut cards = Vec::new();
183
184 if output.contains("MEMORY_NONE") {
186 return cards;
187 }
188
189 for block in output.split("MEMORY_START") {
191 let block = block.trim();
192 if block.is_empty() || !block.contains("MEMORY_END") {
193 continue;
194 }
195
196 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
197
198 let mut kind = None;
200 let mut entity = None;
201 let mut slot = None;
202 let mut value = None;
203 let mut polarity = Polarity::Neutral;
204
205 for line in block.lines() {
206 let line = line.trim();
207 if let Some(rest) = line.strip_prefix("kind:") {
208 kind = parse_memory_kind(rest.trim());
209 } else if let Some(rest) = line.strip_prefix("entity:") {
210 entity = Some(rest.trim().to_string());
211 } else if let Some(rest) = line.strip_prefix("slot:") {
212 slot = Some(rest.trim().to_string());
213 } else if let Some(rest) = line.strip_prefix("value:") {
214 value = Some(rest.trim().to_string());
215 } else if let Some(rest) = line.strip_prefix("polarity:") {
216 polarity = parse_polarity(rest.trim());
217 }
218 }
219
220 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
222 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
223 match MemoryCardBuilder::new()
224 .kind(k)
225 .entity(&e)
226 .slot(&s)
227 .value(&v)
228 .polarity(polarity)
229 .source(frame_id, Some(uri.to_string()))
230 .document_date(timestamp)
231 .engine("openai:gpt-4o-mini", "1.0.0")
232 .build(0)
233 {
234 Ok(card) => cards.push(card),
235 Err(err) => {
236 warn!("Failed to build memory card: {}", err);
237 }
238 }
239 }
240 }
241 }
242
243 cards
244 }
245
246 fn run_batched_inference_blocking(
248 client: &Client,
249 api_key: &str,
250 model: &str,
251 contexts: &[&EnrichmentContext],
252 ) -> Result<String> {
253 let mut prompt = BATCH_EXTRACTION_PROMPT.to_string();
255 for ctx in contexts {
256 prompt.push_str(&format!(
257 "\n\n=== FRAME_ID: {} ===\n{}",
258 ctx.frame_id, ctx.text
259 ));
260 }
261
262 let max_tokens = 1024 + (contexts.len() as u32 * 512);
264
265 let request = ChatRequest {
266 model: model.to_string(),
267 messages: vec![ChatMessage {
268 role: "user".to_string(),
269 content: prompt,
270 }],
271 max_tokens: max_tokens.min(4096), temperature: 0.0,
273 };
274
275 let response = client
276 .post("https://api.openai.com/v1/chat/completions")
277 .header("Authorization", format!("Bearer {}", api_key))
278 .header("Content-Type", "application/json")
279 .json(&request)
280 .send()
281 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
282
283 if !response.status().is_success() {
284 let status = response.status();
285 let body = response.text().unwrap_or_default();
286 return Err(anyhow!("OpenAI API error {}: {}", status, body));
287 }
288
289 let chat_response: ChatResponse = response
290 .json()
291 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
292
293 chat_response
294 .choices
295 .first()
296 .map(|c| c.message.content.clone())
297 .ok_or_else(|| anyhow!("No response from OpenAI"))
298 }
299
300 fn parse_batched_output(
302 output: &str,
303 contexts: &[&EnrichmentContext],
304 ) -> std::collections::HashMap<u64, Vec<MemoryCard>> {
305 let mut results: std::collections::HashMap<u64, Vec<MemoryCard>> = std::collections::HashMap::new();
306
307 for ctx in contexts {
309 results.insert(ctx.frame_id, Vec::new());
310 }
311
312 let ctx_lookup: std::collections::HashMap<u64, &EnrichmentContext> =
314 contexts.iter().map(|c| (c.frame_id, *c)).collect();
315
316 for block in output.split("MEMORY_START") {
318 let block = block.trim();
319 if block.is_empty() || !block.contains("MEMORY_END") {
320 continue;
321 }
322
323 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
324
325 let mut frame_id: Option<u64> = None;
327 let mut kind = None;
328 let mut entity = None;
329 let mut slot = None;
330 let mut value = None;
331 let mut polarity = Polarity::Neutral;
332
333 for line in block.lines() {
334 let line = line.trim();
335 if let Some(rest) = line.strip_prefix("frame_id:") {
336 frame_id = rest.trim().parse().ok();
337 } else if let Some(rest) = line.strip_prefix("kind:") {
338 kind = parse_memory_kind(rest.trim());
339 } else if let Some(rest) = line.strip_prefix("entity:") {
340 entity = Some(rest.trim().to_string());
341 } else if let Some(rest) = line.strip_prefix("slot:") {
342 slot = Some(rest.trim().to_string());
343 } else if let Some(rest) = line.strip_prefix("value:") {
344 value = Some(rest.trim().to_string());
345 } else if let Some(rest) = line.strip_prefix("polarity:") {
346 polarity = parse_polarity(rest.trim());
347 }
348 }
349
350 if let (Some(fid), Some(k), Some(e), Some(s), Some(v)) = (frame_id, kind, entity, slot, value) {
352 if let Some(ctx) = ctx_lookup.get(&fid) {
353 let uri = &ctx.uri;
354 let timestamp = ctx.timestamp;
355
356 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
357 match MemoryCardBuilder::new()
358 .kind(k)
359 .entity(&e)
360 .slot(&s)
361 .value(&v)
362 .polarity(polarity)
363 .source(fid, Some(uri.to_string()))
364 .document_date(timestamp)
365 .engine("openai:gpt-4o-mini", "1.0.0")
366 .build(0)
367 {
368 Ok(card) => {
369 results.entry(fid).or_default().push(card);
370 }
371 Err(err) => {
372 warn!("Failed to build memory card: {}", err);
373 }
374 }
375 }
376 }
377 }
378 }
379
380 results
381 }
382
383 pub fn enrich_batch(
387 &self,
388 contexts: Vec<EnrichmentContext>,
389 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
390 let client = self
391 .client
392 .as_ref()
393 .ok_or_else(|| anyhow!("OpenAI engine not initialized (init() not called)"))?
394 .clone();
395 let client = Arc::new(client);
396 let api_key = Arc::new(self.api_key.clone());
397 let model = Arc::new(self.model.clone());
398 let total = contexts.len();
399 let batch_size = self.batch_size;
400
401 let num_batches = (total + batch_size - 1) / batch_size;
403
404 info!(
405 "Starting parallel enrichment of {} frames with {} workers, {} frames per batch ({} batches)",
406 total, self.parallelism, batch_size, num_batches
407 );
408
409 let batches: Vec<Vec<EnrichmentContext>> = contexts
411 .into_iter()
412 .collect::<Vec<_>>()
413 .chunks(batch_size)
414 .map(|chunk| chunk.to_vec())
415 .collect();
416
417 let pool = rayon::ThreadPoolBuilder::new()
419 .num_threads(self.parallelism)
420 .build()
421 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
422
423 let batch_results: Vec<std::collections::HashMap<u64, Vec<MemoryCard>>> = pool.install(|| {
424 batches
425 .into_par_iter()
426 .enumerate()
427 .map(|(batch_idx, batch)| {
428 let non_empty: Vec<&EnrichmentContext> = batch
430 .iter()
431 .filter(|ctx| !ctx.text.is_empty())
432 .collect();
433
434 if non_empty.is_empty() {
435 return batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect();
437 }
438
439 if batch_idx > 0 && batch_idx % 10 == 0 {
441 info!("Enrichment progress: {} batches processed", batch_idx);
442 }
443
444 match Self::run_batched_inference_blocking(&client, &api_key, &model, &non_empty) {
445 Ok(output) => {
446 debug!(
447 "OpenAI batch output (batch {}): {}...",
448 batch_idx,
449 &output[..output.len().min(100)]
450 );
451 Self::parse_batched_output(&output, &non_empty)
452 }
453 Err(err) => {
454 warn!(
455 "OpenAI batch inference failed (batch {}): {}",
456 batch_idx, err
457 );
458 batch.iter().map(|ctx| (ctx.frame_id, Vec::new())).collect()
460 }
461 }
462 })
463 .collect()
464 });
465
466 let mut results: Vec<(u64, Vec<MemoryCard>)> = Vec::with_capacity(total);
468 for batch_map in batch_results {
469 for (frame_id, cards) in batch_map {
470 results.push((frame_id, cards));
471 }
472 }
473
474 info!(
475 "Parallel enrichment complete: {} frames processed in {} batches",
476 results.len(),
477 num_batches
478 );
479 Ok(results)
480 }
481}
482
483fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
485 match s.to_lowercase().as_str() {
486 "fact" => Some(MemoryKind::Fact),
487 "preference" => Some(MemoryKind::Preference),
488 "event" => Some(MemoryKind::Event),
489 "profile" => Some(MemoryKind::Profile),
490 "relationship" => Some(MemoryKind::Relationship),
491 "other" => Some(MemoryKind::Other),
492 _ => None,
493 }
494}
495
496fn parse_polarity(s: &str) -> Polarity {
498 match s.to_lowercase().as_str() {
499 "positive" => Polarity::Positive,
500 "negative" => Polarity::Negative,
501 _ => Polarity::Neutral,
502 }
503}
504
505impl EnrichmentEngine for OpenAiEngine {
506 fn kind(&self) -> &str {
507 "openai:gpt-4o-mini"
508 }
509
510 fn version(&self) -> &str {
511 "1.0.0"
512 }
513
514 fn init(&mut self) -> memvid_core::Result<()> {
515 if self.api_key.is_empty() {
516 return Err(memvid_core::MemvidError::EmbeddingFailed {
517 reason: "OPENAI_API_KEY environment variable not set".into(),
518 });
519 }
520 let client =
522 crate::http::blocking_client(Duration::from_secs(120)).map_err(|err| {
523 memvid_core::MemvidError::EmbeddingFailed {
524 reason: format!("Failed to create OpenAI HTTP client: {err}").into(),
525 }
526 })?;
527 self.client = Some(client);
528 self.ready = true;
529 Ok(())
530 }
531
532 fn is_ready(&self) -> bool {
533 self.ready
534 }
535
536 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
537 if ctx.text.is_empty() {
538 return EnrichmentResult::empty();
539 }
540
541 let client = match self.client.as_ref() {
542 Some(client) => client,
543 None => {
544 return EnrichmentResult::failed(
545 "OpenAI engine not initialized (init() not called)".to_string(),
546 )
547 }
548 };
549
550 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
551 Ok(output) => {
552 debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
553 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
554 EnrichmentResult::success(cards)
555 }
556 Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
557 }
558 }
559}
560
561impl Default for OpenAiEngine {
562 fn default() -> Self {
563 Self::new()
564 }
565}