Skip to main content

zsync_rs/
matcher.rs

1use crate::checksum::calc_md4;
2use crate::control::{ControlFile, HashLengths};
3use crate::rsum::{Rsum, calc_rsum_block};
4
5#[derive(Debug, thiserror::Error)]
6pub enum MatchError {
7    #[error("IO error: {0}")]
8    Io(#[from] std::io::Error),
9}
10
11const HASH_EMPTY: u32 = u32::MAX;
12const BITHASH_BITS: u32 = 3;
13
14#[derive(Debug, Clone, Copy)]
15struct TargetBlock {
16    rsum: Rsum,
17    checksum: [u8; 16],
18}
19
20/// Read-only scan state shared across threads.
21struct ScanState<'a> {
22    targets: &'a [TargetBlock],
23    hash_table: &'a [u32],
24    hash_next: &'a [u32],
25    bithash: &'a [u8],
26    blocksize: usize,
27    blockshift: u8,
28    seq_matches: usize,
29    checksum_bytes: usize,
30    rsum_a_mask: u16,
31    hash_func_shift: u32,
32    hash_mask: u32,
33    bithash_mask: u32,
34}
35
36impl ScanState<'_> {
37    #[inline(always)]
38    fn calc_hash_rolling(&self, r0: &Rsum, r1: &Rsum) -> u32 {
39        let mut h = r0.b as u32;
40        if self.seq_matches > 1 {
41            h ^= (r1.b as u32) << self.hash_func_shift;
42        } else {
43            h ^= ((r0.a & self.rsum_a_mask) as u32) << self.hash_func_shift;
44        }
45        h
46    }
47
48    #[inline(always)]
49    fn rsum_match(&self, target: &Rsum, rolling: &Rsum) -> bool {
50        target.a == (rolling.a & self.rsum_a_mask) && target.b == rolling.b
51    }
52
53    /// Scan a chunk of data for matching blocks. Pure read-only, no mutations.
54    /// `base_offset` is the absolute byte offset of `data[0]` within the source file.
55    /// Returns vec of (target_block_id, absolute_source_offset).
56    fn scan_chunk(&self, data: &[u8], base_offset: usize) -> Vec<(usize, usize)> {
57        let blocksize = self.blocksize;
58        let blockshift = self.blockshift;
59        let seq_matches = self.seq_matches;
60        let checksum_bytes = self.checksum_bytes;
61        let context = blocksize * seq_matches;
62        let mut matched_blocks = Vec::new();
63
64        if data.len() < context {
65            return matched_blocks;
66        }
67
68        let x_limit = data.len() - context;
69        let mut x = 0usize;
70        let mut next_match_id: Option<usize> = None;
71
72        let mut r0 = calc_rsum_block(&data[0..blocksize]);
73        let mut r1 = if seq_matches > 1 {
74            calc_rsum_block(&data[blocksize..blocksize * 2])
75        } else {
76            Rsum { a: 0, b: 0 }
77        };
78
79        while x < x_limit {
80            let mut blocks_matched = 0usize;
81
82            if let Some(hint_id) = next_match_id.take()
83                && seq_matches > 1
84                && hint_id < self.targets.len()
85            {
86                let target = &self.targets[hint_id];
87                if self.rsum_match(&target.rsum, &r0) {
88                    let checksum = calc_md4(&data[x..x + blocksize]);
89                    if checksum[..checksum_bytes] == target.checksum[..checksum_bytes] {
90                        matched_blocks.push((hint_id, base_offset + x));
91                        blocks_matched = 1;
92                        if hint_id + 1 < self.targets.len() {
93                            next_match_id = Some(hint_id + 1);
94                        }
95                    }
96                }
97            }
98
99            while blocks_matched == 0 && x < x_limit {
100                let hash = self.calc_hash_rolling(&r0, &r1);
101
102                let bh = (hash & self.bithash_mask) as usize;
103                if self.bithash[bh >> 3] & (1 << (bh & 7)) != 0 {
104                    let mut block_idx = self.hash_table[(hash & self.hash_mask) as usize];
105
106                    while block_idx != HASH_EMPTY {
107                        let block_id = block_idx as usize;
108                        block_idx = self.hash_next[block_id];
109
110                        let target = &self.targets[block_id];
111                        if !self.rsum_match(&target.rsum, &r0) {
112                            continue;
113                        }
114
115                        if seq_matches > 1 && block_id + 1 < self.targets.len() {
116                            let next_target = &self.targets[block_id + 1];
117                            if !self.rsum_match(&next_target.rsum, &r1) {
118                                continue;
119                            }
120
121                            let checksum = calc_md4(&data[x..x + blocksize]);
122                            if checksum[..checksum_bytes] != target.checksum[..checksum_bytes] {
123                                continue;
124                            }
125
126                            let next_checksum = calc_md4(&data[x + blocksize..x + blocksize * 2]);
127                            if next_checksum[..checksum_bytes]
128                                == next_target.checksum[..checksum_bytes]
129                            {
130                                matched_blocks.push((block_id, base_offset + x));
131                                matched_blocks.push((block_id + 1, base_offset + x + blocksize));
132                                blocks_matched = seq_matches;
133
134                                if block_id + 2 < self.targets.len() {
135                                    next_match_id = Some(block_id + 2);
136                                }
137                                break;
138                            }
139                        } else {
140                            let checksum = calc_md4(&data[x..x + blocksize]);
141                            if checksum[..checksum_bytes] == target.checksum[..checksum_bytes] {
142                                matched_blocks.push((block_id, base_offset + x));
143                                blocks_matched = 1;
144                                break;
145                            }
146                        }
147                    }
148                }
149
150                if blocks_matched == 0 {
151                    let oc = data[x];
152                    let nc = data[x + blocksize];
153                    r0.a = r0.a.wrapping_add(u16::from(nc)).wrapping_sub(u16::from(oc));
154                    r0.b =
155                        r0.b.wrapping_add(r0.a)
156                            .wrapping_sub(u16::from(oc) << blockshift);
157
158                    if seq_matches > 1 {
159                        let nc2 = data[x + blocksize * 2];
160                        r1.a =
161                            r1.a.wrapping_add(u16::from(nc2))
162                                .wrapping_sub(u16::from(nc));
163                        r1.b =
164                            r1.b.wrapping_add(r1.a)
165                                .wrapping_sub(u16::from(nc) << blockshift);
166                    }
167
168                    x += 1;
169                }
170            }
171
172            if blocks_matched > 0 {
173                x += blocksize * blocks_matched;
174
175                if x >= x_limit {
176                    // Can't calculate rsums for remaining data
177                } else {
178                    if seq_matches > 1 && blocks_matched == 1 {
179                        r0 = r1;
180                    } else {
181                        r0 = calc_rsum_block(&data[x..x + blocksize]);
182                    }
183                    if seq_matches > 1 {
184                        r1 = calc_rsum_block(&data[x + blocksize..x + blocksize * 2]);
185                    }
186                }
187            }
188        }
189
190        matched_blocks
191    }
192}
193
194pub struct BlockMatcher {
195    blocksize: usize,
196    blockshift: u8,
197    hash_lengths: HashLengths,
198    rsum_a_mask: u16,
199    hash_func_shift: u32,
200    targets: Vec<TargetBlock>,
201    known_blocks: Vec<bool>,
202    hash_table: Vec<u32>,
203    hash_next: Vec<u32>,
204    hash_mask: u32,
205    bithash: Vec<u8>,
206    bithash_mask: u32,
207}
208
209impl BlockMatcher {
210    pub fn new(control: &ControlFile) -> Self {
211        let num_blocks = control.block_checksums.len();
212        let seq_matches = control.hash_lengths.seq_matches as u32;
213        let rsum_bytes = control.hash_lengths.rsum_bytes as u32;
214
215        let rsum_a_mask: u16 = match rsum_bytes {
216            0..=2 => 0,
217            3 => 0x00ff,
218            _ => 0xffff,
219        };
220
221        let targets: Vec<TargetBlock> = control
222            .block_checksums
223            .iter()
224            .map(|bc| TargetBlock {
225                rsum: Rsum {
226                    a: bc.rsum.a & rsum_a_mask,
227                    b: bc.rsum.b,
228                },
229                checksum: bc.checksum,
230            })
231            .collect();
232
233        let rsum_bits = rsum_bytes * 8;
234        let avail_bits = if seq_matches > 1 {
235            rsum_bits.min(16) * 2
236        } else {
237            rsum_bits
238        };
239
240        let mut hash_bits = avail_bits;
241        while hash_bits > 5 && (1u32 << (hash_bits - 1)) > num_blocks as u32 {
242            hash_bits -= 1;
243        }
244        let hash_mask = (1u32 << hash_bits) - 1;
245
246        let bithash_bits_total = (hash_bits + BITHASH_BITS).min(avail_bits);
247        let bithash_mask = (1u32 << bithash_bits_total) - 1;
248
249        let hash_func_shift = if seq_matches > 1 && avail_bits < 24 {
250            bithash_bits_total.saturating_sub(avail_bits / 2)
251        } else {
252            bithash_bits_total.saturating_sub(avail_bits - 16)
253        };
254
255        let blockshift = control.blocksize.trailing_zeros() as u8;
256
257        let mut matcher = Self {
258            blocksize: control.blocksize,
259            blockshift,
260            hash_lengths: control.hash_lengths,
261            rsum_a_mask,
262            hash_func_shift,
263            targets,
264            known_blocks: vec![false; num_blocks],
265            hash_table: vec![HASH_EMPTY; (hash_mask + 1) as usize],
266            hash_next: vec![HASH_EMPTY; num_blocks],
267            hash_mask,
268            bithash: vec![0u8; ((bithash_mask + 1) >> 3) as usize + 1],
269            bithash_mask,
270        };
271
272        for id in (0..num_blocks).rev() {
273            let h = matcher.calc_hash(id);
274            let bucket = (h & hash_mask) as usize;
275            matcher.hash_next[id] = matcher.hash_table[bucket];
276            matcher.hash_table[bucket] = id as u32;
277            let bh = (h & bithash_mask) as usize;
278            matcher.bithash[bh >> 3] |= 1 << (bh & 7);
279        }
280
281        matcher
282    }
283
284    fn calc_hash(&self, block_id: usize) -> u32 {
285        let mut h = self.targets[block_id].rsum.b as u32;
286        if self.hash_lengths.seq_matches > 1 {
287            let next_b = if block_id + 1 < self.targets.len() {
288                self.targets[block_id + 1].rsum.b as u32
289            } else {
290                0
291            };
292            h ^= next_b << self.hash_func_shift;
293        } else {
294            h ^= (self.targets[block_id].rsum.a as u32) << self.hash_func_shift;
295        }
296        h
297    }
298
299    fn remove_block_from_hash(&mut self, id: usize) {
300        let h = self.calc_hash(id);
301        let bucket = (h & self.hash_mask) as usize;
302
303        let mut prev = HASH_EMPTY;
304        let mut curr = self.hash_table[bucket];
305
306        while curr != HASH_EMPTY {
307            if curr as usize == id {
308                if prev == HASH_EMPTY {
309                    self.hash_table[bucket] = self.hash_next[id];
310                } else {
311                    self.hash_next[prev as usize] = self.hash_next[id];
312                }
313                return;
314            }
315            prev = curr;
316            curr = self.hash_next[curr as usize];
317        }
318    }
319
320    fn scan_state(&self) -> ScanState<'_> {
321        ScanState {
322            targets: &self.targets,
323            hash_table: &self.hash_table,
324            hash_next: &self.hash_next,
325            bithash: &self.bithash,
326            blocksize: self.blocksize,
327            blockshift: self.blockshift,
328            seq_matches: self.hash_lengths.seq_matches as usize,
329            checksum_bytes: self.hash_lengths.checksum_bytes as usize,
330            rsum_a_mask: self.rsum_a_mask,
331            hash_func_shift: self.hash_func_shift,
332            hash_mask: self.hash_mask,
333            bithash_mask: self.bithash_mask,
334        }
335    }
336
337    pub fn submit_blocks(&mut self, data: &[u8], block_start: usize) -> Result<bool, MatchError> {
338        let blocksize = self.blocksize;
339        let checksum_bytes = self.hash_lengths.checksum_bytes as usize;
340        let num_blocks = data.len() / blocksize;
341
342        for i in 0..num_blocks {
343            let block_data = &data[i * blocksize..(i + 1) * blocksize];
344            let block_id = block_start + i;
345
346            if block_id >= self.targets.len() {
347                break;
348            }
349
350            let checksum = calc_md4(block_data);
351            if checksum[..checksum_bytes] == self.targets[block_id].checksum[..checksum_bytes] {
352                self.known_blocks[block_id] = true;
353            } else {
354                return Ok(false);
355            }
356        }
357
358        Ok(true)
359    }
360
361    pub fn submit_source_data(&mut self, data: &[u8], offset: u64) -> Vec<(usize, usize)> {
362        let context = self.blocksize * self.hash_lengths.seq_matches as usize;
363        if data.len() < context {
364            return Vec::new();
365        }
366
367        let num_threads = std::thread::available_parallelism()
368            .map(|n| n.get())
369            .unwrap_or(1);
370
371        let min_per_thread = 16 * 1024 * 1024; // 16 MB per thread minimum
372        let scannable = data.len() - context;
373
374        let candidates = if num_threads > 1 && scannable >= min_per_thread * 2 {
375            let state = self.scan_state();
376            let actual_threads = num_threads.min(scannable / min_per_thread);
377            let chunk_size = scannable / actual_threads;
378
379            std::thread::scope(|s| {
380                let handles: Vec<_> = (0..actual_threads)
381                    .map(|i| {
382                        let start = i * chunk_size;
383                        let end = if i == actual_threads - 1 {
384                            data.len()
385                        } else {
386                            (i + 1) * chunk_size + context
387                        };
388                        let chunk = &data[start..end];
389                        let state = &state;
390                        let base = offset as usize + start;
391                        s.spawn(move || state.scan_chunk(chunk, base))
392                    })
393                    .collect();
394
395                let mut all: Vec<(usize, usize)> = Vec::new();
396                for h in handles {
397                    all.extend(h.join().unwrap());
398                }
399                all
400            })
401        } else {
402            let state = self.scan_state();
403            state.scan_chunk(data, offset as usize)
404        };
405
406        // Deduplicate: first match per block_id wins
407        let mut seen = vec![false; self.targets.len()];
408        let mut matched_blocks = Vec::new();
409        for (block_id, offset) in candidates {
410            if !seen[block_id] {
411                seen[block_id] = true;
412                self.known_blocks[block_id] = true;
413                self.remove_block_from_hash(block_id);
414                matched_blocks.push((block_id, offset));
415            }
416        }
417
418        matched_blocks
419    }
420
421    pub fn needed_block_ranges(&self) -> Vec<(usize, usize)> {
422        let mut ranges = Vec::new();
423        let mut start: Option<usize> = None;
424
425        for (i, &known) in self.known_blocks.iter().enumerate() {
426            if !known && start.is_none() {
427                start = Some(i);
428            } else if known && start.is_some() {
429                ranges.push((start.unwrap(), i));
430                start = None;
431            }
432        }
433
434        if let Some(s) = start {
435            ranges.push((s, self.known_blocks.len()));
436        }
437
438        ranges
439    }
440
441    pub fn is_block_known(&self, block_id: usize) -> bool {
442        block_id < self.known_blocks.len() && self.known_blocks[block_id]
443    }
444
445    pub fn blocks_todo(&self) -> usize {
446        self.known_blocks.iter().filter(|&&k| !k).count()
447    }
448
449    pub fn is_complete(&self) -> bool {
450        self.known_blocks.iter().all(|&k| k)
451    }
452
453    pub fn total_blocks(&self) -> usize {
454        self.targets.len()
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use crate::control::{BlockChecksum, ControlFile, HashLengths};
462
463    fn make_control(data: &[u8], blocksize: usize) -> ControlFile {
464        let num_blocks = data.len().div_ceil(blocksize);
465        let mut block_checksums = Vec::with_capacity(num_blocks);
466
467        for i in 0..num_blocks {
468            let start = i * blocksize;
469            let end = std::cmp::min(start + blocksize, data.len());
470            let mut block = data[start..end].to_vec();
471            block.resize(blocksize, 0);
472
473            let rsum = calc_rsum_block(&block);
474            let checksum = calc_md4(&block);
475
476            block_checksums.push(BlockChecksum { rsum, checksum });
477        }
478
479        ControlFile {
480            version: "0.6.2".to_string(),
481            filename: Some("test.bin".to_string()),
482            mtime: None,
483            blocksize,
484            length: data.len() as u64,
485            hash_lengths: HashLengths {
486                seq_matches: 1,
487                rsum_bytes: 4,
488                checksum_bytes: 16,
489            },
490            urls: vec!["http://example.com/test.bin".to_string()],
491            sha1: None,
492            block_checksums,
493        }
494    }
495
496    #[test]
497    fn test_matcher_new() {
498        let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
499        let control = make_control(&data, 4);
500        let matcher = BlockMatcher::new(&control);
501        assert_eq!(matcher.blocks_todo(), 2);
502    }
503
504    #[test]
505    fn test_submit_source_data() {
506        let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
507        let control = make_control(&data, 4);
508        let mut matcher = BlockMatcher::new(&control);
509
510        // Pad with context bytes (blocksize * seq_matches) like submit_source_file does
511        let mut padded = data.clone();
512        padded.resize(data.len() + 4, 0);
513        let got = matcher.submit_source_data(&padded, 0);
514        assert_eq!(got.len(), 3);
515        assert!(matcher.is_complete());
516    }
517
518    #[test]
519    fn test_needed_block_ranges() {
520        let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
521        let control = make_control(&data, 4);
522        let matcher = BlockMatcher::new(&control);
523        let ranges = matcher.needed_block_ranges();
524        assert_eq!(ranges, vec![(0, 2)]);
525    }
526}