1use crate::error::{Error, Result};
4use crate::vector::VectorStore;
5use rusqlite::Connection;
6use serde::{Deserialize, Serialize};
7use std::io::Write;
8use std::process::{Command, Stdio};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EmbeddingConfig {
13 pub model_name: String,
14 pub dimension: usize,
15}
16
17impl Default for EmbeddingConfig {
18 fn default() -> Self {
19 Self {
20 model_name: "all-MiniLM-L6-v2".to_string(),
21 dimension: 384,
22 }
23 }
24}
25
26pub struct EmbeddingGenerator {
28 config: EmbeddingConfig,
29 pub skip_existing: bool,
31}
32
33impl EmbeddingGenerator {
34 pub fn new() -> Self {
36 Self {
37 config: EmbeddingConfig::default(),
38 skip_existing: true,
39 }
40 }
41
42 pub fn with_config(config: EmbeddingConfig) -> Self {
44 Self {
45 config,
46 skip_existing: true,
47 }
48 }
49
50 pub fn with_force(mut self, force: bool) -> Self {
53 self.skip_existing = !force;
54 self
55 }
56
57 pub fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
59 if texts.is_empty() {
60 return Ok(Vec::new());
61 }
62
63 let python_script = self.generate_python_script()?;
64
65 let texts_json = serde_json::to_string(&texts)
67 .map_err(|e| Error::Other(format!("Failed to serialize texts: {}", e)))?;
68
69 let mut child = Command::new("python3")
71 .arg("-c")
72 .arg(&python_script)
73 .stdin(Stdio::piped())
74 .stdout(Stdio::piped())
75 .stderr(Stdio::piped())
76 .spawn()
77 .map_err(|e| Error::Other(format!("Failed to spawn Python: {}", e)))?;
78
79 if let Some(mut stdin) = child.stdin.take() {
81 stdin
82 .write_all(texts_json.as_bytes())
83 .map_err(|e| Error::Other(format!("Failed to write to stdin: {}", e)))?;
84 }
85
86 let output = child
88 .wait_with_output()
89 .map_err(|e| Error::Other(format!("Failed to read Python output: {}", e)))?;
90
91 if !output.status.success() {
92 let stderr = String::from_utf8_lossy(&output.stderr);
93 return Err(Error::Other(format!("Python script failed: {}", stderr)));
94 }
95
96 let stdout = String::from_utf8_lossy(&output.stdout);
98 self.parse_embeddings(&stdout)
99 }
100
101 fn generate_python_script(&self) -> Result<String> {
103 let script = format!(
104 r#"
105import sys
106import json
107import numpy as np
108
109try:
110 from sentence_transformers import SentenceTransformer
111
112 # Load model
113 model = SentenceTransformer('{}')
114
115 # Read texts from stdin
116 texts_json = sys.stdin.read()
117 texts = json.loads(texts_json)
118
119 # Generate embeddings
120 embeddings = model.encode(texts, convert_to_numpy=True)
121
122 # Convert to list and print as JSON
123 embeddings_list = embeddings.tolist()
124 print(json.dumps(embeddings_list))
125
126except ImportError:
127 print("{{\"error\": \"sentence-transformers not installed. Run: pip install sentence-transformers\"}}", file=sys.stderr)
128 sys.exit(1)
129except Exception as e:
130 print("{{\"error\": \"{{}}\"}}".format(str(e)), file=sys.stderr)
131 sys.exit(1)
132"#,
133 self.config.model_name
134 );
135
136 Ok(script)
137 }
138
139 fn parse_embeddings(&self, output: &str) -> Result<Vec<Vec<f32>>> {
141 let embeddings: Vec<Vec<f32>> = serde_json::from_str(output)
142 .map_err(|e| Error::Other(format!("Failed to parse embeddings: {}", e)))?;
143
144 for embedding in &embeddings {
146 if embedding.len() != self.config.dimension {
147 return Err(Error::InvalidVectorDimension {
148 expected: self.config.dimension,
149 actual: embedding.len(),
150 });
151 }
152 }
153
154 Ok(embeddings)
155 }
156
157 pub fn generate_for_papers(&self, conn: &Connection) -> Result<EmbeddingStats> {
159 let entities = get_entities_needing_embedding(conn, "paper", !self.skip_existing)?;
160 let total_count = count_entities(conn, "paper")?;
161 let skipped_count = total_count - entities.len() as i64;
162
163 self.generate_and_store(conn, entities, total_count, skipped_count, "paper")
164 }
165
166 pub fn generate_for_skills(&self, conn: &Connection) -> Result<EmbeddingStats> {
168 let entities = get_entities_needing_embedding(conn, "skill", !self.skip_existing)?;
169 let total_count = count_entities(conn, "skill")?;
170 let skipped_count = total_count - entities.len() as i64;
171
172 self.generate_and_store(conn, entities, total_count, skipped_count, "skill")
173 }
174
175 pub fn generate_for_all(&self, conn: &Connection) -> Result<EmbeddingStats> {
177 let papers_stats = self.generate_for_papers(conn)?;
178 let skills_stats = self.generate_for_skills(conn)?;
179
180 Ok(EmbeddingStats {
181 total_count: papers_stats.total_count + skills_stats.total_count,
182 processed_count: papers_stats.processed_count + skills_stats.processed_count,
183 skipped_count: papers_stats.skipped_count + skills_stats.skipped_count,
184 dimension: self.config.dimension,
185 })
186 }
187
188 fn generate_and_store(
190 &self,
191 conn: &Connection,
192 entities: Vec<(i64, String)>,
193 total_count: i64,
194 skipped_count: i64,
195 label: &str,
196 ) -> Result<EmbeddingStats> {
197 if entities.is_empty() {
198 println!(
199 "All {} entities already have real embeddings, skipping.",
200 label
201 );
202 return Ok(EmbeddingStats {
203 total_count,
204 processed_count: 0,
205 skipped_count,
206 dimension: self.config.dimension,
207 });
208 }
209
210 let (entity_ids, texts): (Vec<i64>, Vec<String>) = entities.into_iter().unzip();
211
212 println!(
213 "Generating embeddings for {} {} titles ({} already have real embeddings, skipping)...",
214 texts.len(),
215 label,
216 skipped_count
217 );
218
219 let batch_size = 100;
220 let mut processed_count = 0;
221
222 let store = VectorStore::new();
223 let tx = conn.unchecked_transaction()?;
224
225 for batch_start in (0..texts.len()).step_by(batch_size) {
226 let batch_end = (batch_start + batch_size).min(texts.len());
227 let batch_texts = texts[batch_start..batch_end].to_vec();
228 let batch_ids = entity_ids[batch_start..batch_end].to_vec();
229
230 println!(
231 "Processing batch: {}s {}-{}",
232 label,
233 batch_start + 1,
234 batch_end
235 );
236
237 let embeddings = self.generate_embeddings(batch_texts)?;
238
239 for (entity_id, embedding) in batch_ids.iter().zip(embeddings.iter()) {
240 store.insert_vector(&tx, *entity_id, embedding.clone())?;
241 }
242
243 processed_count += embeddings.len();
244 println!(" Generated {} embeddings", embeddings.len());
245 }
246
247 tx.commit()?;
248
249 println!("✓ Generated {} embeddings for {}s", processed_count, label);
250
251 Ok(EmbeddingStats {
252 total_count,
253 processed_count: processed_count as i64,
254 skipped_count,
255 dimension: self.config.dimension,
256 })
257 }
258}
259
260impl Default for EmbeddingGenerator {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266pub fn get_entities_needing_embedding(
271 conn: &Connection,
272 entity_type: &str,
273 force: bool,
274) -> Result<Vec<(i64, String)>> {
275 let mut stmt = conn.prepare(
276 r#"
277 SELECT e.id, e.name, v.vector
278 FROM kg_entities e
279 LEFT JOIN kg_vectors v ON e.id = v.entity_id
280 WHERE e.entity_type = ?1
281 ORDER BY e.id
282 "#,
283 )?;
284
285 let rows = stmt.query_map([entity_type], |row| {
286 Ok((
287 row.get::<_, i64>(0)?,
288 row.get::<_, String>(1)?,
289 row.get::<_, Option<Vec<u8>>>(2)?,
290 ))
291 })?;
292
293 let mut result = Vec::new();
294 for row in rows {
295 let (id, name, blob) = row?;
296 let needs_embedding = force || is_placeholder_or_missing(blob.as_deref());
297 if needs_embedding {
298 result.push((id, name));
299 }
300 }
301
302 Ok(result)
303}
304
305fn is_placeholder_or_missing(blob: Option<&[u8]>) -> bool {
307 match blob {
308 None => true,
309 Some(b) => b.iter().all(|&x| x == 0),
310 }
311}
312
313fn count_entities(conn: &Connection, entity_type: &str) -> Result<i64> {
315 let count: i64 = conn.query_row(
316 "SELECT COUNT(*) FROM kg_entities WHERE entity_type = ?1",
317 [entity_type],
318 |row| row.get(0),
319 )?;
320 Ok(count)
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct EmbeddingStats {
326 pub total_count: i64,
327 pub processed_count: i64,
328 pub skipped_count: i64,
329 pub dimension: usize,
330}
331
332pub fn check_dependencies() -> Result<bool> {
334 let output = Command::new("python3")
335 .arg("-c")
336 .arg("import sentence_transformers")
337 .stdout(Stdio::piped())
338 .stderr(Stdio::piped())
339 .output()
340 .map_err(|e| Error::Other(format!("Failed to check Python dependencies: {}", e)))?;
341
342 Ok(output.status.success())
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::graph::insert_entity;
349 use crate::graph::Entity;
350 use crate::schema::create_schema;
351
352 fn make_in_memory_conn() -> Connection {
353 let conn = Connection::open_in_memory().unwrap();
354 create_schema(&conn).unwrap();
355 conn
356 }
357
358 #[test]
361 fn test_embedding_config_default() {
362 let config = EmbeddingConfig::default();
363 assert_eq!(config.model_name, "all-MiniLM-L6-v2");
364 assert_eq!(config.dimension, 384);
365 }
366
367 #[test]
368 fn test_embedding_generator_new() {
369 let generator = EmbeddingGenerator::new();
370 assert_eq!(generator.config.model_name, "all-MiniLM-L6-v2");
371 assert_eq!(generator.config.dimension, 384);
372 assert!(generator.skip_existing);
373 }
374
375 #[test]
376 fn test_with_force_sets_skip_existing() {
377 let gen = EmbeddingGenerator::new().with_force(true);
378 assert!(!gen.skip_existing);
379
380 let gen2 = EmbeddingGenerator::new().with_force(false);
381 assert!(gen2.skip_existing);
382 }
383
384 #[test]
387 fn test_parse_embeddings_dimension_mismatch() {
388 let generator = EmbeddingGenerator::new();
389 let result = generator.parse_embeddings("[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]");
391 assert!(result.is_err());
392 }
393
394 #[test]
395 fn test_parse_embeddings_valid_384() {
396 let generator = EmbeddingGenerator::new();
397 let vec384: Vec<f32> = (0..384).map(|i| i as f32 / 1000.0).collect();
398 let json = serde_json::to_string(&[&vec384]).unwrap();
399 let result = generator.parse_embeddings(&json).unwrap();
400 assert_eq!(result.len(), 1);
401 assert_eq!(result[0].len(), 384);
402 assert!((result[0][0] - 0.0).abs() < 1e-6);
403 assert!((result[0][1] - 0.001).abs() < 1e-6);
404 }
405
406 #[test]
407 fn test_parse_embeddings_batch_of_three() {
408 let generator = EmbeddingGenerator::new();
409 let vec384: Vec<f32> = vec![0.5f32; 384];
410 let batch = vec![vec384.clone(), vec384.clone(), vec384.clone()];
411 let json = serde_json::to_string(&batch).unwrap();
412 let result = generator.parse_embeddings(&json).unwrap();
413 assert_eq!(result.len(), 3);
414 for emb in &result {
415 assert_eq!(emb.len(), 384);
416 }
417 }
418
419 #[test]
420 fn test_parse_embeddings_invalid_json() {
421 let generator = EmbeddingGenerator::new();
422 let result = generator.parse_embeddings("not valid json");
423 assert!(result.is_err());
424 }
425
426 #[test]
429 fn test_is_placeholder_missing() {
430 assert!(is_placeholder_or_missing(None));
431 }
432
433 #[test]
434 fn test_is_placeholder_zero_bytes() {
435 let blob = vec![0u8; 384 * 4];
436 assert!(is_placeholder_or_missing(Some(&blob)));
437 }
438
439 #[test]
440 fn test_is_placeholder_real_vector() {
441 let v: Vec<f32> = vec![0.1f32; 384];
443 let mut blob = Vec::with_capacity(384 * 4);
444 for &val in &v {
445 blob.extend_from_slice(&val.to_le_bytes());
446 }
447 assert!(!is_placeholder_or_missing(Some(&blob)));
448 }
449
450 #[test]
453 fn test_get_entities_needing_embedding_no_vector() {
454 let conn = make_in_memory_conn();
455
456 let e1 = Entity::new("paper", "Paper Without Vector");
457 let id1 = insert_entity(&conn, &e1).unwrap();
458 let _ = id1;
459
460 let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
461 assert_eq!(result.len(), 1);
462 assert_eq!(result[0].1, "Paper Without Vector");
463 }
464
465 #[test]
466 fn test_get_entities_needing_embedding_placeholder_vector() {
467 let conn = make_in_memory_conn();
468
469 let e1 = Entity::new("paper", "Paper With Placeholder");
470 let id1 = insert_entity(&conn, &e1).unwrap();
471
472 let placeholder = vec![0.0f32; 384];
474 VectorStore::new()
475 .insert_vector(&conn, id1, placeholder)
476 .unwrap();
477
478 let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
479 assert_eq!(result.len(), 1);
480 }
481
482 #[test]
483 fn test_get_entities_needing_embedding_skip_real_vector() {
484 let conn = make_in_memory_conn();
485
486 let e1 = Entity::new("paper", "Paper With Real Embedding");
487 let id1 = insert_entity(&conn, &e1).unwrap();
488
489 let real_embedding = vec![0.1f32; 384];
491 VectorStore::new()
492 .insert_vector(&conn, id1, real_embedding)
493 .unwrap();
494
495 let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
496 assert!(result.is_empty());
498 }
499
500 #[test]
501 fn test_get_entities_needing_embedding_force_returns_all() {
502 let conn = make_in_memory_conn();
503
504 let e1 = Entity::new("paper", "Paper With Real Embedding");
505 let id1 = insert_entity(&conn, &e1).unwrap();
506
507 let real_embedding = vec![0.1f32; 384];
508 VectorStore::new()
509 .insert_vector(&conn, id1, real_embedding)
510 .unwrap();
511
512 let result = get_entities_needing_embedding(&conn, "paper", true).unwrap();
514 assert_eq!(result.len(), 1);
515 }
516
517 #[test]
518 fn test_get_entities_needing_embedding_mixed() {
519 let conn = make_in_memory_conn();
520
521 let e1 = Entity::new("paper", "Has Real Embedding");
522 let id1 = insert_entity(&conn, &e1).unwrap();
523 VectorStore::new()
524 .insert_vector(&conn, id1, vec![0.1f32; 384])
525 .unwrap();
526
527 let e2 = Entity::new("paper", "Has Placeholder");
528 let id2 = insert_entity(&conn, &e2).unwrap();
529 VectorStore::new()
530 .insert_vector(&conn, id2, vec![0.0f32; 384])
531 .unwrap();
532
533 let e3 = Entity::new("paper", "No Vector");
534 insert_entity(&conn, &e3).unwrap();
535
536 let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
538 assert_eq!(result.len(), 2);
539 let names: Vec<&str> = result.iter().map(|(_, n)| n.as_str()).collect();
540 assert!(names.contains(&"Has Placeholder"));
541 assert!(names.contains(&"No Vector"));
542 assert!(!names.contains(&"Has Real Embedding"));
543 }
544
545 #[test]
548 fn test_generate_for_papers_empty() {
549 let conn = make_in_memory_conn();
550 let generator = EmbeddingGenerator::new();
551 let stats = generator.generate_for_papers(&conn).unwrap();
552 assert_eq!(stats.total_count, 0);
553 assert_eq!(stats.processed_count, 0);
554 assert_eq!(stats.skipped_count, 0);
555 }
556
557 #[test]
558 fn test_generate_for_skills_empty() {
559 let conn = make_in_memory_conn();
560 let generator = EmbeddingGenerator::new();
561 let stats = generator.generate_for_skills(&conn).unwrap();
562 assert_eq!(stats.total_count, 0);
563 assert_eq!(stats.processed_count, 0);
564 assert_eq!(stats.skipped_count, 0);
565 }
566
567 #[test]
568 fn test_generate_for_papers_all_real_embeddings_are_skipped() {
569 let conn = make_in_memory_conn();
570
571 for i in 0..3 {
573 let e = Entity::new("paper", format!("Paper {}", i));
574 let id = insert_entity(&conn, &e).unwrap();
575 VectorStore::new()
576 .insert_vector(&conn, id, vec![0.1f32; 384])
577 .unwrap();
578 }
579
580 let generator = EmbeddingGenerator::new(); let stats = generator.generate_for_papers(&conn).unwrap();
582
583 assert_eq!(stats.total_count, 3);
584 assert_eq!(stats.processed_count, 0);
585 assert_eq!(stats.skipped_count, 3);
586 }
587
588 #[test]
591 fn test_get_entities_batch_boundary() {
592 let conn = make_in_memory_conn();
593
594 for i in 0..105 {
596 let e = Entity::new("paper", format!("Paper {}", i));
597 insert_entity(&conn, &e).unwrap();
598 }
599
600 let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
601 assert_eq!(result.len(), 105);
602 }
603
604 #[test]
607 fn test_embedding_stats_fields() {
608 let stats = EmbeddingStats {
609 total_count: 100,
610 processed_count: 80,
611 skipped_count: 20,
612 dimension: 384,
613 };
614 assert_eq!(stats.total_count, 100);
615 assert_eq!(stats.processed_count, 80);
616 assert_eq!(stats.skipped_count, 20);
617 assert_eq!(stats.dimension, 384);
618 }
619}