Skip to main content

fgumi_dna/
bitenc.rs

1//! A 2-bit DNA encoding for fast UMI comparison.
2//!
3//! This module provides efficient storage and comparison of DNA sequences
4//! using 2 bits per base. This is particularly useful for UMI correction
5//! where each observed UMI must be compared against thousands of expected UMIs.
6//!
7//! # Example
8//!
9//! ```
10//! use fgumi_dna::bitenc::BitEnc;
11//!
12//! let umi1 = BitEnc::from_bytes(b"ACGT").unwrap();
13//! let umi2 = BitEnc::from_bytes(b"ACTT").unwrap();
14//! assert_eq!(umi1.hamming_distance(&umi2), 1);
15//! ```
16
17/// A 2-bit encoded DNA sequence stored in a u64.
18///
19/// Supports sequences up to 32 bases (64 bits).
20/// Each base is encoded as: A=0, C=1, G=2, T=3.
21#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
22pub struct BitEnc {
23    /// The encoded sequence, with bases packed from LSB.
24    bits: u64,
25    /// Number of bases in the sequence.
26    len: u8,
27}
28
29impl BitEnc {
30    /// Encode a single DNA base to 2 bits.
31    #[inline]
32    const fn encode_base(base: u8) -> Option<u64> {
33        match base {
34            b'A' | b'a' => Some(0),
35            b'C' | b'c' => Some(1),
36            b'G' | b'g' => Some(2),
37            b'T' | b't' => Some(3),
38            _ => None,
39        }
40    }
41
42    /// Create a `BitEnc` from a byte slice.
43    ///
44    /// Returns None if the sequence contains non-ACGT bases or exceeds 32 bases.
45    #[inline]
46    #[must_use]
47    pub fn from_bytes(seq: &[u8]) -> Option<Self> {
48        if seq.len() > 32 {
49            return None;
50        }
51
52        let mut bits: u64 = 0;
53        for (i, &base) in seq.iter().enumerate() {
54            let encoded = Self::encode_base(base)?;
55            bits |= encoded << (i * 2);
56        }
57
58        #[expect(clippy::cast_possible_truncation, reason = "guarded by seq.len() <= 32")]
59        let len = seq.len() as u8;
60        Some(Self { bits, len })
61    }
62
63    /// Create a `BitEnc` from a UMI string, skipping non-ACGT characters (e.g., dashes in paired UMIs).
64    ///
65    /// This is useful for paired UMIs like "ACGT-TGCA" where the dash should be ignored.
66    /// Returns None if the sequence contains invalid bases (not ACGT or dash) or exceeds 32 bases.
67    #[inline]
68    #[must_use]
69    pub fn from_umi_str(umi: &str) -> Option<Self> {
70        let mut bits: u64 = 0;
71        let mut base_count: usize = 0;
72
73        for &byte in umi.as_bytes() {
74            if let Some(encoded) = Self::encode_base(byte) {
75                if base_count >= 32 {
76                    return None;
77                }
78                bits |= encoded << (base_count * 2);
79                base_count += 1;
80            } else if byte != b'-' {
81                // Invalid character (not ACGT and not dash)
82                return None;
83            }
84            // Dash is silently skipped
85        }
86
87        #[expect(clippy::cast_possible_truncation, reason = "guarded by base_count <= 32")]
88        let len = base_count as u8;
89        Some(Self { bits, len })
90    }
91
92    /// Get the number of bases in this sequence.
93    #[inline]
94    #[must_use]
95    pub const fn len(&self) -> usize {
96        self.len as usize
97    }
98
99    /// Check if the sequence is empty.
100    #[inline]
101    #[must_use]
102    pub const fn is_empty(&self) -> bool {
103        self.len == 0
104    }
105
106    /// Compute the Hamming distance between two encoded sequences.
107    ///
108    /// Both sequences must have the same length (debug assertion).
109    #[inline]
110    #[must_use]
111    pub fn hamming_distance(&self, other: &Self) -> u32 {
112        debug_assert_eq!(self.len, other.len, "Sequences must have equal length");
113
114        // XOR to find differing bits
115        let diff = self.bits ^ other.bits;
116
117        // For 2-bit encoding, a position differs if either of its 2 bits differ
118        let odd_bits = diff & 0xAAAA_AAAA_AAAA_AAAA;
119        let even_bits = diff & 0x5555_5555_5555_5555;
120
121        // Combine: a position differs if odd OR even bit is set
122        let differs = (odd_bits >> 1) | even_bits;
123
124        differs.count_ones()
125    }
126
127    /// Get the 2-bit encoded base at the given position.
128    ///
129    /// Returns a value in 0..4 representing A, C, G, T respectively.
130    #[inline]
131    #[must_use]
132    pub fn base_at(&self, pos: usize) -> u8 {
133        debug_assert!(pos < self.len as usize, "Position out of bounds");
134        #[allow(clippy::cast_possible_truncation)] // masked to 2 bits
135        let base = ((self.bits >> (pos * 2)) & 0b11) as u8;
136        base
137    }
138
139    /// Return a copy of this sequence with a different base at the given position.
140    ///
141    /// `base` must be in 0..4 (A=0, C=1, G=2, T=3).
142    #[inline]
143    #[must_use]
144    pub fn with_base_at(&self, pos: usize, base: u8) -> Self {
145        debug_assert!(pos < self.len as usize, "Position out of bounds");
146        debug_assert!(base < 4, "Invalid base value");
147        let bit_pos = pos * 2;
148        let mask = !(0b11u64 << bit_pos);
149        let new_bits = (self.bits & mask) | (u64::from(base) << bit_pos);
150        Self { bits: new_bits, len: self.len }
151    }
152
153    /// Extract bits for bases `[start_base, start_base + len)` as a u32.
154    ///
155    /// Each base is 2 bits, so this can extract up to 16 bases into a u32.
156    /// Used for N-gram partitioning in similarity search.
157    #[inline]
158    #[must_use]
159    pub fn extract_bits(&self, start_base: usize, num_bases: usize) -> u32 {
160        debug_assert!(num_bases <= 16, "Can only extract up to 16 bases into u32");
161        debug_assert!(
162            start_base + num_bases <= self.len as usize,
163            "Extraction range exceeds sequence length"
164        );
165        let start_bit = start_base * 2;
166        let mask = (1u64 << (num_bases * 2)) - 1;
167        #[expect(clippy::cast_possible_truncation, reason = "masked to at most 32 bits")]
168        {
169            ((self.bits >> start_bit) & mask) as u32
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_from_bytes() {
180        // Valid sequences
181        let enc = BitEnc::from_bytes(b"ACGT").unwrap();
182        assert_eq!(enc.len(), 4);
183        assert_eq!(enc.bits, 0b11_10_01_00); // T=3, G=2, C=1, A=0
184
185        // Lowercase works
186        assert!(BitEnc::from_bytes(b"acgt").is_some());
187
188        // Invalid: non-ACGT bases
189        assert!(BitEnc::from_bytes(b"ACGN").is_none());
190
191        // Invalid: too long (>32 bases)
192        assert!(BitEnc::from_bytes(&[b'A'; 33]).is_none());
193        assert!(BitEnc::from_bytes(&[b'A'; 32]).is_some());
194    }
195
196    #[test]
197    fn test_hamming_distance() {
198        // Identical sequences
199        let seq1 = BitEnc::from_bytes(b"ACGTACGT").unwrap();
200        let seq2 = BitEnc::from_bytes(b"ACGTACGT").unwrap();
201        assert_eq!(seq1.hamming_distance(&seq2), 0);
202
203        // One difference
204        let seq3 = BitEnc::from_bytes(b"ACGTACTT").unwrap();
205        assert_eq!(seq1.hamming_distance(&seq3), 1);
206
207        // All different
208        let all_a = BitEnc::from_bytes(b"AAAA").unwrap();
209        let all_t = BitEnc::from_bytes(b"TTTT").unwrap();
210        assert_eq!(all_a.hamming_distance(&all_t), 4);
211
212        // Typical 18bp UMI
213        let umi1 = BitEnc::from_bytes(b"AACAACACATCTACCTTC").unwrap();
214        let umi2 = BitEnc::from_bytes(b"AACAACACATCTACCTTA").unwrap();
215        assert_eq!(umi1.hamming_distance(&umi2), 1);
216    }
217
218    #[test]
219    fn test_extract_bits() {
220        let enc = BitEnc::from_bytes(b"ACGTACGT").unwrap();
221
222        // First and last 4 bases are identical (ACGT)
223        assert_eq!(enc.extract_bits(0, 4), enc.extract_bits(4, 4));
224
225        // Single base extraction
226        assert_eq!(enc.extract_bits(0, 1), 0); // A=0
227        assert_eq!(enc.extract_bits(1, 1), 1); // C=1
228        assert_eq!(enc.extract_bits(2, 1), 2); // G=2
229        assert_eq!(enc.extract_bits(3, 1), 3); // T=3
230
231        // Middle extraction (GTAC)
232        assert_eq!(enc.extract_bits(2, 4), 0b01_00_11_10);
233    }
234
235    #[test]
236    fn test_from_umi_str() {
237        // Simple UMI without dash
238        let enc = BitEnc::from_umi_str("ACGT").unwrap();
239        assert_eq!(enc.len(), 4);
240        assert_eq!(enc, BitEnc::from_bytes(b"ACGT").unwrap());
241
242        // Paired UMI with dash - dash should be skipped
243        let paired = BitEnc::from_umi_str("ACGT-TGCA").unwrap();
244        assert_eq!(paired.len(), 8);
245        assert_eq!(paired, BitEnc::from_bytes(b"ACGTTGCA").unwrap());
246
247        // Real paired UMI from test data
248        let real = BitEnc::from_umi_str("GTCTGAGATC-AATCTTTAAT").unwrap();
249        assert_eq!(real.len(), 20);
250
251        // Lowercase works
252        let lower = BitEnc::from_umi_str("acgt-tgca").unwrap();
253        assert_eq!(lower, paired);
254
255        // Invalid character (not ACGT, not dash)
256        assert!(BitEnc::from_umi_str("ACGT-NGCA").is_none());
257
258        // Multiple dashes work
259        let multi = BitEnc::from_umi_str("AC-GT-TG").unwrap();
260        assert_eq!(multi.len(), 6);
261        assert_eq!(multi, BitEnc::from_bytes(b"ACGTTG").unwrap());
262    }
263
264    #[rstest::rstest]
265    #[case(0, 0)] // A
266    #[case(1, 1)] // C
267    #[case(2, 2)] // G
268    #[case(3, 3)] // T
269    fn test_base_at(#[case] pos: usize, #[case] expected: u8) {
270        let enc = BitEnc::from_bytes(b"ACGT").unwrap();
271        assert_eq!(enc.base_at(pos), expected);
272    }
273
274    #[test]
275    fn test_with_base_at() {
276        let enc = BitEnc::from_bytes(b"AAAA").unwrap();
277
278        // Change position 1 to C
279        let modified = enc.with_base_at(1, 1);
280        assert_eq!(modified, BitEnc::from_bytes(b"ACAA").unwrap());
281
282        // Change position 3 to T
283        let modified2 = enc.with_base_at(3, 3);
284        assert_eq!(modified2, BitEnc::from_bytes(b"AAAT").unwrap());
285
286        // Round-trip: replacing a base with its current value is idempotent
287        let enc2 = BitEnc::from_bytes(b"ACGT").unwrap();
288        for pos in 0..4 {
289            let base = enc2.base_at(pos);
290            let restored = enc2.with_base_at(pos, base);
291            assert_eq!(restored, enc2);
292        }
293    }
294
295    #[test]
296    fn test_paired_umi_hamming() {
297        // Two paired UMIs with 1 mismatch
298        let umi1 = BitEnc::from_umi_str("GTCTGAGATC-AATCTTTAAT").unwrap();
299        let umi2 = BitEnc::from_umi_str("GTCTGAGATC-AATCTTTAAC").unwrap(); // T->C at end
300        assert_eq!(umi1.hamming_distance(&umi2), 1);
301
302        // Two identical paired UMIs
303        let umi3 = BitEnc::from_umi_str("AAAGCGATGC-CCAGTTAACC").unwrap();
304        let umi4 = BitEnc::from_umi_str("AAAGCGATGC-CCAGTTAACC").unwrap();
305        assert_eq!(umi3.hamming_distance(&umi4), 0);
306    }
307}