use crate::fingerprint::{AcoustidFingerprint, FingerprintMatcher};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SimilarityIndexConfig {
pub n_bands: usize,
pub band_width: usize,
pub min_similarity: f32,
pub top_k: usize,
}
impl Default for SimilarityIndexConfig {
fn default() -> Self {
Self {
n_bands: 20,
band_width: 4,
min_similarity: 0.5,
top_k: 10,
}
}
}
#[derive(Debug, Clone)]
struct IndexEntry {
id: String,
fingerprint: AcoustidFingerprint,
}
#[derive(Debug, Clone)]
pub struct SimilarityMatch {
pub id: String,
pub score: f32,
}
pub struct SimilarityIndex {
config: SimilarityIndexConfig,
entries: Vec<IndexEntry>,
buckets: HashMap<u64, Vec<usize>>,
}
impl SimilarityIndex {
#[must_use]
pub fn new(config: SimilarityIndexConfig) -> Self {
Self {
config,
entries: Vec::new(),
buckets: HashMap::new(),
}
}
#[must_use]
pub fn default_index() -> Self {
Self::new(SimilarityIndexConfig::default())
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn insert(&mut self, id: String, fingerprint: AcoustidFingerprint) {
let entry_idx = self.entries.len();
let band_keys = self.compute_band_keys(&fingerprint);
for key in band_keys {
self.buckets.entry(key).or_default().push(entry_idx);
}
self.entries.push(IndexEntry { id, fingerprint });
}
#[must_use]
pub fn search(&self, query: &AcoustidFingerprint) -> Vec<SimilarityMatch> {
if self.entries.is_empty() || query.is_empty() {
return Vec::new();
}
let band_keys = self.compute_band_keys(query);
let mut candidate_indices: Vec<usize> = Vec::new();
for key in &band_keys {
if let Some(bucket) = self.buckets.get(key) {
for &idx in bucket {
if !candidate_indices.contains(&idx) {
candidate_indices.push(idx);
}
}
}
}
if candidate_indices.is_empty() {
candidate_indices = (0..self.entries.len()).collect();
}
let mut scored: Vec<(usize, f32)> = candidate_indices
.into_iter()
.map(|idx| {
let sim = FingerprintMatcher::bit_similarity(query, &self.entries[idx].fingerprint);
(idx, sim)
})
.filter(|&(_, sim)| sim >= self.config.min_similarity)
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(self.config.top_k);
scored
.into_iter()
.map(|(idx, score)| SimilarityMatch {
id: self.entries[idx].id.clone(),
score,
})
.collect()
}
pub fn clear(&mut self) {
self.entries.clear();
self.buckets.clear();
}
fn compute_band_keys(&self, fp: &AcoustidFingerprint) -> Vec<u64> {
let bands = self.config.n_bands;
let bw = self.config.band_width;
let n = fp.fingerprint.len();
if n == 0 {
return Vec::new();
}
let mut keys = Vec::with_capacity(bands);
for band in 0..bands {
let start = (band * bw) % n; let mut hash = 0xcbf29ce484222325_u64; const PRIME: u64 = 0x100000001b3;
for i in 0..bw {
let idx = (start + i) % n;
let word = fp.fingerprint[idx];
for byte_shift in [0_u32, 8, 16, 24] {
let byte = ((word >> byte_shift) & 0xFF) as u64;
hash ^= byte;
hash = hash.wrapping_mul(PRIME);
}
}
hash ^= band as u64;
hash = hash.wrapping_mul(PRIME);
keys.push(hash);
}
keys
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fingerprint::AcoustidEncoder;
use std::f32::consts::TAU;
fn sine_fingerprint(freq: f32, dur_secs: f32) -> AcoustidFingerprint {
let sr = 8000_u32;
let n = (sr as f32 * dur_secs) as usize;
let samples: Vec<f32> = (0..n)
.map(|i| (TAU * freq * i as f32 / sr as f32).sin())
.collect();
AcoustidEncoder::compute_with_frame_size(&samples, sr, 256)
}
#[test]
fn test_index_empty() {
let index = SimilarityIndex::default_index();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
}
#[test]
fn test_insert_and_len() {
let mut index = SimilarityIndex::default_index();
let fp = sine_fingerprint(440.0, 1.0);
index.insert("a".to_string(), fp);
assert_eq!(index.len(), 1);
}
#[test]
fn test_search_self_similarity() {
let mut index = SimilarityIndex::default_index();
let fp = sine_fingerprint(440.0, 1.0);
index.insert("a440".to_string(), fp.clone());
let results = index.search(&fp);
assert!(
!results.is_empty(),
"Self-search must return at least one result"
);
assert_eq!(results[0].id, "a440");
assert!(
(results[0].score - 1.0).abs() < 1e-5,
"Self-similarity should be ~1.0"
);
}
#[test]
fn test_search_different_signals() {
let mut index = SimilarityIndex::new(SimilarityIndexConfig {
min_similarity: 0.0,
..SimilarityIndexConfig::default()
});
let fp_a = sine_fingerprint(440.0, 1.0);
let fp_b = sine_fingerprint(523.25, 1.0);
index.insert("A4".to_string(), fp_a.clone());
index.insert("C5".to_string(), fp_b.clone());
let results = index.search(&fp_a);
assert!(!results.is_empty());
assert_eq!(results[0].id, "A4");
}
#[test]
fn test_search_empty_query() {
let mut index = SimilarityIndex::default_index();
let fp = sine_fingerprint(440.0, 1.0);
index.insert("a".to_string(), fp);
let empty = AcoustidFingerprint {
fingerprint: vec![],
duration_secs: 0.0,
};
let results = index.search(&empty);
assert!(results.is_empty());
}
#[test]
fn test_search_empty_index() {
let index = SimilarityIndex::default_index();
let fp = sine_fingerprint(440.0, 1.0);
let results = index.search(&fp);
assert!(results.is_empty());
}
#[test]
fn test_clear_index() {
let mut index = SimilarityIndex::default_index();
let fp = sine_fingerprint(440.0, 1.0);
index.insert("a".to_string(), fp);
assert_eq!(index.len(), 1);
index.clear();
assert!(index.is_empty());
}
#[test]
fn test_top_k_respected() {
let config = SimilarityIndexConfig {
top_k: 2,
min_similarity: 0.0,
..SimilarityIndexConfig::default()
};
let mut index = SimilarityIndex::new(config);
for i in 0..5 {
let fp = sine_fingerprint(200.0 + i as f32 * 50.0, 0.5);
index.insert(format!("track_{i}"), fp);
}
let query = sine_fingerprint(300.0, 0.5);
let results = index.search(&query);
assert!(results.len() <= 2);
}
#[test]
fn test_band_keys_not_empty() {
let index = SimilarityIndex::default_index();
let fp = sine_fingerprint(440.0, 0.5);
let keys = index.compute_band_keys(&fp);
assert!(!keys.is_empty());
assert_eq!(keys.len(), index.config.n_bands);
}
#[test]
fn test_results_sorted_descending() {
let config = SimilarityIndexConfig {
min_similarity: 0.0,
..SimilarityIndexConfig::default()
};
let mut index = SimilarityIndex::new(config);
let fp_ref = sine_fingerprint(440.0, 1.0);
index.insert("ref".to_string(), fp_ref.clone());
let fp_other = sine_fingerprint(1000.0, 1.0);
index.insert("other".to_string(), fp_other);
let results = index.search(&fp_ref);
for w in results.windows(2) {
assert!(
w[0].score >= w[1].score,
"Results must be sorted descending: {:.4} < {:.4}",
w[0].score,
w[1].score
);
}
}
}