use crate::Error;
#[cfg(all(target_arch = "aarch64", not(feature = "nosimd")))]
mod aarch64;
#[cfg(all(target_arch = "x86_64", not(feature = "nosimd")))]
mod avx;
mod naive;
#[cfg(all(target_arch = "x86_64", not(feature = "nosimd")))]
mod sse;
#[inline(always)]
pub fn as_2bit(seq: &[u8]) -> Result<u64, Error> {
impl_as_2bit(seq, false)
}
#[inline(always)]
pub fn as_2bit_lossy(seq: &[u8]) -> Result<u64, Error> {
impl_as_2bit(seq, true)
}
#[inline(always)]
fn impl_as_2bit(seq: &[u8], allow_invalid: bool) -> Result<u64, Error> {
#[cfg(all(target_arch = "aarch64", not(feature = "nosimd")))]
if std::arch::is_aarch64_feature_detected!("neon") {
aarch64::as_2bit(seq, allow_invalid)
} else {
naive::as_2bit(seq, allow_invalid)
}
#[cfg(all(target_arch = "x86_64", not(feature = "nosimd")))]
if is_x86_feature_detected!("avx2") {
avx::as_2bit(seq, allow_invalid)
} else if is_x86_feature_detected!("sse2") {
sse::as_2bit(seq, allow_invalid)
} else {
naive::as_2bit(seq, allow_invalid)
}
#[cfg(any(
feature = "nosimd",
all(not(target_arch = "aarch64"), not(target_arch = "x86_64"),)
))]
naive::as_2bit(seq, allow_invalid)
}
#[inline(always)]
pub fn encode_internal(seq: &[u8], ebuf: &mut Vec<u64>, allow_invalid: bool) -> Result<(), Error> {
#[cfg(all(target_arch = "aarch64", not(feature = "nosimd")))]
if std::arch::is_aarch64_feature_detected!("neon") {
aarch64::encode_internal(seq, ebuf, allow_invalid)
} else {
naive::encode_internal(seq, ebuf, allow_invalid)
}
#[cfg(all(target_arch = "x86_64", not(feature = "nosimd")))]
if is_x86_feature_detected!("avx2") {
avx::encode_internal(seq, ebuf, allow_invalid)
} else if is_x86_feature_detected!("sse2") {
sse::encode_internal(seq, ebuf, allow_invalid)
} else {
naive::encode_internal(seq, ebuf, allow_invalid)
}
#[cfg(any(
feature = "nosimd",
all(not(target_arch = "aarch64"), not(target_arch = "x86_64"),)
))]
naive::encode_internal(seq, ebuf, allow_invalid)
}
#[cfg(test)]
mod testing {
use super::*;
#[test]
fn test_as_2bit_valid_sequence() {
let tests = vec![
(b"ACGT", 0b11100100),
(b"AAAA", 0b00000000),
(b"TTTT", 0b11111111),
(b"GGGG", 0b10101010),
(b"CCCC", 0b01010101),
];
for (input, expected) in tests {
assert_eq!(as_2bit(input).unwrap(), expected);
}
}
#[test]
fn test_as_2bit_longer_sequence() {
let test = b"ACTGACTGACTGACTG";
let expected = 0b10110100101101001011010010110100;
assert_eq!(as_2bit(test).unwrap(), expected);
}
#[test]
fn test_as_2bit_alignments() {
let tests = vec![(b"ACTGGAAAATTTTAAGG", 0b1010000011111111000000001010110100)];
for (input, expected) in tests {
assert_eq!(as_2bit(input).unwrap(), expected);
}
}
#[test]
fn test_as_2bit_lowercase() {
assert_eq!(as_2bit(b"acgt").unwrap(), as_2bit(b"ACGT").unwrap());
}
#[test]
fn test_as_2bit_invalid_base() {
let result = as_2bit(b"ACGN");
assert!(matches!(result, Err(Error::InvalidBase(b'N'))));
}
#[test]
fn test_as_2bit_sequence_too_long() {
let long_seq = vec![b'A'; 33];
assert!(matches!(
as_2bit(&long_seq),
Err(Error::SequenceTooLong(33))
));
}
#[test]
fn test_as_2bit_lossy_valid_sequence() {
let tests = vec![
(b"ACGT", 0b11100100),
(b"AAAA", 0b00000000),
(b"TTTT", 0b11111111),
(b"GGGG", 0b10101010),
(b"CCCC", 0b01010101),
];
for (input, expected) in tests {
assert_eq!(as_2bit_lossy(input).unwrap(), expected);
assert_eq!(as_2bit_lossy(input).unwrap(), as_2bit(input).unwrap());
}
}
#[test]
fn test_as_2bit_lossy_invalid_characters() {
assert_eq!(
as_2bit_lossy(b"ACGN").unwrap(),
as_2bit_lossy(b"ACGA").unwrap()
);
assert_eq!(
as_2bit_lossy(b"NNNN").unwrap(),
as_2bit_lossy(b"AAAA").unwrap()
);
assert_eq!(
as_2bit_lossy(b"ACGTNRY").unwrap(),
as_2bit_lossy(b"ACGTAAA").unwrap()
);
}
#[test]
fn test_as_2bit_lossy_mixed_case_and_invalid() {
assert_eq!(
as_2bit_lossy(b"AcGtNrY").unwrap(),
as_2bit_lossy(b"ACGTAAA").unwrap()
);
}
#[test]
fn test_as_2bit_lossy_sequence_too_long() {
let long_seq = vec![b'A'; 33];
assert!(matches!(
as_2bit_lossy(&long_seq),
Err(Error::SequenceTooLong(33))
));
}
}