use super::index::ComponentIndex;
use crate::model::{CanonicalId, Component, NormalizedSbom};
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone)]
pub struct LshConfig {
pub num_hashes: usize,
pub num_bands: usize,
pub shingle_size: usize,
pub target_threshold: f64,
pub include_ecosystem_token: bool,
pub include_group_token: bool,
}
impl LshConfig {
#[must_use]
pub fn for_threshold(threshold: f64) -> Self {
let (num_bands, rows_per_band) = if threshold >= 0.9 {
(50, 2) } else if threshold >= 0.8 {
(25, 4) } else if threshold >= 0.7 {
(20, 5) } else if threshold >= 0.5 {
(10, 10) } else {
(5, 20) };
Self {
num_hashes: num_bands * rows_per_band,
num_bands,
shingle_size: 3, target_threshold: threshold,
include_ecosystem_token: true, include_group_token: false, }
}
#[must_use]
pub fn default_balanced() -> Self {
Self::for_threshold(0.8)
}
#[must_use]
pub fn strict() -> Self {
Self::for_threshold(0.9)
}
#[must_use]
pub fn permissive() -> Self {
Self::for_threshold(0.5)
}
#[must_use]
pub const fn rows_per_band(&self) -> usize {
self.num_hashes / self.num_bands
}
}
impl Default for LshConfig {
fn default() -> Self {
Self::default_balanced()
}
}
#[derive(Debug, Clone)]
pub struct MinHashSignature {
pub values: Vec<u64>,
}
impl MinHashSignature {
#[must_use]
pub fn estimated_similarity(&self, other: &Self) -> f64 {
if self.values.len() != other.values.len() {
return 0.0;
}
let matching = self
.values
.iter()
.zip(other.values.iter())
.filter(|(a, b)| a == b)
.count();
matching as f64 / self.values.len() as f64
}
}
pub struct LshIndex {
config: LshConfig,
signatures: HashMap<CanonicalId, MinHashSignature>,
buckets: Vec<HashMap<u64, Vec<CanonicalId>>>,
hash_coeffs: Vec<(u64, u64)>,
prime: u64,
}
impl LshIndex {
#[must_use]
pub fn new(config: LshConfig) -> Self {
use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
let mut hash_coeffs = Vec::with_capacity(config.num_hashes);
let random_state = RandomState::new();
for i in 0..config.num_hashes {
let a = random_state.hash_one(i as u64 * 31337) | 1;
let b = random_state.hash_one(i as u64 * 7919 + 12345);
hash_coeffs.push((a, b));
}
let buckets = (0..config.num_bands)
.map(|_| HashMap::with_capacity(64))
.collect();
Self {
config,
signatures: HashMap::with_capacity(256),
buckets,
hash_coeffs,
prime: 0xFFFF_FFFF_FFFF_FFC5, }
}
#[must_use]
pub fn build(sbom: &NormalizedSbom, config: LshConfig) -> Self {
let mut index = Self::new(config);
for (id, comp) in &sbom.components {
index.insert(id.clone(), comp);
}
index
}
pub fn insert(&mut self, id: CanonicalId, component: &Component) {
let shingles = self.compute_shingles(component);
let signature = self.compute_minhash(&shingles);
self.insert_into_buckets(&id, &signature);
self.signatures.insert(id, signature);
}
#[must_use]
pub fn find_candidates(&self, component: &Component) -> Vec<CanonicalId> {
let shingles = self.compute_shingles(component);
let signature = self.compute_minhash(&shingles);
self.find_candidates_by_signature(&signature)
}
#[must_use]
pub fn find_candidates_by_signature(&self, signature: &MinHashSignature) -> Vec<CanonicalId> {
let mut candidates = HashSet::new();
let rows_per_band = self.config.rows_per_band();
for (band_idx, bucket_map) in self.buckets.iter().enumerate() {
let band_hash = self.hash_band(signature, band_idx, rows_per_band);
if let Some(ids) = bucket_map.get(&band_hash) {
for id in ids {
candidates.insert(id.clone());
}
}
}
candidates.into_iter().collect()
}
pub fn find_candidates_for_id(&self, id: &CanonicalId) -> Vec<CanonicalId> {
self.signatures.get(id).map_or_else(Vec::new, |signature| {
self.find_candidates_by_signature(signature)
})
}
#[must_use]
pub fn get_signature(&self, id: &CanonicalId) -> Option<&MinHashSignature> {
self.signatures.get(id)
}
#[must_use]
pub fn estimate_similarity(&self, id_a: &CanonicalId, id_b: &CanonicalId) -> Option<f64> {
let sig_a = self.signatures.get(id_a)?;
let sig_b = self.signatures.get(id_b)?;
Some(sig_a.estimated_similarity(sig_b))
}
pub fn stats(&self) -> LshIndexStats {
let total_components = self.signatures.len();
let total_buckets: usize = self
.buckets
.iter()
.map(std::collections::HashMap::len)
.sum();
let max_bucket_size = self
.buckets
.iter()
.flat_map(|b| b.values())
.map(std::vec::Vec::len)
.max()
.unwrap_or(0);
let avg_bucket_size = if total_buckets > 0 {
self.buckets
.iter()
.flat_map(|b| b.values())
.map(std::vec::Vec::len)
.sum::<usize>() as f64
/ total_buckets as f64
} else {
0.0
};
LshIndexStats {
total_components,
num_bands: self.config.num_bands,
num_hashes: self.config.num_hashes,
total_buckets,
max_bucket_size,
avg_bucket_size,
}
}
fn compute_shingles(&self, component: &Component) -> HashSet<u64> {
let ecosystem = component
.ecosystem
.as_ref()
.map(std::string::ToString::to_string);
let ecosystem_str = ecosystem.as_deref();
let normalized = ComponentIndex::normalize_name(&component.name, ecosystem_str);
let chars: Vec<char> = normalized.chars().collect();
let estimated_shingles = chars.len().saturating_sub(self.config.shingle_size) + 3;
let mut shingles = HashSet::with_capacity(estimated_shingles);
if chars.len() < self.config.shingle_size {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
normalized.hash(&mut hasher);
shingles.insert(hasher.finish());
} else {
for window in chars.windows(self.config.shingle_size) {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
window.hash(&mut hasher);
shingles.insert(hasher.finish());
}
}
if self.config.include_ecosystem_token
&& let Some(ref eco) = ecosystem
{
let mut hasher = std::collections::hash_map::DefaultHasher::new();
"__eco:".hash(&mut hasher);
eco.to_lowercase().hash(&mut hasher);
shingles.insert(hasher.finish());
}
if self.config.include_group_token
&& let Some(ref group) = component.group
{
let mut hasher = std::collections::hash_map::DefaultHasher::new();
"__grp:".hash(&mut hasher);
group.to_lowercase().hash(&mut hasher);
shingles.insert(hasher.finish());
}
shingles
}
fn compute_minhash(&self, shingles: &HashSet<u64>) -> MinHashSignature {
let mut min_hashes = vec![u64::MAX; self.config.num_hashes];
for &shingle in shingles {
for (i, &(a, b)) in self.hash_coeffs.iter().enumerate() {
let hash = a.wrapping_mul(shingle).wrapping_add(b) % self.prime;
if hash < min_hashes[i] {
min_hashes[i] = hash;
}
}
}
MinHashSignature { values: min_hashes }
}
fn insert_into_buckets(&mut self, id: &CanonicalId, signature: &MinHashSignature) {
let rows_per_band = self.config.rows_per_band();
let band_hashes: Vec<u64> = (0..self.config.num_bands)
.map(|band_idx| self.hash_band(signature, band_idx, rows_per_band))
.collect();
for (band_idx, bucket_map) in self.buckets.iter_mut().enumerate() {
bucket_map
.entry(band_hashes[band_idx])
.or_default()
.push(id.clone());
}
}
fn hash_band(
&self,
signature: &MinHashSignature,
band_idx: usize,
rows_per_band: usize,
) -> u64 {
let start = band_idx * rows_per_band;
let end = (start + rows_per_band).min(signature.values.len());
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for &value in &signature.values[start..end] {
value.hash(&mut hasher);
}
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct LshIndexStats {
pub total_components: usize,
pub num_bands: usize,
pub num_hashes: usize,
pub total_buckets: usize,
pub max_bucket_size: usize,
pub avg_bucket_size: f64,
}
impl std::fmt::Display for LshIndexStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LSH Index: {} components, {} bands × {} hashes, {} buckets (max: {}, avg: {:.1})",
self.total_components,
self.num_bands,
self.num_hashes / self.num_bands,
self.total_buckets,
self.max_bucket_size,
self.avg_bucket_size
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::DocumentMetadata;
fn make_component(name: &str) -> Component {
Component::new(name.to_string(), format!("id-{}", name))
}
#[test]
fn test_lsh_config_for_threshold() {
let config = LshConfig::for_threshold(0.8);
assert_eq!(config.num_hashes, 100);
assert!(config.num_bands > 0);
assert_eq!(config.num_hashes, config.num_bands * config.rows_per_band());
}
#[test]
fn test_minhash_signature_similarity() {
let sig_a = MinHashSignature {
values: vec![1, 2, 3, 4, 5],
};
let sig_b = MinHashSignature {
values: vec![1, 2, 3, 4, 5],
};
assert_eq!(sig_a.estimated_similarity(&sig_b), 1.0);
let sig_c = MinHashSignature {
values: vec![1, 2, 3, 6, 7],
};
assert!((sig_a.estimated_similarity(&sig_c) - 0.6).abs() < 0.01);
}
#[test]
fn test_lsh_index_build() {
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(make_component("lodash"));
sbom.add_component(make_component("lodash-es"));
sbom.add_component(make_component("underscore"));
sbom.add_component(make_component("react"));
let index = LshIndex::build(&sbom, LshConfig::default_balanced());
let stats = index.stats();
assert_eq!(stats.total_components, 4);
assert!(stats.total_buckets > 0);
}
#[test]
fn test_lsh_finds_similar_names() {
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(make_component("lodash"));
sbom.add_component(make_component("lodash-es"));
sbom.add_component(make_component("lodash-fp"));
sbom.add_component(make_component("react"));
sbom.add_component(make_component("angular"));
let index = LshIndex::build(&sbom, LshConfig::for_threshold(0.5));
let query = make_component("lodash");
let candidates = index.find_candidates(&query);
assert!(
!candidates.is_empty(),
"Should find at least some candidates"
);
}
#[test]
fn test_lsh_signature_estimation() {
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
let comp1 = make_component("lodash");
let comp2 = make_component("lodash-es");
let comp3 = make_component("completely-different-name");
let id1 = comp1.canonical_id.clone();
let id2 = comp2.canonical_id.clone();
let id3 = comp3.canonical_id.clone();
sbom.add_component(comp1);
sbom.add_component(comp2);
sbom.add_component(comp3);
let index = LshIndex::build(&sbom, LshConfig::default_balanced());
let sim_12 = index.estimate_similarity(&id1, &id2).unwrap();
let sim_13 = index.estimate_similarity(&id1, &id3).unwrap();
assert!(
sim_12 > sim_13,
"lodash vs lodash-es ({:.2}) should be more similar than lodash vs completely-different ({:.2})",
sim_12,
sim_13
);
}
#[test]
fn test_lsh_index_stats() {
let config = LshConfig::for_threshold(0.8);
let index = LshIndex::new(config);
let stats = index.stats();
assert_eq!(stats.total_components, 0);
assert_eq!(stats.num_bands, 25);
assert_eq!(stats.num_hashes, 100);
}
}