1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34#[derive(Debug, Serialize, Clone)]
36struct ChatMessage {
37 role: String,
38 content: String,
39}
40
41#[derive(Debug, Serialize)]
43struct ChatRequest {
44 model: String,
45 messages: Vec<ChatMessage>,
46 max_tokens: u32,
47 temperature: f32,
48}
49
50#[derive(Debug, Deserialize)]
52struct ChatResponse {
53 choices: Vec<ChatChoice>,
54}
55
56#[derive(Debug, Deserialize)]
57struct ChatChoice {
58 message: ChatMessageResponse,
59}
60
61#[derive(Debug, Deserialize)]
62struct ChatMessageResponse {
63 content: String,
64}
65
66pub struct OpenAiEngine {
68 api_key: String,
70 model: String,
72 ready: bool,
74 parallelism: usize,
76 client: Option<Client>,
78}
79
80impl OpenAiEngine {
81 pub fn new() -> Self {
83 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
84 Self {
85 api_key,
86 model: "gpt-4o-mini".to_string(),
87 ready: false,
88 parallelism: 20, client: None,
90 }
91 }
92
93 pub fn with_model(model: &str) -> Self {
95 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
96 Self {
97 api_key,
98 model: model.to_string(),
99 ready: false,
100 parallelism: 20,
101 client: None,
102 }
103 }
104
105 pub fn with_parallelism(mut self, n: usize) -> Self {
107 self.parallelism = n;
108 self
109 }
110
111 fn run_inference_blocking(client: &Client, api_key: &str, model: &str, text: &str) -> Result<String> {
113 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
114
115 let request = ChatRequest {
116 model: model.to_string(),
117 messages: vec![ChatMessage {
118 role: "user".to_string(),
119 content: prompt,
120 }],
121 max_tokens: 1024,
122 temperature: 0.0,
123 };
124
125 let response = client
126 .post("https://api.openai.com/v1/chat/completions")
127 .header("Authorization", format!("Bearer {}", api_key))
128 .header("Content-Type", "application/json")
129 .json(&request)
130 .send()
131 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
132
133 if !response.status().is_success() {
134 let status = response.status();
135 let body = response.text().unwrap_or_default();
136 return Err(anyhow!("OpenAI API error {}: {}", status, body));
137 }
138
139 let chat_response: ChatResponse = response
140 .json()
141 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
142
143 chat_response
144 .choices
145 .first()
146 .map(|c| c.message.content.clone())
147 .ok_or_else(|| anyhow!("No response from OpenAI"))
148 }
149
150 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
152 let mut cards = Vec::new();
153
154 if output.contains("MEMORY_NONE") {
156 return cards;
157 }
158
159 for block in output.split("MEMORY_START") {
161 let block = block.trim();
162 if block.is_empty() || !block.contains("MEMORY_END") {
163 continue;
164 }
165
166 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
167
168 let mut kind = None;
170 let mut entity = None;
171 let mut slot = None;
172 let mut value = None;
173 let mut polarity = Polarity::Neutral;
174
175 for line in block.lines() {
176 let line = line.trim();
177 if let Some(rest) = line.strip_prefix("kind:") {
178 kind = parse_memory_kind(rest.trim());
179 } else if let Some(rest) = line.strip_prefix("entity:") {
180 entity = Some(rest.trim().to_string());
181 } else if let Some(rest) = line.strip_prefix("slot:") {
182 slot = Some(rest.trim().to_string());
183 } else if let Some(rest) = line.strip_prefix("value:") {
184 value = Some(rest.trim().to_string());
185 } else if let Some(rest) = line.strip_prefix("polarity:") {
186 polarity = parse_polarity(rest.trim());
187 }
188 }
189
190 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
192 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
193 match MemoryCardBuilder::new()
194 .kind(k)
195 .entity(&e)
196 .slot(&s)
197 .value(&v)
198 .polarity(polarity)
199 .source(frame_id, Some(uri.to_string()))
200 .document_date(timestamp)
201 .engine("openai:gpt-4o-mini", "1.0.0")
202 .build(0)
203 {
204 Ok(card) => cards.push(card),
205 Err(err) => {
206 warn!("Failed to build memory card: {}", err);
207 }
208 }
209 }
210 }
211 }
212
213 cards
214 }
215
216 pub fn enrich_batch(
219 &self,
220 contexts: Vec<EnrichmentContext>,
221 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
222 let client = self
223 .client
224 .as_ref()
225 .ok_or_else(|| anyhow!("OpenAI engine not initialized (init() not called)"))?
226 .clone();
227 let client = Arc::new(client);
228 let api_key = Arc::new(self.api_key.clone());
229 let model = Arc::new(self.model.clone());
230 let total = contexts.len();
231
232 info!(
233 "Starting parallel enrichment of {} frames with {} workers",
234 total, self.parallelism
235 );
236
237 let pool = rayon::ThreadPoolBuilder::new()
239 .num_threads(self.parallelism)
240 .build()
241 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
242
243 let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
244 contexts
245 .into_par_iter()
246 .enumerate()
247 .map(|(i, ctx)| {
248 if ctx.text.is_empty() {
249 return (ctx.frame_id, vec![]);
250 }
251
252 if i > 0 && i % 50 == 0 {
254 info!("Enrichment progress: {}/{} frames", i, total);
255 }
256
257 match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
258 Ok(output) => {
259 debug!(
260 "OpenAI output for frame {}: {}",
261 ctx.frame_id,
262 &output[..output.len().min(100)]
263 );
264 let cards =
265 Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
266 (ctx.frame_id, cards)
267 }
268 Err(err) => {
269 warn!(
270 "OpenAI inference failed for frame {}: {}",
271 ctx.frame_id, err
272 );
273 (ctx.frame_id, vec![])
274 }
275 }
276 })
277 .collect()
278 });
279
280 info!(
281 "Parallel enrichment complete: {} frames processed",
282 results.len()
283 );
284 Ok(results)
285 }
286}
287
288fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
290 match s.to_lowercase().as_str() {
291 "fact" => Some(MemoryKind::Fact),
292 "preference" => Some(MemoryKind::Preference),
293 "event" => Some(MemoryKind::Event),
294 "profile" => Some(MemoryKind::Profile),
295 "relationship" => Some(MemoryKind::Relationship),
296 "other" => Some(MemoryKind::Other),
297 _ => None,
298 }
299}
300
301fn parse_polarity(s: &str) -> Polarity {
303 match s.to_lowercase().as_str() {
304 "positive" => Polarity::Positive,
305 "negative" => Polarity::Negative,
306 _ => Polarity::Neutral,
307 }
308}
309
310impl EnrichmentEngine for OpenAiEngine {
311 fn kind(&self) -> &str {
312 "openai:gpt-4o-mini"
313 }
314
315 fn version(&self) -> &str {
316 "1.0.0"
317 }
318
319 fn init(&mut self) -> memvid_core::Result<()> {
320 if self.api_key.is_empty() {
321 return Err(memvid_core::MemvidError::EmbeddingFailed {
322 reason: "OPENAI_API_KEY environment variable not set".into(),
323 });
324 }
325 let client =
326 crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
327 memvid_core::MemvidError::EmbeddingFailed {
328 reason: format!("Failed to create OpenAI HTTP client: {err}").into(),
329 }
330 })?;
331 self.client = Some(client);
332 self.ready = true;
333 Ok(())
334 }
335
336 fn is_ready(&self) -> bool {
337 self.ready
338 }
339
340 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
341 if ctx.text.is_empty() {
342 return EnrichmentResult::empty();
343 }
344
345 let client = match self.client.as_ref() {
346 Some(client) => client,
347 None => {
348 return EnrichmentResult::failed(
349 "OpenAI engine not initialized (init() not called)".to_string(),
350 )
351 }
352 };
353
354 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
355 Ok(output) => {
356 debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
357 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
358 EnrichmentResult::success(cards)
359 }
360 Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
361 }
362 }
363}
364
365impl Default for OpenAiEngine {
366 fn default() -> Self {
367 Self::new()
368 }
369}