use anyhow::Result;
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct EmbeddingResult {
pub embedding: Vec<f32>,
pub token_count: Option<usize>,
pub model: String,
pub generation_time_ms: Option<u64>,
}
impl EmbeddingResult {
#[must_use]
pub fn new(embedding: Vec<f32>, model: String) -> Self {
Self {
embedding,
token_count: None,
model,
generation_time_ms: None,
}
}
#[must_use]
pub fn detailed(
embedding: Vec<f32>,
model: String,
token_count: usize,
generation_time_ms: u64,
) -> Self {
Self {
embedding,
token_count: Some(token_count),
model,
generation_time_ms: Some(generation_time_ms),
}
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
let embedding = self.embed_text(text).await?;
embeddings.push(embedding);
}
Ok(embeddings)
}
async fn similarity(&self, text1: &str, text2: &str) -> Result<f32> {
let embedding1 = self.embed_text(text1).await?;
let embedding2 = self.embed_text(text2).await?;
Ok(crate::embeddings::similarity::cosine_similarity(
&embedding1,
&embedding2,
))
}
fn embedding_dimension(&self) -> usize;
fn model_name(&self) -> &str;
async fn is_available(&self) -> bool {
self.embed_text("test").await.is_ok()
}
async fn warmup(&self) -> Result<()> {
self.embed_text("warmup test").await?;
Ok(())
}
fn metadata(&self) -> serde_json::Value {
serde_json::json!({
"model": self.model_name(),
"dimension": self.embedding_dimension()
})
}
}
pub mod utils {
use anyhow::Result;
pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
let magnitude = (vector.iter().map(|x| x * x).sum::<f32>()).sqrt();
if magnitude > 0.0 {
for x in &mut vector {
*x /= magnitude;
}
}
vector
}
#[allow(dead_code)]
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<()> {
if embedding.len() != expected {
anyhow::bail!(
"Embedding dimension mismatch: got {}, expected {}",
embedding.len(),
expected
);
}
Ok(())
}
#[allow(dead_code)]
pub fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
if text.len() <= max_chars {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let words: Vec<&str> = text.split_whitespace().collect();
let mut current_chunk = String::new();
for word in words {
if current_chunk.len() + word.len() + 1 > max_chars && !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
current_chunk = word.to_string();
} else {
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
current_chunk.push_str(word);
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
chunks
}
#[allow(dead_code)]
pub fn average_embeddings(embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
if embeddings.is_empty() {
anyhow::bail!("Cannot average empty embeddings list");
}
let dimension = embeddings[0].len();
let mut result = vec![0.0; dimension];
for embedding in embeddings {
if embedding.len() != dimension {
anyhow::bail!("Inconsistent embedding dimensions");
}
for (i, &value) in embedding.iter().enumerate() {
result[i] += value;
}
}
let count = embeddings.len() as f32;
for value in &mut result {
*value /= count;
}
Ok(normalize_vector(result))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_vector() {
let vector = vec![3.0, 4.0]; let normalized = utils::normalize_vector(vector);
assert!((normalized[0] - 0.6).abs() < 0.001);
assert!((normalized[1] - 0.8).abs() < 0.001);
let magnitude = (normalized.iter().map(|x| x * x).sum::<f32>()).sqrt();
assert!((magnitude - 1.0).abs() < 0.001);
}
#[test]
fn test_chunk_text() {
let text =
"This is a long text that needs to be chunked into smaller pieces for processing";
let chunks = utils::chunk_text(text, 25);
assert!(chunks.len() > 1);
for chunk in &chunks {
assert!(chunk.len() <= 25);
}
let rejoined = chunks.join(" ");
let original_words: Vec<&str> = text.split_whitespace().collect();
let rejoined_words: Vec<&str> = rejoined.split_whitespace().collect();
assert_eq!(original_words, rejoined_words);
}
#[test]
fn test_average_embeddings() {
let embeddings = vec![
vec![1.0, 2.0, 3.0],
vec![2.0, 4.0, 6.0],
vec![3.0, 6.0, 9.0],
];
let averaged = utils::average_embeddings(&embeddings)
.expect("average_embeddings should succeed with valid embedding vectors");
let expected_magnitude = (4.0 + 16.0 + 36.0_f32).sqrt(); let expected = [
2.0 / expected_magnitude,
4.0 / expected_magnitude,
6.0 / expected_magnitude,
];
for (actual, expected) in averaged.iter().zip(expected.iter()) {
assert!((actual - expected).abs() < 0.001);
}
}
#[test]
fn test_validate_dimension() {
let embedding = vec![1.0, 2.0, 3.0];
assert!(utils::validate_dimension(&embedding, 3).is_ok());
assert!(utils::validate_dimension(&embedding, 4).is_err());
}
}