Skip to main content

sshash_lib/builder/
encode.rs

1//! String encoding module for building the Spectrum-Preserving String Set
2//!
3//! Encodes DNA sequences into the 2-bit representation, extracts k-mers,
4//! and builds the offsets structure for string boundaries.
5
6use crate::encoding;
7use crate::kmer::{Kmer, KmerBits};
8use crate::offsets::OffsetsVector;
9use crate::spectrum_preserving_string_set::SpectrumPreservingStringSet;
10use anyhow::Result;
11
12/// Encoder for building SPSS from DNA sequences
13///
14/// Accumulates sequences, encodes them to 2-bit format, tracks offsets,
15/// and builds the final SPSS structure.
16///
17/// IMPORTANT: Bases are packed contiguously across sequences without
18/// byte-boundary padding. This means `base_idx / 4` always gives the correct
19/// byte and `(base_idx % 4) * 2` gives the correct bit offset within that byte.
20pub struct Encoder<const K: usize>
21where
22    Kmer<K>: KmerBits,
23{
24    /// Encoded strings (2-bit packed, contiguous across all sequences)
25    strings: Vec<u8>,
26    
27    /// Offset to start of each string (in bases from the beginning of all strings)
28    offsets: OffsetsVector,
29    
30    /// Total number of k-mers
31    num_kmers: u64,
32    
33    /// Total number of strings added
34    num_strings: u64,
35    
36    /// Total number of bases added so far (tracks bit position for contiguous packing)
37    total_bases: u64,
38}
39
40impl<const K: usize> Encoder<K>
41where
42    Kmer<K>: KmerBits,
43{
44    /// Create a new encoder
45    pub fn new() -> Self {
46        Self {
47            strings: Vec::new(),
48            offsets: OffsetsVector::new(),  // Already starts with [0]
49            num_kmers: 0,
50            num_strings: 0,
51            total_bases: 0,
52        }
53    }
54    
55    /// Add a DNA sequence to the encoder
56    ///
57    /// Bases are packed contiguously into the byte buffer without padding
58    /// between sequences, so decode_kmer can use simple base_idx/4 arithmetic.
59    ///
60    /// # Arguments
61    /// * `sequence` - DNA sequence (A, C, G, T only)
62    ///
63    /// # Errors
64    /// Returns error if sequence contains invalid characters or is too short
65    pub fn add_sequence(&mut self, sequence: &[u8]) -> Result<()> {
66        let seq_len = sequence.len();
67        
68        // Skip sequences too short to contain a k-mer
69        if seq_len < K {
70            return Ok(());
71        }
72        
73        // Pack each base contiguously into the byte buffer
74        for (i, &base) in sequence.iter().enumerate() {
75            let encoded = encoding::encode_base(base).map_err(|_| {
76                anyhow::anyhow!("Invalid base at position {}: {:?}", i, base as char)
77            })?;
78            
79            let base_idx = self.total_bases as usize;
80            let byte_idx = base_idx / 4;
81            let bit_offset = (base_idx % 4) * 2;
82            
83            // Extend buffer if needed
84            if byte_idx >= self.strings.len() {
85                self.strings.push(0);
86            }
87            
88            self.strings[byte_idx] |= encoded << bit_offset;
89            self.total_bases += 1;
90        }
91        
92        // Update offsets (offset is in bases, not bytes)
93        self.offsets.push(self.total_bases);
94        
95        // Count k-mers in this string
96        let kmers_in_string = if seq_len >= K {
97            (seq_len - K + 1) as u64
98        } else {
99            0
100        };
101        
102        self.num_kmers += kmers_in_string;
103        self.num_strings += 1;
104        
105        Ok(())
106    }
107    
108    /// Get the current number of k-mers
109    pub fn num_kmers(&self) -> u64 {
110        self.num_kmers
111    }
112    
113    /// Get the current number of strings
114    pub fn num_strings(&self) -> u64 {
115        self.num_strings
116    }
117    
118    /// Build the final SpectrumPreservingStringSet
119    ///
120    /// Consumes the encoder and returns the SPSS.
121    pub fn build(self, m: usize) -> SpectrumPreservingStringSet {
122        SpectrumPreservingStringSet::from_parts(
123            self.strings,
124            self.offsets,
125            K,
126            m,
127        )
128    }
129}
130
131impl<const K: usize> Default for Encoder<K>
132where
133    Kmer<K>: KmerBits,
134{
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    
144    #[test]
145    fn test_encoder_creation() {
146        let encoder = Encoder::<31>::new();
147        assert_eq!(encoder.num_kmers(), 0);
148        assert_eq!(encoder.num_strings(), 0);
149    }
150    
151    #[test]
152    fn test_encoder_add_sequence() {
153        let mut encoder = Encoder::<7>::new();
154        
155        // Add sequence "ACGTACGT" (length 8, contains 2 k=7-mers)
156        encoder.add_sequence(b"ACGTACGT").unwrap();
157        
158        assert_eq!(encoder.num_strings(), 1);
159        assert_eq!(encoder.num_kmers(), 2);  // 8 - 7 + 1 = 2
160    }
161    
162    #[test]
163    fn test_encoder_skip_short_sequence() {
164        let mut encoder = Encoder::<31>::new();
165        
166        // Add sequence shorter than k
167        encoder.add_sequence(b"ACGT").unwrap();  // Length 4 < 31
168        
169        assert_eq!(encoder.num_strings(), 0);  // Not counted
170        assert_eq!(encoder.num_kmers(), 0);
171    }
172    
173    #[test]
174    fn test_encoder_multiple_sequences() {
175        let mut encoder = Encoder::<5>::new();
176        
177        encoder.add_sequence(b"ACGTACGT").unwrap();  // 8 bases, 4 k=5-mers
178        encoder.add_sequence(b"TGCA").unwrap();      // 4 bases < 5, skipped
179        encoder.add_sequence(b"AAAAAAA").unwrap();   // 7 bases, 3 k=5-mers
180        
181        assert_eq!(encoder.num_strings(), 2);  // Only 2 sequences >= k
182        assert_eq!(encoder.num_kmers(), 7);    // 4 + 3 = 7
183    }
184    
185    #[test]
186    fn test_encoder_build_spss() {
187        let mut encoder = Encoder::<7>::new();
188        encoder.add_sequence(b"ACGTACGT").unwrap();
189        encoder.add_sequence(b"TGCATGCA").unwrap();
190        
191        let spss = encoder.build(5);  // m=5
192        
193        assert_eq!(spss.num_strings(), 2);
194        assert_eq!(spss.total_bases(), 16);  // 8 + 8
195    }
196}