1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34#[derive(Debug, Serialize, Clone)]
36struct ChatMessage {
37 role: String,
38 content: String,
39}
40
41#[derive(Debug, Serialize)]
43struct ChatRequest {
44 model: String,
45 messages: Vec<ChatMessage>,
46 max_tokens: u32,
47 temperature: f32,
48}
49
50#[derive(Debug, Deserialize)]
52struct ChatResponse {
53 choices: Vec<ChatChoice>,
54}
55
56#[derive(Debug, Deserialize)]
57struct ChatChoice {
58 message: ChatMessageResponse,
59}
60
61#[derive(Debug, Deserialize)]
62struct ChatMessageResponse {
63 content: String,
64}
65
66pub struct MistralEngine {
68 api_key: String,
70 model: String,
72 ready: bool,
74 parallelism: usize,
76 client: Option<Client>,
78}
79
80impl MistralEngine {
81 pub fn new() -> Self {
83 let api_key = std::env::var("MISTRAL_API_KEY").unwrap_or_default();
84 Self {
85 api_key,
86 model: "mistral-large-latest".to_string(),
87 ready: false,
88 parallelism: 20,
89 client: None,
90 }
91 }
92
93 pub fn with_model(model: &str) -> Self {
95 let api_key = std::env::var("MISTRAL_API_KEY").unwrap_or_default();
96 Self {
97 api_key,
98 model: model.to_string(),
99 ready: false,
100 parallelism: 20,
101 client: None,
102 }
103 }
104
105 pub fn with_parallelism(mut self, n: usize) -> Self {
107 self.parallelism = n;
108 self
109 }
110
111 fn run_inference_blocking(
113 client: &Client,
114 api_key: &str,
115 model: &str,
116 text: &str,
117 ) -> Result<String> {
118 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
119
120 let request = ChatRequest {
121 model: model.to_string(),
122 messages: vec![ChatMessage {
123 role: "user".to_string(),
124 content: prompt,
125 }],
126 max_tokens: 1024,
127 temperature: 0.0,
128 };
129
130 let response = client
131 .post("https://api.mistral.ai/v1/chat/completions")
132 .header("Authorization", format!("Bearer {}", api_key))
133 .header("Content-Type", "application/json")
134 .json(&request)
135 .send()
136 .map_err(|e| anyhow!("Mistral API request failed: {}", e))?;
137
138 if !response.status().is_success() {
139 let status = response.status();
140 let body = response.text().unwrap_or_default();
141 return Err(anyhow!("Mistral API error {}: {}", status, body));
142 }
143
144 let chat_response: ChatResponse = response
145 .json()
146 .map_err(|e| anyhow!("Failed to parse Mistral response: {}", e))?;
147
148 chat_response
149 .choices
150 .first()
151 .map(|c| c.message.content.clone())
152 .ok_or_else(|| anyhow!("No response from Mistral"))
153 }
154
155 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
157 let mut cards = Vec::new();
158
159 if output.contains("MEMORY_NONE") {
160 return cards;
161 }
162
163 for block in output.split("MEMORY_START") {
164 let block = block.trim();
165 if block.is_empty() || !block.contains("MEMORY_END") {
166 continue;
167 }
168
169 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
170
171 let mut kind = None;
172 let mut entity = None;
173 let mut slot = None;
174 let mut value = None;
175 let mut polarity = Polarity::Neutral;
176
177 for line in block.lines() {
178 let line = line.trim();
179 if let Some(rest) = line.strip_prefix("kind:") {
180 kind = parse_memory_kind(rest.trim());
181 } else if let Some(rest) = line.strip_prefix("entity:") {
182 entity = Some(rest.trim().to_string());
183 } else if let Some(rest) = line.strip_prefix("slot:") {
184 slot = Some(rest.trim().to_string());
185 } else if let Some(rest) = line.strip_prefix("value:") {
186 value = Some(rest.trim().to_string());
187 } else if let Some(rest) = line.strip_prefix("polarity:") {
188 polarity = parse_polarity(rest.trim());
189 }
190 }
191
192 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
193 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
194 match MemoryCardBuilder::new()
195 .kind(k)
196 .entity(&e)
197 .slot(&s)
198 .value(&v)
199 .polarity(polarity)
200 .source(frame_id, Some(uri.to_string()))
201 .document_date(timestamp)
202 .engine("mistral:mistral-large", "1.0.0")
203 .build(0)
204 {
205 Ok(card) => cards.push(card),
206 Err(err) => {
207 warn!("Failed to build memory card: {}", err);
208 }
209 }
210 }
211 }
212 }
213
214 cards
215 }
216
217 pub fn enrich_batch(
219 &self,
220 contexts: Vec<EnrichmentContext>,
221 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
222 let client = self
223 .client
224 .as_ref()
225 .ok_or_else(|| anyhow!("Mistral engine not initialized (init() not called)"))?
226 .clone();
227 let client = Arc::new(client);
228 let api_key = Arc::new(self.api_key.clone());
229 let model = Arc::new(self.model.clone());
230 let total = contexts.len();
231
232 info!(
233 "Starting parallel enrichment of {} frames with {} workers",
234 total, self.parallelism
235 );
236
237 let pool = rayon::ThreadPoolBuilder::new()
238 .num_threads(self.parallelism)
239 .build()
240 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
241
242 let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
243 contexts
244 .into_par_iter()
245 .enumerate()
246 .map(|(i, ctx)| {
247 if ctx.text.is_empty() {
248 return (ctx.frame_id, vec![]);
249 }
250
251 if i > 0 && i % 50 == 0 {
252 info!("Enrichment progress: {}/{} frames", i, total);
253 }
254
255 match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
256 Ok(output) => {
257 debug!(
258 "Mistral output for frame {}: {}",
259 ctx.frame_id,
260 &output[..output.len().min(100)]
261 );
262 let cards =
263 Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
264 (ctx.frame_id, cards)
265 }
266 Err(err) => {
267 warn!(
268 "Mistral inference failed for frame {}: {}",
269 ctx.frame_id, err
270 );
271 (ctx.frame_id, vec![])
272 }
273 }
274 })
275 .collect()
276 });
277
278 info!(
279 "Parallel enrichment complete: {} frames processed",
280 results.len()
281 );
282 Ok(results)
283 }
284}
285
286fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
287 match s.to_lowercase().as_str() {
288 "fact" => Some(MemoryKind::Fact),
289 "preference" => Some(MemoryKind::Preference),
290 "event" => Some(MemoryKind::Event),
291 "profile" => Some(MemoryKind::Profile),
292 "relationship" => Some(MemoryKind::Relationship),
293 "other" => Some(MemoryKind::Other),
294 _ => None,
295 }
296}
297
298fn parse_polarity(s: &str) -> Polarity {
299 match s.to_lowercase().as_str() {
300 "positive" => Polarity::Positive,
301 "negative" => Polarity::Negative,
302 _ => Polarity::Neutral,
303 }
304}
305
306impl EnrichmentEngine for MistralEngine {
307 fn kind(&self) -> &str {
308 "mistral:mistral-large"
309 }
310
311 fn version(&self) -> &str {
312 "1.0.0"
313 }
314
315 fn init(&mut self) -> memvid_core::Result<()> {
316 if self.api_key.is_empty() {
317 return Err(memvid_core::MemvidError::EmbeddingFailed {
318 reason: "MISTRAL_API_KEY environment variable not set".into(),
319 });
320 }
321 let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
322 memvid_core::MemvidError::EmbeddingFailed {
323 reason: format!("Failed to create Mistral HTTP client: {err}").into(),
324 }
325 })?;
326 self.client = Some(client);
327 self.ready = true;
328 Ok(())
329 }
330
331 fn is_ready(&self) -> bool {
332 self.ready
333 }
334
335 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
336 if ctx.text.is_empty() {
337 return EnrichmentResult::empty();
338 }
339
340 let client = match self.client.as_ref() {
341 Some(client) => client,
342 None => {
343 return EnrichmentResult::failed(
344 "Mistral engine not initialized (init() not called)".to_string(),
345 )
346 }
347 };
348
349 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
350 Ok(output) => {
351 debug!("Mistral output for frame {}: {}", ctx.frame_id, output);
352 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
353 EnrichmentResult::success(cards)
354 }
355 Err(err) => EnrichmentResult::failed(format!("Mistral inference failed: {}", err)),
356 }
357 }
358}
359
360impl Default for MistralEngine {
361 fn default() -> Self {
362 Self::new()
363 }
364}