use std::arch::x86_64::*;
use crate::Error;
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_4_bases(packed: u64, lookup: __m128i) -> __m128i {
let mut indices = [0u8; 16];
for (i, v) in indices.iter_mut().take(4).enumerate() {
*v = ((packed >> (i * 4)) & 0b1111) as u8;
}
let index_vec = _mm_loadu_si128(indices.as_ptr() as *const __m128i);
_mm_shuffle_epi8(lookup, index_vec)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_8_bases(packed: u64, lookup: __m128i) -> __m128i {
let mut indices = [0u8; 16];
for (i, v) in indices.iter_mut().take(8).enumerate() {
*v = ((packed >> (i * 4)) & 0b1111) as u8;
}
let index_vec = _mm_loadu_si128(indices.as_ptr() as *const __m128i);
_mm_shuffle_epi8(lookup, index_vec)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_16_bases(packed: u64, lookup: __m128i) -> __m128i {
let mut indices = [0u8; 16];
for (i, v) in indices.iter_mut().enumerate() {
*v = ((packed >> (i * 4)) & 0b1111) as u8;
}
let index_vec = _mm_loadu_si128(indices.as_ptr() as *const __m128i);
_mm_shuffle_epi8(lookup, index_vec)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn unpack_32_bases(packed1: u64, packed2: u64, lookup: __m256i) -> __m256i {
let mut indices = [0u8; 32];
let bytes1 = packed1.to_le_bytes();
for i in 0..8 {
let byte = bytes1[i];
indices[i * 2] = byte & 0x0F; indices[i * 2 + 1] = byte >> 4; }
let bytes2 = packed2.to_le_bytes();
for i in 0..8 {
let byte = bytes2[i];
indices[16 + i * 2] = byte & 0x0F;
indices[16 + i * 2 + 1] = byte >> 4;
}
let index_vec = _mm256_loadu_si256(indices.as_ptr() as *const __m256i);
_mm256_shuffle_epi8(lookup, index_vec)
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
unsafe fn process_remainder_4bit(packed: u64, start: usize, end: usize, sequence: &mut Vec<u8>) {
static LOOKUP: [u8; 16] = [
b'A', b'C', b'G', b'T', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N', b'N',
b'N', ];
let count = end - start;
let old_len = sequence.len();
sequence.reserve(count);
let ptr = sequence.as_mut_ptr().add(old_len);
for i in 0..count {
let bits = (packed >> ((start + i) * 4)) & 0b1111;
*ptr.add(i) = LOOKUP[bits as usize];
}
sequence.set_len(old_len + count);
}
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn from_4bit_simd(
packed: u64,
expected_size: usize,
sequence: &mut Vec<u8>,
) -> Result<(), Error> {
if expected_size > 16 {
return Err(Error::InvalidLength(expected_size));
}
sequence.reserve(expected_size);
if expected_size >= 16 {
let lookup = _mm_setr_epi8(
b'A' as i8, b'C' as i8, b'G' as i8, b'T' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, );
let result = unpack_16_bases(packed, lookup);
let mut temp = [0u8; 16];
_mm_storeu_si128(temp.as_mut_ptr() as *mut __m128i, result);
sequence.extend_from_slice(&temp[..expected_size]);
} else if expected_size >= 8 {
let lookup = _mm_setr_epi8(
b'A' as i8, b'C' as i8, b'G' as i8, b'T' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, );
let simd_chunks = expected_size / 8;
for chunk in 0..simd_chunks {
let chunk_data = packed >> (chunk * 32);
let result = unpack_8_bases(chunk_data, lookup);
let mut temp = [0u8; 16];
_mm_storeu_si128(temp.as_mut_ptr() as *mut __m128i, result);
sequence.extend_from_slice(&temp[..8]);
}
let remaining_start = simd_chunks * 8;
process_remainder_4bit(packed, remaining_start, expected_size, sequence);
} else if expected_size >= 4 {
let lookup = _mm_setr_epi8(
b'A' as i8, b'C' as i8, b'G' as i8, b'T' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, );
let simd_chunks = expected_size / 4;
for chunk in 0..simd_chunks {
let chunk_data = packed >> (chunk * 16);
let result = unpack_4_bases(chunk_data, lookup);
let mut temp = [0u8; 16];
_mm_storeu_si128(temp.as_mut_ptr() as *mut __m128i, result);
sequence.extend_from_slice(&temp[..4]);
}
let remaining_start = simd_chunks * 4;
process_remainder_4bit(packed, remaining_start, expected_size, sequence);
} else {
process_remainder_4bit(packed, 0, expected_size, sequence);
}
Ok(())
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline(always)]
pub unsafe fn decode_internal(
ebuf: &[u64],
n_bases: usize,
sequence: &mut Vec<u8>,
) -> Result<(), Error> {
let lookup = _mm256_setr_epi8(
b'A' as i8, b'C' as i8, b'G' as i8, b'T' as i8, b'N' as i8, b'N' as i8, b'N' as i8,
b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8,
b'N' as i8, b'N' as i8, b'A' as i8, b'C' as i8, b'G' as i8, b'T' as i8, b'N' as i8, b'N' as i8, b'N' as i8,
b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8, b'N' as i8,
b'N' as i8, b'N' as i8,
);
let old_len = sequence.len();
sequence.reserve(n_bases);
let mut out_ptr = sequence.as_mut_ptr().add(old_len);
let full_chunks = n_bases / 32;
for i in 0..full_chunks {
let packed1 = ebuf[i * 2];
let packed2 = ebuf[i * 2 + 1];
let result = unpack_32_bases(packed1, packed2, lookup);
_mm256_storeu_si256(out_ptr as *mut __m256i, result);
out_ptr = out_ptr.add(32);
}
let remaining_bases = n_bases % 32;
if remaining_bases > 0 {
let offset = full_chunks * 2;
let packed1 = ebuf.get(offset).copied().unwrap_or(0);
let packed2 = ebuf.get(offset + 1).copied().unwrap_or(0);
let result = unpack_32_bases(packed1, packed2, lookup);
let mut temp = [0u8; 32];
_mm256_storeu_si256(temp.as_mut_ptr() as *mut __m256i, result);
std::ptr::copy_nonoverlapping(temp.as_ptr(), out_ptr, remaining_bases);
}
sequence.set_len(old_len + n_bases);
Ok(())
}
#[cfg(test)]
mod testing {
use super::*;
use crate::as_4bit;
#[test]
fn test_from_4bit_simd_basic() {
let expected = b"ACGT";
let packed = as_4bit(expected).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 4, &mut observed).unwrap();
}
assert_eq!(&observed, expected);
}
#[test]
fn test_from_4bit_simd_with_n() {
let expected = b"ACGN";
let packed = as_4bit(expected).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 4, &mut observed).unwrap();
}
assert_eq!(&observed, expected);
}
#[test]
fn test_from_4bit_simd_max_length() {
let expected = b"ACGTACGTACGTACGT"; let packed = as_4bit(expected).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 16, &mut observed).unwrap();
}
assert_eq!(&observed, expected);
}
#[test]
fn test_various_lengths() {
for len in 1..=16 {
let input = b"ACGTACGTACGTACGT";
let packed = as_4bit(&input[..len]).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, len, &mut observed).unwrap();
}
assert_eq!(&observed, &input[..len], "Failed at length {}", len);
}
}
#[test]
fn test_append() {
let sequence = b"ACGTACGTACGTACGT";
let packed = as_4bit(sequence).unwrap();
let mut observed = Vec::new();
unsafe {
from_4bit_simd(packed, 8, &mut observed).unwrap();
from_4bit_simd(packed, 8, &mut observed).unwrap();
}
let expected = b"ACGTACGTACGTACGT"; assert_eq!(&observed, expected);
}
#[test]
fn test_multi_chunk_decoding() {
let sequence1 = b"ACGTACGTACGTACGT"; let sequence2 = b"TGCATGCATGCATGCA"; let mut ebuf = Vec::new();
let packed1 = as_4bit(sequence1).unwrap();
let packed2 = as_4bit(sequence2).unwrap();
ebuf.push(packed1);
ebuf.push(packed2);
let mut decoded = Vec::new();
unsafe {
decode_internal(&ebuf, 32, &mut decoded).unwrap();
}
let expected: Vec<u8> = sequence1.iter().chain(sequence2.iter()).cloned().collect();
assert_eq!(decoded, expected);
}
#[test]
fn test_partial_last_chunk() {
let sequence = b"ACGTACGTACGTACGTACGT"; let mut ebuf = Vec::new();
let packed1 = as_4bit(&sequence[..16]).unwrap();
let packed2 = as_4bit(&sequence[16..]).unwrap();
ebuf.push(packed1);
ebuf.push(packed2);
let mut decoded = Vec::new();
unsafe {
decode_internal(&ebuf, 20, &mut decoded).unwrap();
}
assert_eq!(&decoded, sequence);
}
}