use bytes::Bytes;
use crate::error::KmerLengthError;
const PACK_TABLE: [u64; 256] = {
let mut table = [0u64; 256];
table[b'A' as usize] = 0;
table[b'a' as usize] = 0;
table[b'C' as usize] = 1;
table[b'c' as usize] = 1;
table[b'G' as usize] = 2;
table[b'g' as usize] = 2;
table[b'T' as usize] = 3;
table[b't' as usize] = 3;
table
};
const UNPACK_TABLE: [u8; 4] = [b'A', b'C', b'G', b'T'];
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct KmerLength(u8);
impl KmerLength {
pub const MIN: u8 = 1;
pub const MAX: u8 = 32;
#[allow(clippy::cast_possible_truncation)]
pub const fn new(k: usize) -> Result<Self, KmerLengthError> {
if k < Self::MIN as usize || k > Self::MAX as usize {
return Err(KmerLengthError {
k,
min: Self::MIN,
max: Self::MAX,
});
}
Ok(Self(k as u8))
}
#[inline]
#[allow(unsafe_code)]
pub const unsafe fn new_unchecked(k: u8) -> Self {
Self(k)
}
#[inline]
pub const fn get(self) -> usize {
self.0 as usize
}
#[inline]
pub const fn as_u8(self) -> u8 {
self.0
}
}
impl TryFrom<usize> for KmerLength {
type Error = KmerLengthError;
fn try_from(k: usize) -> Result<Self, Self::Error> {
Self::new(k)
}
}
impl From<KmerLength> for usize {
fn from(k: KmerLength) -> Self {
k.get()
}
}
use std::marker::PhantomData;
use crate::error::InvalidBaseError;
#[derive(Debug, Clone, Copy, Default)]
pub struct Unpacked;
#[derive(Debug, Clone, Copy, Default)]
pub struct Packed;
#[derive(Debug, Clone, Copy, Default)]
pub struct Canonical;
#[derive(Debug, Clone)]
pub struct Kmer<S = Unpacked> {
bytes: Bytes,
packed_bits: u64,
is_reverse_complement: bool,
_state: PhantomData<S>,
}
impl Default for Kmer<Unpacked> {
fn default() -> Self {
Self {
bytes: Bytes::new(),
packed_bits: 0,
is_reverse_complement: false,
_state: PhantomData,
}
}
}
impl<S> Kmer<S> {
#[inline]
pub const fn bytes(&self) -> &Bytes {
&self.bytes
}
}
impl Kmer<Unpacked> {
#[allow(clippy::needless_pass_by_value)]
pub fn from_sub(sub: Bytes) -> Result<Self, InvalidBaseError> {
let normalized: Result<Vec<u8>, InvalidBaseError> = sub
.iter()
.enumerate()
.map(|(i, &byte)| match byte {
b'A' | b'C' | b'G' | b'T' => Ok(byte),
b'a' | b'c' | b'g' | b't' => Ok(byte.to_ascii_uppercase()),
_ => Err(InvalidBaseError {
base: byte,
position: i,
}),
})
.collect();
Ok(Self {
bytes: Bytes::from(normalized?),
packed_bits: 0,
is_reverse_complement: false,
_state: PhantomData,
})
}
pub fn pack(self) -> Kmer<Packed> {
let packed_bits = pack_bytes(&self.bytes);
Kmer {
bytes: self.bytes,
packed_bits,
is_reverse_complement: false,
_state: PhantomData,
}
}
}
impl Kmer<Packed> {
#[inline]
pub const fn packed_bits(&self) -> u64 {
self.packed_bits
}
pub fn canonical(self) -> Kmer<Canonical> {
let k = self.bytes.len();
let rc_bits = reverse_complement_bits(self.packed_bits, k);
let use_reverse_complement = rc_bits < self.packed_bits;
if use_reverse_complement {
let rc_bytes = unpack_to_bytes_from_raw(rc_bits, k);
Kmer {
bytes: rc_bytes,
packed_bits: rc_bits,
is_reverse_complement: true,
_state: PhantomData,
}
} else {
Kmer {
bytes: self.bytes,
packed_bits: self.packed_bits,
is_reverse_complement: false,
_state: PhantomData,
}
}
}
}
impl Kmer<Canonical> {
#[inline]
pub const fn packed_bits(&self) -> u64 {
self.packed_bits
}
#[inline]
pub const fn is_reverse_complement(&self) -> bool {
self.is_reverse_complement
}
}
pub fn unpack_to_bytes(packed_bits: u64, k: KmerLength) -> Bytes {
let k = k.get();
(0..k)
.map(|i| {
let shift = (k - 1 - i) * 2;
let bits = ((packed_bits >> shift) & 0b11) as usize;
UNPACK_TABLE[bits]
})
.collect()
}
#[allow(unsafe_code)]
pub fn unpack_to_string(packed_bits: u64, k: KmerLength) -> String {
let bytes = unpack_to_bytes(packed_bits, k);
unsafe { String::from_utf8_unchecked(bytes.to_vec()) }
}
#[inline]
fn pack_bytes(bytes: &[u8]) -> u64 {
bytes
.iter()
.fold(0u64, |acc, &b| (acc << 2) | PACK_TABLE[b as usize])
}
#[inline]
fn unpack_to_bytes_from_raw(packed_bits: u64, k: usize) -> Bytes {
(0..k)
.map(|i| {
let shift = (k - 1 - i) * 2;
let bits = ((packed_bits >> shift) & 0b11) as usize;
UNPACK_TABLE[bits]
})
.collect()
}
#[inline]
pub const fn reverse_complement_bits(bits: u64, k: usize) -> u64 {
debug_assert!(k >= 1 && k <= 32, "k must be 1..=32");
let mask = if k == 32 {
u64::MAX
} else {
(1u64 << (2 * k)) - 1
};
let mut rc = (!bits) & mask;
rc = ((rc >> 2) & 0x3333_3333_3333_3333) | ((rc & 0x3333_3333_3333_3333) << 2);
rc = ((rc >> 4) & 0x0F0F_0F0F_0F0F_0F0F) | ((rc & 0x0F0F_0F0F_0F0F_0F0F) << 4);
rc = ((rc >> 8) & 0x00FF_00FF_00FF_00FF) | ((rc & 0x00FF_00FF_00FF_00FF) << 8);
rc = ((rc >> 16) & 0x0000_FFFF_0000_FFFF) | ((rc & 0x0000_FFFF_0000_FFFF) << 16);
rc = rc.rotate_left(32);
rc >> (64 - 2 * k)
}
#[inline]
pub const fn canonical_bits(packed: u64, k: usize) -> u64 {
let rc = reverse_complement_bits(packed, k);
if packed <= rc {
packed
} else {
rc
}
}
#[inline]
pub fn validate_and_pack(seq: &[u8]) -> Result<u64, InvalidBaseError> {
debug_assert!(
!seq.is_empty(),
"validate_and_pack requires a non-empty slice"
);
debug_assert!(
seq.len() <= 32,
"validate_and_pack: sequence length {len} exceeds maximum k-mer size 32",
len = seq.len()
);
let mut packed: u64 = 0;
for (i, &b) in seq.iter().enumerate() {
match b {
b'A' | b'a' | b'C' | b'c' | b'G' | b'g' | b'T' | b't' => {
packed = (packed << 2) | PACK_TABLE[b as usize];
}
_ => {
return Err(InvalidBaseError {
base: b,
position: i,
});
}
}
}
Ok(packed)
}
#[inline]
pub fn pack_canonical(seq: &[u8]) -> Result<u64, InvalidBaseError> {
let packed = validate_and_pack(seq)?;
Ok(canonical_bits(packed, seq.len()))
}
pub enum KmerByte {
A,
C,
G,
T,
}
impl From<&u8> for KmerByte {
fn from(val: &u8) -> Self {
debug_assert!(
matches!(val, b'A' | b'a' | b'C' | b'c' | b'G' | b'g' | b'T' | b't'),
"KmerByte::from called with invalid base: {val:#x}"
);
match val {
b'A' | b'a' => Self::A,
b'C' | b'c' => Self::C,
b'G' | b'g' => Self::G,
b'T' | b't' => Self::T,
_ => unreachable!("invalid base passed to KmerByte::from"),
}
}
}
impl From<KmerByte> for u8 {
fn from(val: KmerByte) -> Self {
match val {
KmerByte::A => b'A',
KmerByte::C => b'C',
KmerByte::G => b'G',
KmerByte::T => b'T',
}
}
}
impl From<u64> for KmerByte {
fn from(val: u64) -> Self {
debug_assert!(
val <= 3,
"KmerByte::from called with invalid 2-bit value: {val}"
);
match val {
0 => Self::A,
1 => Self::C,
2 => Self::G,
3 => Self::T,
_ => unreachable!("invalid 2-bit value passed to KmerByte::from"),
}
}
}
impl From<KmerByte> for u64 {
fn from(val: KmerByte) -> Self {
match val {
KmerByte::A => 0,
KmerByte::C => 1,
KmerByte::G => 2,
KmerByte::T => 3,
}
}
}
impl KmerByte {
pub const fn try_from_byte(val: u8) -> Result<Self, InvalidBaseError> {
match val {
b'A' | b'a' => Ok(Self::A),
b'C' | b'c' => Ok(Self::C),
b'G' | b'g' => Ok(Self::G),
b'T' | b't' => Ok(Self::T),
_ => Err(InvalidBaseError {
base: val,
position: 0,
}),
}
}
#[allow(clippy::cast_possible_truncation)]
pub const fn try_from_bits(val: u64) -> Result<Self, InvalidBaseError> {
match val {
0 => Ok(Self::A),
1 => Ok(Self::C),
2 => Ok(Self::G),
3 => Ok(Self::T),
_ => Err(InvalidBaseError {
base: val as u8,
position: 0,
}),
}
}
#[must_use]
pub const fn reverse_complement(self) -> Self {
match self {
Self::A => Self::T,
Self::C => Self::G,
Self::G => Self::C,
Self::T => Self::A,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
pub mod test {
use super::*;
#[test]
fn bytes_from_valid_substring() {
let sub = b"GATTACA";
let kmer = Kmer::from_sub(Bytes::copy_from_slice(sub)).unwrap();
insta::assert_snapshot!(format!("{:?}", kmer.bytes()), @r#"b"GATTACA""#);
}
#[test]
fn from_substring_returns_err_for_invalid_substring() {
let sub = b"N";
let result = Kmer::from_sub(Bytes::copy_from_slice(sub));
assert!(result.is_err());
}
#[test]
fn from_sub_finds_invalid_byte_position() {
let test_cases = [
("NACNN", 0, b'N'),
("ANCNG", 1, b'N'),
("AANTG", 2, b'N'),
("CCCNG", 3, b'N'),
("AACTN", 4, b'N'),
];
for (dna, expected_pos, expected_base) in test_cases {
let res = Kmer::from_sub(Bytes::copy_from_slice(dna.as_bytes()));
let err = res.unwrap_err();
assert_eq!(err.position, expected_pos, "for sequence {dna}");
assert_eq!(err.base, expected_base, "for sequence {dna}");
}
}
#[test]
fn pack_unpack_roundtrip() {
let sequences = ["ACGT", "AAAA", "TTTT", "CCCC", "GGGG", "GATTACA"];
for seq in sequences {
let kmer = Kmer::from_sub(Bytes::copy_from_slice(seq.as_bytes())).unwrap();
let packed = kmer.pack();
let k = KmerLength::new(seq.len()).unwrap();
let unpacked = unpack_to_bytes(packed.packed_bits(), k);
assert_eq!(unpacked.as_ref(), seq.as_bytes());
}
}
#[test]
fn pack_unpack_roundtrip_various_lengths() {
for k_val in 1..=32 {
let seq = "A".repeat(k_val);
let kmer = Kmer::from_sub(Bytes::copy_from_slice(seq.as_bytes())).unwrap();
let packed = kmer.pack();
let k = KmerLength::new(k_val).unwrap();
let unpacked = unpack_to_bytes(packed.packed_bits(), k);
assert_eq!(unpacked.as_ref(), seq.as_bytes());
}
}
#[test]
fn canonical_selects_lexicographically_smaller() {
let kmer = Kmer::from_sub(Bytes::copy_from_slice(b"ACGT"))
.unwrap()
.pack()
.canonical();
assert_eq!(kmer.bytes().as_ref(), b"ACGT");
assert!(!kmer.is_reverse_complement());
let kmer = Kmer::from_sub(Bytes::copy_from_slice(b"AAA"))
.unwrap()
.pack()
.canonical();
assert_eq!(kmer.bytes().as_ref(), b"AAA");
assert!(!kmer.is_reverse_complement());
let kmer = Kmer::from_sub(Bytes::copy_from_slice(b"TTT"))
.unwrap()
.pack()
.canonical();
assert_eq!(kmer.bytes().as_ref(), b"AAA");
assert!(kmer.is_reverse_complement());
let kmer = Kmer::from_sub(Bytes::copy_from_slice(b"GATTACA"))
.unwrap()
.pack()
.canonical();
assert_eq!(kmer.bytes().as_ref(), b"GATTACA");
assert!(!kmer.is_reverse_complement());
let kmer = Kmer::from_sub(Bytes::copy_from_slice(b"TGTAATC"))
.unwrap()
.pack()
.canonical();
assert_eq!(kmer.bytes().as_ref(), b"GATTACA");
assert!(kmer.is_reverse_complement());
}
#[test]
fn kmer_byte_reverse_complement() {
assert!(matches!(KmerByte::A.reverse_complement(), KmerByte::T));
assert!(matches!(KmerByte::T.reverse_complement(), KmerByte::A));
assert!(matches!(KmerByte::C.reverse_complement(), KmerByte::G));
assert!(matches!(KmerByte::G.reverse_complement(), KmerByte::C));
}
#[test]
fn kmer_byte_to_u64_roundtrip() {
for (byte, expected) in [(b'A', 0u64), (b'C', 1u64), (b'G', 2u64), (b'T', 3u64)] {
let kmer_byte: KmerByte = (&byte).into();
let val: u64 = kmer_byte.into();
assert_eq!(val, expected);
let back: KmerByte = val.into();
let back_byte: u8 = back.into();
assert_eq!(back_byte, byte);
}
}
#[test]
fn empty_kmer() {
let kmer = Kmer::from_sub(Bytes::new()).unwrap();
assert!(kmer.bytes().is_empty());
}
#[test]
fn single_base_kmers() {
for base in [b'A', b'C', b'G', b'T'] {
let kmer = Kmer::from_sub(Bytes::copy_from_slice(&[base])).unwrap();
assert_eq!(kmer.bytes().as_ref(), &[base]);
}
}
#[test]
fn soft_masked_bases_converted_to_uppercase() {
let sub = b"AAAa";
let kmer = Kmer::from_sub(Bytes::copy_from_slice(sub)).unwrap();
insta::assert_snapshot!(format!("{:?}", kmer.bytes()), @r#"b"AAAA""#);
}
#[test]
fn soft_masked_all_lowercase() {
let sub = b"gattaca";
let kmer = Kmer::from_sub(Bytes::copy_from_slice(sub)).unwrap();
insta::assert_snapshot!(format!("{:?}", kmer.bytes()), @r#"b"GATTACA""#);
}
#[test]
fn soft_masked_mixed_case() {
let sub = b"AcGt";
let kmer = Kmer::from_sub(Bytes::copy_from_slice(sub)).unwrap();
insta::assert_snapshot!(format!("{:?}", kmer.bytes()), @r#"b"ACGT""#);
}
#[test]
fn kmer_length_valid_range() {
for k in 1..=32 {
let result = KmerLength::new(k);
assert!(result.is_ok(), "k={k} should be valid");
assert_eq!(result.unwrap().get(), k);
}
}
#[test]
fn kmer_length_rejects_zero() {
let result = KmerLength::new(0);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.k, 0);
assert_eq!(err.min, KmerLength::MIN);
assert_eq!(err.max, KmerLength::MAX);
}
#[test]
fn kmer_length_rejects_too_large() {
for k in [33, 64, 100, 1000] {
let result = KmerLength::new(k);
assert!(result.is_err(), "k={k} should be invalid");
}
}
#[test]
fn kmer_length_try_from() {
let k: Result<KmerLength, _> = 21usize.try_into();
assert!(k.is_ok());
assert_eq!(k.unwrap().get(), 21);
}
#[test]
fn kmer_length_into_usize() {
let k = KmerLength::new(21).unwrap();
let n: usize = k.into();
assert_eq!(n, 21);
}
#[test]
fn kmer_length_as_u8() {
let k = KmerLength::new(21).unwrap();
assert_eq!(k.as_u8(), 21);
}
#[test]
fn unpack_to_string_works() {
let k = KmerLength::new(4).unwrap();
let s = unpack_to_string(0b00_01_10_11, k);
assert_eq!(s, "ACGT");
}
#[test]
fn reverse_complement_bits_single_base() {
assert_eq!(reverse_complement_bits(0b00, 1), 0b11);
assert_eq!(reverse_complement_bits(0b11, 1), 0b00);
assert_eq!(reverse_complement_bits(0b01, 1), 0b10);
assert_eq!(reverse_complement_bits(0b10, 1), 0b01);
}
#[test]
fn reverse_complement_bits_max_k() {
let all_a = 0u64;
let all_t = u64::MAX; assert_eq!(reverse_complement_bits(all_a, 32), all_t);
assert_eq!(reverse_complement_bits(all_t, 32), all_a);
}
#[test]
fn reverse_complement_bits_palindrome() {
assert_eq!(reverse_complement_bits(0b00_01_10_11, 4), 0b00_01_10_11);
assert_eq!(reverse_complement_bits(0b00_11, 2), 0b00_11);
}
#[test]
fn canonical_bits_returns_smaller() {
assert_eq!(canonical_bits(0b00_00_00, 3), 0b00_00_00);
assert_eq!(canonical_bits(0b11_11_11, 3), 0b00_00_00);
}
#[test]
fn canonical_bits_palindrome_returns_forward() {
let acgt = 0b00_01_10_11u64;
assert_eq!(canonical_bits(acgt, 4), acgt);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "validate_and_pack requires a non-empty slice")]
fn pack_canonical_panics_on_empty_slice() {
let _ = pack_canonical(b"");
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "validate_and_pack requires a non-empty slice")]
fn validate_and_pack_panics_on_empty_slice() {
let _ = validate_and_pack(b"");
}
#[test]
fn validate_and_pack_rejects_invalid() {
let err = validate_and_pack(b"ACN").unwrap_err();
assert_eq!(err.position, 2);
assert_eq!(err.base, b'N');
}
#[test]
fn pack_canonical_matches_type_state_pipeline() {
let sequences = [
"GATTACA",
"TGTAATC",
"AAA",
"TTT",
"ACGT",
"A",
"T",
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", ];
for seq in sequences {
let via_pipeline = Kmer::from_sub(Bytes::copy_from_slice(seq.as_bytes()))
.unwrap()
.pack()
.canonical()
.packed_bits();
let via_fast_path = pack_canonical(seq.as_bytes()).unwrap();
assert_eq!(via_pipeline, via_fast_path, "mismatch for {seq}");
}
}
#[test]
fn type_state_flow() {
let unpacked: Kmer<Unpacked> = Kmer::from_sub(Bytes::from_static(b"GATTACA")).unwrap();
let packed: Kmer<Packed> = unpacked.pack();
let canonical: Kmer<Canonical> = packed.canonical();
assert!(canonical.packed_bits() > 0);
assert!(!canonical.is_reverse_complement());
}
}