all_pairs_hamming/
chunked_join.rs1use hashbrown::HashSet;
3
4use crate::errors::{AllPairsHammingError, Result};
5use crate::multi_sort::MultiSort;
6use crate::sketch::Sketch;
7
8pub struct ChunkedJoiner<S> {
38 chunks: Vec<Vec<S>>,
39 shows_progress: bool,
40}
41
42impl<S> ChunkedJoiner<S>
43where
44 S: Sketch,
45{
46 pub fn new(num_chunks: usize) -> Self {
49 Self {
50 chunks: vec![vec![]; num_chunks],
51 shows_progress: false,
52 }
53 }
54
55 pub const fn shows_progress(mut self, yes: bool) -> Self {
57 self.shows_progress = yes;
58 self
59 }
60
61 pub fn add<I>(&mut self, sketch: I) -> Result<()>
65 where
66 I: IntoIterator<Item = S>,
67 {
68 let num_chunks = self.num_chunks();
69 let mut iter = sketch.into_iter();
70 for chunk in self.chunks.iter_mut() {
71 chunk.push(iter.next().ok_or_else(|| {
72 let msg = format!("The input sketch must include {num_chunks} chunks at least.");
73 AllPairsHammingError::input(msg)
74 })?);
75 }
76 Ok(())
77 }
78
79 pub fn similar_pairs(&self, radius: f64) -> Vec<(usize, usize, f64)> {
82 let dimension = S::dim() * self.num_chunks();
83 let hamradius = (dimension as f64 * radius).ceil() as usize;
84 if self.shows_progress {
85 eprintln!(
86 "[ChunkedJoiner::similar_pairs] #dimensions={dimension}, hamradius={hamradius}"
87 );
88 }
89
90 let mut candidates = HashSet::new();
92 for (j, chunk) in self.chunks.iter().enumerate() {
93 if j + hamradius + 1 < self.chunks.len() {
96 continue;
97 }
98 let r = (j + hamradius + 1 - self.chunks.len()) / self.chunks.len();
99 MultiSort::new().similar_pairs(chunk, r, &mut candidates);
100
101 if self.shows_progress {
102 eprintln!(
103 "[ChunkedJoiner::similar_pairs] Processed {}/{}...",
104 j + 1,
105 self.chunks.len()
106 );
107 eprintln!(
108 "[ChunkedJoiner::similar_pairs] #candidates={}",
109 candidates.len()
110 );
111 }
112 }
113 if self.shows_progress {
114 eprintln!("[ChunkedJoiner::similar_pairs] Done");
115 }
116
117 let mut candidates: Vec<_> = candidates.into_iter().collect();
118 candidates.sort_unstable();
119
120 let bound = (dimension as f64 * radius) as usize;
121 let mut matched = vec![];
122
123 for (i, j) in candidates {
124 if let Some(dist) = self.hamming_distance(i, j, bound) {
125 let dist = dist as f64 / dimension as f64;
126 if dist <= radius {
127 matched.push((i, j, dist));
128 }
129 }
130 }
131 if self.shows_progress {
132 eprintln!("[ChunkedJoiner::similar_pairs] #matched={}", matched.len());
133 }
134 matched
135 }
136
137 pub fn num_chunks(&self) -> usize {
139 self.chunks.len()
140 }
141
142 pub fn num_sketches(&self) -> usize {
144 self.chunks.get(0).map(|v| v.len()).unwrap_or(0)
145 }
146
147 pub fn memory_in_bytes(&self) -> usize {
149 self.num_chunks() * self.num_sketches() * std::mem::size_of::<S>()
150 }
151
152 fn hamming_distance(&self, i: usize, j: usize, bound: usize) -> Option<usize> {
153 let mut dist = 0;
154 for chunk in &self.chunks {
155 dist += chunk[i].hamdist(chunk[j]);
156 if bound < dist {
157 return None;
158 }
159 }
160 Some(dist)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 fn example_sketches() -> Vec<u16> {
169 vec![
170 0b_1110_0011_1111_1011, 0b_0001_0111_0111_1101, 0b_1100_1101_1000_1100, 0b_1100_1101_0001_0100, 0b_1010_1110_0010_1010, 0b_0111_1001_0011_1111, 0b_1110_0011_0001_0000, 0b_1000_0111_1001_0101, 0b_1110_1101_1000_1101, 0b_0111_1001_0011_1001, ]
181 }
182
183 fn naive_search(sketches: &[u16], radius: f64) -> Vec<(usize, usize, f64)> {
184 let mut results = vec![];
185 for i in 0..sketches.len() {
186 let x = sketches[i];
187 for j in i + 1..sketches.len() {
188 let y = sketches[j];
189 let dist = x.hamdist(y);
190 let dist = dist as f64 / 16.;
191 if dist <= radius {
192 results.push((i, j, dist));
193 }
194 }
195 }
196 results
197 }
198
199 fn test_similar_pairs(radius: f64) {
200 let sketches = example_sketches();
201 let expected = naive_search(&sketches, radius);
202
203 let mut joiner = ChunkedJoiner::new(2);
204 for s in sketches {
205 joiner.add([(s & 0xFF) as u8, (s >> 8) as u8]).unwrap();
206 }
207 let mut results = joiner.similar_pairs(radius);
208 results.sort_by_key(|&(i, j, _)| (i, j));
209 assert_eq!(results, expected);
210 }
211
212 #[test]
213 fn test_similar_pairs_for_all() {
214 for radius in 0..=10 {
215 test_similar_pairs(radius as f64 / 10.);
216 }
217 }
218
219 #[test]
220 fn test_short_sketch() {
221 let mut joiner = ChunkedJoiner::new(2);
222 let result = joiner.add([0u64]);
223 assert!(result.is_err());
224 }
225}