use std::arch::aarch64::*;
use crate::Error;
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_4_bases(packed: u64, lookup: uint8x16_t) -> uint8x8_t {
let mut indices = [0u8; 8];
for (i, v) in indices.iter_mut().take(4).enumerate() {
*v = ((packed >> (i * 4)) & 0b1111) as u8;
}
let index_vec = vld1_u8(indices.as_ptr());
vqtbl1_u8(lookup, index_vec) }
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_8_bases(packed: u64, lookup: uint8x16_t) -> uint8x8_t {
let mut indices = [0u8; 8];
for (i, v) in indices.iter_mut().enumerate() {
*v = ((packed >> (i * 4)) & 0b1111) as u8;
}
let index_vec = vld1_u8(indices.as_ptr());
vqtbl1_u8(lookup, index_vec)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_16_bases(packed: u64, lookup: uint8x16_t) -> uint8x16_t {
let mut indices = [0u8; 16];
for (i, v) in indices.iter_mut().enumerate() {
*v = ((packed >> (i * 4)) & 0b1111) as u8;
}
let index_vec = vld1q_u8(indices.as_ptr());
vqtbl1q_u8(lookup, index_vec)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn process_remainder_4bit(packed: u64, start: usize, end: usize, sequence: &mut Vec<u8>) {
static LOOKUP: [u8; 16] = [
b'A', b'C', b'G', b'T', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N',
b'N', ];
let count = end - start;
let old_len = sequence.len();
sequence.reserve(count);
let ptr = sequence.as_mut_ptr().add(old_len);
for i in 0..count {
let bits = (packed >> ((start + i) * 4)) & 0b1111;
*ptr.add(i) = LOOKUP[bits as usize];
}
sequence.set_len(old_len + count);
}
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn from_4bit_simd(
packed: u64,
expected_size: usize,
sequence: &mut Vec<u8>,
) -> Result<(), Error> {
if expected_size > 16 {
return Err(Error::InvalidLength(expected_size));
}
sequence.reserve(expected_size);
if expected_size >= 16 {
let lookup = vld1q_u8(
[
b'A', b'C', b'G', b'T', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', ]
.as_ptr(),
);
let result = unpack_16_bases(packed, lookup);
let mut temp = [0u8; 16];
vst1q_u8(temp.as_mut_ptr(), result);
sequence.extend_from_slice(&temp[..expected_size]);
} else if expected_size >= 8 {
let lookup = vld1q_u8(
[
b'A', b'C', b'G', b'T', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', ]
.as_ptr(),
);
let simd_chunks = expected_size / 8;
for chunk in 0..simd_chunks {
let chunk_data = packed >> (chunk * 32);
let result = unpack_8_bases(chunk_data, lookup);
let mut temp = [0u8; 8];
vst1_u8(temp.as_mut_ptr(), result);
sequence.extend_from_slice(&temp[..8]);
}
let remaining_start = simd_chunks * 8;
process_remainder_4bit(packed, remaining_start, expected_size, sequence);
} else if expected_size >= 4 {
let lookup = vld1q_u8(
[
b'A', b'C', b'G', b'T', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', ]
.as_ptr(),
);
let simd_chunks = expected_size / 4;
for chunk in 0..simd_chunks {
let chunk_data = packed >> (chunk * 16);
let result = unpack_4_bases(chunk_data, lookup);
let mut temp = [0u8; 8];
vst1_u8(temp.as_mut_ptr(), result);
sequence.extend_from_slice(&temp[..4]);
}
let remaining_start = simd_chunks * 4;
process_remainder_4bit(packed, remaining_start, expected_size, sequence);
} else {
process_remainder_4bit(packed, 0, expected_size, sequence);
}
Ok(())
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
pub unsafe fn decode_internal(
ebuf: &[u64],
n_bases: usize,
sequence: &mut Vec<u8>,
) -> Result<(), Error> {
sequence.reserve(n_bases);
let lookup = vld1q_u8(
[
b'A', b'C', b'G', b'T', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', ]
.as_ptr(),
);
let full_chunks = n_bases / 16;
let mut temp = [0u8; 16];
for &chunk in ebuf.iter().take(full_chunks) {
let result = unpack_16_bases(chunk, lookup);
vst1q_u8(temp.as_mut_ptr(), result);
sequence.extend_from_slice(&temp);
}
let remaining_bases = n_bases % 16;
if remaining_bases > 0 {
let last_chunk = ebuf[full_chunks];
let result = unpack_16_bases(last_chunk, lookup);
vst1q_u8(temp.as_mut_ptr(), result);
sequence.extend_from_slice(&temp[..remaining_bases]);
}
Ok(())
}
#[cfg(test)]
mod testing {
use super::*;
use crate::fourbit::as_4bit;
#[test]
fn test_from_4bit_simd_basic() {
let expected = b"ACGT";
let packed = as_4bit(expected).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 4, &mut observed).unwrap();
}
assert_eq!(&observed, expected);
}
#[test]
fn test_from_4bit_simd_with_n() {
let expected = b"ACGN";
let packed = as_4bit(expected).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 4, &mut observed).unwrap();
}
assert_eq!(&observed, expected);
}
#[test]
fn test_from_4bit_simd_max_length() {
let expected = b"ACGTACGTACGTACGT"; let packed = as_4bit(expected).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 16, &mut observed).unwrap();
}
assert_eq!(&observed, expected);
}
#[test]
fn test_various_lengths() {
for len in 1..=16 {
let input = b"ACGTACGTACGTACGT";
let packed = as_4bit(&input[..len]).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, len, &mut observed).unwrap();
}
assert_eq!(&observed, &input[..len], "Failed at length {}", len);
}
}
#[test]
fn test_append() {
let sequence = b"ACGTACGTACGTACGT";
let packed = as_4bit(sequence).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 8, &mut observed).unwrap();
from_4bit_simd(packed, 8, &mut observed).unwrap();
}
let expected = b"ACGTACGTACGTACGT"; assert_eq!(&observed, expected);
}
#[test]
fn test_multi_chunk_decoding() {
let sequence1 = b"ACGTACGTACGTACGT"; let sequence2 = b"TGCATGCATGCATGCA"; let mut ebuf = Vec::new();
let packed1 = as_4bit(sequence1).unwrap();
let packed2 = as_4bit(sequence2).unwrap();
ebuf.push(packed1);
ebuf.push(packed2);
let mut decoded = Vec::new();
unsafe {
decode_internal(&ebuf, 32, &mut decoded).unwrap();
}
let expected: Vec<u8> = sequence1.iter().chain(sequence2.iter()).cloned().collect();
assert_eq!(decoded, expected);
}
#[test]
fn test_partial_last_chunk() {
let sequence = b"ACGTACGTACGTACGTACGT"; let mut ebuf = Vec::new();
let packed1 = as_4bit(&sequence[..16]).unwrap();
let packed2 = as_4bit(&sequence[16..]).unwrap();
ebuf.push(packed1);
ebuf.push(packed2);
let mut decoded = Vec::new();
unsafe {
decode_internal(&ebuf, 20, &mut decoded).unwrap();
}
assert_eq!(&decoded, sequence);
}
}