1use anyhow::{anyhow, Result};
7use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
8use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
9use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
10use reqwest::blocking::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16const EXTRACTION_PROMPT: &str = r#"You are a memory extraction assistant. Extract structured facts from the text.
18
19For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
20MEMORY_START
21kind: <Fact|Preference|Event|Profile|Relationship|Other>
22entity: <the main entity this memory is about, use "user" for the human in the conversation>
23slot: <a short key describing what aspect of the entity>
24value: <the actual information>
25polarity: <Positive|Negative|Neutral>
26MEMORY_END
27
28Only extract information that is explicitly stated. Do not infer or guess.
29If there are no clear facts to extract, output MEMORY_NONE.
30
31Extract memories from this text:
32"#;
33
34#[derive(Debug, Serialize)]
36struct GeminiRequest {
37 contents: Vec<GeminiContent>,
38 #[serde(rename = "generationConfig")]
39 generation_config: GenerationConfig,
40}
41
42#[derive(Debug, Serialize)]
43struct GeminiContent {
44 parts: Vec<GeminiPart>,
45}
46
47#[derive(Debug, Serialize)]
48struct GeminiPart {
49 text: String,
50}
51
52#[derive(Debug, Serialize)]
53struct GenerationConfig {
54 temperature: f32,
55 #[serde(rename = "maxOutputTokens")]
56 max_output_tokens: u32,
57}
58
59#[derive(Debug, Deserialize)]
61struct GeminiResponse {
62 candidates: Option<Vec<Candidate>>,
63}
64
65#[derive(Debug, Deserialize)]
66struct Candidate {
67 content: CandidateContent,
68}
69
70#[derive(Debug, Deserialize)]
71struct CandidateContent {
72 parts: Vec<CandidatePart>,
73}
74
75#[derive(Debug, Deserialize)]
76struct CandidatePart {
77 text: Option<String>,
78}
79
80pub struct GeminiEngine {
82 api_key: String,
84 model: String,
86 ready: bool,
88 parallelism: usize,
90 client: Option<Client>,
92}
93
94impl GeminiEngine {
95 pub fn new() -> Self {
97 let api_key = std::env::var("GOOGLE_API_KEY")
98 .or_else(|_| std::env::var("GEMINI_API_KEY"))
99 .unwrap_or_default();
100 Self {
101 api_key,
102 model: "gemini-2.5-flash".to_string(),
103 ready: false,
104 parallelism: 20,
105 client: None,
106 }
107 }
108
109 pub fn with_model(model: &str) -> Self {
111 let api_key = std::env::var("GOOGLE_API_KEY")
112 .or_else(|_| std::env::var("GEMINI_API_KEY"))
113 .unwrap_or_default();
114 Self {
115 api_key,
116 model: model.to_string(),
117 ready: false,
118 parallelism: 20,
119 client: None,
120 }
121 }
122
123 pub fn with_parallelism(mut self, n: usize) -> Self {
125 self.parallelism = n;
126 self
127 }
128
129
130 fn run_inference_blocking(
132 client: &Client,
133 api_key: &str,
134 model: &str,
135 text: &str,
136 ) -> Result<String> {
137 let prompt = format!("{}\n\n{}", EXTRACTION_PROMPT, text);
138
139 let request = GeminiRequest {
140 contents: vec![GeminiContent {
141 parts: vec![GeminiPart { text: prompt }],
142 }],
143 generation_config: GenerationConfig {
144 temperature: 0.0,
145 max_output_tokens: 1024,
146 },
147 };
148
149 let url = format!(
150 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
151 model, api_key
152 );
153
154 let response = client
155 .post(&url)
156 .header("Content-Type", "application/json")
157 .json(&request)
158 .send()
159 .map_err(|e| anyhow!("Gemini API request failed: {}", e))?;
160
161 if !response.status().is_success() {
162 let status = response.status();
163 let body = response.text().unwrap_or_default();
164 return Err(anyhow!("Gemini API error {}: {}", status, body));
165 }
166
167 let gemini_response: GeminiResponse = response
168 .json()
169 .map_err(|e| anyhow!("Failed to parse Gemini response: {}", e))?;
170
171 gemini_response
172 .candidates
173 .and_then(|c| c.into_iter().next())
174 .and_then(|c| c.content.parts.into_iter().next())
175 .and_then(|p| p.text)
176 .ok_or_else(|| anyhow!("No text response from Gemini"))
177 }
178
179 fn parse_output(output: &str, frame_id: u64, uri: &str, timestamp: i64) -> Vec<MemoryCard> {
181 let mut cards = Vec::new();
182
183 if output.contains("MEMORY_NONE") {
184 return cards;
185 }
186
187 for block in output.split("MEMORY_START") {
188 let block = block.trim();
189 if block.is_empty() || !block.contains("MEMORY_END") {
190 continue;
191 }
192
193 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
194
195 let mut kind = None;
196 let mut entity = None;
197 let mut slot = None;
198 let mut value = None;
199 let mut polarity = Polarity::Neutral;
200
201 for line in block.lines() {
202 let line = line.trim();
203 if let Some(rest) = line.strip_prefix("kind:") {
204 kind = parse_memory_kind(rest.trim());
205 } else if let Some(rest) = line.strip_prefix("entity:") {
206 entity = Some(rest.trim().to_string());
207 } else if let Some(rest) = line.strip_prefix("slot:") {
208 slot = Some(rest.trim().to_string());
209 } else if let Some(rest) = line.strip_prefix("value:") {
210 value = Some(rest.trim().to_string());
211 } else if let Some(rest) = line.strip_prefix("polarity:") {
212 polarity = parse_polarity(rest.trim());
213 }
214 }
215
216 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
217 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
218 match MemoryCardBuilder::new()
219 .kind(k)
220 .entity(&e)
221 .slot(&s)
222 .value(&v)
223 .polarity(polarity)
224 .source(frame_id, Some(uri.to_string()))
225 .document_date(timestamp)
226 .engine("gemini:gemini-2.5-flash", "1.0.0")
227 .build(0)
228 {
229 Ok(card) => cards.push(card),
230 Err(err) => {
231 warn!("Failed to build memory card: {}", err);
232 }
233 }
234 }
235 }
236 }
237
238 cards
239 }
240
241 pub fn enrich_batch(
243 &self,
244 contexts: Vec<EnrichmentContext>,
245 ) -> Result<Vec<(u64, Vec<MemoryCard>)>> {
246 let client = self
247 .client
248 .as_ref()
249 .ok_or_else(|| anyhow!("Gemini engine not initialized (init() not called)"))?
250 .clone();
251 let client = Arc::new(client);
252 let api_key = Arc::new(self.api_key.clone());
253 let model = Arc::new(self.model.clone());
254 let total = contexts.len();
255
256 info!(
257 "Starting parallel enrichment of {} frames with {} workers",
258 total, self.parallelism
259 );
260
261 let pool = rayon::ThreadPoolBuilder::new()
262 .num_threads(self.parallelism)
263 .build()
264 .map_err(|err| anyhow!("failed to build enrichment thread pool: {err}"))?;
265
266 let results: Vec<(u64, Vec<MemoryCard>)> = pool.install(|| {
267 contexts
268 .into_par_iter()
269 .enumerate()
270 .map(|(i, ctx)| {
271 if ctx.text.is_empty() {
272 return (ctx.frame_id, vec![]);
273 }
274
275 if i > 0 && i % 50 == 0 {
276 info!("Enrichment progress: {}/{} frames", i, total);
277 }
278
279 match Self::run_inference_blocking(&client, &api_key, &model, &ctx.text) {
280 Ok(output) => {
281 debug!(
282 "Gemini output for frame {}: {}",
283 ctx.frame_id,
284 &output[..output.len().min(100)]
285 );
286 let cards =
287 Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
288 (ctx.frame_id, cards)
289 }
290 Err(err) => {
291 warn!(
292 "Gemini inference failed for frame {}: {}",
293 ctx.frame_id, err
294 );
295 (ctx.frame_id, vec![])
296 }
297 }
298 })
299 .collect()
300 });
301
302 info!(
303 "Parallel enrichment complete: {} frames processed",
304 results.len()
305 );
306 Ok(results)
307 }
308}
309
310fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
311 match s.to_lowercase().as_str() {
312 "fact" => Some(MemoryKind::Fact),
313 "preference" => Some(MemoryKind::Preference),
314 "event" => Some(MemoryKind::Event),
315 "profile" => Some(MemoryKind::Profile),
316 "relationship" => Some(MemoryKind::Relationship),
317 "other" => Some(MemoryKind::Other),
318 _ => None,
319 }
320}
321
322fn parse_polarity(s: &str) -> Polarity {
323 match s.to_lowercase().as_str() {
324 "positive" => Polarity::Positive,
325 "negative" => Polarity::Negative,
326 _ => Polarity::Neutral,
327 }
328}
329
330impl EnrichmentEngine for GeminiEngine {
331 fn kind(&self) -> &str {
332 "gemini:gemini-2.5-flash"
333 }
334
335 fn version(&self) -> &str {
336 "1.0.0"
337 }
338
339 fn init(&mut self) -> memvid_core::Result<()> {
340 if self.api_key.is_empty() {
341 return Err(memvid_core::MemvidError::EmbeddingFailed {
342 reason: "GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set".into(),
343 });
344 }
345 let client = crate::http::blocking_client(Duration::from_secs(60)).map_err(|err| {
346 memvid_core::MemvidError::EmbeddingFailed {
347 reason: format!("Failed to create Gemini HTTP client: {err}").into(),
348 }
349 })?;
350 self.client = Some(client);
351 self.ready = true;
352 Ok(())
353 }
354
355 fn is_ready(&self) -> bool {
356 self.ready
357 }
358
359 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
360 if ctx.text.is_empty() {
361 return EnrichmentResult::empty();
362 }
363
364 let client = match self.client.as_ref() {
365 Some(client) => client,
366 None => {
367 return EnrichmentResult::failed(
368 "Gemini engine not initialized (init() not called)".to_string(),
369 )
370 }
371 };
372
373 match Self::run_inference_blocking(client, &self.api_key, &self.model, &ctx.text) {
374 Ok(output) => {
375 debug!("Gemini output for frame {}: {}", ctx.frame_id, output);
376 let cards = Self::parse_output(&output, ctx.frame_id, &ctx.uri, ctx.timestamp);
377 EnrichmentResult::success(cards)
378 }
379 Err(err) => EnrichmentResult::failed(format!("Gemini inference failed: {}", err)),
380 }
381 }
382}
383
384impl Default for GeminiEngine {
385 fn default() -> Self {
386 Self::new()
387 }
388}