1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34#[derive(Debug, Serialize, Clone)]
36struct InputMessage {
37 role: String,
38 content: String,
39}
40
41#[derive(Debug, Serialize)]
43struct XaiRequest {
44 model: String,
45 input: Vec<InputMessage>,
46}
47
48#[derive(Debug, Deserialize)]
50struct XaiResponse {
51 output: Option<Vec<OutputItem>>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OutputItem {
56 content: Option<Vec<ContentItem>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ContentItem {
61 text: Option<String>,
62}
63
64pub struct XaiEngine {
66 api_key: String,
68 model: String,
70 ready: bool,
72 parallelism: usize,
74 client: Option<Client>,
76}
77
78impl XaiEngine {
79 pub fn new() -> Self {
81 let api_key = std::env::var("XAI_API_KEY").unwrap_or_default();
82 Self {
83 api_key,
84 model: "grok-4-fast".to_string(),
85 ready: false,
86 parallelism: 20,
87 client: None,
88 }
89 }
90
91 pub fn with_model(model: &str) -> Self {
93 let api_key = std::env::var("XAI_API_KEY").unwrap_or_default();
94 Self {
95 api_key,
96 model: model.to_string(),
97 ready: false,
98 parallelism: 20,
99 client: None,
100 }
101 }
102
103 pub fn with_parallelism(mut self, n: usize) -> Self {
105 self.parallelism = n;
106 self
107 }
108
109 fn run_inference_blocking(
111 client: &Client,
112 api_key: &str,
113 model: &str,
114 text: &str,
115 ) -> Result<String> {
116 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
117
118 let request = XaiRequest {
119 model: model.to_string(),
120 input: vec![
121 InputMessage {
122 role: "system".to_string(),
123 content:
124 "You are a memory extraction assistant that extracts structured facts."
125 .to_string(),
126 },
127 InputMessage {
128 role: "user".to_string(),
129 content: prompt,
130 },
131 ],
132 };
133
134 let response = client
135 .post("https://api.x.ai/v1/responses")
136 .header("Authorization", format!("Bearer {}", api_key))
137 .header("Content-Type", "application/json")
138 .json(&request)
139 .send()
140 .map_err(|e| anyhow!("xAI API request failed: {}", e))?;
141
142 if !response.status().is_success() {
143 let status = response.status();
144 let body = response.text().unwrap_or_default();
145 return Err(anyhow!("xAI API error {}: {}", status, body));
146 }
147
148 let xai_response: XaiResponse = response
149 .json()
150 .map_err(|e| anyhow!("Failed to parse xAI response: {}", e))?;
151
152 xai_response
154 .output
155 .and_then(|outputs| outputs.into_iter().next())
156 .and_then(|output| output.content)
157 .and_then(|contents| contents.into_iter().next())
158 .and_then(|content| content.text)
159 .ok_or_else(|| anyhow!("No response from xAI"))
160 }
161
162 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
164 let mut cards = Vec::new();
165
166 if output.contains("MEMORY_NONE") {
167 return cards;
168 }
169
170 for block in output.split("MEMORY_START") {
171 let block = block.trim();
172 if block.is_empty() || !block.contains("MEMORY_END") {
173 continue;
174 }
175
176 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
177
178 let mut kind = None;
179 let mut entity = None;
180 let mut slot = None;
181 let mut value = None;
182 let mut polarity = Polarity::Neutral;
183
184 for line in block.lines() {
185 let line = line.trim();
186 if let Some(rest) = line.strip_prefix("kind:") {
187 kind = parse_memory_kind(rest.trim());
188 } else if let Some(rest) = line.strip_prefix("entity:") {
189 entity = Some(rest.trim().to_string());
190 } else if let Some(rest) = line.strip_prefix("slot:") {
191 slot = Some(rest.trim().to_string());
192 } else if let Some(rest) = line.strip_prefix("value:") {
193 value = Some(rest.trim().to_string());
194 } else if let Some(rest) = line.strip_prefix("polarity:") {
195 polarity = parse_polarity(rest.trim());
196 }
197 }
198
199 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
200 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
201 match MemoryCardBuilder::new()
202 .kind(k)
203 .entity(&e)
204 .slot(&s)
205 .value(&v)
206 .polarity(polarity)
207 .source(frame_id, Some(uri.to_string()))
208 .document_date(timestamp)
209 .engine("xai:grok-4-fast", "1.0.0")
210 .build(0)
211 {
212 Ok(card) => cards.push(card),
213 Err(err) => {
214 warn!("Failed to build memory card: {}", err);
215 }
216 }
217 }
218 }
219 }
220
221 cards
222 }
223
224 pub fn enrich_batch(
226 &self,
227 contexts: Vec<EnrichmentContext>,
228 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
229 let client = self
230 .client
231 .as_ref()
232 .ok_or_else(|| anyhow!("xAI engine not initialized (init() not called)"))?
233 .clone();
234 let client = Arc::new(client);
235 let api_key = Arc::new(self.api_key.clone());
236 let model = Arc::new(self.model.clone());
237 let total = contexts.len();
238
239 info!(
240 "Starting parallel enrichment of {} frames with {} workers",
241 total, self.parallelism
242 );
243
244 let pool = rayon::ThreadPoolBuilder::new()
245 .num_threads(self.parallelism)
246 .build()
247 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
248
249 let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
250 contexts
251 .into_par_iter()
252 .enumerate()
253 .map(|(i, ctx)| {
254 if ctx.text.is_empty() {
255 return (ctx.frame_id, vec![]);
256 }
257
258 if i > 0 && i % 50 == 0 {
259 info!("Enrichment progress: {}/{} frames", i, total);
260 }
261
262 match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
263 Ok(output) => {
264 debug!(
265 "xAI output for frame {}: {}",
266 ctx.frame_id,
267 &output[..output.len().min(100)]
268 );
269 let cards =
270 Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
271 (ctx.frame_id, cards)
272 }
273 Err(err) => {
274 warn!("xAI inference failed for frame {}: {}", ctx.frame_id, err);
275 (ctx.frame_id, vec![])
276 }
277 }
278 })
279 .collect()
280 });
281
282 info!(
283 "Parallel enrichment complete: {} frames processed",
284 results.len()
285 );
286 Ok(results)
287 }
288}
289
290fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
291 match s.to_lowercase().as_str() {
292 "fact" => Some(MemoryKind::Fact),
293 "preference" => Some(MemoryKind::Preference),
294 "event" => Some(MemoryKind::Event),
295 "profile" => Some(MemoryKind::Profile),
296 "relationship" => Some(MemoryKind::Relationship),
297 "other" => Some(MemoryKind::Other),
298 _ => None,
299 }
300}
301
302fn parse_polarity(s: &str) -> Polarity {
303 match s.to_lowercase().as_str() {
304 "positive" => Polarity::Positive,
305 "negative" => Polarity::Negative,
306 _ => Polarity::Neutral,
307 }
308}
309
310impl EnrichmentEngine for XaiEngine {
311 fn kind(&self) -> &str {
312 "xai:grok-4-fast"
313 }
314
315 fn version(&self) -> &str {
316 "1.0.0"
317 }
318
319 fn init(&mut self) -> memvid_core::Result<()> {
320 if self.api_key.is_empty() {
321 return Err(memvid_core::MemvidError::EmbeddingFailed {
322 reason: "XAI_API_KEY environment variable not set".into(),
323 });
324 }
325 let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
326 memvid_core::MemvidError::EmbeddingFailed {
327 reason: format!("Failed to create xAI HTTP client: {err}").into(),
328 }
329 })?;
330 self.client = Some(client);
331 self.ready = true;
332 Ok(())
333 }
334
335 fn is_ready(&self) -> bool {
336 self.ready
337 }
338
339 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
340 if ctx.text.is_empty() {
341 return EnrichmentResult::empty();
342 }
343
344 let client = match self.client.as_ref() {
345 Some(client) => client,
346 None => {
347 return EnrichmentResult::failed(
348 "xAI engine not initialized (init() not called)".to_string(),
349 )
350 }
351 };
352
353 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
354 Ok(output) => {
355 debug!("xAI output for frame {}: {}", ctx.frame_id, output);
356 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
357 EnrichmentResult::success(cards)
358 }
359 Err(err) => EnrichmentResult::failed(format!("xAI inference failed: {}", err)),
360 }
361 }
362}
363
364impl Default for XaiEngine {
365 fn default() -> Self {
366 Self::new()
367 }
368}