1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tracing::{debug, info, warn};
14
15const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
17
18For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
19MEMORY_START
20kind: <Fact|Preference|Event|Profile|Relationship|Other>
21entity: <the main entity this memory is about, use "user" for the human in the conversation>
22slot: <a short key describing what aspect of the entity>
23value: <the actual information>
24polarity: <Positive|Negative|Neutral>
25MEMORY_END
26
27Only extract information that is explicitly stated. Do not infer or guess.
28If there are no clear facts to extract, output MEMORY_NONE.
29
30Extract memories from this text:
31"#;
32
33#[derive(Debug, Serialize, Clone)]
35struct ChatMessage {
36 role: String,
37 content: String,
38}
39
40#[derive(Debug, Serialize)]
42struct ChatRequest {
43 model: String,
44 messages: Vec<ChatMessage>,
45 max_tokens: u32,
46 temperature: f32,
47}
48
49#[derive(Debug, Deserialize)]
51struct ChatResponse {
52 choices: Vec<ChatChoice>,
53}
54
55#[derive(Debug, Deserialize)]
56struct ChatChoice {
57 message: ChatMessageResponse,
58}
59
60#[derive(Debug, Deserialize)]
61struct ChatMessageResponse {
62 content: String,
63}
64
65pub struct OpenAiEngine {
67 api_key: String,
69 model: String,
71 ready: bool,
73 parallelism: usize,
75}
76
77impl OpenAiEngine {
78 pub fn new() -> Self {
80 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
81 Self {
82 api_key,
83 model: "gpt-4o-mini".to_string(),
84 ready: false,
85 parallelism: 20, }
87 }
88
89 pub fn with_model(model: &str) -> Self {
91 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
92 Self {
93 api_key,
94 model: model.to_string(),
95 ready: false,
96 parallelism: 20,
97 }
98 }
99
100 pub fn with_parallelism(mut self, n: usize) -> Self {
102 self.parallelism = n;
103 self
104 }
105
106 fn run_inference_blocking(api_key: &str, model: &str, text: &str) -> Result<String> {
108 let client = Client::new();
109 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
110
111 let request = ChatRequest {
112 model: model.to_string(),
113 messages: vec![ChatMessage {
114 role: "user".to_string(),
115 content: prompt,
116 }],
117 max_tokens: 1024,
118 temperature: 0.0,
119 };
120
121 let response = client
122 .post("https://api.openai.com/v1/chat/completions")
123 .header("Authorization", format!("Bearer {}", api_key))
124 .header("Content-Type", "application/json")
125 .json(&request)
126 .send()
127 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
128
129 if !response.status().is_success() {
130 let status = response.status();
131 let body = response.text().unwrap_or_default();
132 return Err(anyhow!("OpenAI API error {}: {}", status, body));
133 }
134
135 let chat_response: ChatResponse = response
136 .json()
137 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
138
139 chat_response
140 .choices
141 .first()
142 .map(|c| c.message.content.clone())
143 .ok_or_else(|| anyhow!("No response from OpenAI"))
144 }
145
146 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
148 let mut cards = Vec::new();
149
150 if output.contains("MEMORY_NONE") {
152 return cards;
153 }
154
155 for block in output.split("MEMORY_START") {
157 let block = block.trim();
158 if block.is_empty() || !block.contains("MEMORY_END") {
159 continue;
160 }
161
162 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
163
164 let mut kind = None;
166 let mut entity = None;
167 let mut slot = None;
168 let mut value = None;
169 let mut polarity = Polarity::Neutral;
170
171 for line in block.lines() {
172 let line = line.trim();
173 if let Some(rest) = line.strip_prefix("kind:") {
174 kind = parse_memory_kind(rest.trim());
175 } else if let Some(rest) = line.strip_prefix("entity:") {
176 entity = Some(rest.trim().to_string());
177 } else if let Some(rest) = line.strip_prefix("slot:") {
178 slot = Some(rest.trim().to_string());
179 } else if let Some(rest) = line.strip_prefix("value:") {
180 value = Some(rest.trim().to_string());
181 } else if let Some(rest) = line.strip_prefix("polarity:") {
182 polarity = parse_polarity(rest.trim());
183 }
184 }
185
186 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
188 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
189 match MemoryCardBuilder::new()
190 .kind(k)
191 .entity(&e)
192 .slot(&s)
193 .value(&v)
194 .polarity(polarity)
195 .source(frame_id, Some(uri.to_string()))
196 .document_date(timestamp)
197 .engine("openai:gpt-4o-mini", "1.0.0")
198 .build(0)
199 {
200 Ok(card) => cards.push(card),
201 Err(err) => {
202 warn!("Failed to build memory card: {}", err);
203 }
204 }
205 }
206 }
207 }
208
209 cards
210 }
211
212 pub fn enrich_batch(&self, contexts: Vec<EnrichmentContext>) -> Vec<(u64, Vec<MemoryCard>)> {
215 let api_key = Arc::new(self.api_key.clone());
216 let model = Arc::new(self.model.clone());
217 let total = contexts.len();
218
219 info!(
220 "Starting parallel enrichment of {} frames with {} workers",
221 total, self.parallelism
222 );
223
224 let pool = rayon::ThreadPoolBuilder::new()
226 .num_threads(self.parallelism)
227 .build()
228 .unwrap();
229
230 let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
231 contexts
232 .into_par_iter()
233 .enumerate()
234 .map(|(i, ctx)| {
235 if ctx.text.is_empty() {
236 return (ctx.frame_id, vec![]);
237 }
238
239 if i > 0 && i % 50 == 0 {
241 info!("Enrichment progress: {}/{} frames", i, total);
242 }
243
244 match Self::run_inference_blocking(&api_key, &model, &ctx.text) {
245 Ok(output) => {
246 debug!(
247 "OpenAI output for frame {}: {}",
248 ctx.frame_id,
249 &output[..output.len().min(100)]
250 );
251 let cards =
252 Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
253 (ctx.frame_id, cards)
254 }
255 Err(err) => {
256 warn!(
257 "OpenAI inference failed for frame {}: {}",
258 ctx.frame_id, err
259 );
260 (ctx.frame_id, vec![])
261 }
262 }
263 })
264 .collect()
265 });
266
267 info!(
268 "Parallel enrichment complete: {} frames processed",
269 results.len()
270 );
271 results
272 }
273}
274
275fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
277 match s.to_lowercase().as_str() {
278 "fact" => Some(MemoryKind::Fact),
279 "preference" => Some(MemoryKind::Preference),
280 "event" => Some(MemoryKind::Event),
281 "profile" => Some(MemoryKind::Profile),
282 "relationship" => Some(MemoryKind::Relationship),
283 "other" => Some(MemoryKind::Other),
284 _ => None,
285 }
286}
287
288fn parse_polarity(s: &str) -> Polarity {
290 match s.to_lowercase().as_str() {
291 "positive" => Polarity::Positive,
292 "negative" => Polarity::Negative,
293 _ => Polarity::Neutral,
294 }
295}
296
297impl EnrichmentEngine for OpenAiEngine {
298 fn kind(&self) -> &str {
299 "openai:gpt-4o-mini"
300 }
301
302 fn version(&self) -> &str {
303 "1.0.0"
304 }
305
306 fn init(&mut self) -> memvid_core::Result<()> {
307 if self.api_key.is_empty() {
308 return Err(memvid_core::MemvidError::EmbeddingFailed {
309 reason: "OPENAI_API_KEY environment variable not set".into(),
310 });
311 }
312 self.ready = true;
313 Ok(())
314 }
315
316 fn is_ready(&self) -> bool {
317 self.ready
318 }
319
320 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
321 if ctx.text.is_empty() {
322 return EnrichmentResult::empty();
323 }
324
325 match Self::run_inference_blocking(&self.api_key, &self.model, &ctx.text) {
326 Ok(output) => {
327 debug!("OpenAI output for frame {}: {}", ctx.frame_id, output);
328 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
329 EnrichmentResult::success(cards)
330 }
331 Err(err) => EnrichmentResult::failed(format!("OpenAI inference failed: {}", err)),
332 }
333 }
334}
335
336impl Default for OpenAiEngine {
337 fn default() -> Self {
338 Self::new()
339 }
340}