use crate::error::ImgFprintError;
#[cfg(feature = "local-embedding")]
pub mod local;
#[cfg(feature = "local-embedding")]
pub use local::{LocalProvider, LocalProviderConfig};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
#[derive(Debug, Clone, PartialEq)]
#[allow(clippy::derive_partial_eq_without_eq)] pub struct Embedding {
vector: Vec<f32>,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
model_id: Option<String>,
}
impl Embedding {
pub fn new(vector: Vec<f32>) -> Result<Self, ImgFprintError> {
Self::new_with_model(vector, None)
}
pub fn new_with_model(
vector: Vec<f32>,
model_id: Option<String>,
) -> Result<Self, ImgFprintError> {
if vector.is_empty() {
return Err(ImgFprintError::InvalidEmbedding(
"embedding vector cannot be empty".to_string(),
));
}
if vector.iter().any(|&v| !v.is_finite()) {
return Err(ImgFprintError::InvalidEmbedding(
"embedding contains non-finite values (NaN or infinity)".to_string(),
));
}
Ok(Self { vector, model_id })
}
#[inline]
pub fn model_id(&self) -> Option<&str> {
self.model_id.as_deref()
}
#[inline]
pub fn as_slice(&self) -> &[f32] {
&self.vector
}
pub fn vector(&self) -> Vec<f32> {
self.vector.clone()
}
#[inline]
#[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize {
self.vector.len()
}
#[inline]
pub fn dimension(&self) -> usize {
self.len()
}
}
pub trait EmbeddingProvider {
fn embed(&self, image: &[u8]) -> Result<Embedding, ImgFprintError>;
}
pub fn semantic_similarity(a: &Embedding, b: &Embedding) -> Result<f32, ImgFprintError> {
if let (Some(a_model), Some(b_model)) = (a.model_id(), b.model_id()) {
if a_model != b_model {
return Err(ImgFprintError::InvalidEmbedding(format!(
"model ID mismatch: '{}' vs '{}'",
a_model, b_model
)));
}
}
let a_vec = a.as_slice();
let b_vec = b.as_slice();
if a_vec.len() != b_vec.len() {
return Err(ImgFprintError::EmbeddingDimensionMismatch {
expected: a_vec.len(),
actual: b_vec.len(),
});
}
let mut dot_product: f32 = 0.0;
let mut norm_first_sq: f32 = 0.0;
let mut norm_second_sq: f32 = 0.0;
for i in 0..a_vec.len() {
let a_i = a_vec[i];
let b_i = b_vec[i];
dot_product += a_i * b_i;
norm_first_sq += a_i * a_i;
norm_second_sq += b_i * b_i;
}
let norm_a = norm_first_sq.sqrt();
let norm_b = norm_second_sq.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return Err(ImgFprintError::InvalidEmbedding(
"cannot compute similarity for zero vector".to_string(),
));
}
Ok(dot_product / (norm_a * norm_b))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_new_valid() {
let vector = vec![0.1, 0.2, 0.3, 0.4];
let embedding = Embedding::new(vector.clone()).unwrap();
assert_eq!(embedding.len(), 4);
assert_eq!(embedding.dimension(), 4);
assert_eq!(embedding.as_slice(), &vector);
assert_eq!(embedding.vector(), vector);
}
#[test]
fn test_embedding_empty_vector() {
let result = Embedding::new(vec![]);
assert!(matches!(
result,
Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("empty")
));
}
#[test]
fn test_embedding_nan_values() {
let result = Embedding::new(vec![0.1, f32::NAN, 0.3]);
assert!(matches!(
result,
Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("non-finite")
));
}
#[test]
fn test_embedding_infinity_values() {
let result = Embedding::new(vec![0.1, f32::INFINITY, 0.3]);
assert!(matches!(
result,
Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("non-finite")
));
}
#[test]
fn test_embedding_negative_infinity() {
let result = Embedding::new(vec![0.1, f32::NEG_INFINITY, 0.3]);
assert!(matches!(
result,
Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("non-finite")
));
}
fn emb(vector: Vec<f32>) -> Embedding {
Embedding::new(vector).unwrap()
}
#[test]
fn test_cosine_similarity_identical() {
let a = emb(vec![1.0, 0.0, 0.0, 0.0]);
let b = emb(vec![1.0, 0.0, 0.0, 0.0]);
let sim = semantic_similarity(&a, &b).unwrap();
assert!(
(sim - 1.0).abs() < 1e-6,
"Identical vectors should have similarity 1.0, got {}",
sim
);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = emb(vec![1.0, 0.0, 0.0]);
let b = emb(vec![0.0, 1.0, 0.0]);
let sim = semantic_similarity(&a, &b).unwrap();
assert!(
sim.abs() < 1e-6,
"Orthogonal vectors should have similarity ~0.0, got {}",
sim
);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = emb(vec![1.0, 0.0, 0.0]);
let b = emb(vec![-1.0, 0.0, 0.0]);
let sim = semantic_similarity(&a, &b).unwrap();
assert!(
(sim - (-1.0)).abs() < 1e-6,
"Opposite vectors should have similarity -1.0, got {}",
sim
);
}
#[test]
fn test_cosine_similarity_45_degrees() {
let a = emb(vec![1.0, 0.0]);
let b = emb(vec![1.0, 1.0]);
let sim = semantic_similarity(&a, &b).unwrap();
let expected = 1.0 / f32::sqrt(2.0);
assert!(
(sim - expected).abs() < 1e-5,
"45-degree angle similarity should be ~0.707, got {}",
sim
);
}
#[test]
fn test_cosine_similarity_dimension_mismatch() {
let a = emb(vec![1.0, 0.0, 0.0]);
let b = emb(vec![1.0, 0.0]);
let result = semantic_similarity(&a, &b);
assert!(matches!(
result,
Err(ImgFprintError::EmbeddingDimensionMismatch {
expected: 3,
actual: 2
})
));
}
#[test]
fn test_cosine_similarity_with_normalization() {
let a = emb(vec![0.5, 0.5, 0.5, 0.5]);
let b = emb(vec![0.5, 0.5, 0.5, 0.5]);
let sim = semantic_similarity(&a, &b).unwrap();
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_various_dimensions() {
for dim in [512, 768, 1024] {
let a = emb(vec![1.0; dim]);
let b = emb(vec![1.0; dim]);
let sim = semantic_similarity(&a, &b).unwrap();
assert!((sim - 1.0).abs() < 1e-6, "Failed for dimension {}", dim);
}
}
#[test]
fn test_cosine_similarity_negative_values() {
let a = emb(vec![1.0, -1.0, 1.0, -1.0]);
let b = emb(vec![-1.0, 1.0, -1.0, 1.0]);
let sim = semantic_similarity(&a, &b).unwrap();
assert!(
(sim - (-1.0)).abs() < 1e-6,
"Opposite signs should give -1.0, got {}",
sim
);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = emb(vec![1.0, 0.0, 0.0]);
let _b = emb(vec![0.0, 0.0, 0.0]);
let zero_embedding = Embedding {
vector: vec![0.0; 3],
model_id: None,
};
let result = semantic_similarity(&a, &zero_embedding);
assert!(matches!(
result,
Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("zero vector")
));
}
#[test]
fn test_embedding_clone() {
let a = emb(vec![0.1, 0.2, 0.3]);
let b = a.clone();
assert_eq!(a.as_slice(), b.as_slice());
assert_eq!(a.len(), b.len());
}
#[test]
fn test_embedding_partial_eq() {
let a = emb(vec![0.1, 0.2, 0.3]);
let b = emb(vec![0.1, 0.2, 0.3]);
let c = emb(vec![0.3, 0.2, 0.1]);
assert_eq!(a, b);
assert_ne!(a, c);
}
struct MockProvider {
return_value: Vec<f32>,
}
impl EmbeddingProvider for MockProvider {
fn embed(&self, _image: &[u8]) -> Result<Embedding, ImgFprintError> {
Embedding::new(self.return_value.clone())
}
}
#[test]
fn test_embedding_provider_mock() {
let provider = MockProvider {
return_value: vec![0.1, 0.2, 0.3],
};
let image_bytes = vec![0u8; 100];
let embedding = provider.embed(&image_bytes).unwrap();
assert_eq!(embedding.as_slice(), &[0.1, 0.2, 0.3]);
}
#[test]
fn test_embedding_provider_error_propagation() {
struct FailingProvider;
impl EmbeddingProvider for FailingProvider {
fn embed(&self, _image: &[u8]) -> Result<Embedding, ImgFprintError> {
Err(ImgFprintError::ProviderError("network timeout".to_string()))
}
}
let provider = FailingProvider;
let image_bytes = vec![0u8; 100];
let result = provider.embed(&image_bytes);
assert!(matches!(
result,
Err(ImgFprintError::ProviderError(msg)) if msg == "network timeout"
));
}
#[test]
fn test_embedding_new_with_model() {
let vector = vec![0.1, 0.2, 0.3, 0.4];
let embedding =
Embedding::new_with_model(vector.clone(), Some("clip-vit-base-patch32".to_string()))
.unwrap();
assert_eq!(embedding.len(), 4);
assert_eq!(embedding.model_id(), Some("clip-vit-base-patch32"));
assert_eq!(embedding.as_slice(), &vector);
}
#[test]
fn test_embedding_model_id_mismatch() {
let a = Embedding::new_with_model(vec![0.1; 512], Some("model-a".to_string())).unwrap();
let b = Embedding::new_with_model(vec![0.1; 512], Some("model-b".to_string())).unwrap();
let result = semantic_similarity(&a, &b);
assert!(matches!(
result,
Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("model ID mismatch")
));
}
#[test]
fn test_embedding_same_model_id_ok() {
let a =
Embedding::new_with_model(vec![1.0; 512], Some("clip-vit-base-patch32".to_string()))
.unwrap();
let b =
Embedding::new_with_model(vec![1.0; 512], Some("clip-vit-base-patch32".to_string()))
.unwrap();
let sim = semantic_similarity(&a, &b).unwrap();
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_embedding_missing_model_id_ok() {
let a = Embedding::new(vec![1.0; 512]).unwrap();
let b = Embedding::new_with_model(vec![1.0; 512], None).unwrap();
let c = Embedding::new_with_model(vec![1.0; 512], Some("model-a".to_string())).unwrap();
let sim1 = semantic_similarity(&a, &b).unwrap();
assert!((sim1 - 1.0).abs() < 1e-6);
let sim2 = semantic_similarity(&a, &c).unwrap();
assert!((sim2 - 1.0).abs() < 1e-6);
}
}