use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum EncodingError {
#[error("Invalid DNA base: {0:?}")]
InvalidBase(u8),
#[error("Invalid k-mer string: {0}")]
InvalidKmer(String),
#[error("K-mer length mismatch: expected {expected}, got {actual}")]
LengthMismatch {
expected: usize,
actual: usize,
},
}
#[inline]
pub const fn encode_base(base: u8) -> Result<u8, EncodingError> {
match base {
b'A' | b'a' => Ok(0b00),
b'C' | b'c' => Ok(0b01),
b'G' | b'g' => Ok(0b11),
b'T' | b't' => Ok(0b10),
_ => Err(EncodingError::InvalidBase(base)),
}
}
#[inline]
pub const fn decode_base(bits: u8) -> u8 {
match bits & 0b11 {
0b00 => b'A',
0b01 => b'C',
0b11 => b'G',
0b10 => b'T',
_ => unreachable!(),
}
}
#[inline]
pub const fn complement_base(bits: u8) -> u8 {
bits ^ 0b10
}
pub fn encode_string(s: &str) -> Result<Vec<u8>, EncodingError> {
encode_sequence(s.as_bytes())
}
pub fn encode_sequence(sequence: &[u8]) -> Result<Vec<u8>, EncodingError> {
let mut result = Vec::with_capacity(sequence.len().div_ceil(4));
let mut current = 0u8;
let mut bit_pos = 0;
for (i, &base) in sequence.iter().enumerate() {
let encoded = encode_base(base).map_err(|_| {
EncodingError::InvalidKmer(format!("Invalid base at position {}: {:?}", i, base as char))
})?;
current |= encoded << bit_pos;
bit_pos += 2;
if bit_pos == 8 {
result.push(current);
current = 0;
bit_pos = 0;
}
}
if bit_pos > 0 {
result.push(current);
}
Ok(result)
}
pub fn decode_string(data: &[u8], length: usize) -> String {
let mut result = String::with_capacity(length);
let mut bit_pos = 0;
let mut byte_idx = 0;
for _ in 0..length {
if byte_idx >= data.len() {
break;
}
let bits = (data[byte_idx] >> bit_pos) & 0b11;
result.push(decode_base(bits) as char);
bit_pos += 2;
if bit_pos == 8 {
bit_pos = 0;
byte_idx += 1;
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_base() {
assert_eq!(encode_base(b'A').unwrap(), 0b00);
assert_eq!(encode_base(b'a').unwrap(), 0b00);
assert_eq!(encode_base(b'C').unwrap(), 0b01);
assert_eq!(encode_base(b'c').unwrap(), 0b01);
assert_eq!(encode_base(b'G').unwrap(), 0b11);
assert_eq!(encode_base(b'g').unwrap(), 0b11);
assert_eq!(encode_base(b'T').unwrap(), 0b10);
assert_eq!(encode_base(b't').unwrap(), 0b10);
assert!(encode_base(b'N').is_err());
assert!(encode_base(b'X').is_err());
assert!(encode_base(b'0').is_err());
}
#[test]
fn test_decode_base() {
assert_eq!(decode_base(0b00), b'A');
assert_eq!(decode_base(0b01), b'C');
assert_eq!(decode_base(0b11), b'G');
assert_eq!(decode_base(0b10), b'T');
}
#[test]
fn test_complement_base() {
assert_eq!(complement_base(0b00), 0b10); assert_eq!(complement_base(0b10), 0b00); assert_eq!(complement_base(0b01), 0b11); assert_eq!(complement_base(0b11), 0b01); }
#[test]
fn test_encode_decode_roundtrip() {
let sequences = vec!["ACGT", "AAAA", "TTTT", "ACGTACGT", "GATTACA"];
for seq in sequences {
let encoded = encode_string(seq).unwrap();
let decoded = decode_string(&encoded, seq.len());
assert_eq!(decoded, seq.to_uppercase());
}
}
#[test]
fn test_encode_mixed_case() {
let lower = encode_string("acgt").unwrap();
let upper = encode_string("ACGT").unwrap();
assert_eq!(lower, upper);
}
#[test]
fn test_encode_invalid() {
assert!(encode_string("ACGTN").is_err());
assert!(encode_string("ACGT X").is_err());
}
}