use std::arch::aarch64::*;
use super::naive;
use crate::Error;
#[repr(u8)]
enum NucleotideBits4 {
A = 0b0000,
C = 0b0001,
G = 0b0010,
T = 0b0011,
N = 0b1111, }
#[repr(align(16))]
struct SimdConstants4 {
zeros: uint8x8_t,
ones: uint8x8_t,
twos: uint8x8_t,
threes: uint8x8_t,
ns: uint8x8_t, }
impl SimdConstants4 {
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn new() -> Self {
Self {
zeros: vdup_n_u8(NucleotideBits4::A as u8),
ones: vdup_n_u8(NucleotideBits4::C as u8),
twos: vdup_n_u8(NucleotideBits4::G as u8),
threes: vdup_n_u8(NucleotideBits4::T as u8),
ns: vdup_n_u8(NucleotideBits4::N as u8),
}
}
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn create_dual_pattern_mask(chunk: uint8x8_t, upper: u8, lower: u8) -> uint8x8_t {
vorr_u8(
vceq_u8(chunk, vdup_n_u8(upper)),
vceq_u8(chunk, vdup_n_u8(lower)),
)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn create_ambiguous_mask(chunk: uint8x8_t) -> uint8x8_t {
let n_mask = create_dual_pattern_mask(chunk, b'N', b'n');
let r_mask = create_dual_pattern_mask(chunk, b'R', b'r'); let y_mask = create_dual_pattern_mask(chunk, b'Y', b'y'); let s_mask = create_dual_pattern_mask(chunk, b'S', b's'); let w_mask = create_dual_pattern_mask(chunk, b'W', b'w'); let k_mask = create_dual_pattern_mask(chunk, b'K', b'k'); let m_mask = create_dual_pattern_mask(chunk, b'M', b'm'); let b_mask = create_dual_pattern_mask(chunk, b'B', b'b'); let d_mask = create_dual_pattern_mask(chunk, b'D', b'd'); let h_mask = create_dual_pattern_mask(chunk, b'H', b'h'); let v_mask = create_dual_pattern_mask(chunk, b'V', b'v');
vorr_u8(
n_mask,
vorr_u8(
r_mask,
vorr_u8(
y_mask,
vorr_u8(
s_mask,
vorr_u8(
w_mask,
vorr_u8(
k_mask,
vorr_u8(
m_mask,
vorr_u8(b_mask, vorr_u8(d_mask, vorr_u8(h_mask, v_mask))),
),
),
),
),
),
),
)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn set_bits_4bit(
c_mask: uint8x8_t,
g_mask: uint8x8_t,
t_mask: uint8x8_t,
n_mask: uint8x8_t,
constants: &SimdConstants4,
) -> uint8x8_t {
let mut result = constants.zeros;
result = vbsl_u8(c_mask, constants.ones, result);
result = vbsl_u8(g_mask, constants.twos, result);
result = vbsl_u8(t_mask, constants.threes, result);
result = vbsl_u8(n_mask, constants.ns, result);
result
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn process_simd_chunk_4bit(chunk: uint8x8_t, constants: &SimdConstants4) -> uint8x8_t {
let c_mask = create_dual_pattern_mask(chunk, b'C', b'c');
let g_mask = create_dual_pattern_mask(chunk, b'G', b'g');
let t_mask = create_dual_pattern_mask(chunk, b'T', b't');
let n_mask = create_ambiguous_mask(chunk);
set_bits_4bit(c_mask, g_mask, t_mask, n_mask, constants)
}
pub fn as_4bit(seq: &[u8]) -> Result<u64, Error> {
if seq.len() > 16 {
return Err(Error::SequenceTooLong(seq.len()));
}
if seq.len() < 8 {
return naive::as_4bit(seq);
}
if let Some(&invalid) = seq.iter().find(|&&b| !is_valid_nucleotide_4bit(b)) {
return Err(Error::InvalidBase(invalid));
}
let mut packed = 0u64;
let len = seq.len();
let simd_len = len - (len % 8);
unsafe {
let constants = SimdConstants4::new();
for chunk_idx in (0..simd_len).step_by(8) {
let chunk = vld1_u8(seq[chunk_idx..].as_ptr());
let result = process_simd_chunk_4bit(chunk, &constants);
let mut temp = [0u8; 8];
vst1_u8(temp.as_mut_ptr(), result);
for (i, &val) in temp.iter().enumerate() {
packed |= (val as u64) << ((chunk_idx + i) * 4);
}
}
for (i, &base) in seq.iter().skip(simd_len).enumerate() {
let bits = match base {
b'A' | b'a' => NucleotideBits4::A as u64,
b'C' | b'c' => NucleotideBits4::C as u64,
b'G' | b'g' => NucleotideBits4::G as u64,
b'T' | b't' => NucleotideBits4::T as u64,
_ => NucleotideBits4::N as u64, };
packed |= bits << ((simd_len + i) * 4);
}
}
Ok(packed)
}
#[inline(always)]
fn is_valid_nucleotide_4bit(base: u8) -> bool {
matches!(
base,
b'A' | b'a'
| b'C'
| b'c'
| b'G'
| b'g'
| b'T'
| b't'
| b'N'
| b'n'
| b'R'
| b'r'
| b'Y'
| b'y'
| b'S'
| b's'
| b'W'
| b'w'
| b'K'
| b'k'
| b'M'
| b'm'
| b'B'
| b'b'
| b'D'
| b'd'
| b'H'
| b'h'
| b'V'
| b'v'
)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
pub unsafe fn encode_8_nucleotides_4bit(nucs: uint8x8_t) -> u32 {
let constants = SimdConstants4::new();
let result = process_simd_chunk_4bit(nucs, &constants);
let mut temp = [0u8; 8];
vst1_u8(temp.as_mut_ptr(), result);
let mut packed = 0u32;
for (i, &val) in temp.iter().enumerate() {
packed |= (val as u32) << (i * 4);
}
packed
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn valid_block_4bit(v: uint8x16_t) -> bool {
let lower = vorrq_u8(v, vdupq_n_u8(0x20));
let is_a = vceqq_u8(lower, vdupq_n_u8(b'a'));
let is_c = vceqq_u8(lower, vdupq_n_u8(b'c'));
let is_g = vceqq_u8(lower, vdupq_n_u8(b'g'));
let is_t = vceqq_u8(lower, vdupq_n_u8(b't'));
let is_n = vceqq_u8(lower, vdupq_n_u8(b'n'));
let is_r = vceqq_u8(lower, vdupq_n_u8(b'r'));
let is_y = vceqq_u8(lower, vdupq_n_u8(b'y'));
let is_s = vceqq_u8(lower, vdupq_n_u8(b's'));
let is_w = vceqq_u8(lower, vdupq_n_u8(b'w'));
let is_k = vceqq_u8(lower, vdupq_n_u8(b'k'));
let is_m = vceqq_u8(lower, vdupq_n_u8(b'm'));
let is_b = vceqq_u8(lower, vdupq_n_u8(b'b'));
let is_d = vceqq_u8(lower, vdupq_n_u8(b'd'));
let is_h = vceqq_u8(lower, vdupq_n_u8(b'h'));
let is_v = vceqq_u8(lower, vdupq_n_u8(b'v'));
let ok = vorrq_u8(
is_a,
vorrq_u8(
is_c,
vorrq_u8(
is_g,
vorrq_u8(
is_t,
vorrq_u8(
is_n,
vorrq_u8(
is_r,
vorrq_u8(
is_y,
vorrq_u8(
is_s,
vorrq_u8(
is_w,
vorrq_u8(
is_k,
vorrq_u8(
is_m,
vorrq_u8(
is_b,
vorrq_u8(is_d, vorrq_u8(is_h, is_v)),
),
),
),
),
),
),
),
),
),
),
),
);
vminvq_u8(ok) == 0xFF
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
pub unsafe fn encode_nucleotides_simd_4bit(input: &[u8], output: &mut [u64]) -> Result<(), Error> {
if input.len() < 16 {
let tail = as_4bit(input)?;
output[0] = tail;
return Ok(());
}
output.fill(0);
let mut ip = input.as_ptr();
let mut left = input.len();
let mut out = output.as_mut_ptr();
while left >= 16 {
let v = vld1q_u8(ip);
if !valid_block_4bit(v) {
return Err(Error::InvalidBase(*ip));
}
let low = vget_low_u8(v);
let high = vget_high_u8(v);
let low_packed = encode_8_nucleotides_4bit(low);
let high_packed = encode_8_nucleotides_4bit(high);
*out = (low_packed as u64) | ((high_packed as u64) << 32);
ip = ip.add(16);
left -= 16;
out = out.add(1);
}
if left != 0 {
let mut tail = 0u64;
for i in 0..left {
tail |= match *ip.add(i) | 0x20 {
b'a' => 0u64,
b'c' => 1u64,
b'g' => 2u64,
b't' => 3u64,
b'n' | b'r' | b'y' | b's' | b'w' | b'k' | b'm' | b'b' | b'd' | b'h' | b'v' => 15u64,
_ => return Err(Error::InvalidBase(*ip.add(i))),
} << (4 * i);
}
*out = tail;
}
Ok(())
}
pub fn encode_internal(sequence: &[u8], ebuf: &mut Vec<u64>) -> Result<(), Error> {
if sequence.len() < 16 {
let bits = naive::as_4bit(sequence)?;
ebuf.clear();
ebuf.push(bits);
return Ok(());
}
unsafe {
let n_chunks = sequence.len().div_ceil(16);
ebuf.resize(n_chunks, 0);
encode_nucleotides_simd_4bit(sequence, ebuf)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_4bit_basic_encoding() {
let tests = vec![
(b"ACGT", 0b0011001000010000), (b"AAAA", 0b0000000000000000),
(b"TTTT", 0b0011001100110011),
(b"NNNN", 0b1111111111111111),
];
for (input, expected) in tests {
assert_eq!(as_4bit(input).unwrap(), expected);
}
}
#[test]
fn test_4bit_ambiguous_bases() {
let seq_with_n = b"ACGN";
let seq_with_r = b"ACGR";
let n_result = as_4bit(seq_with_n).unwrap();
let r_result = as_4bit(seq_with_r).unwrap();
assert_eq!(n_result, r_result);
let last_bits_n = (n_result >> 12) & 0b1111;
let last_bits_r = (r_result >> 12) & 0b1111;
assert_eq!(last_bits_n, 0b1111);
assert_eq!(last_bits_r, 0b1111);
}
#[test]
fn test_4bit_sequence_too_long() {
let long_seq = vec![b'A'; 17];
assert!(matches!(
as_4bit(&long_seq),
Err(Error::SequenceTooLong(17))
));
}
#[test]
fn test_4bit_case_insensitive() {
assert_eq!(as_4bit(b"acgt").unwrap(), as_4bit(b"ACGT").unwrap());
assert_eq!(as_4bit(b"nrys").unwrap(), as_4bit(b"NRYS").unwrap());
}
}