ragc_core/
lz_diff.rs

1// LZ Diff Encoding
2// Encodes a target sequence as differences from a reference sequence
3
4#![allow(clippy::same_item_push)]
5
6use ragc_common::{hash::MurMur64Hash, types::Contig};
7use std::collections::HashMap;
8
9/// Constants for LZ diff encoding
10const N_CODE: u8 = 4;
11const N_RUN_STARTER_CODE: u8 = 30;
12const MIN_NRUN_LEN: u32 = 4;
13const MAX_NO_TRIES: usize = 64;
14const HASHING_STEP: usize = 4; // USE_SPARSE_HT mode
15
16/// LZ Diff encoder/decoder (V2 implementation)
17pub struct LZDiff {
18    reference: Vec<u8>,
19    reference_len: usize,       // Original length before padding
20    ht: HashMap<u64, Vec<u32>>, // Hash table: kmer_hash -> list of positions
21    min_match_len: u32,
22    key_len: u32,
23    key_mask: u64,
24}
25
26impl LZDiff {
27    /// Create a new LZ diff encoder with the given minimum match length
28    pub fn new(min_match_len: u32) -> Self {
29        let key_len = min_match_len - (HASHING_STEP as u32) + 1;
30        let key_mask = if key_len >= 32 {
31            !0u64
32        } else {
33            (1u64 << (2 * key_len)) - 1
34        };
35
36        LZDiff {
37            reference: Vec::new(),
38            reference_len: 0,
39            ht: HashMap::new(),
40            min_match_len,
41            key_len,
42            key_mask,
43        }
44    }
45
46    /// Prepare the encoder with a reference sequence
47    pub fn prepare(&mut self, reference: &Contig) {
48        self.reference = reference.clone();
49        self.reference_len = reference.len(); // Store original length before padding
50                                              // Add padding for key_len
51        self.reference
52            .resize(self.reference.len() + self.key_len as usize, 31);
53
54        // Pre-allocate hash table capacity based on reference length
55        // to avoid repeated rehashing (major source of page faults!)
56        // Each entry in ht corresponds to one k-mer position (every HASHING_STEP bases)
57        let expected_entries = (self.reference.len() / HASHING_STEP) + 1;
58        if self.ht.capacity() < expected_entries {
59            self.ht = HashMap::with_capacity(expected_entries);
60        }
61
62        self.build_index();
63    }
64
65    /// Build hash table index for k-mers in reference
66    fn build_index(&mut self) {
67        self.ht.clear();
68        let ref_len = self.reference.len();
69
70        let mut i = 0;
71        while i + (self.key_len as usize) < ref_len {
72            if let Some(code) = self.get_code(&self.reference[i..]) {
73                let hash = MurMur64Hash::hash(code);
74                // Store i / HASHING_STEP (like C++ implementation)
75                // Pre-allocate capacity to avoid repeated 0→4→8→16 growths
76                self.ht
77                    .entry(hash)
78                    .or_insert_with(|| Vec::with_capacity(4))
79                    .push((i / HASHING_STEP) as u32);
80            }
81            i += HASHING_STEP;
82        }
83    }
84
85    /// Extract k-mer code from sequence
86    #[allow(clippy::needless_range_loop)]
87    fn get_code(&self, seq: &[u8]) -> Option<u64> {
88        let mut code = 0u64;
89        for i in 0..(self.key_len as usize) {
90            if seq[i] > 3 {
91                return None; // Invalid base (N or other)
92            }
93            code = (code << 2) | (seq[i] as u64);
94        }
95        Some(code)
96    }
97
98    /// Extract k-mer code using sliding window optimization
99    fn get_code_skip1(&self, prev_code: u64, seq: &[u8]) -> Option<u64> {
100        let last_base_idx = (self.key_len as usize) - 1;
101        if seq[last_base_idx] > 3 {
102            return None;
103        }
104        let code = ((prev_code << 2) & self.key_mask) | (seq[last_base_idx] as u64);
105        Some(code)
106    }
107
108    /// Check for N-run (at least 3 consecutive N bases)
109    fn get_nrun_len(&self, seq: &[u8], max_len: usize) -> u32 {
110        if seq.len() < 3 || seq[0] != N_CODE || seq[1] != N_CODE || seq[2] != N_CODE {
111            return 0;
112        }
113
114        let mut len = 3;
115        while len < max_len && seq[len] == N_CODE {
116            len += 1;
117        }
118        len as u32
119    }
120
121    /// Encode a literal base
122    fn encode_literal(&self, base: u8, encoded: &mut Vec<u8>) {
123        encoded.push(b'A' + base);
124    }
125
126    /// Encode an N-run
127    fn encode_nrun(&self, len: u32, encoded: &mut Vec<u8>) {
128        encoded.push(N_RUN_STARTER_CODE);
129        self.append_int(encoded, (len - MIN_NRUN_LEN) as i64);
130        encoded.push(N_CODE);
131    }
132
133    /// Encode a match
134    fn encode_match(&self, ref_pos: u32, len: Option<u32>, pred_pos: u32, encoded: &mut Vec<u8>) {
135        let dif_pos = (ref_pos as i32) - (pred_pos as i32);
136        self.append_int(encoded, dif_pos as i64);
137
138        if let Some(match_len) = len {
139            encoded.push(b',');
140            self.append_int(encoded, (match_len - self.min_match_len) as i64);
141        }
142
143        encoded.push(b'.');
144    }
145
146    /// Append integer as ASCII decimal
147    fn append_int(&self, text: &mut Vec<u8>, mut x: i64) {
148        if x == 0 {
149            text.push(b'0');
150            return;
151        }
152
153        if x < 0 {
154            text.push(b'-');
155            x = -x;
156        }
157
158        // Write digits directly to output (in reverse), then reverse just that portion
159        let start_pos = text.len();
160        while x > 0 {
161            text.push(b'0' + (x % 10) as u8);
162            x /= 10;
163        }
164
165        // Reverse just the digits we added
166        text[start_pos..].reverse();
167    }
168
169    /// Find best match in reference for the given position
170    fn find_best_match(
171        &self,
172        hash: u64,
173        target: &[u8],
174        text_pos: usize,
175        max_len: usize,
176        no_prev_literals: usize,
177    ) -> Option<(u32, u32, u32)> {
178        // Returns (ref_pos, len_bck, len_fwd)
179
180        let positions = self.ht.get(&hash)?;
181
182        let mut best_ref_pos = 0;
183        let mut best_len_bck = 0;
184        let mut best_len_fwd = 0;
185        let mut min_to_update = self.min_match_len as usize;
186
187        for &pos in positions.iter().take(MAX_NO_TRIES) {
188            let h_pos = (pos as usize) * HASHING_STEP;
189
190            // Bounds check
191            if h_pos >= self.reference.len() {
192                continue;
193            }
194
195            let ref_ptr = &self.reference[h_pos..];
196            let text_ptr = &target[text_pos..];
197
198            // Forward match
199            let f_len = Self::matching_length(text_ptr, ref_ptr, max_len);
200
201            if f_len >= self.key_len as usize {
202                // Backward match
203                let mut b_len = 0;
204                let max_back = no_prev_literals.min(h_pos).min(text_pos);
205                while b_len < max_back {
206                    if target[text_pos - b_len - 1] != self.reference[h_pos - b_len - 1] {
207                        break;
208                    }
209                    b_len += 1;
210                }
211
212                if b_len + f_len > min_to_update {
213                    best_len_bck = b_len as u32;
214                    best_len_fwd = f_len as u32;
215                    best_ref_pos = h_pos as u32;
216                    min_to_update = b_len + f_len;
217                }
218            }
219        }
220
221        if (best_len_bck + best_len_fwd) as usize >= self.min_match_len as usize {
222            Some((best_ref_pos, best_len_bck, best_len_fwd))
223        } else {
224            None
225        }
226    }
227
228    /// Count matching length between two sequences
229    fn matching_length(s1: &[u8], s2: &[u8], max_len: usize) -> usize {
230        let mut len = 0;
231        let max = max_len.min(s1.len()).min(s2.len());
232        while len < max && s1[len] == s2[len] {
233            len += 1;
234        }
235        len
236    }
237
238    /// Encode target sequence relative to reference
239    pub fn encode(&mut self, target: &Contig) -> Vec<u8> {
240        // Pre-allocate capacity to avoid repeated reallocations
241        // Typical LZ compression achieves 2-4:1, so estimate capacity as target_len / 2
242        let mut encoded = Vec::with_capacity(target.len() / 2);
243
244        // Optimization: if target equals reference, return empty
245        if target.len() == self.reference_len
246            && target
247                .iter()
248                .zip(self.reference.iter())
249                .all(|(a, b)| a == b)
250        {
251            return encoded;
252        }
253
254        let text_size = target.len();
255        let mut i = 0;
256        let mut pred_pos = 0u32;
257        let mut no_prev_literals = 0usize;
258        let mut x_prev: Option<u64> = None;
259
260        while i + (self.key_len as usize) < text_size {
261            // Get k-mer code
262            let x = if let Some(prev) = x_prev {
263                if no_prev_literals > 0 {
264                    self.get_code_skip1(prev, &target[i..])
265                } else {
266                    self.get_code(&target[i..])
267                }
268            } else {
269                self.get_code(&target[i..])
270            };
271
272            x_prev = x;
273
274            if x.is_none() {
275                // Check for N-run
276                let nrun_len = self.get_nrun_len(&target[i..], text_size - i);
277
278                if nrun_len >= MIN_NRUN_LEN {
279                    self.encode_nrun(nrun_len, &mut encoded);
280                    i += nrun_len as usize;
281                    no_prev_literals = 0;
282                } else {
283                    // Single literal
284                    self.encode_literal(target[i], &mut encoded);
285                    i += 1;
286                    pred_pos += 1;
287                    no_prev_literals += 1;
288                }
289                continue;
290            }
291
292            // Try to find match
293            let hash = MurMur64Hash::hash(x.unwrap());
294            let max_len = text_size - i;
295
296            if let Some((match_pos, len_bck, len_fwd)) =
297                self.find_best_match(hash, target, i, max_len, no_prev_literals)
298            {
299                // Handle backward extension
300                if len_bck > 0 {
301                    for _ in 0..len_bck {
302                        encoded.pop();
303                    }
304                    i -= len_bck as usize;
305                    pred_pos -= len_bck;
306                }
307
308                // Check if this is a match to end of sequence
309                let total_len = len_bck + len_fwd;
310                let len_to_encode = if i + (total_len as usize) == text_size
311                    && (match_pos as usize) + (total_len as usize) == self.reference_len
312                {
313                    None // Match to end
314                } else {
315                    Some(total_len)
316                };
317
318                self.encode_match(match_pos - len_bck, len_to_encode, pred_pos, &mut encoded);
319
320                pred_pos = match_pos - len_bck + total_len;
321                i += total_len as usize;
322                no_prev_literals = 0;
323            } else {
324                // No match, encode literal
325                self.encode_literal(target[i], &mut encoded);
326                i += 1;
327                pred_pos += 1;
328                no_prev_literals += 1;
329            }
330        }
331
332        // Encode remaining bases as literals
333        while i < text_size {
334            self.encode_literal(target[i], &mut encoded);
335            i += 1;
336        }
337
338        encoded
339    }
340
341    /// Decode encoded sequence using reference
342    pub fn decode(&self, encoded: &[u8]) -> Vec<u8> {
343        let mut decoded = Vec::new();
344        let mut pred_pos = 0usize;
345        let mut i = 0;
346
347        while i < encoded.len() {
348            if self.is_literal(encoded[i]) {
349                let c = self.decode_literal(encoded[i]);
350                let actual_c = if c == b'!' {
351                    self.reference[pred_pos]
352                } else {
353                    c
354                };
355                decoded.push(actual_c);
356                pred_pos += 1;
357                i += 1;
358            } else if encoded[i] == N_RUN_STARTER_CODE {
359                let (len, consumed) = self.decode_nrun(&encoded[i..]);
360                decoded.resize(decoded.len() + len as usize, N_CODE);
361                i += consumed;
362            } else {
363                // It's a match
364                let (ref_pos, len, consumed) = self.decode_match(&encoded[i..], pred_pos);
365                let actual_len = if len == u32::MAX {
366                    // Match to end: use original reference length (before padding)
367                    self.reference_len - ref_pos
368                } else {
369                    len as usize
370                };
371                decoded.extend_from_slice(&self.reference[ref_pos..ref_pos + actual_len]);
372                pred_pos = ref_pos + actual_len;
373                i += consumed;
374            }
375        }
376
377        decoded
378    }
379
380    /// Check if byte is a literal
381    fn is_literal(&self, c: u8) -> bool {
382        (b'A'..=b'A' + 20).contains(&c) || c == b'!'
383    }
384
385    /// Decode a literal
386    fn decode_literal(&self, c: u8) -> u8 {
387        if c == b'!' {
388            b'!'
389        } else {
390            c - b'A'
391        }
392    }
393
394    /// Decode an N-run, returns (length, bytes_consumed)
395    fn decode_nrun(&self, data: &[u8]) -> (u32, usize) {
396        let mut i = 1; // Skip starter code
397        let (raw_len, len_bytes) = self.read_int(&data[i..]);
398        i += len_bytes;
399        i += 1; // Skip N_CODE suffix
400        ((raw_len as u32) + MIN_NRUN_LEN, i)
401    }
402
403    /// Decode a match, returns (ref_pos, length, bytes_consumed)
404    fn decode_match(&self, data: &[u8], pred_pos: usize) -> (usize, u32, usize) {
405        let mut i = 0;
406        let (raw_pos, pos_bytes) = self.read_int(&data[i..]);
407        i += pos_bytes;
408
409        let ref_pos = ((pred_pos as i64) + raw_pos) as usize;
410
411        let len = if data[i] == b',' {
412            i += 1; // Skip comma
413            let (raw_len, len_bytes) = self.read_int(&data[i..]);
414            i += len_bytes;
415            i += 1; // Skip period
416            (raw_len as u32) + self.min_match_len
417        } else {
418            i += 1; // Skip period
419            u32::MAX // Sentinel for "to end of sequence"
420        };
421
422        (ref_pos, len, i)
423    }
424
425    /// Read ASCII decimal integer, returns (value, bytes_consumed)
426    fn read_int(&self, data: &[u8]) -> (i64, usize) {
427        let mut i = 0;
428        let mut is_neg = false;
429
430        if data[i] == b'-' {
431            is_neg = true;
432            i += 1;
433        }
434
435        let mut x = 0i64;
436        while i < data.len() && data[i] >= b'0' && data[i] <= b'9' {
437            x = x * 10 + ((data[i] - b'0') as i64);
438            i += 1;
439        }
440
441        if is_neg {
442            x = -x;
443        }
444
445        (x, i)
446    }
447
448    /// Get coding cost vector for target sequence
449    /// This computes the per-position cost of encoding the target against the reference
450    /// Returns a vector where v_costs[i] is the cost of encoding position i
451    /// If prefix_costs=true, match cost is placed at start of match; otherwise at end
452    pub fn get_coding_cost_vector(&self, target: &Contig, prefix_costs: bool) -> Vec<u32> {
453        let mut v_costs = Vec::with_capacity(target.len());
454
455        if self.reference.is_empty() {
456            return v_costs;
457        }
458
459        let text_size = target.len();
460        let mut i = 0;
461        let mut pred_pos = 0u32;
462        let mut no_prev_literals = 0usize;
463        let mut x_prev: Option<u64> = None;
464
465        while i + (self.key_len as usize) < text_size {
466            // Get k-mer code
467            let x = if let Some(prev) = x_prev {
468                if no_prev_literals > 0 {
469                    self.get_code_skip1(prev, &target[i..])
470                } else {
471                    self.get_code(&target[i..])
472                }
473            } else {
474                self.get_code(&target[i..])
475            };
476
477            x_prev = x;
478
479            if x.is_none() {
480                // Check for N-run
481                let nrun_len = self.get_nrun_len(&target[i..], text_size - i);
482
483                if nrun_len >= MIN_NRUN_LEN {
484                    let tc = self.coding_cost_nrun(nrun_len);
485                    if prefix_costs {
486                        v_costs.push(tc);
487                        for _ in 1..nrun_len {
488                            v_costs.push(0);
489                        }
490                    } else {
491                        for _ in 1..nrun_len {
492                            v_costs.push(0);
493                        }
494                        v_costs.push(tc);
495                    }
496                    i += nrun_len as usize;
497                    no_prev_literals = 0;
498                } else {
499                    // Single literal: cost is 1
500                    v_costs.push(1);
501                    i += 1;
502                    pred_pos += 1;
503                    no_prev_literals += 1;
504                }
505                continue;
506            }
507
508            // Try to find match
509            let hash = MurMur64Hash::hash(x.unwrap());
510            let max_len = text_size - i;
511
512            if let Some((match_pos, len_bck, len_fwd)) =
513                self.find_best_match(hash, target, i, max_len, no_prev_literals)
514            {
515                // Handle backward extension
516                if len_bck > 0 {
517                    for _ in 0..len_bck {
518                        v_costs.pop();
519                    }
520                    i -= len_bck as usize;
521                    pred_pos -= len_bck;
522                }
523
524                let total_len = len_bck + len_fwd;
525                let tc = self.coding_cost_match(match_pos - len_bck, total_len, pred_pos);
526
527                if prefix_costs {
528                    v_costs.push(tc);
529                    for _ in 1..total_len {
530                        v_costs.push(0);
531                    }
532                } else {
533                    for _ in 1..total_len {
534                        v_costs.push(0);
535                    }
536                    v_costs.push(tc);
537                }
538
539                pred_pos = match_pos - len_bck + total_len;
540                i += total_len as usize;
541                no_prev_literals = 0;
542            } else {
543                // No match, literal cost is 1
544                v_costs.push(1);
545                i += 1;
546                pred_pos += 1;
547                no_prev_literals += 1;
548            }
549        }
550
551        // Remaining bases are literals
552        while i < text_size {
553            v_costs.push(1);
554            i += 1;
555        }
556
557        v_costs
558    }
559
560    /// Compute coding cost for N-run
561    fn coding_cost_nrun(&self, len: u32) -> u32 {
562        // Cost: N_RUN_STARTER_CODE + decimal digits + N_CODE suffix
563        let delta = len - MIN_NRUN_LEN;
564        let digits = if delta == 0 {
565            1
566        } else {
567            ((delta as f64).log10().floor() as u32) + 1
568        };
569        1 + digits + 1 // starter + digits + suffix
570    }
571
572    /// Compute coding cost for match
573    fn coding_cost_match(&self, match_pos: u32, len: u32, pred_pos: u32) -> u32 {
574        // Cost: position_diff digits + ',' + length digits + '.'
575        let dif_pos = (match_pos as i32) - (pred_pos as i32);
576        let pos_digits = if dif_pos == 0 {
577            1
578        } else {
579            ((dif_pos.abs() as f64).log10().floor() as u32) + 1 + if dif_pos < 0 { 1 } else { 0 }
580        };
581
582        let delta = len - self.min_match_len;
583        let len_digits = if delta == 0 {
584            1
585        } else {
586            ((delta as f64).log10().floor() as u32) + 1
587        };
588
589        pos_digits + 1 + len_digits + 1 // pos + ',' + len + '.'
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn test_simple_literal() {
599        let reference = vec![0, 0, 0, 1, 1, 1];
600        let target = vec![0, 1, 2, 3];
601
602        let mut lz = LZDiff::new(18);
603        lz.prepare(&reference);
604
605        let encoded = lz.encode(&target);
606        let decoded = lz.decode(&encoded);
607
608        assert_eq!(target, decoded);
609    }
610
611    #[test]
612    fn test_identical_sequences() {
613        let reference = vec![0, 1, 2, 3, 0, 1, 2, 3];
614        let target = reference.clone();
615
616        let mut lz = LZDiff::new(18);
617        lz.prepare(&reference);
618
619        let encoded = lz.encode(&target);
620        // Should be empty (optimization)
621        assert_eq!(encoded.len(), 0);
622
623        // Special handling for empty encoding
624        let decoded = if encoded.is_empty() && target.len() == reference.len() {
625            reference.clone()
626        } else {
627            lz.decode(&encoded)
628        };
629
630        assert_eq!(target, decoded);
631    }
632}