memvid_cli/enrich/
gemini.rs

1//! Gemini (Google) enrichment engine using Gemini 2.5 Flash.
2//!
3//! This engine uses the Google AI Studio API to extract structured memory cards
4//! from text content. Supports parallel batch processing for speed.
5
6use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// The extraction prompt for Gemini
17const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34/// Gemini API request
35#[derive(Debug, Serialize)]
36struct GeminiRequest {
37    contents: Vec<GeminiContent>,
38    #[serde(rename = "generationConfig")]
39    generation_config: GenerationConfig,
40}
41
42#[derive(Debug, Serialize)]
43struct GeminiContent {
44    parts: Vec<GeminiPart>,
45}
46
47#[derive(Debug, Serialize)]
48struct GeminiPart {
49    text: String,
50}
51
52#[derive(Debug, Serialize)]
53struct GenerationConfig {
54    temperature: f32,
55    #[serde(rename = "maxOutputTokens")]
56    max_output_tokens: u32,
57}
58
59/// Gemini API response
60#[derive(Debug, Deserialize)]
61struct GeminiResponse {
62    candidates: Option<Vec<Candidate>>,
63}
64
65#[derive(Debug, Deserialize)]
66struct Candidate {
67    content: CandidateContent,
68}
69
70#[derive(Debug, Deserialize)]
71struct CandidateContent {
72    parts: Vec<CandidatePart>,
73}
74
75#[derive(Debug, Deserialize)]
76struct CandidatePart {
77    text: Option<String>,
78}
79
80/// Gemini enrichment engine using Gemini 2.5 Flash with parallel processing.
81pub struct GeminiEngine {
82    /// API key
83    api_key: String,
84    /// Model to use
85    model: String,
86    /// Whether the engine is initialized
87    ready: bool,
88    /// Number of parallel workers (default: 20)
89    parallelism: usize,
90    /// Shared HTTP client (built in `init`)
91    client: Option<Client>,
92}
93
94impl GeminiEngine {
95    /// Create a new Gemini engine.
96    pub fn new() -> Self {
97        let api_key = std::env::var("GOOGLE_API_KEY")
98            .or_else(|_| std::env::var("GEMINI_API_KEY"))
99            .unwrap_or_default();
100        Self {
101            api_key,
102            model: "gemini-2.5-flash".to_string(),
103            ready: false,
104            parallelism: 20,
105            client: None,
106        }
107    }
108
109    /// Create with a specific model.
110    pub fn with_model(model: &str) -> Self {
111        let api_key = std::env::var("GOOGLE_API_KEY")
112            .or_else(|_| std::env::var("GEMINI_API_KEY"))
113            .unwrap_or_default();
114        Self {
115            api_key,
116            model: model.to_string(),
117            ready: false,
118            parallelism: 20,
119            client: None,
120        }
121    }
122
123    /// Set parallelism level.
124    pub fn with_parallelism(mut self, n: usize) -> Self {
125        self.parallelism = n;
126        self
127    }
128
129
130    /// Run inference via Gemini API (blocking, thread-safe).
131    fn run_inference_blocking(
132        client: &Client,
133        api_key: &str,
134        model: &str,
135        text: &str,
136    ) -> Result<String> {
137        let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
138
139        let request = GeminiRequest {
140            contents: vec![GeminiContent {
141                parts: vec![GeminiPart { text: prompt }],
142            }],
143            generation_config: GenerationConfig {
144                temperature: 0.0,
145                max_output_tokens: 1024,
146            },
147        };
148
149        let url = format!(
150            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
151            model, api_key
152        );
153
154        let response = client
155            .post(&url)
156            .header("Content-Type", "application/json")
157            .json(&request)
158            .send()
159            .map_err(|e| anyhow!("Gemini API request failed: {}", e))?;
160
161        if !response.status().is_success() {
162            let status = response.status();
163            let body = response.text().unwrap_or_default();
164            return Err(anyhow!("Gemini API error {}: {}", status, body));
165        }
166
167        let gemini_response: GeminiResponse = response
168            .json()
169            .map_err(|e| anyhow!("Failed to parse Gemini response: {}", e))?;
170
171        gemini_response
172            .candidates
173            .and_then(|c| c.into_iter().next())
174            .and_then(|c| c.content.parts.into_iter().next())
175            .and_then(|p| p.text)
176            .ok_or_else(|| anyhow!("No text response from Gemini"))
177    }
178
179    /// Parse the LLM output into memory cards.
180    fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
181        let mut cards = Vec::new();
182
183        if output.contains("MEMORY_NONE") {
184            return cards;
185        }
186
187        for block in output.split("MEMORY_START") {
188            let block = block.trim();
189            if block.is_empty() || !block.contains("MEMORY_END") {
190                continue;
191            }
192
193            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
194
195            let mut kind = None;
196            let mut entity = None;
197            let mut slot = None;
198            let mut value = None;
199            let mut polarity = Polarity::Neutral;
200
201            for line in block.lines() {
202                let line = line.trim();
203                if let Some(rest) = line.strip_prefix("kind:") {
204                    kind = parse_memory_kind(rest.trim());
205                } else if let Some(rest) = line.strip_prefix("entity:") {
206                    entity = Some(rest.trim().to_string());
207                } else if let Some(rest) = line.strip_prefix("slot:") {
208                    slot = Some(rest.trim().to_string());
209                } else if let Some(rest) = line.strip_prefix("value:") {
210                    value = Some(rest.trim().to_string());
211                } else if let Some(rest) = line.strip_prefix("polarity:") {
212                    polarity = parse_polarity(rest.trim());
213                }
214            }
215
216            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
217                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
218                    match MemoryCardBuilder::new()
219                        .kind(k)
220                        .entity(&e)
221                        .slot(&s)
222                        .value(&v)
223                        .polarity(polarity)
224                        .source(frame_id, Some(uri.to_string()))
225                        .document_date(timestamp)
226                        .engine("gemini:gemini-2.5-flash", "1.0.0")
227                        .build(0)
228                    {
229                        Ok(card) => cards.push(card),
230                        Err(err) => {
231                            warn!("Failed to build memory card: {}", err);
232                        }
233                    }
234                }
235            }
236        }
237
238        cards
239    }
240
241    /// Process multiple frames in parallel and return all cards.
242    pub fn enrich_batch(
243        &self,
244        contexts: Vec<EnrichmentContext>,
245    ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
246        let client = self
247            .client
248            .as_ref()
249            .ok_or_else(|| anyhow!("Gemini engine not initialized (init() not called)"))?
250            .clone();
251        let client = Arc::new(client);
252        let api_key = Arc::new(self.api_key.clone());
253        let model = Arc::new(self.model.clone());
254        let total = contexts.len();
255
256        info!(
257            "Starting parallel enrichment of {} frames with {} workers",
258            total, self.parallelism
259        );
260
261        let pool = rayon::ThreadPoolBuilder::new()
262            .num_threads(self.parallelism)
263            .build()
264            .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
265
266        let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
267            contexts
268                .into_par_iter()
269                .enumerate()
270                .map(|(i, ctx)| {
271                    if ctx.text.is_empty() {
272                        return (ctx.frame_id, vec![]);
273                    }
274
275                    if i > 0 && i % 50 == 0 {
276                        info!("Enrichment progress: {}/{} frames", i, total);
277                    }
278
279                    match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
280                        Ok(output) => {
281                            debug!(
282                                "Gemini output for frame {}: {}",
283                                ctx.frame_id,
284                                &output[..output.len().min(100)]
285                            );
286                            let cards =
287                                Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
288                            (ctx.frame_id, cards)
289                        }
290                        Err(err) => {
291                            warn!(
292                                "Gemini inference failed for frame {}: {}",
293                                ctx.frame_id, err
294                            );
295                            (ctx.frame_id, vec![])
296                        }
297                    }
298                })
299                .collect()
300        });
301
302        info!(
303            "Parallel enrichment complete: {} frames processed",
304            results.len()
305        );
306        Ok(results)
307    }
308}
309
310fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
311    match s.to_lowercase().as_str() {
312        "fact" => Some(MemoryKind::Fact),
313        "preference" => Some(MemoryKind::Preference),
314        "event" => Some(MemoryKind::Event),
315        "profile" => Some(MemoryKind::Profile),
316        "relationship" => Some(MemoryKind::Relationship),
317        "other" => Some(MemoryKind::Other),
318        _ => None,
319    }
320}
321
322fn parse_polarity(s: &str) -> Polarity {
323    match s.to_lowercase().as_str() {
324        "positive" => Polarity::Positive,
325        "negative" => Polarity::Negative,
326        _ => Polarity::Neutral,
327    }
328}
329
330impl EnrichmentEngine for GeminiEngine {
331    fn kind(&self) -> &str {
332        "gemini:gemini-2.5-flash"
333    }
334
335    fn version(&self) -> &str {
336        "1.0.0"
337    }
338
339    fn init(&mut self) -> memvid_core::Result<()> {
340        if self.api_key.is_empty() {
341            return Err(memvid_core::MemvidError::EmbeddingFailed {
342                reason: "GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set".into(),
343            });
344        }
345        let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
346            memvid_core::MemvidError::EmbeddingFailed {
347                reason: format!("Failed to create Gemini HTTP client: {err}").into(),
348            }
349        })?;
350        self.client = Some(client);
351        self.ready = true;
352        Ok(())
353    }
354
355    fn is_ready(&self) -> bool {
356        self.ready
357    }
358
359    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
360        if ctx.text.is_empty() {
361            return EnrichmentResult::empty();
362        }
363
364        let client = match self.client.as_ref() {
365            Some(client) => client,
366            None => {
367                return EnrichmentResult::failed(
368                    "Gemini engine not initialized (init() not called)".to_string(),
369                )
370            }
371        };
372
373        match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
374            Ok(output) => {
375                debug!("Gemini output for frame {}: {}", ctx.frame_id, output);
376                let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
377                EnrichmentResult::success(cards)
378            }
379            Err(err) => EnrichmentResult::failed(format!("Gemini inference failed: {}", err)),
380        }
381    }
382}
383
384impl Default for GeminiEngine {
385    fn default() -> Self {
386        Self::new()
387    }
388}