use std::io;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use crate::embed::{EmbeddingConfig, EmbeddingStats};
use crate::error::{Error, Result};
pub struct AsyncEmbeddingGenerator {
config: EmbeddingConfig,
pub skip_existing: bool,
}
impl AsyncEmbeddingGenerator {
pub fn new() -> Self {
Self {
config: EmbeddingConfig::default(),
skip_existing: true,
}
}
pub fn with_config(config: EmbeddingConfig) -> Self {
Self {
config,
skip_existing: true,
}
}
pub fn with_force(mut self, force: bool) -> Self {
self.skip_existing = !force;
self
}
pub async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let python_script = self.build_python_script();
let texts_json = serde_json::to_string(&texts)
.map_err(|e| Error::Other(format!("failed to serialize texts: {e}")))?;
let mut child = Command::new("python3")
.arg("-c")
.arg(&python_script)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| Error::Other(format!("failed to spawn Python: {e}")))?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(texts_json.as_bytes())
.await
.map_err(|e| Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
}
let output = child
.wait_with_output()
.await
.map_err(|e| Error::Other(format!("failed to read Python output: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::Other(format!("Python script failed: {stderr}")));
}
let stdout = String::from_utf8_lossy(&output.stdout);
self.parse_embeddings(&stdout)
}
pub async fn generate_for_type(
&self,
kg: &super::AsyncKnowledgeGraph,
entity_type: &str,
) -> Result<EmbeddingStats> {
let entities = kg
.list_entities(Some(entity_type.to_string()), None)
.await?;
let total_count = entities.len() as i64;
let to_process = if self.skip_existing {
let mut need_embedding = Vec::new();
for entity in entities {
if let Some(id) = entity.id {
let has_vector = kg.search_vectors(vec![0.0; 1], 1).await;
let inner = kg.inner();
let skip = {
let kg_lock = inner.lock().map_err(|e| {
crate::error::Error::TaskPanicked(format!("mutex poisoned: {e}"))
})?;
let store = crate::vector::VectorStore::new();
store.get_vector(kg_lock.connection(), id).is_ok()
};
drop(has_vector);
if !skip {
need_embedding.push(entity);
}
}
}
need_embedding
} else {
entities
};
if to_process.is_empty() {
return Ok(EmbeddingStats {
total_count,
processed_count: 0,
skipped_count: total_count,
dimension: self.config.dimension,
});
}
let texts: Vec<String> = to_process.iter().map(|e| e.name.clone()).collect();
let embeddings = self.generate_embeddings(texts).await?;
let mut processed_count = 0i64;
for (entity, embedding) in to_process.iter().zip(embeddings.iter()) {
if let Some(id) = entity.id {
kg.insert_vector(id, embedding.clone()).await?;
processed_count += 1;
}
}
Ok(EmbeddingStats {
total_count,
processed_count,
skipped_count: total_count - processed_count,
dimension: self.config.dimension,
})
}
fn build_python_script(&self) -> String {
format!(
r#"
import sys
import json
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('{model}')
texts_json = sys.stdin.read()
texts = json.loads(texts_json)
embeddings = model.encode(texts, convert_to_numpy=True)
print(json.dumps(embeddings.tolist()))
except ImportError:
print('{{"error": "sentence-transformers not installed. Run: pip install sentence-transformers"}}', file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f'{{"error": "{{}}"}}".format(str(e)), file=sys.stderr)
sys.exit(1)
"#,
model = self.config.model_name
)
}
fn parse_embeddings(&self, output: &str) -> Result<Vec<Vec<f32>>> {
let embeddings: Vec<Vec<f32>> = serde_json::from_str(output.trim())
.map_err(|e| Error::Other(format!("failed to parse embeddings: {e}")))?;
for embedding in &embeddings {
if embedding.len() != self.config.dimension {
return Err(Error::InvalidVectorDimension {
expected: self.config.dimension,
actual: embedding.len(),
});
}
}
Ok(embeddings)
}
}
impl Default for AsyncEmbeddingGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let gen = AsyncEmbeddingGenerator::new();
assert_eq!(gen.config.model_name, "all-MiniLM-L6-v2");
assert_eq!(gen.config.dimension, 384);
assert!(gen.skip_existing);
}
#[test]
fn test_with_force() {
let gen = AsyncEmbeddingGenerator::new().with_force(true);
assert!(!gen.skip_existing);
}
#[tokio::test]
async fn test_empty_texts_returns_empty() {
let gen = AsyncEmbeddingGenerator::new();
let result = gen.generate_embeddings(vec![]).await.unwrap();
assert!(result.is_empty());
}
}