Skip to main content

memvid_cli/enrich/
xai.rs

1//! xAI (Grok) enrichment engine using Grok-4 Fast.
2//!
3//! This engine uses the xAI API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// The extraction prompt for Grok
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34/// xAI API request message
35#[derive(Debug, Serialize, Clone)]
36struct InputMessage {
37    role: String,
38    content: String,
39}
40
41/// xAI API request (uses /v1/responses endpoint)
42#[derive(Debug, Serialize)]
43struct XaiRequest {
44    model: String,
45    input: Vec<InputMessage>,
46}
47
48/// xAI API response
49#[derive(Debug, Deserialize)]
50struct XaiResponse {
51    output: Option<Vec<OutputItem>>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OutputItem {
56    content: Option<Vec<ContentItem>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ContentItem {
61    text: Option<String>,
62}
63
64/// xAI enrichment engine using Grok-4 Fast with parallel processing.
65pub struct XaiEngine {
66    /// API key
67    api_key: String,
68    /// Model to use
69    model: String,
70    /// Whether the engine is initialized
71    ready: bool,
72    /// Number of parallel workers (default: 20)
73    parallelism: usize,
74    /// Shared HTTP client (built in `init`)
75    client: Option<Client>,
76}
77
78impl XaiEngine {
79    /// Create a new xAI engine.
80    pub fn new() -> Self {
81        let api_key = std::env::var("XAI_API_KEY").unwrap_or_default();
82        Self {
83            api_key,
84            model: "grok-4-fast".to_string(),
85            ready: false,
86            parallelism: 20,
87            client: None,
88        }
89    }
90
91    /// Create with a specific model.
92    pub fn with_model(model: &str) -> Self {
93        let api_key = std::env::var("XAI_API_KEY").unwrap_or_default();
94        Self {
95            api_key,
96            model: model.to_string(),
97            ready: false,
98            parallelism: 20,
99            client: None,
100        }
101    }
102
103    /// Set parallelism level.
104    pub fn with_parallelism(mut self, n: usize) -> Self {
105        self.parallelism = n;
106        self
107    }
108
109    /// Run inference via xAI API (blocking, thread-safe).
110    fn run_inference_blocking(
111        client: &Client,
112        api_key: &str,
113        model: &str,
114        text: &str,
115    ) -> Result<String> {
116        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
117
118        let request = XaiRequest {
119            model: model.to_string(),
120            input: vec![
121                InputMessage {
122                    role: "system".to_string(),
123                    content:
124                        "You are a memory extraction assistant that extracts structured facts."
125                            .to_string(),
126                },
127                InputMessage {
128                    role: "user".to_string(),
129                    content: prompt,
130                },
131            ],
132        };
133
134        let response = client
135            .post("https://api.x.ai/v1/responses")
136            .header("Authorization", format!("Bearer {}", api_key))
137            .header("Content-Type", "application/json")
138            .json(&request)
139            .send()
140            .map_err(|e| anyhow!("xAI API request failed: {}", e))?;
141
142        if !response.status().is_success() {
143            let status = response.status();
144            let body = response.text().unwrap_or_default();
145            return Err(anyhow!("xAI API error {}: {}", status, body));
146        }
147
148        let xai_response: XaiResponse = response
149            .json()
150            .map_err(|e| anyhow!("Failed to parse xAI response: {}", e))?;
151
152        // Extract text from the nested response structure
153        xai_response
154            .output
155            .and_then(|outputs| outputs.into_iter().next())
156            .and_then(|output| output.content)
157            .and_then(|contents| contents.into_iter().next())
158            .and_then(|content| content.text)
159            .ok_or_else(|| anyhow!("No response from xAI"))
160    }
161
162    /// Parse the LLM output into memory cards.
163    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
164        let mut cards = Vec::new();
165
166        if output.contains("MEMORY_NONE") {
167            return cards;
168        }
169
170        for block in output.split("MEMORY_START") {
171            let block = block.trim();
172            if block.is_empty() || !block.contains("MEMORY_END") {
173                continue;
174            }
175
176            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
177
178            let mut kind = None;
179            let mut entity = None;
180            let mut slot = None;
181            let mut value = None;
182            let mut polarity = Polarity::Neutral;
183
184            for line in block.lines() {
185                let line = line.trim();
186                if let Some(rest) = line.strip_prefix("kind:") {
187                    kind = parse_memory_kind(rest.trim());
188                } else if let Some(rest) = line.strip_prefix("entity:") {
189                    entity = Some(rest.trim().to_string());
190                } else if let Some(rest) = line.strip_prefix("slot:") {
191                    slot = Some(rest.trim().to_string());
192                } else if let Some(rest) = line.strip_prefix("value:") {
193                    value = Some(rest.trim().to_string());
194                } else if let Some(rest) = line.strip_prefix("polarity:") {
195                    polarity = parse_polarity(rest.trim());
196                }
197            }
198
199            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
200                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
201                    match MemoryCardBuilder::new()
202                        .kind(k)
203                        .entity(&e)
204                        .slot(&s)
205                        .value(&v)
206                        .polarity(polarity)
207                        .source(frame_id, Some(uri.to_string()))
208                        .document_date(timestamp)
209                        .engine("xai:grok-4-fast", "1.0.0")
210                        .build(0)
211                    {
212                        Ok(card) => cards.push(card),
213                        Err(err) => {
214                            warn!("Failed to build memory card: {}", err);
215                        }
216                    }
217                }
218            }
219        }
220
221        cards
222    }
223
224    /// Process multiple frames in parallel and return all cards.
225    pub fn enrich_batch(
226        &self,
227        contexts: Vec<EnrichmentContext>,
228    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
229        let client = self
230            .client
231            .as_ref()
232            .ok_or_else(|| anyhow!("xAI engine not initialized (init() not called)"))?
233            .clone();
234        let client = Arc::new(client);
235        let api_key = Arc::new(self.api_key.clone());
236        let model = Arc::new(self.model.clone());
237        let total = contexts.len();
238
239        info!(
240            "Starting parallel enrichment of {} frames with {} workers",
241            total, self.parallelism
242        );
243
244        let pool = rayon::ThreadPoolBuilder::new()
245            .num_threads(self.parallelism)
246            .build()
247            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
248
249        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
250            contexts
251                .into_par_iter()
252                .enumerate()
253                .map(|(i, ctx)| {
254                    if ctx.text.is_empty() {
255                        return (ctx.frame_id, vec![]);
256                    }
257
258                    if i > 0 && i % 50 == 0 {
259                        info!("Enrichment progress: {}/{} frames", i, total);
260                    }
261
262                    match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
263                        Ok(output) => {
264                            debug!(
265                                "xAI output for frame {}: {}",
266                                ctx.frame_id,
267                                &output[..output.len().min(100)]
268                            );
269                            let cards =
270                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
271                            (ctx.frame_id, cards)
272                        }
273                        Err(err) => {
274                            warn!("xAI inference failed for frame {}: {}", ctx.frame_id, err);
275                            (ctx.frame_id, vec![])
276                        }
277                    }
278                })
279                .collect()
280        });
281
282        info!(
283            "Parallel enrichment complete: {} frames processed",
284            results.len()
285        );
286        Ok(results)
287    }
288}
289
290fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
291    match s.to_lowercase().as_str() {
292        "fact" => Some(MemoryKind::Fact),
293        "preference" => Some(MemoryKind::Preference),
294        "event" => Some(MemoryKind::Event),
295        "profile" => Some(MemoryKind::Profile),
296        "relationship" => Some(MemoryKind::Relationship),
297        "other" => Some(MemoryKind::Other),
298        _ => None,
299    }
300}
301
302fn parse_polarity(s: &str) -> Polarity {
303    match s.to_lowercase().as_str() {
304        "positive" => Polarity::Positive,
305        "negative" => Polarity::Negative,
306        _ => Polarity::Neutral,
307    }
308}
309
310impl EnrichmentEngine for XaiEngine {
311    fn kind(&self) -> &str {
312        "xai:grok-4-fast"
313    }
314
315    fn version(&self) -> &str {
316        "1.0.0"
317    }
318
319    fn init(&mut self) -> memvid_core::Result<()> {
320        if self.api_key.is_empty() {
321            return Err(memvid_core::MemvidError::EmbeddingFailed {
322                reason: "XAI_API_KEY environment variable not set".into(),
323            });
324        }
325        let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
326            memvid_core::MemvidError::EmbeddingFailed {
327                reason: format!("Failed to create xAI HTTP client: {err}").into(),
328            }
329        })?;
330        self.client = Some(client);
331        self.ready = true;
332        Ok(())
333    }
334
335    fn is_ready(&self) -> bool {
336        self.ready
337    }
338
339    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
340        if ctx.text.is_empty() {
341            return EnrichmentResult::empty();
342        }
343
344        let client = match self.client.as_ref() {
345            Some(client) => client,
346            None => {
347                return EnrichmentResult::failed(
348                    "xAI engine not initialized (init() not called)".to_string(),
349                )
350            }
351        };
352
353        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
354            Ok(output) => {
355                debug!("xAI output for frame {}: {}", ctx.frame_id, output);
356                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
357                EnrichmentResult::success(cards)
358            }
359            Err(err) => EnrichmentResult::failed(format!("xAI inference failed: {}", err)),
360        }
361    }
362}
363
364impl Default for XaiEngine {
365    fn default() -> Self {
366        Self::new()
367    }
368}