use crate::encoding::{decode_base, encode_base, EncodingError};
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::{BitAnd, BitOr, Not, Shl, Shr};
pub trait KmerBits: Sized {
type Storage: Copy
+ Ord
+ Hash
+ fmt::Debug
+ fmt::Display
+ fmt::Binary
+ From<u8>
+ From<u64>
+ BitAnd<Output = Self::Storage>
+ BitOr<Output = Self::Storage>
+ Not<Output = Self::Storage>
+ Shl<usize, Output = Self::Storage>
+ Shr<usize, Output = Self::Storage>
+ Shr<i32, Output = Self::Storage>
+ Send
+ Sync;
const BITS: usize;
const MAX_K: usize = Self::BITS / 2 - 1;
fn to_u8(val: Self::Storage) -> u8;
fn to_u64(val: Self::Storage) -> u64;
fn to_u128(val: Self::Storage) -> u128;
fn from_u8(val: u8) -> Self::Storage;
fn from_u64(val: u64) -> Self::Storage;
fn from_u128(val: u128) -> Self::Storage;
fn shl(val: Self::Storage, bits: usize) -> Self::Storage;
fn shr(val: Self::Storage, bits: usize) -> Self::Storage;
fn bitand(a: Self::Storage, b: Self::Storage) -> Self::Storage;
fn bitor(a: Self::Storage, b: Self::Storage) -> Self::Storage;
fn bitnot(a: Self::Storage) -> Self::Storage;
}
macro_rules! impl_kmer_bits_u64 {
($($k:literal),* $(,)?) => {
$(
impl KmerBits for Kmer<$k> {
type Storage = u64;
const BITS: usize = 64;
#[inline]
fn to_u8(val: Self::Storage) -> u8 {
val as u8
}
#[inline]
fn to_u64(val: Self::Storage) -> u64 {
val
}
#[inline]
fn to_u128(val: Self::Storage) -> u128 {
val as u128
}
#[inline]
fn from_u8(val: u8) -> Self::Storage {
val as u64
}
#[inline]
fn from_u64(val: u64) -> Self::Storage {
val
}
#[inline]
fn from_u128(val: u128) -> Self::Storage {
val as u64
}
#[inline]
fn shl(val: Self::Storage, bits: usize) -> Self::Storage {
val << bits
}
#[inline]
fn shr(val: Self::Storage, bits: usize) -> Self::Storage {
val >> bits
}
#[inline]
fn bitand(a: Self::Storage, b: Self::Storage) -> Self::Storage {
a & b
}
#[inline]
fn bitor(a: Self::Storage, b: Self::Storage) -> Self::Storage {
a | b
}
#[inline]
fn bitnot(a: Self::Storage) -> Self::Storage {
!a
}
}
)*
};
}
impl_kmer_bits_u64!(3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31);
macro_rules! impl_kmer_bits_u128 {
($($k:literal),* $(,)?) => {
$(
impl KmerBits for Kmer<$k> {
type Storage = u128;
const BITS: usize = 128;
#[inline]
fn to_u8(val: Self::Storage) -> u8 {
val as u8
}
#[inline]
fn to_u64(val: Self::Storage) -> u64 {
val as u64
}
#[inline]
fn to_u128(val: Self::Storage) -> u128 {
val
}
#[inline]
fn from_u8(val: u8) -> Self::Storage {
val as u128
}
#[inline]
fn from_u64(val: u64) -> Self::Storage {
val as u128
}
#[inline]
fn from_u128(val: u128) -> Self::Storage {
val
}
#[inline]
fn shl(val: Self::Storage, bits: usize) -> Self::Storage {
val << bits
}
#[inline]
fn shr(val: Self::Storage, bits: usize) -> Self::Storage {
val >> bits
}
#[inline]
fn bitand(a: Self::Storage, b: Self::Storage) -> Self::Storage {
a & b
}
#[inline]
fn bitor(a: Self::Storage, b: Self::Storage) -> Self::Storage {
a | b
}
#[inline]
fn bitnot(a: Self::Storage) -> Self::Storage {
!a
}
}
)*
};
}
impl_kmer_bits_u128!(33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63);
#[derive(Clone, Copy)]
pub struct Kmer<const K: usize>
where
Kmer<K>: KmerBits,
{
bits: <Kmer<K> as KmerBits>::Storage,
}
impl<const K: usize> Kmer<K>
where
Kmer<K>: KmerBits,
{
#[inline]
pub fn new(bits: <Kmer<K> as KmerBits>::Storage) -> Self {
Self { bits }
}
#[inline]
pub fn from_bits(bits: u128) -> Self {
Self { bits: <Kmer<K> as KmerBits>::from_u128(bits) }
}
#[inline]
pub fn bits(&self) -> <Kmer<K> as KmerBits>::Storage {
self.bits
}
#[inline]
pub fn as_u64(&self) -> u64 {
if K > 31 {
panic!("Cannot convert Kmer<{}> to u64, use as_u128()", K);
}
unsafe { *(&self.bits as *const _ as *const u64) }
}
#[inline]
pub fn as_u128(&self) -> u128 {
if K <= 31 {
unsafe { *(&self.bits as *const _ as *const u64) as u128 }
} else {
unsafe { *(&self.bits as *const _ as *const u128) }
}
}
#[inline]
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Result<Self, EncodingError> {
<Self as std::str::FromStr>::from_str(s)
}
#[inline]
pub fn from_string(s: &str) -> Result<Self, EncodingError> {
Self::from_str(s)
}
#[inline]
pub fn from_ascii_unchecked(bytes: &[u8]) -> Self {
debug_assert_eq!(bytes.len(), K);
if K <= 31 {
let mut bits: u64 = 0;
for (i, &b) in bytes.iter().enumerate() {
bits |= (((b >> 1) & 3) as u64) << (i * 2);
}
Self { bits: <Kmer<K> as KmerBits>::from_u64(bits) }
} else {
let mut bits: u128 = 0;
for (i, &b) in bytes.iter().enumerate() {
bits |= (((b >> 1) & 3) as u128) << (i * 2);
}
Self { bits: <Kmer<K> as KmerBits>::from_u128(bits) }
}
}
#[inline]
pub fn reverse_complement(&self) -> Self {
if K <= 31 {
let mut x = <Kmer<K> as KmerBits>::to_u64(self.bits);
x ^= 0xAAAA_AAAA_AAAA_AAAAu64;
x = ((x >> 2) & 0x3333_3333_3333_3333u64) | ((x & 0x3333_3333_3333_3333u64) << 2);
x = ((x >> 4) & 0x0F0F_0F0F_0F0F_0F0Fu64) | ((x & 0x0F0F_0F0F_0F0F_0F0Fu64) << 4);
x = x.swap_bytes();
x >>= 64 - K * 2;
Self { bits: <Kmer<K> as KmerBits>::from_u64(x) }
} else {
let mut x = <Kmer<K> as KmerBits>::to_u128(self.bits);
x ^= 0xAAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAAu128;
x = ((x >> 2) & 0x3333_3333_3333_3333_3333_3333_3333_3333u128)
| ((x & 0x3333_3333_3333_3333_3333_3333_3333_3333u128) << 2);
x = ((x >> 4) & 0x0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0Fu128)
| ((x & 0x0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0Fu128) << 4);
x = x.swap_bytes();
x >>= 128 - K * 2;
Self { bits: <Kmer<K> as KmerBits>::from_u128(x) }
}
}
pub fn canonical(&self) -> Self {
let rc = self.reverse_complement();
if self.bits < rc.bits {
*self
} else {
rc
}
}
pub fn get_base(&self, pos: usize) -> u8 {
assert!(pos < K, "Position {} out of bounds for k-mer of length {}", pos, K);
let shift = pos * 2;
<Kmer<K> as KmerBits>::to_u8(
<Kmer<K> as KmerBits>::bitand(
<Kmer<K> as KmerBits>::shr(self.bits, shift),
<Kmer<K> as KmerBits>::from_u8(0b11u8)
)
)
}
pub fn set_base(&mut self, pos: usize, base: u8) {
assert!(pos < K, "Position {} out of bounds for k-mer of length {}", pos, K);
assert!(base <= 0b11, "Base value must be 0-3");
let shift = pos * 2;
let mask = <Kmer<K> as KmerBits>::bitnot(
<Kmer<K> as KmerBits>::shl(
<Kmer<K> as KmerBits>::from_u8(0x3u8),
shift
)
);
self.bits = <Kmer<K> as KmerBits>::bitand(self.bits, mask);
let new_bits = <Kmer<K> as KmerBits>::shl(
<Kmer<K> as KmerBits>::from_u8(base),
shift
);
self.bits = <Kmer<K> as KmerBits>::bitor(self.bits, new_bits);
}
#[inline]
pub fn empty() -> Self {
Self {
bits: <Kmer<K> as KmerBits>::from_u8(0),
}
}
#[inline]
pub fn append_base(self, base: u8) -> Self {
assert!(base <= 0b11, "Base value must be 0-3");
let shifted = <Kmer<K> as KmerBits>::shl(self.bits, 2);
let mask_bits = 2 * K;
let mask = if mask_bits >= <Kmer<K> as KmerBits>::BITS {
<Kmer<K> as KmerBits>::from_u8(0xFF) } else {
<Kmer<K> as KmerBits>::from_u64((1u64 << mask_bits) - 1)
};
let masked = <Kmer<K> as KmerBits>::bitand(shifted, mask);
let new_bits = <Kmer<K> as KmerBits>::bitor(
masked,
<Kmer<K> as KmerBits>::from_u8(base)
);
Self { bits: new_bits }
}
#[inline]
pub fn write_ascii(&self, out: &mut [u8]) {
debug_assert!(out.len() >= K);
let mut bits = self.bits;
for slot in out.iter_mut().take(K) {
let base_bits = <Kmer<K> as KmerBits>::to_u8(
<Kmer<K> as KmerBits>::bitand(
bits,
<Kmer<K> as KmerBits>::from_u8(0b11u8),
),
);
*slot = decode_base(base_bits);
bits = <Kmer<K> as KmerBits>::shr(bits, 2);
}
}
#[inline]
pub fn roll_right_base(self, new_base: u8) -> Self {
debug_assert!(new_base <= 0b11);
let shifted = <Kmer<K> as KmerBits>::shr(self.bits, 2);
let placed = <Kmer<K> as KmerBits>::shl(
<Kmer<K> as KmerBits>::from_u8(new_base),
2 * (K - 1),
);
Self { bits: <Kmer<K> as KmerBits>::bitor(shifted, placed) }
}
}
impl<const K: usize> PartialEq for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn eq(&self, other: &Self) -> bool {
self.bits == other.bits
}
}
impl<const K: usize> Eq for Kmer<K> where Kmer<K>: KmerBits {}
impl<const K: usize> PartialOrd for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<const K: usize> Ord for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.bits.cmp(&other.bits)
}
}
impl<const K: usize> Hash for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn hash<H: Hasher>(&self, state: &mut H) {
self.bits.hash(state);
}
}
impl<const K: usize> fmt::Debug for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Kmer<{}>(\"{}\")", K, self)
}
}
impl<const K: usize> fmt::Display for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut bits = self.bits;
for _ in 0..K {
let base_bits = <Kmer<K> as KmerBits>::to_u8(
<Kmer<K> as KmerBits>::bitand(
bits,
<Kmer<K> as KmerBits>::from_u8(0b11u8)
)
);
write!(f, "{}", decode_base(base_bits) as char)?;
bits = <Kmer<K> as KmerBits>::shr(bits, 2);
}
Ok(())
}
}
impl<const K: usize> Default for Kmer<K>
where
Kmer<K>: KmerBits,
{
fn default() -> Self {
Self {
bits: <Kmer<K> as KmerBits>::Storage::from(0u8),
}
}
}
impl<const K: usize> std::str::FromStr for Kmer<K>
where
Kmer<K>: KmerBits,
{
type Err = EncodingError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() != K {
return Err(EncodingError::LengthMismatch {
expected: K,
actual: s.len(),
});
}
let mut bits = <Kmer<K> as KmerBits>::from_u8(0);
for (i, &base) in s.as_bytes().iter().enumerate() {
let encoded = encode_base(base)?;
let shifted = <Kmer<K> as KmerBits>::shl(
<Kmer<K> as KmerBits>::from_u8(encoded),
i * 2
);
bits = <Kmer<K> as KmerBits>::bitor(bits, shifted);
}
Ok(Self { bits })
}
}
pub type Kmer31 = Kmer<31>;
pub type Kmer21 = Kmer<21>;
pub type Kmer63 = Kmer<63>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kmer_storage_types() {
assert_eq!(<Kmer<3> as KmerBits>::BITS, 64);
assert_eq!(<Kmer<31> as KmerBits>::BITS, 64);
assert_eq!(std::mem::size_of::<Kmer<31>>(), 8);
assert_eq!(<Kmer<33> as KmerBits>::BITS, 128);
assert_eq!(<Kmer<63> as KmerBits>::BITS, 128);
assert_eq!(std::mem::size_of::<Kmer<63>>(), 16);
}
#[test]
fn test_kmer_from_str() {
let kmer: Kmer<5> = Kmer::from_str("ACGTG").unwrap();
assert_eq!(kmer.to_string(), "ACGTG");
let kmer: Kmer<31> = Kmer::from_str("ACGTACGTACGTACGTACGTACGTACGTACG").unwrap();
assert_eq!(kmer.to_string(), "ACGTACGTACGTACGTACGTACGTACGTACG");
}
#[test]
fn test_kmer_reverse_complement() {
let kmer: Kmer<5> = Kmer::from_str("ACGTG").unwrap();
let rc = kmer.reverse_complement();
assert_eq!(rc.to_string(), "CACGT");
let kmer: Kmer<7> = Kmer::from_str("ACGTACG").unwrap();
let rc = kmer.reverse_complement();
assert_eq!(rc.to_string(), "CGTACGT");
}
#[test]
fn test_roll_right_base() {
let kmer: Kmer<5> = Kmer::from_str("ACGTG").unwrap();
let rolled = kmer.roll_right_base(0);
assert_eq!(rolled.to_string(), "CGTGA");
let kmer: Kmer<7> = Kmer::from_str("ACGTACG").unwrap();
let rolled = kmer.roll_right_base(2);
assert_eq!(rolled.to_string(), "CGTACGT");
}
#[test]
fn test_kmer_canonical() {
let kmer: Kmer<5> = Kmer::from_str("ACGTG").unwrap();
let canon = kmer.canonical();
let rc = kmer.reverse_complement();
assert!(canon == kmer || canon == rc);
assert!(canon.bits <= kmer.bits && canon.bits <= rc.bits);
}
#[test]
fn test_kmer_case_insensitive() {
let lower: Kmer<5> = Kmer::from_str("acgtg").unwrap();
let upper: Kmer<5> = Kmer::from_str("ACGTG").unwrap();
assert_eq!(lower, upper);
}
#[test]
fn test_kmer_length_mismatch() {
let result: Result<Kmer<5>, _> = Kmer::from_str("ACGT");
assert!(result.is_err());
let result: Result<Kmer<5>, _> = Kmer::from_str("ACGTGG");
assert!(result.is_err());
}
#[test]
fn test_kmer_get_set_base() {
let mut kmer: Kmer<5> = Kmer::from_str("AAAAA").unwrap();
assert_eq!(kmer.get_base(0), 0b00);
kmer.set_base(2, 0b10); assert_eq!(kmer.to_string(), "AATAA");
kmer.set_base(4, 0b11); assert_eq!(kmer.to_string(), "AATAG");
}
#[test]
fn test_kmer_ordering() {
let kmer1: Kmer<5> = Kmer::from_str("AAAAA").unwrap();
let kmer2: Kmer<5> = Kmer::from_str("AAAAC").unwrap();
let kmer3: Kmer<5> = Kmer::from_str("TTTTT").unwrap();
assert!(kmer1 < kmer2);
assert!(kmer2 < kmer3);
assert!(kmer1 < kmer3);
}
}