use super::naive;
use crate::error::Error;
use std::arch::aarch64::*;
#[repr(u8)]
enum NucleotideBits {
A = 0b00,
C = 0b01,
G = 0b10,
T = 0b11,
}
#[repr(align(16))] struct SimdConstants {
zeros: uint8x8_t,
ones: uint8x8_t,
twos: uint8x8_t,
threes: uint8x8_t,
}
impl SimdConstants {
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn new() -> Self {
Self {
zeros: vdup_n_u8(NucleotideBits::A as u8),
ones: vdup_n_u8(NucleotideBits::C as u8),
twos: vdup_n_u8(NucleotideBits::G as u8),
threes: vdup_n_u8(NucleotideBits::T as u8),
}
}
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn create_dual_pattern_mask(chunk: uint8x8_t, upper: u8, lower: u8) -> uint8x8_t {
vorr_u8(
vceq_u8(chunk, vdup_n_u8(upper)),
vceq_u8(chunk, vdup_n_u8(lower)),
)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn set_bits(
c_mask: uint8x8_t,
g_mask: uint8x8_t,
t_mask: uint8x8_t,
constants: &SimdConstants,
) -> uint8x8_t {
let mut result = constants.zeros;
result = vbsl_u8(c_mask, constants.ones, result);
result = vbsl_u8(g_mask, constants.twos, result);
result = vbsl_u8(t_mask, constants.threes, result);
result
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn process_simd_chunk(chunk: uint8x8_t, constants: &SimdConstants) -> uint8x8_t {
let (c_mask, g_mask, t_mask) = (
create_dual_pattern_mask(chunk, b'C', b'c'),
create_dual_pattern_mask(chunk, b'G', b'g'),
create_dual_pattern_mask(chunk, b'T', b't'),
);
set_bits(c_mask, g_mask, t_mask, constants)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn as_2bit(seq: &[u8], allow_invalid: bool) -> Result<u64, Error> {
if seq.len() > 32 {
return Err(Error::SequenceTooLong(seq.len()));
}
if seq.len() < 8 {
return naive::as_2bit(seq, allow_invalid);
}
if !allow_invalid
&& let Some(&invalid) = seq
.iter()
.find(|&&b| !matches!(b, b'A' | b'a' | b'C' | b'c' | b'G' | b'g' | b'T' | b't'))
{
return Err(Error::InvalidBase(invalid));
}
let mut packed = 0u64;
let len = seq.len();
let simd_len = len - (len % 8);
unsafe {
let constants = SimdConstants::new();
for chunk_idx in (0..simd_len).step_by(8) {
let chunk = vld1_u8(seq[chunk_idx..].as_ptr());
let result = process_simd_chunk(chunk, &constants);
let mut temp = [0u8; 8];
vst1_u8(temp.as_mut_ptr(), result);
for (i, &val) in temp.iter().enumerate() {
packed |= (val as u64) << ((chunk_idx + i) * 2);
}
}
for (i, &base) in seq.iter().skip(simd_len).enumerate() {
let bits = match base {
b'A' | b'a' => NucleotideBits::A as u64,
b'C' | b'c' => NucleotideBits::C as u64,
b'G' | b'g' => NucleotideBits::G as u64,
b'T' | b't' => NucleotideBits::T as u64,
_ => NucleotideBits::A as u64, };
packed |= bits << ((simd_len + i) * 2);
}
}
Ok(packed)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
pub unsafe fn encode_16_nucleotides(nucs: uint8x16_t) -> u32 {
let t1 = vshrq_n_u8(nucs, 1);
let t2 = vshrq_n_u8(nucs, 2);
let code = vandq_u8(veorq_u8(t1, t2), vdupq_n_u8(3));
let even = vuzp1q_u8(code, code); let odd = vuzp2q_u8(code, code); let nibbles = vorrq_u8(even, vshlq_n_u8(odd, 2));
let even_b = vuzp1q_u8(nibbles, nibbles); let odd_b = vuzp2q_u8(nibbles, nibbles); let packed = vorrq_u8(even_b, vshlq_n_u8(odd_b, 4));
vgetq_lane_u32(vreinterpretq_u32_u8(packed), 0)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn valid_block(v: uint8x16_t) -> bool {
let lower = vorrq_u8(v, vdupq_n_u8(0x20));
let is_a = vceqq_u8(lower, vdupq_n_u8(b'a'));
let is_c = vceqq_u8(lower, vdupq_n_u8(b'c'));
let is_g = vceqq_u8(lower, vdupq_n_u8(b'g'));
let is_t = vceqq_u8(lower, vdupq_n_u8(b't'));
let ok = vorrq_u8(is_a, vorrq_u8(is_c, vorrq_u8(is_g, is_t)));
vminvq_u8(ok) == 0xFF
}
#[cfg(target_arch = "aarch64")]
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
pub unsafe fn encode_nucleotides_simd(
input: &[u8],
output: &mut [u64],
allow_invalid: bool,
) -> Result<(), Error> {
if input.len() < 32 {
let tail = as_2bit(input, allow_invalid)?;
output[0] = tail;
return Ok(());
}
output.fill(0);
let mut ip = input.as_ptr();
let mut left = input.len();
let mut out = output.as_mut_ptr();
while left >= 32 {
let v0 = vld1q_u8(ip);
let v1 = vld1q_u8(ip.add(16));
if !allow_invalid && (!valid_block(v0) || !valid_block(v1)) {
output.fill(0);
return Err(Error::InvalidBase(*ip));
}
*out = (encode_16_nucleotides(v0) as u64) | ((encode_16_nucleotides(v1) as u64) << 32);
ip = ip.add(32);
left -= 32;
out = out.add(1);
}
if left != 0 {
let mut tail = 0u64;
for i in 0..left {
tail |= match *ip.add(i) | 0x20 {
b'a' => 0u64,
b'c' => 1u64,
b'g' => 2u64,
b't' => 3u64,
_ => {
if allow_invalid {
0u64
} else {
return Err(Error::InvalidBase(*ip.add(i)));
}
}
} << (2 * i);
}
*out = tail;
}
Ok(())
}
#[inline(always)]
pub fn encode_internal(
sequence: &[u8],
ebuf: &mut Vec<u64>,
allow_invalid: bool,
) -> Result<(), Error> {
if sequence.len() < 32 {
let bits = as_2bit(sequence, allow_invalid)?;
ebuf.push(bits);
return Ok(());
}
#[cfg(all(target_arch = "aarch64", not(feature = "nosimd")))]
if std::arch::is_aarch64_feature_detected!("neon") {
unsafe {
let n_chunks = sequence.len().div_ceil(32);
ebuf.resize(n_chunks, 0);
encode_nucleotides_simd(sequence, ebuf, allow_invalid)?;
}
return Ok(());
}
Ok(())
}