use crate::domain::error::{DomainError, DomainResult};
use ndarray::Array1;
use std::any::Any;
use std::fmt::Debug;
use std::io::Cursor;
use tracing::instrument;
pub trait Embedder: Send + Sync + Debug {
fn embed(&self, text: &str) -> DomainResult<Option<Vec<f32>>>;
fn as_any(&self) -> &dyn Any; }
#[instrument(skip_all)]
pub fn cosine_similarity(vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
let dot_product = vec1.dot(vec2);
let magnitude_vec1 = vec1.dot(vec1).sqrt();
let magnitude_vec2 = vec2.dot(vec2).sqrt();
if magnitude_vec1 == 0.0 || magnitude_vec2 == 0.0 {
return 0.0;
}
dot_product / (magnitude_vec1 * magnitude_vec2)
}
#[instrument(skip_all)]
pub fn deserialize_embedding(bytes: Vec<u8>) -> Result<Vec<f32>, DomainError> {
bincode::decode_from_slice::<Vec<f32>, _>(&bytes, bincode::config::legacy())
.map(|(result, _)| result)
.map_err(|e| DomainError::DeserializationError(e.to_string()))
}
#[instrument(skip_all)]
pub fn serialize_embedding(embedding: Vec<f32>) -> Result<Vec<u8>, DomainError> {
bincode::encode_to_vec(&embedding, bincode::config::legacy())
.map_err(|e| DomainError::SerializationError(e.to_string()))
}
#[instrument(skip_all, level = "debug")]
pub fn bytes_to_array(bytes: &[u8]) -> Result<Array1<f32>, DomainError> {
let mut cursor = Cursor::new(bytes);
let num_floats = bytes.len() / 4;
let mut values = Vec::with_capacity(num_floats);
for _ in 0..num_floats {
match byteorder::ReadBytesExt::read_f32::<byteorder::LittleEndian>(&mut cursor) {
Ok(value) => values.push(value),
Err(e) => return Err(DomainError::Io(e)),
}
}
Ok(Array1::from(values))
}
#[instrument(skip_all, level = "debug")]
pub fn array_to_bytes(array: &Array1<f32>) -> Result<Vec<u8>, DomainError> {
let mut buffer = Vec::with_capacity(array.len() * 4);
for &value in array.iter() {
match byteorder::WriteBytesExt::write_f32::<byteorder::LittleEndian>(&mut buffer, value) {
Ok(_) => {}
Err(e) => return Err(DomainError::Io(e)),
}
}
Ok(buffer)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
const EPSILON: f32 = 1e-6;
#[test]
fn given_two_vectors_when_calculate_cosine_similarity_then_returns_similarity_score() {
let vec1 = array![1.0, 0.0];
let vec2 = array![0.0, 1.0];
let similarity = cosine_similarity(&vec1, &vec2);
assert!((similarity - 0.0).abs() < EPSILON);
let vec3 = array![1.0, 1.0];
let vec4 = array![1.0, 1.0];
let similarity = cosine_similarity(&vec3, &vec4);
assert!((similarity - 1.0).abs() < EPSILON);
}
#[test]
fn given_embedding_vector_when_serialize_deserialize_then_preserves_data() {
let original = vec![1.0f32, 2.0, 3.0];
let bytes = serialize_embedding(original.clone()).unwrap();
let deserialized = deserialize_embedding(bytes).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn given_array_when_convert_to_from_bytes_then_preserves_data() {
let original = array![1.0f32, 2.0, 3.0, 4.0];
let bytes = array_to_bytes(&original).unwrap();
let reconstructed = bytes_to_array(&bytes).unwrap();
for (a, b) in original.iter().zip(reconstructed.iter()) {
assert!((a - b).abs() < EPSILON);
}
}
}