#![allow(dead_code)]
use crate::lsh::{LshConfig, LshIndex, LshResult};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ItemVector {
pub id: String,
pub values: Vec<f64>,
}
impl ItemVector {
#[must_use]
pub fn new(id: impl Into<String>, values: Vec<f64>) -> Self {
Self {
id: id.into(),
values,
}
}
#[must_use]
pub fn dot_product(&self, other: &Self) -> f64 {
self.values
.iter()
.zip(other.values.iter())
.map(|(a, b)| a * b)
.sum()
}
#[must_use]
pub fn magnitude(&self) -> f64 {
self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
}
#[must_use]
pub fn cosine_similarity(&self, other: &Self) -> f64 {
let denom = self.magnitude() * other.magnitude();
if denom < f64::EPSILON {
return 0.0;
}
(self.dot_product(other) / denom).clamp(-1.0, 1.0)
}
#[must_use]
pub fn dim(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct SimilarityMatrix {
scores: HashMap<String, HashMap<String, f64>>,
}
impl SimilarityMatrix {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, id_a: impl Into<String>, id_b: impl Into<String>, score: f64) {
let a = id_a.into();
let b = id_b.into();
self.scores
.entry(a.clone())
.or_default()
.insert(b.clone(), score);
self.scores.entry(b).or_default().insert(a, score);
}
#[must_use]
pub fn get(&self, id_a: &str, id_b: &str) -> Option<f64> {
self.scores.get(id_a)?.get(id_b).copied()
}
#[must_use]
pub fn find_similar(&self, query_id: &str, top_k: usize) -> Vec<(String, f64)> {
let Some(row) = self.scores.get(query_id) else {
return Vec::new();
};
let mut pairs: Vec<(String, f64)> = row
.iter()
.filter(|(id, _)| id.as_str() != query_id)
.map(|(id, &score)| (id.clone(), score))
.collect();
pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
pairs.truncate(top_k);
pairs
}
#[must_use]
pub fn item_count(&self) -> usize {
self.scores.len()
}
#[must_use]
pub fn from_vectors(vectors: &[ItemVector]) -> Self {
let mut matrix = Self::new();
for i in 0..vectors.len() {
for j in (i + 1)..vectors.len() {
let sim = vectors[i].cosine_similarity(&vectors[j]);
matrix.insert(vectors[i].id.clone(), vectors[j].id.clone(), sim);
}
}
matrix
}
}
#[derive(Debug, Clone)]
pub struct LshItemConfig {
pub dim: usize,
pub num_tables: usize,
pub num_planes: usize,
}
impl Default for LshItemConfig {
fn default() -> Self {
Self {
dim: 64,
num_tables: 4,
num_planes: 8,
}
}
}
pub struct LshItemIndex {
inner: LshIndex,
}
impl LshItemIndex {
#[must_use]
pub fn new(config: LshItemConfig) -> Self {
let lsh_config = LshConfig {
dim: config.dim,
num_tables: config.num_tables,
num_planes: config.num_planes,
};
Self {
inner: LshIndex::new(lsh_config),
}
}
pub fn insert(&mut self, item: ItemVector) {
self.inner.insert(item.id, item.values);
}
pub fn bulk_insert(&mut self, items: impl IntoIterator<Item = ItemVector>) {
for item in items {
self.insert(item);
}
}
#[must_use]
pub fn find_similar(&self, query_vector: &[f64], top_k: usize) -> Vec<LshResult> {
self.inner.query(query_vector, top_k)
}
#[must_use]
pub fn find_similar_to_item(&self, item_id: &str, top_k: usize) -> Option<Vec<LshResult>> {
let n = self.inner.len();
if n == 0 {
return None;
}
let zero = vec![0.0f64; self.inner.dim()];
let all_items = self.inner.exact_top_k(&zero, n);
let pos = all_items.iter().position(|r| r.item_id == item_id)?;
let query_vec = self.inner.get_vector(pos)?.to_vec();
let mut results = self.inner.query(&query_vec, top_k + 1);
results.retain(|r| r.item_id != item_id);
results.truncate(top_k);
for (i, r) in results.iter_mut().enumerate() {
r.rank = i + 1;
}
Some(results)
}
#[must_use]
pub fn exact_top_k(&self, query_vector: &[f64], top_k: usize) -> Vec<LshResult> {
self.inner.exact_top_k(query_vector, top_k)
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn dim(&self) -> usize {
self.inner.dim()
}
}
impl Default for LshItemIndex {
fn default() -> Self {
Self::new(LshItemConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vec2(id: &str, x: f64, y: f64) -> ItemVector {
ItemVector::new(id, vec![x, y])
}
#[test]
fn test_dot_product_basic() {
let a = vec2("a", 1.0, 2.0);
let b = vec2("b", 3.0, 4.0);
assert!((a.dot_product(&b) - 11.0).abs() < 1e-10);
}
#[test]
fn test_dot_product_zero() {
let a = vec2("a", 0.0, 0.0);
let b = vec2("b", 1.0, 1.0);
assert!((a.dot_product(&b) - 0.0).abs() < 1e-10);
}
#[test]
fn test_magnitude_unit_vector() {
let v = vec2("v", 1.0, 0.0);
assert!((v.magnitude() - 1.0).abs() < 1e-10);
}
#[test]
fn test_magnitude_general() {
let v = vec2("v", 3.0, 4.0);
assert!((v.magnitude() - 5.0).abs() < 1e-10);
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec2("a", 1.0, 2.0);
let b = vec2("b", 1.0, 2.0);
assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-9);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec2("a", 1.0, 0.0);
let b = vec2("b", 0.0, 1.0);
assert!(a.cosine_similarity(&b).abs() < 1e-10);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec2("a", 0.0, 0.0);
let b = vec2("b", 1.0, 2.0);
assert!((a.cosine_similarity(&b) - 0.0).abs() < 1e-10);
}
#[test]
fn test_item_vector_dim() {
let v = ItemVector::new("x", vec![1.0, 2.0, 3.0]);
assert_eq!(v.dim(), 3);
}
#[test]
fn test_item_vector_is_empty() {
let v = ItemVector::new("x", vec![]);
assert!(v.is_empty());
}
#[test]
fn test_similarity_matrix_insert_and_get() {
let mut m = SimilarityMatrix::new();
m.insert("a", "b", 0.8);
assert!((m.get("a", "b").expect("should succeed in test") - 0.8).abs() < 1e-10);
assert!((m.get("b", "a").expect("should succeed in test") - 0.8).abs() < 1e-10);
}
#[test]
fn test_similarity_matrix_missing_returns_none() {
let m = SimilarityMatrix::new();
assert!(m.get("x", "y").is_none());
}
#[test]
fn test_find_similar_ordering() {
let mut m = SimilarityMatrix::new();
m.insert("a", "b", 0.9);
m.insert("a", "c", 0.5);
m.insert("a", "d", 0.7);
let results = m.find_similar("a", 3);
assert_eq!(results[0].0, "b");
assert_eq!(results[1].0, "d");
}
#[test]
fn test_find_similar_top_k_limit() {
let mut m = SimilarityMatrix::new();
for i in 0..10_u32 {
m.insert("q", format!("item{i}"), f64::from(i) * 0.1);
}
let results = m.find_similar("q", 3);
assert_eq!(results.len(), 3);
}
#[test]
fn test_find_similar_empty_for_unknown() {
let m = SimilarityMatrix::new();
assert!(m.find_similar("unknown", 5).is_empty());
}
#[test]
fn test_item_count() {
let mut m = SimilarityMatrix::new();
m.insert("a", "b", 0.5);
m.insert("a", "c", 0.6);
assert_eq!(m.item_count(), 3);
}
#[test]
fn test_from_vectors_builds_correct_similarity() {
let vectors = vec![
ItemVector::new("a", vec![1.0, 0.0]),
ItemVector::new("b", vec![1.0, 0.0]),
ItemVector::new("c", vec![0.0, 1.0]),
];
let m = SimilarityMatrix::from_vectors(&vectors);
let ab = m.get("a", "b").expect("should succeed in test");
let ac = m.get("a", "c").expect("should succeed in test");
assert!((ab - 1.0).abs() < 1e-9);
assert!(ac.abs() < 1e-9);
}
fn make_lsh_index() -> LshItemIndex {
LshItemIndex::new(LshItemConfig {
dim: 4,
num_tables: 6,
num_planes: 10,
})
}
#[test]
fn test_lsh_item_index_creation() {
let idx = LshItemIndex::default();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
}
#[test]
fn test_lsh_item_index_insert_and_len() {
let mut idx = make_lsh_index();
idx.insert(ItemVector::new("a", vec![1.0, 0.0, 0.0, 0.0]));
idx.insert(ItemVector::new("b", vec![0.0, 1.0, 0.0, 0.0]));
assert_eq!(idx.len(), 2);
assert!(!idx.is_empty());
}
#[test]
fn test_lsh_item_index_bulk_insert() {
let mut idx = make_lsh_index();
let items = vec![
ItemVector::new("x", vec![1.0, 0.0, 0.0, 0.0]),
ItemVector::new("y", vec![0.0, 1.0, 0.0, 0.0]),
ItemVector::new("z", vec![0.0, 0.0, 1.0, 0.0]),
];
idx.bulk_insert(items);
assert_eq!(idx.len(), 3);
}
#[test]
fn test_lsh_item_index_find_similar_returns_results() {
let mut idx = make_lsh_index();
idx.insert(ItemVector::new("a", vec![1.0, 0.0, 0.0, 0.0]));
idx.insert(ItemVector::new("b", vec![0.0, 1.0, 0.0, 0.0]));
idx.insert(ItemVector::new("c", vec![0.0, 0.0, 1.0, 0.0]));
let results = idx.find_similar(&[1.0, 0.0, 0.0, 0.0], 2);
assert!(!results.is_empty());
assert!(results.len() <= 2);
}
#[test]
fn test_lsh_item_index_identical_vector_is_top_result() {
let mut idx = make_lsh_index();
idx.insert(ItemVector::new("target", vec![1.0, 0.0, 0.0, 0.0]));
idx.insert(ItemVector::new("other1", vec![0.0, 1.0, 0.0, 0.0]));
idx.insert(ItemVector::new("other2", vec![0.0, 0.0, 1.0, 0.0]));
let results = idx.find_similar(&[1.0, 0.0, 0.0, 0.0], 1);
if !results.is_empty() {
assert_eq!(results[0].item_id, "target");
}
}
#[test]
fn test_lsh_item_index_similarity_in_range() {
let mut idx = make_lsh_index();
idx.insert(ItemVector::new("a", vec![1.0, 0.0, 0.0, 0.0]));
idx.insert(ItemVector::new("b", vec![0.707, 0.707, 0.0, 0.0]));
let results = idx.find_similar(&[1.0, 0.0, 0.0, 0.0], 2);
for r in &results {
assert!(r.similarity >= -1.0 && r.similarity <= 1.0);
}
}
#[test]
fn test_lsh_item_index_exact_top_k() {
let mut idx = make_lsh_index();
idx.insert(ItemVector::new("a", vec![1.0, 0.0, 0.0, 0.0]));
idx.insert(ItemVector::new("b", vec![1.0, 0.0, 0.0, 0.0]));
idx.insert(ItemVector::new("c", vec![0.0, 1.0, 0.0, 0.0]));
let exact = idx.exact_top_k(&[1.0, 0.0, 0.0, 0.0], 2);
assert_eq!(exact.len(), 2);
assert!(
exact[0].similarity > exact[1].similarity
|| (exact[0].similarity - exact[1].similarity).abs() < 1e-9
);
}
#[test]
fn test_lsh_item_index_empty_query_returns_empty() {
let idx = make_lsh_index();
let results = idx.find_similar(&[1.0, 0.0, 0.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_lsh_item_index_dim() {
let idx = make_lsh_index();
assert_eq!(idx.dim(), 4);
}
#[test]
fn test_lsh_item_config_default() {
let config = LshItemConfig::default();
assert_eq!(config.dim, 64);
assert_eq!(config.num_tables, 4);
assert_eq!(config.num_planes, 8);
}
}