use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Embedding {
pub data_point_id: Uuid,
pub data_type: String,
pub field_name: String,
pub vector: Vec<f32>,
}
impl Embedding {
pub fn new(
data_point_id: Uuid,
data_type: impl Into<String>,
field_name: impl Into<String>,
vector: Vec<f32>,
) -> Self {
Self {
data_point_id,
data_type: data_type.into(),
field_name: field_name.into(),
vector,
}
}
pub fn dimensions(&self) -> usize {
self.vector.len()
}
pub fn norm(&self) -> f32 {
self.vector.iter().map(|x| x * x).sum::<f32>().sqrt()
}
pub fn cosine_similarity(&self, other: &Embedding) -> Option<f32> {
if self.vector.len() != other.vector.len() {
return None;
}
let similarity = self
.vector
.iter()
.zip(&other.vector)
.map(|(a, b)| a * b)
.sum();
Some(similarity)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_creation() {
let id = Uuid::new_v4();
let embedding = Embedding::new(id, "DocumentChunk", "text", vec![0.1, 0.2, 0.3]);
assert_eq!(embedding.data_point_id, id);
assert_eq!(embedding.data_type, "DocumentChunk");
assert_eq!(embedding.field_name, "text");
assert_eq!(embedding.dimensions(), 3);
}
#[test]
fn test_norm() {
let embedding = Embedding::new(
Uuid::new_v4(),
"Entity",
"name",
vec![0.6, 0.8], );
let norm = embedding.norm();
assert!((norm - 1.0).abs() < 0.01, "Expected norm ~1.0, got {norm}");
}
#[test]
fn test_cosine_similarity() {
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let e1 = Embedding::new(id1, "Entity", "name", vec![1.0, 0.0, 0.0]);
let e2 = Embedding::new(id2, "Entity", "name", vec![1.0, 0.0, 0.0]);
let e3 = Embedding::new(id2, "Entity", "name", vec![0.0, 1.0, 0.0]);
assert_eq!(e1.cosine_similarity(&e2), Some(1.0));
assert_eq!(e1.cosine_similarity(&e3), Some(0.0));
}
#[test]
fn test_cosine_similarity_dimension_mismatch() {
let e1 = Embedding::new(Uuid::new_v4(), "Entity", "name", vec![1.0, 0.0]);
let e2 = Embedding::new(Uuid::new_v4(), "Entity", "name", vec![1.0, 0.0, 0.0]);
assert_eq!(e1.cosine_similarity(&e2), None);
}
}