1use super::index::ComponentIndex;
21use crate::model::{CanonicalId, Component, NormalizedSbom};
22use std::collections::{HashMap, HashSet};
23use std::hash::{Hash, Hasher};
24
25#[derive(Debug, Clone)]
27pub struct LshConfig {
28 pub num_hashes: usize,
30 pub num_bands: usize,
32 pub shingle_size: usize,
34 pub target_threshold: f64,
36 pub include_ecosystem_token: bool,
38 pub include_group_token: bool,
40}
41
42impl LshConfig {
43 #[must_use]
48 pub fn for_threshold(threshold: f64) -> Self {
49 let (num_bands, rows_per_band) = if threshold >= 0.9 {
58 (50, 2) } else if threshold >= 0.8 {
60 (25, 4) } else if threshold >= 0.7 {
62 (20, 5) } else if threshold >= 0.5 {
64 (10, 10) } else {
66 (5, 20) };
68
69 Self {
70 num_hashes: num_bands * rows_per_band,
71 num_bands,
72 shingle_size: 3, target_threshold: threshold,
74 include_ecosystem_token: true, include_group_token: false, }
77 }
78
79 #[must_use]
81 pub fn default_balanced() -> Self {
82 Self::for_threshold(0.8)
83 }
84
85 #[must_use]
87 pub fn strict() -> Self {
88 Self::for_threshold(0.9)
89 }
90
91 #[must_use]
93 pub fn permissive() -> Self {
94 Self::for_threshold(0.5)
95 }
96
97 #[must_use]
99 pub const fn rows_per_band(&self) -> usize {
100 self.num_hashes / self.num_bands
101 }
102}
103
104impl Default for LshConfig {
105 fn default() -> Self {
106 Self::default_balanced()
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct MinHashSignature {
113 pub values: Vec<u64>,
115}
116
117impl MinHashSignature {
118 #[must_use]
120 pub fn estimated_similarity(&self, other: &Self) -> f64 {
121 if self.values.len() != other.values.len() {
122 return 0.0;
123 }
124
125 let matching = self
126 .values
127 .iter()
128 .zip(other.values.iter())
129 .filter(|(a, b)| a == b)
130 .count();
131
132 matching as f64 / self.values.len() as f64
133 }
134}
135
136pub struct LshIndex {
138 config: LshConfig,
140 signatures: HashMap<CanonicalId, MinHashSignature>,
142 buckets: Vec<HashMap<u64, Vec<CanonicalId>>>,
144 hash_coeffs: Vec<(u64, u64)>,
146 prime: u64,
148}
149
150impl LshIndex {
151 #[must_use]
153 pub fn new(config: LshConfig) -> Self {
154 use std::collections::hash_map::RandomState;
155 use std::hash::BuildHasher;
156
157 let mut hash_coeffs = Vec::with_capacity(config.num_hashes);
159 let random_state = RandomState::new();
160
161 for i in 0..config.num_hashes {
162 let a = random_state.hash_one(i as u64 * 31337) | 1; let b = random_state.hash_one(i as u64 * 7919 + 12345);
165
166 hash_coeffs.push((a, b));
167 }
168
169 let buckets = (0..config.num_bands)
171 .map(|_| HashMap::with_capacity(64))
172 .collect();
173
174 Self {
175 config,
176 signatures: HashMap::with_capacity(256),
177 buckets,
178 hash_coeffs,
179 prime: 0xFFFF_FFFF_FFFF_FFC5, }
181 }
182
183 #[must_use]
185 pub fn build(sbom: &NormalizedSbom, config: LshConfig) -> Self {
186 let mut index = Self::new(config);
187
188 for (id, comp) in &sbom.components {
189 index.insert(id.clone(), comp);
190 }
191
192 index
193 }
194
195 pub fn insert(&mut self, id: CanonicalId, component: &Component) {
197 let shingles = self.compute_shingles(component);
199
200 let signature = self.compute_minhash(&shingles);
202
203 self.insert_into_buckets(&id, &signature);
205
206 self.signatures.insert(id, signature);
208 }
209
210 #[must_use]
215 pub fn find_candidates(&self, component: &Component) -> Vec<CanonicalId> {
216 let shingles = self.compute_shingles(component);
217 let signature = self.compute_minhash(&shingles);
218
219 self.find_candidates_by_signature(&signature)
220 }
221
222 #[must_use]
224 pub fn find_candidates_by_signature(&self, signature: &MinHashSignature) -> Vec<CanonicalId> {
225 let mut candidates = HashSet::new();
226 let rows_per_band = self.config.rows_per_band();
227
228 for (band_idx, bucket_map) in self.buckets.iter().enumerate() {
229 let band_hash = self.hash_band(signature, band_idx, rows_per_band);
230
231 if let Some(ids) = bucket_map.get(&band_hash) {
232 for id in ids {
233 candidates.insert(id.clone());
234 }
235 }
236 }
237
238 candidates.into_iter().collect()
239 }
240
241 pub fn find_candidates_for_id(&self, id: &CanonicalId) -> Vec<CanonicalId> {
245 self.signatures.get(id).map_or_else(Vec::new, |signature| {
246 self.find_candidates_by_signature(signature)
247 })
248 }
249
250 #[must_use]
252 pub fn get_signature(&self, id: &CanonicalId) -> Option<&MinHashSignature> {
253 self.signatures.get(id)
254 }
255
256 #[must_use]
258 pub fn estimate_similarity(&self, id_a: &CanonicalId, id_b: &CanonicalId) -> Option<f64> {
259 let sig_a = self.signatures.get(id_a)?;
260 let sig_b = self.signatures.get(id_b)?;
261 Some(sig_a.estimated_similarity(sig_b))
262 }
263
264 pub fn stats(&self) -> LshIndexStats {
266 let total_components = self.signatures.len();
267 let total_buckets: usize = self
268 .buckets
269 .iter()
270 .map(std::collections::HashMap::len)
271 .sum();
272 let max_bucket_size = self
273 .buckets
274 .iter()
275 .flat_map(|b| b.values())
276 .map(std::vec::Vec::len)
277 .max()
278 .unwrap_or(0);
279 let avg_bucket_size = if total_buckets > 0 {
280 self.buckets
281 .iter()
282 .flat_map(|b| b.values())
283 .map(std::vec::Vec::len)
284 .sum::<usize>() as f64
285 / total_buckets as f64
286 } else {
287 0.0
288 };
289
290 LshIndexStats {
291 total_components,
292 num_bands: self.config.num_bands,
293 num_hashes: self.config.num_hashes,
294 total_buckets,
295 max_bucket_size,
296 avg_bucket_size,
297 }
298 }
299
300 fn compute_shingles(&self, component: &Component) -> HashSet<u64> {
306 let ecosystem = component
308 .ecosystem
309 .as_ref()
310 .map(std::string::ToString::to_string);
311 let ecosystem_str = ecosystem.as_deref();
312
313 let normalized = ComponentIndex::normalize_name(&component.name, ecosystem_str);
315 let chars: Vec<char> = normalized.chars().collect();
316
317 let estimated_shingles = chars.len().saturating_sub(self.config.shingle_size) + 3;
319 let mut shingles = HashSet::with_capacity(estimated_shingles);
320
321 if chars.len() < self.config.shingle_size {
323 let mut hasher = std::collections::hash_map::DefaultHasher::new();
325 normalized.hash(&mut hasher);
326 shingles.insert(hasher.finish());
327 } else {
328 for window in chars.windows(self.config.shingle_size) {
330 let mut hasher = std::collections::hash_map::DefaultHasher::new();
331 window.hash(&mut hasher);
332 shingles.insert(hasher.finish());
333 }
334 }
335
336 if self.config.include_ecosystem_token
338 && let Some(ref eco) = ecosystem
339 {
340 let mut hasher = std::collections::hash_map::DefaultHasher::new();
341 "__eco:".hash(&mut hasher);
342 eco.to_lowercase().hash(&mut hasher);
343 shingles.insert(hasher.finish());
344 }
345
346 if self.config.include_group_token
348 && let Some(ref group) = component.group
349 {
350 let mut hasher = std::collections::hash_map::DefaultHasher::new();
351 "__grp:".hash(&mut hasher);
352 group.to_lowercase().hash(&mut hasher);
353 shingles.insert(hasher.finish());
354 }
355
356 shingles
357 }
358
359 fn compute_minhash(&self, shingles: &HashSet<u64>) -> MinHashSignature {
361 let mut min_hashes = vec![u64::MAX; self.config.num_hashes];
362
363 for &shingle in shingles {
364 for (i, &(a, b)) in self.hash_coeffs.iter().enumerate() {
365 let hash = a.wrapping_mul(shingle).wrapping_add(b) % self.prime;
367 if hash < min_hashes[i] {
368 min_hashes[i] = hash;
369 }
370 }
371 }
372
373 MinHashSignature { values: min_hashes }
374 }
375
376 fn insert_into_buckets(&mut self, id: &CanonicalId, signature: &MinHashSignature) {
378 let rows_per_band = self.config.rows_per_band();
379
380 let band_hashes: Vec<u64> = (0..self.config.num_bands)
382 .map(|band_idx| self.hash_band(signature, band_idx, rows_per_band))
383 .collect();
384
385 for (band_idx, bucket_map) in self.buckets.iter_mut().enumerate() {
386 bucket_map
387 .entry(band_hashes[band_idx])
388 .or_default()
389 .push(id.clone());
390 }
391 }
392
393 fn hash_band(
395 &self,
396 signature: &MinHashSignature,
397 band_idx: usize,
398 rows_per_band: usize,
399 ) -> u64 {
400 let start = band_idx * rows_per_band;
401 let end = (start + rows_per_band).min(signature.values.len());
402
403 let mut hasher = std::collections::hash_map::DefaultHasher::new();
404 for &value in &signature.values[start..end] {
405 value.hash(&mut hasher);
406 }
407 hasher.finish()
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct LshIndexStats {
414 pub total_components: usize,
416 pub num_bands: usize,
418 pub num_hashes: usize,
420 pub total_buckets: usize,
422 pub max_bucket_size: usize,
424 pub avg_bucket_size: f64,
426}
427
428impl std::fmt::Display for LshIndexStats {
429 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 write!(
431 f,
432 "LSH Index: {} components, {} bands × {} hashes, {} buckets (max: {}, avg: {:.1})",
433 self.total_components,
434 self.num_bands,
435 self.num_hashes / self.num_bands,
436 self.total_buckets,
437 self.max_bucket_size,
438 self.avg_bucket_size
439 )
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use crate::model::DocumentMetadata;
447
448 fn make_component(name: &str) -> Component {
449 Component::new(name.to_string(), format!("id-{}", name))
450 }
451
452 #[test]
453 fn test_lsh_config_for_threshold() {
454 let config = LshConfig::for_threshold(0.8);
455 assert_eq!(config.num_hashes, 100);
456 assert!(config.num_bands > 0);
457 assert_eq!(config.num_hashes, config.num_bands * config.rows_per_band());
458 }
459
460 #[test]
461 fn test_minhash_signature_similarity() {
462 let sig_a = MinHashSignature {
463 values: vec![1, 2, 3, 4, 5],
464 };
465 let sig_b = MinHashSignature {
466 values: vec![1, 2, 3, 4, 5],
467 };
468 assert_eq!(sig_a.estimated_similarity(&sig_b), 1.0);
469
470 let sig_c = MinHashSignature {
471 values: vec![1, 2, 3, 6, 7],
472 };
473 assert!((sig_a.estimated_similarity(&sig_c) - 0.6).abs() < 0.01);
474 }
475
476 #[test]
477 fn test_lsh_index_build() {
478 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
479 sbom.add_component(make_component("lodash"));
480 sbom.add_component(make_component("lodash-es"));
481 sbom.add_component(make_component("underscore"));
482 sbom.add_component(make_component("react"));
483
484 let index = LshIndex::build(&sbom, LshConfig::default_balanced());
485 let stats = index.stats();
486
487 assert_eq!(stats.total_components, 4);
488 assert!(stats.total_buckets > 0);
489 }
490
491 #[test]
492 fn test_lsh_finds_similar_names() {
493 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
494 sbom.add_component(make_component("lodash"));
495 sbom.add_component(make_component("lodash-es"));
496 sbom.add_component(make_component("lodash-fp"));
497 sbom.add_component(make_component("react"));
498 sbom.add_component(make_component("angular"));
499
500 let index = LshIndex::build(&sbom, LshConfig::for_threshold(0.5));
501
502 let query = make_component("lodash");
504 let candidates = index.find_candidates(&query);
505
506 assert!(
509 !candidates.is_empty(),
510 "Should find at least some candidates"
511 );
512 }
513
514 #[test]
515 fn test_lsh_signature_estimation() {
516 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
517
518 let comp1 = make_component("lodash");
519 let comp2 = make_component("lodash-es");
520 let comp3 = make_component("completely-different-name");
521
522 let id1 = comp1.canonical_id.clone();
523 let id2 = comp2.canonical_id.clone();
524 let id3 = comp3.canonical_id.clone();
525
526 sbom.add_component(comp1);
527 sbom.add_component(comp2);
528 sbom.add_component(comp3);
529
530 let index = LshIndex::build(&sbom, LshConfig::default_balanced());
531
532 let sim_12 = index.estimate_similarity(&id1, &id2).unwrap();
534 let sim_13 = index.estimate_similarity(&id1, &id3).unwrap();
535
536 assert!(
537 sim_12 > sim_13,
538 "lodash vs lodash-es ({:.2}) should be more similar than lodash vs completely-different ({:.2})",
539 sim_12,
540 sim_13
541 );
542 }
543
544 #[test]
545 fn test_lsh_index_stats() {
546 let config = LshConfig::for_threshold(0.8);
547 let index = LshIndex::new(config);
548
549 let stats = index.stats();
550 assert_eq!(stats.total_components, 0);
551 assert_eq!(stats.num_bands, 25);
552 assert_eq!(stats.num_hashes, 100);
553 }
554}