memvid_cli/enrich/
mistral.rs

1//! Mistral enrichment engine using Mistral Large.
2//!
3//! This engine uses the Mistral API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// The extraction prompt for Mistral
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34/// Mistral API request message
35#[derive(Debug, Serialize, Clone)]
36struct ChatMessage {
37    role: String,
38    content: String,
39}
40
41/// Mistral API request
42#[derive(Debug, Serialize)]
43struct ChatRequest {
44    model: String,
45    messages: Vec<ChatMessage>,
46    max_tokens: u32,
47    temperature: f32,
48}
49
50/// Mistral API response
51#[derive(Debug, Deserialize)]
52struct ChatResponse {
53    choices: Vec<ChatChoice>,
54}
55
56#[derive(Debug, Deserialize)]
57struct ChatChoice {
58    message: ChatMessageResponse,
59}
60
61#[derive(Debug, Deserialize)]
62struct ChatMessageResponse {
63    content: String,
64}
65
66/// Mistral enrichment engine using Mistral Large with parallel processing.
67pub struct MistralEngine {
68    /// API key
69    api_key: String,
70    /// Model to use
71    model: String,
72    /// Whether the engine is initialized
73    ready: bool,
74    /// Number of parallel workers (default: 20)
75    parallelism: usize,
76    /// Shared HTTP client (built in `init`)
77    client: Option<Client>,
78}
79
80impl MistralEngine {
81    /// Create a new Mistral engine.
82    pub fn new() -> Self {
83        let api_key = std::env::var("MISTRAL_API_KEY").unwrap_or_default();
84        Self {
85            api_key,
86            model: "mistral-large-latest".to_string(),
87            ready: false,
88            parallelism: 20,
89            client: None,
90        }
91    }
92
93    /// Create with a specific model.
94    pub fn with_model(model: &str) -> Self {
95        let api_key = std::env::var("MISTRAL_API_KEY").unwrap_or_default();
96        Self {
97            api_key,
98            model: model.to_string(),
99            ready: false,
100            parallelism: 20,
101            client: None,
102        }
103    }
104
105    /// Set parallelism level.
106    pub fn with_parallelism(mut self, n: usize) -> Self {
107        self.parallelism = n;
108        self
109    }
110
111
112    /// Run inference via Mistral API (blocking, thread-safe).
113    fn run_inference_blocking(
114        client: &Client,
115        api_key: &str,
116        model: &str,
117        text: &str,
118    ) -> Result<String> {
119        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
120
121        let request = ChatRequest {
122            model: model.to_string(),
123            messages: vec![ChatMessage {
124                role: "user".to_string(),
125                content: prompt,
126            }],
127            max_tokens: 1024,
128            temperature: 0.0,
129        };
130
131        let response = client
132            .post("https://api.mistral.ai/v1/chat/completions")
133            .header("Authorization", format!("Bearer {}", api_key))
134            .header("Content-Type", "application/json")
135            .json(&request)
136            .send()
137            .map_err(|e| anyhow!("Mistral API request failed: {}", e))?;
138
139        if !response.status().is_success() {
140            let status = response.status();
141            let body = response.text().unwrap_or_default();
142            return Err(anyhow!("Mistral API error {}: {}", status, body));
143        }
144
145        let chat_response: ChatResponse = response
146            .json()
147            .map_err(|e| anyhow!("Failed to parse Mistral response: {}", e))?;
148
149        chat_response
150            .choices
151            .first()
152            .map(|c| c.message.content.clone())
153            .ok_or_else(|| anyhow!("No response from Mistral"))
154    }
155
156    /// Parse the LLM output into memory cards.
157    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
158        let mut cards = Vec::new();
159
160        if output.contains("MEMORY_NONE") {
161            return cards;
162        }
163
164        for block in output.split("MEMORY_START") {
165            let block = block.trim();
166            if block.is_empty() || !block.contains("MEMORY_END") {
167                continue;
168            }
169
170            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
171
172            let mut kind = None;
173            let mut entity = None;
174            let mut slot = None;
175            let mut value = None;
176            let mut polarity = Polarity::Neutral;
177
178            for line in block.lines() {
179                let line = line.trim();
180                if let Some(rest) = line.strip_prefix("kind:") {
181                    kind = parse_memory_kind(rest.trim());
182                } else if let Some(rest) = line.strip_prefix("entity:") {
183                    entity = Some(rest.trim().to_string());
184                } else if let Some(rest) = line.strip_prefix("slot:") {
185                    slot = Some(rest.trim().to_string());
186                } else if let Some(rest) = line.strip_prefix("value:") {
187                    value = Some(rest.trim().to_string());
188                } else if let Some(rest) = line.strip_prefix("polarity:") {
189                    polarity = parse_polarity(rest.trim());
190                }
191            }
192
193            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
194                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
195                    match MemoryCardBuilder::new()
196                        .kind(k)
197                        .entity(&e)
198                        .slot(&s)
199                        .value(&v)
200                        .polarity(polarity)
201                        .source(frame_id, Some(uri.to_string()))
202                        .document_date(timestamp)
203                        .engine("mistral:mistral-large", "1.0.0")
204                        .build(0)
205                    {
206                        Ok(card) => cards.push(card),
207                        Err(err) => {
208                            warn!("Failed to build memory card: {}", err);
209                        }
210                    }
211                }
212            }
213        }
214
215        cards
216    }
217
218    /// Process multiple frames in parallel and return all cards.
219    pub fn enrich_batch(
220        &self,
221        contexts: Vec<EnrichmentContext>,
222    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
223        let client = self
224            .client
225            .as_ref()
226            .ok_or_else(|| anyhow!("Mistral engine not initialized (init() not called)"))?
227            .clone();
228        let client = Arc::new(client);
229        let api_key = Arc::new(self.api_key.clone());
230        let model = Arc::new(self.model.clone());
231        let total = contexts.len();
232
233        info!(
234            "Starting parallel enrichment of {} frames with {} workers",
235            total, self.parallelism
236        );
237
238        let pool = rayon::ThreadPoolBuilder::new()
239            .num_threads(self.parallelism)
240            .build()
241            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
242
243        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
244            contexts
245                .into_par_iter()
246                .enumerate()
247                .map(|(i, ctx)| {
248                    if ctx.text.is_empty() {
249                        return (ctx.frame_id, vec![]);
250                    }
251
252                    if i > 0 && i % 50 == 0 {
253                        info!("Enrichment progress: {}/{} frames", i, total);
254                    }
255
256                    match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
257                        Ok(output) => {
258                            debug!(
259                                "Mistral output for frame {}: {}",
260                                ctx.frame_id,
261                                &output[..output.len().min(100)]
262                            );
263                            let cards =
264                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
265                            (ctx.frame_id, cards)
266                        }
267                        Err(err) => {
268                            warn!(
269                                "Mistral inference failed for frame {}: {}",
270                                ctx.frame_id, err
271                            );
272                            (ctx.frame_id, vec![])
273                        }
274                    }
275                })
276                .collect()
277        });
278
279        info!(
280            "Parallel enrichment complete: {} frames processed",
281            results.len()
282        );
283        Ok(results)
284    }
285}
286
287fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
288    match s.to_lowercase().as_str() {
289        "fact" => Some(MemoryKind::Fact),
290        "preference" => Some(MemoryKind::Preference),
291        "event" => Some(MemoryKind::Event),
292        "profile" => Some(MemoryKind::Profile),
293        "relationship" => Some(MemoryKind::Relationship),
294        "other" => Some(MemoryKind::Other),
295        _ => None,
296    }
297}
298
299fn parse_polarity(s: &str) -> Polarity {
300    match s.to_lowercase().as_str() {
301        "positive" => Polarity::Positive,
302        "negative" => Polarity::Negative,
303        _ => Polarity::Neutral,
304    }
305}
306
307impl EnrichmentEngine for MistralEngine {
308    fn kind(&self) -> &str {
309        "mistral:mistral-large"
310    }
311
312    fn version(&self) -> &str {
313        "1.0.0"
314    }
315
316    fn init(&mut self) -> memvid_core::Result<()> {
317        if self.api_key.is_empty() {
318            return Err(memvid_core::MemvidError::EmbeddingFailed {
319                reason: "MISTRAL_API_KEY environment variable not set".into(),
320            });
321        }
322        let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
323            memvid_core::MemvidError::EmbeddingFailed {
324                reason: format!("Failed to create Mistral HTTP client: {err}").into(),
325            }
326        })?;
327        self.client = Some(client);
328        self.ready = true;
329        Ok(())
330    }
331
332    fn is_ready(&self) -> bool {
333        self.ready
334    }
335
336    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
337        if ctx.text.is_empty() {
338            return EnrichmentResult::empty();
339        }
340
341        let client = match self.client.as_ref() {
342            Some(client) => client,
343            None => {
344                return EnrichmentResult::failed(
345                    "Mistral engine not initialized (init() not called)".to_string(),
346                )
347            }
348        };
349
350        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
351            Ok(output) => {
352                debug!("Mistral output for frame {}: {}", ctx.frame_id, output);
353                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
354                EnrichmentResult::success(cards)
355            }
356            Err(err) => EnrichmentResult::failed(format!("Mistral inference failed: {}", err)),
357        }
358    }
359}
360
361impl Default for MistralEngine {
362    fn default() -> Self {
363        Self::new()
364    }
365}