use std::sync::Arc;
use async_trait::async_trait;
use infernum_core::Result;
use abaddon::{Engine, InferenceEngine};
#[async_trait]
pub trait Embedder: Send + Sync {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
async fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed(&[text]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| infernum_core::Error::internal("No embedding generated"))
}
fn dimension(&self) -> usize;
fn model_name(&self) -> &str;
}
pub struct EngineEmbedder {
engine: Arc<Engine>,
dimension: usize,
}
impl EngineEmbedder {
pub fn new(engine: Arc<Engine>) -> Self {
let model_info = engine.model_info();
let dimension = model_info.hidden_size as usize;
tracing::debug!(
model_id = %model_info.id.0,
dimension = dimension,
"Created engine embedder"
);
Self { engine, dimension }
}
pub fn with_dimension(engine: Arc<Engine>, dimension: usize) -> Self {
Self { engine, dimension }
}
#[must_use]
pub fn engine(&self) -> &Arc<Engine> {
&self.engine
}
}
#[async_trait]
impl Embedder for EngineEmbedder {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
let request = infernum_core::EmbedRequest::new(text.to_string());
let response = self.engine.embed(request).await?;
let embedding = response
.data
.into_iter()
.next()
.ok_or_else(|| infernum_core::Error::internal("No embedding in response"))?;
let vec = embedding
.embedding
.as_floats()
.map_err(|e| infernum_core::Error::internal(e))?;
embeddings.push(vec);
}
Ok(embeddings)
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.engine.model_info().id.0
}
}
pub struct MockEmbedder {
dimension: usize,
}
impl MockEmbedder {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Ok(texts
.iter()
.map(|text| {
let hash = text.bytes().fold(0u64, |acc, b| acc.wrapping_add(b as u64));
(0..self.dimension)
.map(|i| {
let seed = hash.wrapping_add(i as u64);
((seed % 1000) as f32 / 1000.0) - 0.5
})
.collect()
})
.collect())
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
"mock-embedder"
}
}
pub struct SentenceEmbedder {
embedder: Arc<dyn Embedder>,
pooling: PoolingStrategy,
normalize: bool,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum PoolingStrategy {
Cls,
#[default]
Mean,
Last,
Max,
}
impl SentenceEmbedder {
pub fn new(embedder: Arc<dyn Embedder>) -> Self {
Self {
embedder,
pooling: PoolingStrategy::Mean,
normalize: true,
}
}
#[must_use]
pub fn with_pooling(mut self, strategy: PoolingStrategy) -> Self {
self.pooling = strategy;
self
}
#[must_use]
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
fn normalize_vec(vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in vec.iter_mut() {
*x /= norm;
}
}
}
}
#[async_trait]
impl Embedder for SentenceEmbedder {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = self.embedder.embed(texts).await?;
if self.normalize {
for emb in embeddings.iter_mut() {
Self::normalize_vec(emb);
}
}
Ok(embeddings)
}
fn dimension(&self) -> usize {
self.embedder.dimension()
}
fn model_name(&self) -> &str {
self.embedder.model_name()
}
}
pub struct BatchEmbedder {
embedder: Arc<dyn Embedder>,
batch_size: usize,
}
impl BatchEmbedder {
pub fn new(embedder: Arc<dyn Embedder>) -> Self {
Self {
embedder,
batch_size: 32,
}
}
#[must_use]
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut all_embeddings = Vec::with_capacity(texts.len());
for chunk in texts.chunks(self.batch_size) {
let refs: Vec<&str> = chunk.iter().map(String::as_str).collect();
let batch_embeddings = self.embedder.embed(&refs).await?;
all_embeddings.extend(batch_embeddings);
}
Ok(all_embeddings)
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-10 && norm_b > 1e-10 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_embedder() {
let embedder = MockEmbedder::new(384);
let embeddings = embedder.embed(&["hello", "world"]).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 384);
assert_eq!(embeddings[1].len(), 384);
let emb1 = embedder.embed(&["hello"]).await.unwrap();
let emb2 = embedder.embed(&["hello"]).await.unwrap();
assert_eq!(emb1[0], emb2[0]);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c)).abs() < 1e-6);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((euclidean_distance(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![3.0, 4.0, 0.0];
assert!((euclidean_distance(&a, &c) - 5.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_sentence_embedder_normalization() {
let mock = Arc::new(MockEmbedder::new(3));
let embedder = SentenceEmbedder::new(mock);
let embeddings = embedder.embed(&["test"]).await.unwrap();
let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
}