1use cyanea_core::{CyaneaError, Result};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct TwoBitSequence {
22 data: Vec<u8>,
23 len: usize,
24}
25
26#[inline]
28fn encode_base(b: u8) -> Result<u8> {
29 match b {
30 b'A' | b'a' => Ok(0b00),
31 b'C' | b'c' => Ok(0b01),
32 b'G' | b'g' => Ok(0b10),
33 b'T' | b't' => Ok(0b11),
34 _ => Err(CyaneaError::InvalidInput(format!(
35 "invalid DNA base for 2-bit encoding: '{}' (0x{:02X}). Only A, C, G, T are supported",
36 b as char, b
37 ))),
38 }
39}
40
41#[inline]
43fn decode_base(bits: u8) -> u8 {
44 match bits & 0b11 {
45 0b00 => b'A',
46 0b01 => b'C',
47 0b10 => b'G',
48 0b11 => b'T',
49 _ => unreachable!(),
50 }
51}
52
53impl TwoBitSequence {
54 pub fn encode(seq: &[u8]) -> Result<Self> {
72 let len = seq.len();
73 let num_bytes = (len + 3) / 4;
74 let mut data = vec![0u8; num_bytes];
75
76 for (i, &base) in seq.iter().enumerate() {
77 let bits = encode_base(base)?;
78 let byte_idx = i / 4;
79 let bit_offset = 6 - (i % 4) * 2; data[byte_idx] |= bits << bit_offset;
81 }
82
83 Ok(Self { data, len })
84 }
85
86 pub fn decode(&self) -> Vec<u8> {
90 let mut result = Vec::with_capacity(self.len);
91 for i in 0..self.len {
92 let byte_idx = i / 4;
93 let bit_offset = 6 - (i % 4) * 2;
94 let bits = (self.data[byte_idx] >> bit_offset) & 0b11;
95 result.push(decode_base(bits));
96 }
97 result
98 }
99
100 pub fn get(&self, index: usize) -> Option<u8> {
104 if index >= self.len {
105 return None;
106 }
107 let byte_idx = index / 4;
108 let bit_offset = 6 - (index % 4) * 2;
109 let bits = (self.data[byte_idx] >> bit_offset) & 0b11;
110 Some(decode_base(bits))
111 }
112
113 pub fn len(&self) -> usize {
115 self.len
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.len == 0
121 }
122
123 pub fn kmer(&self, pos: usize, k: usize) -> Option<u64> {
130 if k == 0 || k > 32 || pos + k > self.len {
131 return None;
132 }
133
134 let mut value: u64 = 0;
135 for i in 0..k {
136 let idx = pos + i;
137 let byte_idx = idx / 4;
138 let bit_offset = 6 - (idx % 4) * 2;
139 let bits = ((self.data[byte_idx] >> bit_offset) & 0b11) as u64;
140 value = (value << 2) | bits;
141 }
142
143 Some(value)
144 }
145
146 pub fn complement(&self) -> Self {
154 let mut data = self.data.clone();
155
156 for byte in &mut data {
158 *byte ^= 0xFF;
159 }
160
161 let remainder = self.len % 4;
163 if remainder != 0 && !data.is_empty() {
164 let last = data.len() - 1;
165 let used_bits = remainder * 2;
167 let mask = !0u8 << (8 - used_bits);
168 data[last] &= mask;
169 }
170
171 Self {
172 data,
173 len: self.len,
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn encode_decode_roundtrip() {
184 let original = b"ACGTACGT";
185 let encoded = TwoBitSequence::encode(original).unwrap();
186 assert_eq!(encoded.decode(), original);
187 }
188
189 #[test]
190 fn encode_decode_all_bases() {
191 let original = b"AAAA";
192 assert_eq!(TwoBitSequence::encode(original).unwrap().decode(), original.to_vec());
193
194 let original = b"CCCC";
195 assert_eq!(TwoBitSequence::encode(original).unwrap().decode(), original.to_vec());
196
197 let original = b"GGGG";
198 assert_eq!(TwoBitSequence::encode(original).unwrap().decode(), original.to_vec());
199
200 let original = b"TTTT";
201 assert_eq!(TwoBitSequence::encode(original).unwrap().decode(), original.to_vec());
202 }
203
204 #[test]
205 fn encode_decode_non_multiple_of_four() {
206 let seq = TwoBitSequence::encode(b"A").unwrap();
208 assert_eq!(seq.len(), 1);
209 assert_eq!(seq.decode(), b"A");
210
211 let seq = TwoBitSequence::encode(b"CG").unwrap();
213 assert_eq!(seq.len(), 2);
214 assert_eq!(seq.decode(), b"CG");
215
216 let seq = TwoBitSequence::encode(b"ACT").unwrap();
218 assert_eq!(seq.len(), 3);
219 assert_eq!(seq.decode(), b"ACT");
220
221 let seq = TwoBitSequence::encode(b"ACGTA").unwrap();
223 assert_eq!(seq.len(), 5);
224 assert_eq!(seq.decode(), b"ACGTA");
225 }
226
227 #[test]
228 fn encode_case_insensitive() {
229 let upper = TwoBitSequence::encode(b"ACGT").unwrap();
230 let lower = TwoBitSequence::encode(b"acgt").unwrap();
231 assert_eq!(upper, lower);
232 assert_eq!(lower.decode(), b"ACGT");
233 }
234
235 #[test]
236 fn encode_non_dna_error() {
237 assert!(TwoBitSequence::encode(b"ACGN").is_err());
238 assert!(TwoBitSequence::encode(b"ACGR").is_err());
239 assert!(TwoBitSequence::encode(b"ACGU").is_err());
240 assert!(TwoBitSequence::encode(b"XYZ").is_err());
241 }
242
243 #[test]
244 fn empty_sequence() {
245 let seq = TwoBitSequence::encode(b"").unwrap();
246 assert_eq!(seq.len(), 0);
247 assert!(seq.is_empty());
248 assert_eq!(seq.decode(), Vec::<u8>::new());
249 assert_eq!(seq.get(0), None);
250 }
251
252 #[test]
253 fn get_individual_bases() {
254 let seq = TwoBitSequence::encode(b"ACGT").unwrap();
255 assert_eq!(seq.get(0), Some(b'A'));
256 assert_eq!(seq.get(1), Some(b'C'));
257 assert_eq!(seq.get(2), Some(b'G'));
258 assert_eq!(seq.get(3), Some(b'T'));
259 assert_eq!(seq.get(4), None);
260 }
261
262 #[test]
263 fn get_bases_non_aligned() {
264 let seq = TwoBitSequence::encode(b"TAGCAA").unwrap();
265 assert_eq!(seq.get(0), Some(b'T'));
266 assert_eq!(seq.get(1), Some(b'A'));
267 assert_eq!(seq.get(2), Some(b'G'));
268 assert_eq!(seq.get(3), Some(b'C'));
269 assert_eq!(seq.get(4), Some(b'A'));
270 assert_eq!(seq.get(5), Some(b'A'));
271 }
272
273 #[test]
274 fn kmer_extraction() {
275 let seq = TwoBitSequence::encode(b"ACGT").unwrap();
277
278 assert_eq!(seq.kmer(0, 2), Some(0b0001));
280 assert_eq!(seq.kmer(1, 2), Some(0b0110));
282 assert_eq!(seq.kmer(2, 2), Some(0b1011));
284
285 assert_eq!(seq.kmer(0, 4), Some(0b00011011));
287 }
288
289 #[test]
290 fn kmer_edge_cases() {
291 let seq = TwoBitSequence::encode(b"ACGT").unwrap();
292
293 assert_eq!(seq.kmer(0, 0), None);
295 assert_eq!(seq.kmer(0, 33), None);
297 assert_eq!(seq.kmer(3, 2), None);
299 assert_eq!(seq.kmer(0, 1), Some(0b00)); assert_eq!(seq.kmer(3, 1), Some(0b11)); }
303
304 #[test]
305 fn kmer_max_k32() {
306 let seq = TwoBitSequence::encode(&vec![b'A'; 32]).unwrap();
308 assert_eq!(seq.kmer(0, 32), Some(0u64));
309
310 let seq = TwoBitSequence::encode(&vec![b'T'; 32]).unwrap();
312 assert_eq!(seq.kmer(0, 32), Some(u64::MAX));
313 }
314
315 #[test]
316 fn complement_basic() {
317 let seq = TwoBitSequence::encode(b"ACGT").unwrap();
318 let comp = seq.complement();
319 assert_eq!(comp.decode(), b"TGCA");
320 }
321
322 #[test]
323 fn complement_all_same() {
324 let seq = TwoBitSequence::encode(b"AAAA").unwrap();
325 assert_eq!(seq.complement().decode(), b"TTTT");
326
327 let seq = TwoBitSequence::encode(b"CCCC").unwrap();
328 assert_eq!(seq.complement().decode(), b"GGGG");
329 }
330
331 #[test]
332 fn complement_non_aligned() {
333 let seq = TwoBitSequence::encode(b"ACG").unwrap();
334 let comp = seq.complement();
335 assert_eq!(comp.decode(), b"TGC");
336 assert_eq!(comp.len(), 3);
337 }
338
339 #[test]
340 fn complement_involution() {
341 let seq = TwoBitSequence::encode(b"ACGTACGTAA").unwrap();
342 let double_comp = seq.complement().complement();
343 assert_eq!(seq.decode(), double_comp.decode());
344 }
345
346 #[test]
347 fn complement_empty() {
348 let seq = TwoBitSequence::encode(b"").unwrap();
349 let comp = seq.complement();
350 assert!(comp.is_empty());
351 assert_eq!(comp.decode(), Vec::<u8>::new());
352 }
353
354 #[test]
355 fn compact_storage() {
356 let seq = TwoBitSequence::encode(b"ACGTACGT").unwrap();
358 assert_eq!(seq.data.len(), 2);
359
360 let seq = TwoBitSequence::encode(b"ACGTACGTA").unwrap();
362 assert_eq!(seq.data.len(), 3);
363 }
364
365 #[test]
366 fn long_sequence_roundtrip() {
367 let bases = b"ACGT";
368 let long: Vec<u8> = (0..1000).map(|i| bases[i % 4]).collect();
369 let seq = TwoBitSequence::encode(&long).unwrap();
370 assert_eq!(seq.len(), 1000);
371 assert_eq!(seq.decode(), long);
372 }
373}