use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::distance::{self, DistanceMetric};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityResult {
pub score: f32,
pub rank: usize,
pub index: usize,
pub metadata: HashMap<String, String>,
}
impl SimilarityResult {
pub fn new(score: f32, rank: usize, index: usize) -> Self {
Self {
score,
rank,
index,
metadata: HashMap::new(),
}
}
pub fn with_metadata(
score: f32,
rank: usize,
index: usize,
metadata: HashMap<String, String>,
) -> Self {
Self {
score,
rank,
index,
metadata,
}
}
}
pub struct EmbeddingSimilarity {
metric: DistanceMetric,
}
impl EmbeddingSimilarity {
pub fn new(metric: DistanceMetric) -> Self {
Self { metric }
}
pub fn metric(&self) -> DistanceMetric {
self.metric
}
pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
distance::compute_similarity(self.metric, a, b)
}
pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
distance::compute_distance(self.metric, a, b)
}
pub fn score_all(&self, query: &[f32], candidates: &[Vec<f32>]) -> Vec<SimilarityResult> {
let mut results: Vec<SimilarityResult> = candidates
.iter()
.enumerate()
.map(|(i, c)| SimilarityResult::new(self.similarity(query, c), 0, i))
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (rank, r) in results.iter_mut().enumerate() {
r.rank = rank;
}
results
}
}
pub struct PairwiseSimilarityMatrix {
pub matrix: Vec<Vec<f32>>,
pub rows: usize,
pub cols: usize,
}
impl PairwiseSimilarityMatrix {
pub fn compute(
embeddings_a: &[Vec<f32>],
embeddings_b: &[Vec<f32>],
metric: DistanceMetric,
) -> Self {
let rows = embeddings_a.len();
let cols = embeddings_b.len();
let matrix: Vec<Vec<f32>> = embeddings_a
.iter()
.map(|a| {
embeddings_b
.iter()
.map(|b| distance::compute_similarity(metric, a, b))
.collect()
})
.collect();
Self { matrix, rows, cols }
}
pub fn compute_symmetric(embeddings: &[Vec<f32>], metric: DistanceMetric) -> Self {
let n = embeddings.len();
let mut matrix = vec![vec![0.0f32; n]; n];
for i in 0..n {
matrix[i][i] = distance::compute_similarity(metric, &embeddings[i], &embeddings[i]);
for j in (i + 1)..n {
let sim = distance::compute_similarity(metric, &embeddings[i], &embeddings[j]);
matrix[i][j] = sim;
matrix[j][i] = sim;
}
}
Self {
matrix,
rows: n,
cols: n,
}
}
pub fn get(&self, i: usize, j: usize) -> f32 {
self.matrix[i][j]
}
pub fn most_similar_per_row(&self) -> Vec<(usize, f32)> {
self.matrix
.iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, &score)| (idx, score))
.unwrap_or((0, 0.0))
})
.collect()
}
pub fn most_similar_per_col(&self) -> Vec<(usize, f32)> {
(0..self.cols)
.map(|j| {
self.matrix
.iter()
.enumerate()
.map(|(i, row)| (i, row[j]))
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, 0.0))
})
.collect()
}
}
pub struct KNearestNeighbors {
embeddings: Vec<Vec<f32>>,
metric: DistanceMetric,
}
impl KNearestNeighbors {
pub fn new(embeddings: Vec<Vec<f32>>, metric: DistanceMetric) -> Self {
Self { embeddings, metric }
}
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<SimilarityResult> {
let calc = EmbeddingSimilarity::new(self.metric);
let mut results = calc.score_all(query, &self.embeddings);
results.truncate(k);
results
}
pub fn search_with_metadata(
&self,
query: &[f32],
k: usize,
metadata: &[HashMap<String, String>],
) -> Vec<SimilarityResult> {
let mut results = self.search(query, k);
for r in &mut results {
if let Some(m) = metadata.get(r.index) {
r.metadata = m.clone();
}
}
results
}
pub fn add(&mut self, embedding: Vec<f32>) {
self.embeddings.push(embedding);
}
pub fn add_batch(&mut self, embeddings: Vec<Vec<f32>>) {
self.embeddings.extend(embeddings);
}
}
pub struct ClusterAssignment {
pub assignments: Vec<usize>,
pub centroids: Vec<Vec<f32>>,
pub k: usize,
}
impl ClusterAssignment {
pub fn assign(
embeddings: &[Vec<f32>],
initial_centroids: &[Vec<f32>],
metric: DistanceMetric,
) -> Self {
let k = initial_centroids.len();
assert!(k > 0, "must have at least one centroid");
assert!(
!embeddings.is_empty(),
"must have at least one embedding to cluster"
);
let assignments: Vec<usize> = embeddings
.iter()
.map(|emb| {
initial_centroids
.iter()
.enumerate()
.map(|(ci, c)| (ci, distance::compute_similarity(metric, emb, c)))
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(ci, _)| ci)
.unwrap()
})
.collect();
let dim = embeddings[0].len();
let mut centroids = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, &label) in assignments.iter().enumerate() {
counts[label] += 1;
for (d, &val) in embeddings[i].iter().enumerate() {
centroids[label][d] += val;
}
}
for (ci, centroid) in centroids.iter_mut().enumerate() {
if counts[ci] > 0 {
let n = counts[ci] as f32;
for val in centroid.iter_mut() {
*val /= n;
}
}
}
Self {
assignments,
centroids,
k,
}
}
pub fn members(&self, cluster: usize) -> Vec<usize> {
self.assignments
.iter()
.enumerate()
.filter(|(_, &c)| c == cluster)
.map(|(i, _)| i)
.collect()
}
pub fn cluster_sizes(&self) -> Vec<usize> {
let mut sizes = vec![0usize; self.k];
for &c in &self.assignments {
sizes[c] += 1;
}
sizes
}
}
pub struct EmbeddingNormalizer;
impl EmbeddingNormalizer {
pub fn l2_normalize(vector: &[f32]) -> Vec<f32> {
distance::normalize(vector)
}
pub fn l2_normalize_batch(vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
vectors.iter().map(|v| Self::l2_normalize(v)).collect()
}
pub fn mean_center(vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
if vectors.is_empty() {
return vec![];
}
let mean = distance::mean_vector(vectors);
vectors
.iter()
.map(|v| v.iter().zip(mean.iter()).map(|(a, b)| a - b).collect())
.collect()
}
pub fn unit_variance(vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
if vectors.is_empty() {
return vec![];
}
let dim = vectors[0].len();
let n = vectors.len() as f32;
let mean = distance::mean_vector(vectors);
let mut variance = vec![0.0f32; dim];
for v in vectors {
for (d, &val) in v.iter().enumerate() {
let diff = val - mean[d];
variance[d] += diff * diff;
}
}
let stddev: Vec<f32> = variance.iter().map(|v| (v / n).sqrt()).collect();
vectors
.iter()
.map(|v| {
v.iter()
.enumerate()
.map(|(d, &val)| {
if stddev[d] == 0.0 {
val
} else {
val / stddev[d]
}
})
.collect()
})
.collect()
}
pub fn normalize_full(vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
let centered = Self::mean_center(vectors);
Self::l2_normalize_batch(¢ered)
}
}
pub struct DimensionalityReducer {
projection: Vec<Vec<f32>>,
pub output_dim: usize,
pub input_dim: usize,
}
impl DimensionalityReducer {
pub fn new(input_dim: usize, output_dim: usize, seed: u64) -> Self {
assert!(output_dim > 0, "output_dim must be positive");
assert!(input_dim > 0, "input_dim must be positive");
let scale = 1.0 / (output_dim as f32).sqrt();
let mut projection = Vec::with_capacity(output_dim);
let mut state = seed;
for _ in 0..output_dim {
let mut row = Vec::with_capacity(input_dim);
for _ in 0..input_dim {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let val = if state.is_multiple_of(2) {
scale
} else {
-scale
};
row.push(val);
}
projection.push(row);
}
Self {
projection,
output_dim,
input_dim,
}
}
pub fn project(&self, vector: &[f32]) -> Vec<f32> {
assert_eq!(
vector.len(),
self.input_dim,
"vector dimension mismatch: expected {}, got {}",
self.input_dim,
vector.len()
);
self.projection
.iter()
.map(|row| row.iter().zip(vector.iter()).map(|(a, b)| a * b).sum())
.collect()
}
pub fn project_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
vectors.iter().map(|v| self.project(v)).collect()
}
}
pub struct SimilarityThreshold {
pub min_score: f32,
}
impl SimilarityThreshold {
pub fn new(min_score: f32) -> Self {
Self { min_score }
}
pub fn filter(&self, results: Vec<SimilarityResult>) -> Vec<SimilarityResult> {
results
.into_iter()
.filter(|r| r.score >= self.min_score)
.collect()
}
pub fn filter_and_rerank(&self, results: Vec<SimilarityResult>) -> Vec<SimilarityResult> {
let mut filtered = self.filter(results);
for (rank, r) in filtered.iter_mut().enumerate() {
r.rank = rank;
}
filtered
}
pub fn search(
&self,
query: &[f32],
candidates: &[Vec<f32>],
metric: DistanceMetric,
) -> Vec<SimilarityResult> {
let calc = EmbeddingSimilarity::new(metric);
let results = calc.score_all(query, candidates);
self.filter_and_rerank(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-5;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_similarity_result_new() {
let r = SimilarityResult::new(0.95, 0, 3);
assert!(approx_eq(r.score, 0.95));
assert_eq!(r.rank, 0);
assert_eq!(r.index, 3);
assert!(r.metadata.is_empty());
}
#[test]
fn test_similarity_result_with_metadata() {
let mut meta = HashMap::new();
meta.insert("label".to_string(), "cat".to_string());
let r = SimilarityResult::with_metadata(0.8, 1, 5, meta);
assert_eq!(r.metadata.get("label").unwrap(), "cat");
assert_eq!(r.index, 5);
}
#[test]
fn test_similarity_result_serialization() {
let r = SimilarityResult::new(0.5, 2, 7);
let json = serde_json::to_string(&r).unwrap();
let decoded: SimilarityResult = serde_json::from_str(&json).unwrap();
assert!(approx_eq(decoded.score, 0.5));
assert_eq!(decoded.rank, 2);
assert_eq!(decoded.index, 7);
}
#[test]
fn test_embedding_similarity_cosine_identical() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Cosine);
let v = vec![1.0, 2.0, 3.0];
assert!(approx_eq(calc.similarity(&v, &v), 1.0));
}
#[test]
fn test_embedding_similarity_cosine_orthogonal() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Cosine);
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(calc.similarity(&a, &b), 0.0));
}
#[test]
fn test_embedding_similarity_euclidean() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Euclidean);
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!(approx_eq(calc.similarity(&a, &b), 1.0 / 6.0));
}
#[test]
fn test_embedding_similarity_dot_product() {
let calc = EmbeddingSimilarity::new(DistanceMetric::DotProduct);
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!(approx_eq(calc.similarity(&a, &b), 32.0));
}
#[test]
fn test_embedding_similarity_manhattan() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Manhattan);
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 6.0, 3.0];
assert!(approx_eq(calc.similarity(&a, &b), 0.125));
}
#[test]
fn test_embedding_similarity_chebyshev() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Chebyshev);
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 6.0, 3.0];
assert!(approx_eq(calc.similarity(&a, &b), 0.2));
}
#[test]
fn test_embedding_similarity_hamming() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Hamming);
let a = vec![1.0, 0.0, 1.0, 0.0];
let b = vec![1.0, 1.0, 0.0, 0.0];
assert!(approx_eq(calc.similarity(&a, &b), 0.5));
}
#[test]
fn test_embedding_similarity_distance() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Cosine);
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(calc.distance(&a, &b), 1.0));
}
#[test]
fn test_embedding_similarity_metric_getter() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Manhattan);
assert_eq!(calc.metric(), DistanceMetric::Manhattan);
}
#[test]
fn test_score_all_ordering() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Cosine);
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
vec![0.0, 1.0, 0.0], vec![1.0, 0.0, 0.0], vec![0.7, 0.7, 0.0], ];
let results = calc.score_all(&query, &candidates);
assert_eq!(results.len(), 3);
assert_eq!(results[0].index, 1); assert_eq!(results[0].rank, 0);
assert_eq!(results[1].rank, 1);
assert_eq!(results[2].rank, 2);
assert!(results[0].score >= results[1].score);
assert!(results[1].score >= results[2].score);
}
#[test]
fn test_score_all_empty_candidates() {
let calc = EmbeddingSimilarity::new(DistanceMetric::Cosine);
let query = vec![1.0, 0.0];
let candidates: Vec<Vec<f32>> = vec![];
let results = calc.score_all(&query, &candidates);
assert!(results.is_empty());
}
#[test]
fn test_pairwise_matrix_identity() {
let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let b = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let m = PairwiseSimilarityMatrix::compute(&a, &b, DistanceMetric::Cosine);
assert_eq!(m.rows, 2);
assert_eq!(m.cols, 2);
assert!(approx_eq(m.get(0, 0), 1.0));
assert!(approx_eq(m.get(0, 1), 0.0));
assert!(approx_eq(m.get(1, 0), 0.0));
assert!(approx_eq(m.get(1, 1), 1.0));
}
#[test]
fn test_pairwise_matrix_asymmetric_shape() {
let a = vec![vec![1.0, 0.0, 0.0]];
let b = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let m = PairwiseSimilarityMatrix::compute(&a, &b, DistanceMetric::Cosine);
assert_eq!(m.rows, 1);
assert_eq!(m.cols, 3);
assert!(approx_eq(m.get(0, 0), 1.0));
assert!(approx_eq(m.get(0, 1), 0.0));
assert!(approx_eq(m.get(0, 2), 0.0));
}
#[test]
fn test_pairwise_matrix_symmetric() {
let vecs = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let m = PairwiseSimilarityMatrix::compute_symmetric(&vecs, DistanceMetric::Cosine);
assert_eq!(m.rows, 3);
assert_eq!(m.cols, 3);
for i in 0..3 {
for j in 0..3 {
assert!(approx_eq(m.get(i, j), m.get(j, i)));
}
}
assert!(approx_eq(m.get(0, 0), 1.0));
assert!(approx_eq(m.get(1, 1), 1.0));
}
#[test]
fn test_most_similar_per_row() {
let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let b = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let m = PairwiseSimilarityMatrix::compute(&a, &b, DistanceMetric::Cosine);
let best = m.most_similar_per_row();
assert_eq!(best[0].0, 1); assert_eq!(best[1].0, 0); }
#[test]
fn test_most_similar_per_col() {
let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let b = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let m = PairwiseSimilarityMatrix::compute(&a, &b, DistanceMetric::Cosine);
let best = m.most_similar_per_col();
assert_eq!(best[0].0, 1); assert_eq!(best[1].0, 0); }
#[test]
fn test_knn_search_basic() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![0.9, 0.1, 0.0],
];
let knn = KNearestNeighbors::new(embeddings, DistanceMetric::Cosine);
let results = knn.search(&[1.0, 0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].index, 0); assert_eq!(results[1].index, 3); }
#[test]
fn test_knn_search_k_larger_than_collection() {
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let knn = KNearestNeighbors::new(embeddings, DistanceMetric::Cosine);
let results = knn.search(&[1.0, 0.0], 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_knn_len_and_empty() {
let knn = KNearestNeighbors::new(vec![], DistanceMetric::Cosine);
assert!(knn.is_empty());
assert_eq!(knn.len(), 0);
let knn2 = KNearestNeighbors::new(vec![vec![1.0]], DistanceMetric::Cosine);
assert!(!knn2.is_empty());
assert_eq!(knn2.len(), 1);
}
#[test]
fn test_knn_add() {
let mut knn = KNearestNeighbors::new(vec![vec![1.0, 0.0]], DistanceMetric::Cosine);
assert_eq!(knn.len(), 1);
knn.add(vec![0.0, 1.0]);
assert_eq!(knn.len(), 2);
}
#[test]
fn test_knn_add_batch() {
let mut knn = KNearestNeighbors::new(vec![], DistanceMetric::Cosine);
knn.add_batch(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]]);
assert_eq!(knn.len(), 3);
}
#[test]
fn test_knn_search_with_metadata() {
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.7, 0.7]];
let metadata = vec![
{
let mut m = HashMap::new();
m.insert("id".to_string(), "a".to_string());
m
},
{
let mut m = HashMap::new();
m.insert("id".to_string(), "b".to_string());
m
},
{
let mut m = HashMap::new();
m.insert("id".to_string(), "c".to_string());
m
},
];
let knn = KNearestNeighbors::new(embeddings, DistanceMetric::Cosine);
let results = knn.search_with_metadata(&[1.0, 0.0], 2, &metadata);
assert_eq!(results.len(), 2);
assert_eq!(results[0].metadata.get("id").unwrap(), "a");
}
#[test]
fn test_knn_search_ranks_are_sequential() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.7, 0.7],
vec![0.0, 1.0],
vec![0.9, 0.1],
];
let knn = KNearestNeighbors::new(embeddings, DistanceMetric::Cosine);
let results = knn.search(&[1.0, 0.0], 4);
for (i, r) in results.iter().enumerate() {
assert_eq!(r.rank, i);
}
}
#[test]
fn test_cluster_assignment_two_clusters() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
vec![0.1, 0.9],
];
let centroids = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let ca = ClusterAssignment::assign(&embeddings, ¢roids, DistanceMetric::Cosine);
assert_eq!(ca.k, 2);
assert_eq!(ca.assignments[0], 0);
assert_eq!(ca.assignments[1], 0);
assert_eq!(ca.assignments[2], 1);
assert_eq!(ca.assignments[3], 1);
}
#[test]
fn test_cluster_assignment_members() {
let embeddings = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0]];
let centroids = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let ca = ClusterAssignment::assign(&embeddings, ¢roids, DistanceMetric::Cosine);
let members_0 = ca.members(0);
let members_1 = ca.members(1);
assert!(members_0.contains(&0));
assert!(members_0.contains(&1));
assert!(members_1.contains(&2));
}
#[test]
fn test_cluster_sizes() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.8, 0.2],
vec![0.0, 1.0],
];
let centroids = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let ca = ClusterAssignment::assign(&embeddings, ¢roids, DistanceMetric::Cosine);
let sizes = ca.cluster_sizes();
assert_eq!(sizes[0], 3);
assert_eq!(sizes[1], 1);
}
#[test]
fn test_cluster_centroids_recomputed() {
let embeddings = vec![vec![2.0, 0.0], vec![4.0, 0.0]];
let centroids = vec![vec![1.0, 0.0]];
let ca = ClusterAssignment::assign(&embeddings, ¢roids, DistanceMetric::Cosine);
assert!(approx_eq(ca.centroids[0][0], 3.0));
assert!(approx_eq(ca.centroids[0][1], 0.0));
}
#[test]
fn test_cluster_single_centroid() {
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let centroids = vec![vec![0.5, 0.5]];
let ca = ClusterAssignment::assign(&embeddings, ¢roids, DistanceMetric::Cosine);
assert!(ca.assignments.iter().all(|&c| c == 0));
assert_eq!(ca.cluster_sizes(), vec![3]);
}
#[test]
fn test_l2_normalize() {
let v = vec![3.0, 4.0];
let n = EmbeddingNormalizer::l2_normalize(&v);
let mag: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(approx_eq(mag, 1.0));
}
#[test]
fn test_l2_normalize_zero_vector() {
let v = vec![0.0, 0.0, 0.0];
let n = EmbeddingNormalizer::l2_normalize(&v);
assert!(n.iter().all(|&x| x == 0.0));
}
#[test]
fn test_l2_normalize_batch() {
let vecs = vec![vec![3.0, 4.0], vec![0.0, 5.0]];
let normed = EmbeddingNormalizer::l2_normalize_batch(&vecs);
for n in &normed {
let mag: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(approx_eq(mag, 1.0));
}
}
#[test]
fn test_mean_center() {
let vecs = vec![vec![2.0, 4.0], vec![4.0, 6.0]];
let centered = EmbeddingNormalizer::mean_center(&vecs);
assert!(approx_eq(centered[0][0], -1.0));
assert!(approx_eq(centered[0][1], -1.0));
assert!(approx_eq(centered[1][0], 1.0));
assert!(approx_eq(centered[1][1], 1.0));
}
#[test]
fn test_mean_center_empty() {
let centered = EmbeddingNormalizer::mean_center(&[]);
assert!(centered.is_empty());
}
#[test]
fn test_unit_variance() {
let vecs = vec![vec![1.0, 10.0], vec![3.0, 20.0], vec![5.0, 30.0]];
let scaled = EmbeddingNormalizer::unit_variance(&vecs);
assert_eq!(scaled.len(), 3);
for v in &scaled {
assert_eq!(v.len(), 2);
}
}
#[test]
fn test_unit_variance_empty() {
let scaled = EmbeddingNormalizer::unit_variance(&[]);
assert!(scaled.is_empty());
}
#[test]
fn test_normalize_full() {
let vecs = vec![vec![2.0, 4.0], vec![4.0, 6.0]];
let result = EmbeddingNormalizer::normalize_full(&vecs);
assert_eq!(result.len(), 2);
for v in &result {
let mag: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(approx_eq(mag, 1.0));
}
}
#[test]
fn test_reducer_output_dimensions() {
let reducer = DimensionalityReducer::new(100, 10, 42);
assert_eq!(reducer.input_dim, 100);
assert_eq!(reducer.output_dim, 10);
let v = vec![1.0; 100];
let projected = reducer.project(&v);
assert_eq!(projected.len(), 10);
}
#[test]
fn test_reducer_deterministic() {
let r1 = DimensionalityReducer::new(50, 5, 123);
let r2 = DimensionalityReducer::new(50, 5, 123);
let v = vec![1.0; 50];
let p1 = r1.project(&v);
let p2 = r2.project(&v);
for (a, b) in p1.iter().zip(p2.iter()) {
assert!(approx_eq(*a, *b));
}
}
#[test]
fn test_reducer_different_seeds() {
let r1 = DimensionalityReducer::new(50, 5, 100);
let r2 = DimensionalityReducer::new(50, 5, 200);
let v = vec![1.0; 50];
let p1 = r1.project(&v);
let p2 = r2.project(&v);
let any_diff = p1.iter().zip(p2.iter()).any(|(a, b)| !approx_eq(*a, *b));
assert!(
any_diff,
"different seeds should produce different projections"
);
}
#[test]
fn test_reducer_batch() {
let reducer = DimensionalityReducer::new(20, 3, 42);
let vecs = vec![vec![1.0; 20], vec![2.0; 20], vec![0.5; 20]];
let projected = reducer.project_batch(&vecs);
assert_eq!(projected.len(), 3);
for p in &projected {
assert_eq!(p.len(), 3);
}
}
#[test]
#[should_panic(expected = "vector dimension mismatch")]
fn test_reducer_dimension_mismatch() {
let reducer = DimensionalityReducer::new(10, 3, 42);
let v = vec![1.0; 5]; reducer.project(&v);
}
#[test]
fn test_reducer_preserves_relative_similarity() {
let reducer = DimensionalityReducer::new(100, 20, 42);
let a = vec![1.0; 100];
let mut b = vec![1.0; 100];
b[0] = 0.9; let c = vec![-1.0; 100];
let pa = reducer.project(&a);
let pb = reducer.project(&b);
let pc = reducer.project(&c);
let sim_ab = distance::cosine_similarity(&pa, &pb);
let sim_ac = distance::cosine_similarity(&pa, &pc);
assert!(
sim_ab > sim_ac,
"similar vectors should project closer together"
);
}
#[test]
fn test_threshold_filter() {
let threshold = SimilarityThreshold::new(0.5);
let results = vec![
SimilarityResult::new(0.9, 0, 0),
SimilarityResult::new(0.3, 1, 1),
SimilarityResult::new(0.7, 2, 2),
SimilarityResult::new(0.1, 3, 3),
];
let filtered = threshold.filter(results);
assert_eq!(filtered.len(), 2);
assert!(filtered.iter().all(|r| r.score >= 0.5));
}
#[test]
fn test_threshold_filter_all_pass() {
let threshold = SimilarityThreshold::new(0.0);
let results = vec![
SimilarityResult::new(0.5, 0, 0),
SimilarityResult::new(0.1, 1, 1),
];
let filtered = threshold.filter(results);
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_threshold_filter_none_pass() {
let threshold = SimilarityThreshold::new(1.0);
let results = vec![
SimilarityResult::new(0.5, 0, 0),
SimilarityResult::new(0.9, 1, 1),
];
let filtered = threshold.filter(results);
assert!(filtered.is_empty());
}
#[test]
fn test_threshold_filter_and_rerank() {
let threshold = SimilarityThreshold::new(0.5);
let results = vec![
SimilarityResult::new(0.9, 0, 0),
SimilarityResult::new(0.3, 1, 1),
SimilarityResult::new(0.7, 2, 2),
];
let filtered = threshold.filter_and_rerank(results);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered[0].rank, 0);
assert_eq!(filtered[1].rank, 1);
}
#[test]
fn test_threshold_search() {
let threshold = SimilarityThreshold::new(0.5);
let query = vec![1.0, 0.0];
let candidates = vec![
vec![1.0, 0.0], vec![0.0, 1.0], vec![0.7, 0.7], ];
let results = threshold.search(&query, &candidates, DistanceMetric::Cosine);
assert_eq!(results.len(), 2); assert_eq!(results[0].rank, 0);
assert_eq!(results[1].rank, 1);
}
#[test]
fn test_threshold_exact_boundary() {
let threshold = SimilarityThreshold::new(0.5);
let results = vec![SimilarityResult::new(0.5, 0, 0)];
let filtered = threshold.filter(results);
assert_eq!(filtered.len(), 1); }
#[test]
fn test_knn_with_threshold() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let knn = KNearestNeighbors::new(embeddings, DistanceMetric::Cosine);
let results = knn.search(&[1.0, 0.0, 0.0], 4);
let threshold = SimilarityThreshold::new(0.5);
let filtered = threshold.filter_and_rerank(results);
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_cluster_then_search_within() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
vec![0.1, 0.9],
];
let centroids = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let ca = ClusterAssignment::assign(&embeddings, ¢roids, DistanceMetric::Cosine);
let cluster_0_members = ca.members(0);
let cluster_0_embeddings: Vec<Vec<f32>> = cluster_0_members
.iter()
.map(|&i| embeddings[i].clone())
.collect();
let knn = KNearestNeighbors::new(cluster_0_embeddings, DistanceMetric::Cosine);
let results = knn.search(&[1.0, 0.0], 1);
assert_eq!(results.len(), 1);
assert!(results[0].score > 0.9);
}
#[test]
fn test_normalize_then_compare() {
let vecs = vec![vec![3.0, 4.0], vec![6.0, 8.0]]; let normed = EmbeddingNormalizer::l2_normalize_batch(&vecs);
let calc = EmbeddingSimilarity::new(DistanceMetric::Cosine);
let sim = calc.similarity(&normed[0], &normed[1]);
assert!(approx_eq(sim, 1.0)); }
#[test]
fn test_reduce_then_knn() {
let reducer = DimensionalityReducer::new(50, 5, 42);
let embeddings: Vec<Vec<f32>> = (0..10)
.map(|i| {
let mut v = vec![0.0f32; 50];
v[i % 50] = 1.0;
v
})
.collect();
let projected = reducer.project_batch(&embeddings);
let knn = KNearestNeighbors::new(projected, DistanceMetric::Cosine);
let query = reducer.project(&{
let mut v = vec![0.0f32; 50];
v[0] = 1.0;
v
});
let results = knn.search(&query, 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0].index, 0); }
#[test]
fn test_pairwise_then_threshold() {
let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let b = vec![vec![1.0, 0.0], vec![0.7, 0.7]];
let m = PairwiseSimilarityMatrix::compute(&a, &b, DistanceMetric::Cosine);
let threshold = SimilarityThreshold::new(0.5);
let row_0_results: Vec<SimilarityResult> = m.matrix[0]
.iter()
.enumerate()
.map(|(j, &score)| SimilarityResult::new(score, j, j))
.collect();
let filtered = threshold.filter(row_0_results);
assert_eq!(filtered.len(), 2); }
}