1use super::index::ComponentIndex;
21use crate::model::{CanonicalId, Component, NormalizedSbom};
22use std::collections::{HashMap, HashSet};
23use std::hash::{Hash, Hasher};
24
25#[derive(Debug, Clone)]
27pub struct LshConfig {
28 pub num_hashes: usize,
30 pub num_bands: usize,
32 pub shingle_size: usize,
34 pub target_threshold: f64,
36 pub include_ecosystem_token: bool,
38 pub include_group_token: bool,
40}
41
42impl LshConfig {
43 pub fn for_threshold(threshold: f64) -> Self {
48 let (num_bands, rows_per_band) = if threshold >= 0.9 {
57 (50, 2) } else if threshold >= 0.8 {
59 (25, 4) } else if threshold >= 0.7 {
61 (20, 5) } else if threshold >= 0.5 {
63 (10, 10) } else {
65 (5, 20) };
67
68 Self {
69 num_hashes: num_bands * rows_per_band,
70 num_bands,
71 shingle_size: 3, target_threshold: threshold,
73 include_ecosystem_token: true, include_group_token: false, }
76 }
77
78 pub fn default_balanced() -> Self {
80 Self::for_threshold(0.8)
81 }
82
83 pub fn strict() -> Self {
85 Self::for_threshold(0.9)
86 }
87
88 pub fn permissive() -> Self {
90 Self::for_threshold(0.5)
91 }
92
93 pub fn rows_per_band(&self) -> usize {
95 self.num_hashes / self.num_bands
96 }
97}
98
99impl Default for LshConfig {
100 fn default() -> Self {
101 Self::default_balanced()
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct MinHashSignature {
108 pub values: Vec<u64>,
110}
111
112impl MinHashSignature {
113 pub fn estimated_similarity(&self, other: &MinHashSignature) -> f64 {
115 if self.values.len() != other.values.len() {
116 return 0.0;
117 }
118
119 let matching = self
120 .values
121 .iter()
122 .zip(other.values.iter())
123 .filter(|(a, b)| a == b)
124 .count();
125
126 matching as f64 / self.values.len() as f64
127 }
128}
129
130pub struct LshIndex {
132 config: LshConfig,
134 signatures: HashMap<CanonicalId, MinHashSignature>,
136 buckets: Vec<HashMap<u64, Vec<CanonicalId>>>,
138 hash_coeffs: Vec<(u64, u64)>,
140 prime: u64,
142}
143
144impl LshIndex {
145 pub fn new(config: LshConfig) -> Self {
147 use std::collections::hash_map::RandomState;
148 use std::hash::BuildHasher;
149
150 let mut hash_coeffs = Vec::with_capacity(config.num_hashes);
152 let random_state = RandomState::new();
153
154 for i in 0..config.num_hashes {
155 let a = random_state.hash_one(i as u64 * 31337) | 1; let b = random_state.hash_one(i as u64 * 7919 + 12345);
158
159 hash_coeffs.push((a, b));
160 }
161
162 let buckets = (0..config.num_bands)
164 .map(|_| HashMap::with_capacity(64))
165 .collect();
166
167 Self {
168 config,
169 signatures: HashMap::with_capacity(256),
170 buckets,
171 hash_coeffs,
172 prime: 0xFFFFFFFFFFFFFFC5, }
174 }
175
176 pub fn build(sbom: &NormalizedSbom, config: LshConfig) -> Self {
178 let mut index = Self::new(config);
179
180 for (id, comp) in &sbom.components {
181 index.insert(id.clone(), comp);
182 }
183
184 index
185 }
186
187 pub fn insert(&mut self, id: CanonicalId, component: &Component) {
189 let shingles = self.compute_shingles(component);
191
192 let signature = self.compute_minhash(&shingles);
194
195 self.insert_into_buckets(&id, &signature);
197
198 self.signatures.insert(id, signature);
200 }
201
202 pub fn find_candidates(&self, component: &Component) -> Vec<CanonicalId> {
207 let shingles = self.compute_shingles(component);
208 let signature = self.compute_minhash(&shingles);
209
210 self.find_candidates_by_signature(&signature)
211 }
212
213 pub fn find_candidates_by_signature(&self, signature: &MinHashSignature) -> Vec<CanonicalId> {
215 let mut candidates = HashSet::new();
216 let rows_per_band = self.config.rows_per_band();
217
218 for (band_idx, bucket_map) in self.buckets.iter().enumerate() {
219 let band_hash = self.hash_band(signature, band_idx, rows_per_band);
220
221 if let Some(ids) = bucket_map.get(&band_hash) {
222 for id in ids {
223 candidates.insert(id.clone());
224 }
225 }
226 }
227
228 candidates.into_iter().collect()
229 }
230
231 pub fn find_candidates_for_id(&self, id: &CanonicalId) -> Vec<CanonicalId> {
235 if let Some(signature) = self.signatures.get(id) {
236 self.find_candidates_by_signature(signature)
237 } else {
238 Vec::new()
239 }
240 }
241
242 pub fn get_signature(&self, id: &CanonicalId) -> Option<&MinHashSignature> {
244 self.signatures.get(id)
245 }
246
247 pub fn estimate_similarity(&self, id_a: &CanonicalId, id_b: &CanonicalId) -> Option<f64> {
249 let sig_a = self.signatures.get(id_a)?;
250 let sig_b = self.signatures.get(id_b)?;
251 Some(sig_a.estimated_similarity(sig_b))
252 }
253
254 pub fn stats(&self) -> LshIndexStats {
256 let total_components = self.signatures.len();
257 let total_buckets: usize = self.buckets.iter().map(|b| b.len()).sum();
258 let max_bucket_size = self
259 .buckets
260 .iter()
261 .flat_map(|b| b.values())
262 .map(|v| v.len())
263 .max()
264 .unwrap_or(0);
265 let avg_bucket_size = if total_buckets > 0 {
266 self.buckets
267 .iter()
268 .flat_map(|b| b.values())
269 .map(|v| v.len())
270 .sum::<usize>() as f64
271 / total_buckets as f64
272 } else {
273 0.0
274 };
275
276 LshIndexStats {
277 total_components,
278 num_bands: self.config.num_bands,
279 num_hashes: self.config.num_hashes,
280 total_buckets,
281 max_bucket_size,
282 avg_bucket_size,
283 }
284 }
285
286 fn compute_shingles(&self, component: &Component) -> HashSet<u64> {
292 let ecosystem = component.ecosystem.as_ref().map(|e| e.to_string());
294 let ecosystem_str = ecosystem.as_deref();
295
296 let normalized = ComponentIndex::normalize_name(&component.name, ecosystem_str);
298 let chars: Vec<char> = normalized.chars().collect();
299
300 let estimated_shingles = chars.len().saturating_sub(self.config.shingle_size) + 3;
302 let mut shingles = HashSet::with_capacity(estimated_shingles);
303
304 if chars.len() < self.config.shingle_size {
306 let mut hasher = std::collections::hash_map::DefaultHasher::new();
308 normalized.hash(&mut hasher);
309 shingles.insert(hasher.finish());
310 } else {
311 for window in chars.windows(self.config.shingle_size) {
313 let mut hasher = std::collections::hash_map::DefaultHasher::new();
314 window.hash(&mut hasher);
315 shingles.insert(hasher.finish());
316 }
317 }
318
319 if self.config.include_ecosystem_token {
321 if let Some(ref eco) = ecosystem {
322 let mut hasher = std::collections::hash_map::DefaultHasher::new();
323 "__eco:".hash(&mut hasher);
324 eco.to_lowercase().hash(&mut hasher);
325 shingles.insert(hasher.finish());
326 }
327 }
328
329 if self.config.include_group_token {
331 if let Some(ref group) = component.group {
332 let mut hasher = std::collections::hash_map::DefaultHasher::new();
333 "__grp:".hash(&mut hasher);
334 group.to_lowercase().hash(&mut hasher);
335 shingles.insert(hasher.finish());
336 }
337 }
338
339 shingles
340 }
341
342 fn compute_minhash(&self, shingles: &HashSet<u64>) -> MinHashSignature {
344 let mut min_hashes = vec![u64::MAX; self.config.num_hashes];
345
346 for &shingle in shingles {
347 for (i, &(a, b)) in self.hash_coeffs.iter().enumerate() {
348 let hash = a.wrapping_mul(shingle).wrapping_add(b) % self.prime;
350 if hash < min_hashes[i] {
351 min_hashes[i] = hash;
352 }
353 }
354 }
355
356 MinHashSignature { values: min_hashes }
357 }
358
359 fn insert_into_buckets(&mut self, id: &CanonicalId, signature: &MinHashSignature) {
361 let rows_per_band = self.config.rows_per_band();
362
363 let band_hashes: Vec<u64> = (0..self.config.num_bands)
365 .map(|band_idx| self.hash_band(signature, band_idx, rows_per_band))
366 .collect();
367
368 for (band_idx, bucket_map) in self.buckets.iter_mut().enumerate() {
369 bucket_map
370 .entry(band_hashes[band_idx])
371 .or_default()
372 .push(id.clone());
373 }
374 }
375
376 fn hash_band(
378 &self,
379 signature: &MinHashSignature,
380 band_idx: usize,
381 rows_per_band: usize,
382 ) -> u64 {
383 let start = band_idx * rows_per_band;
384 let end = (start + rows_per_band).min(signature.values.len());
385
386 let mut hasher = std::collections::hash_map::DefaultHasher::new();
387 for &value in &signature.values[start..end] {
388 value.hash(&mut hasher);
389 }
390 hasher.finish()
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct LshIndexStats {
397 pub total_components: usize,
399 pub num_bands: usize,
401 pub num_hashes: usize,
403 pub total_buckets: usize,
405 pub max_bucket_size: usize,
407 pub avg_bucket_size: f64,
409}
410
411impl std::fmt::Display for LshIndexStats {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 write!(
414 f,
415 "LSH Index: {} components, {} bands × {} hashes, {} buckets (max: {}, avg: {:.1})",
416 self.total_components,
417 self.num_bands,
418 self.num_hashes / self.num_bands,
419 self.total_buckets,
420 self.max_bucket_size,
421 self.avg_bucket_size
422 )
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::model::DocumentMetadata;
430
431 fn make_component(name: &str) -> Component {
432 Component::new(name.to_string(), format!("id-{}", name))
433 }
434
435 #[test]
436 fn test_lsh_config_for_threshold() {
437 let config = LshConfig::for_threshold(0.8);
438 assert_eq!(config.num_hashes, 100);
439 assert!(config.num_bands > 0);
440 assert_eq!(config.num_hashes, config.num_bands * config.rows_per_band());
441 }
442
443 #[test]
444 fn test_minhash_signature_similarity() {
445 let sig_a = MinHashSignature {
446 values: vec![1, 2, 3, 4, 5],
447 };
448 let sig_b = MinHashSignature {
449 values: vec![1, 2, 3, 4, 5],
450 };
451 assert_eq!(sig_a.estimated_similarity(&sig_b), 1.0);
452
453 let sig_c = MinHashSignature {
454 values: vec![1, 2, 3, 6, 7],
455 };
456 assert!((sig_a.estimated_similarity(&sig_c) - 0.6).abs() < 0.01);
457 }
458
459 #[test]
460 fn test_lsh_index_build() {
461 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
462 sbom.add_component(make_component("lodash"));
463 sbom.add_component(make_component("lodash-es"));
464 sbom.add_component(make_component("underscore"));
465 sbom.add_component(make_component("react"));
466
467 let index = LshIndex::build(&sbom, LshConfig::default_balanced());
468 let stats = index.stats();
469
470 assert_eq!(stats.total_components, 4);
471 assert!(stats.total_buckets > 0);
472 }
473
474 #[test]
475 fn test_lsh_finds_similar_names() {
476 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
477 sbom.add_component(make_component("lodash"));
478 sbom.add_component(make_component("lodash-es"));
479 sbom.add_component(make_component("lodash-fp"));
480 sbom.add_component(make_component("react"));
481 sbom.add_component(make_component("angular"));
482
483 let index = LshIndex::build(&sbom, LshConfig::for_threshold(0.5));
484
485 let query = make_component("lodash");
487 let candidates = index.find_candidates(&query);
488
489 assert!(
492 !candidates.is_empty(),
493 "Should find at least some candidates"
494 );
495 }
496
497 #[test]
498 fn test_lsh_signature_estimation() {
499 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
500
501 let comp1 = make_component("lodash");
502 let comp2 = make_component("lodash-es");
503 let comp3 = make_component("completely-different-name");
504
505 let id1 = comp1.canonical_id.clone();
506 let id2 = comp2.canonical_id.clone();
507 let id3 = comp3.canonical_id.clone();
508
509 sbom.add_component(comp1);
510 sbom.add_component(comp2);
511 sbom.add_component(comp3);
512
513 let index = LshIndex::build(&sbom, LshConfig::default_balanced());
514
515 let sim_12 = index.estimate_similarity(&id1, &id2).unwrap();
517 let sim_13 = index.estimate_similarity(&id1, &id3).unwrap();
518
519 assert!(
520 sim_12 > sim_13,
521 "lodash vs lodash-es ({:.2}) should be more similar than lodash vs completely-different ({:.2})",
522 sim_12, sim_13
523 );
524 }
525
526 #[test]
527 fn test_lsh_index_stats() {
528 let config = LshConfig::for_threshold(0.8);
529 let index = LshIndex::new(config);
530
531 let stats = index.stats();
532 assert_eq!(stats.total_components, 0);
533 assert_eq!(stats.num_bands, 25);
534 assert_eq!(stats.num_hashes, 100);
535 }
536}