use crate::bitmask::FastqBitmask;
#[inline]
pub fn lex_block(block: &[u8; 64]) -> u64 {
#[cfg(target_arch = "aarch64")]
{
unsafe { lex_block_newlines_neon(block) }
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { lex_block_newlines_avx2(block) }
} else {
lex_block_newlines_scalar(block)
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
lex_block_newlines_scalar(block)
}
}
#[inline]
#[must_use]
pub fn lex_block_full(block: &[u8; 64]) -> FastqBitmask {
#[cfg(target_arch = "aarch64")]
{
unsafe { lex_block_neon(block) }
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { lex_block_avx2(block) }
} else {
lex_block_scalar(block)
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
lex_block_scalar(block)
}
}
const ENCODE_LUT: [u8; 16] = [
0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, ];
#[allow(dead_code)]
#[inline]
fn lex_block_newlines_scalar(block: &[u8; 64]) -> u64 {
let mut newlines: u64 = 0;
for (i, &byte) in block.iter().enumerate() {
if byte == b'\n' {
newlines |= 1u64 << i;
}
}
newlines
}
#[allow(dead_code)]
#[inline]
pub fn lex_block_scalar(block: &[u8; 64]) -> FastqBitmask {
let mut newlines: u64 = 0;
let mut is_acgt: u64 = 0;
let mut two_bits: u128 = 0;
for (i, &byte) in block.iter().enumerate() {
if byte == b'\n' {
newlines |= 1u64 << i;
}
let upper = byte & 0xDF; if upper == b'A' || upper == b'C' || upper == b'G' || upper == b'T' {
is_acgt |= 1u64 << i;
let enc = u128::from(ENCODE_LUT[((byte >> 1) & 0x0F) as usize]);
two_bits |= enc << (i * 2);
}
}
FastqBitmask { newlines, is_acgt, two_bits }
}
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::uint8x16_t;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn neon_movemask(cmp: uint8x16_t) -> u64 {
use std::arch::aarch64::{
vaddv_u8, vcreate_u8, vget_high_u8, vget_low_u8, vmul_u8, vshrq_n_u8,
};
let shifted = vshrq_n_u8(cmp, 7);
let weights = vcreate_u8(u64::from_le_bytes([1, 2, 4, 8, 16, 32, 64, 128]));
let low = vget_low_u8(shifted);
let high = vget_high_u8(shifted);
let low_sum = u64::from(vaddv_u8(vmul_u8(low, weights)));
let high_sum = u64::from(vaddv_u8(vmul_u8(high, weights)));
low_sum | (high_sum << 8)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn lex_block_newlines_neon(block: &[u8; 64]) -> u64 {
use std::arch::aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8};
unsafe {
let newline = vdupq_n_u8(b'\n');
let mut result: u64 = 0;
for chunk_idx in 0..4 {
let v = vld1q_u8(block.as_ptr().add(chunk_idx * 16));
result |= neon_movemask(vceqq_u8(v, newline)) << (chunk_idx * 16);
}
result
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn lex_block_neon(block: &[u8; 64]) -> FastqBitmask {
use std::arch::aarch64::{
vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vorrq_u8, vqtbl1q_u8, vshrq_n_u8,
};
#[inline]
#[allow(clippy::cast_possible_truncation)]
unsafe fn neon_two_bits(v: uint8x16_t, lut: uint8x16_t) -> u32 {
unsafe {
let idx = vandq_u8(vshrq_n_u8::<1>(v), vdupq_n_u8(0x0F));
let encoded = vqtbl1q_u8(lut, idx);
let bit0_mask = vdupq_n_u8(0x01);
let bit0 = vandq_u8(encoded, bit0_mask);
let bit1 = vandq_u8(vshrq_n_u8::<1>(encoded), bit0_mask);
let bit0_packed = neon_movemask(vceqq_u8(bit0, bit0_mask)) as u32;
let bit1_packed = neon_movemask(vceqq_u8(bit1, bit0_mask)) as u32;
interleave_bits(bit0_packed as u16, bit1_packed as u16)
}
}
unsafe {
let newline_vec = vdupq_n_u8(b'\n');
let case_mask = vdupq_n_u8(0xDF);
let a_vec = vdupq_n_u8(b'A');
let c_vec = vdupq_n_u8(b'C');
let g_vec = vdupq_n_u8(b'G');
let t_vec = vdupq_n_u8(b'T');
let lut = vld1q_u8(ENCODE_LUT.as_ptr());
let mut newlines: u64 = 0;
let mut is_acgt: u64 = 0;
let mut two_bits: u128 = 0;
for chunk_idx in 0..4 {
let v = vld1q_u8(block.as_ptr().add(chunk_idx * 16));
let nl_cmp = vceqq_u8(v, newline_vec);
newlines |= neon_movemask(nl_cmp) << (chunk_idx * 16);
let upper = vandq_u8(v, case_mask);
let acgt_cmp = vorrq_u8(
vorrq_u8(vceqq_u8(upper, a_vec), vceqq_u8(upper, c_vec)),
vorrq_u8(vceqq_u8(upper, g_vec), vceqq_u8(upper, t_vec)),
);
is_acgt |= neon_movemask(acgt_cmp) << (chunk_idx * 16);
let chunk_two_bits = u128::from(neon_two_bits(v, lut));
two_bits |= chunk_two_bits << (chunk_idx * 32);
}
FastqBitmask { newlines, is_acgt, two_bits }
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
const fn movemask_u32(mask: i32) -> u32 {
mask.cast_unsigned()
}
#[cfg(target_arch = "x86_64")]
#[inline]
const fn byte_as_i8(b: u8) -> i8 {
i8::from_ne_bytes([b])
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn lex_block_newlines_avx2(block: &[u8; 64]) -> u64 {
use std::arch::x86_64::{
_mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8, _mm256_set1_epi8,
};
unsafe {
let newline = _mm256_set1_epi8(i8::from_ne_bytes([b'\n']));
let v1 = _mm256_loadu_si256(block.as_ptr().cast());
let v2 = _mm256_loadu_si256(block.as_ptr().add(32).cast());
let nl1 = movemask_u32(_mm256_movemask_epi8(_mm256_cmpeq_epi8(v1, newline)));
let nl2 = movemask_u32(_mm256_movemask_epi8(_mm256_cmpeq_epi8(v2, newline)));
u64::from(nl1) | (u64::from(nl2) << 32)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn lex_block_avx2(block: &[u8; 64]) -> FastqBitmask {
use std::arch::x86_64::{
__m256i, _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8,
_mm256_or_si256, _mm256_set1_epi8, _mm256_shuffle_epi8, _mm256_srli_epi16,
};
#[inline]
unsafe fn avx2_two_bits(v: __m256i, lut: __m256i, one_mask: __m256i) -> u64 {
unsafe {
let idx = _mm256_and_si256(_mm256_srli_epi16(v, 1), _mm256_set1_epi8(0x0F));
let encoded = _mm256_shuffle_epi8(lut, idx);
let bit0 = _mm256_and_si256(encoded, one_mask);
let bit1 = _mm256_and_si256(_mm256_srli_epi16(encoded, 1), one_mask);
let bit0_mask = movemask_u32(_mm256_movemask_epi8(_mm256_cmpeq_epi8(bit0, one_mask)));
let bit1_mask = movemask_u32(_mm256_movemask_epi8(_mm256_cmpeq_epi8(bit1, one_mask)));
interleave_bits_32(bit0_mask, bit1_mask)
}
}
unsafe {
let newline = _mm256_set1_epi8(byte_as_i8(b'\n'));
let case_mask = _mm256_set1_epi8(byte_as_i8(0xDF));
let a_vec = _mm256_set1_epi8(byte_as_i8(b'A'));
let c_vec = _mm256_set1_epi8(byte_as_i8(b'C'));
let g_vec = _mm256_set1_epi8(byte_as_i8(b'G'));
let t_vec = _mm256_set1_epi8(byte_as_i8(b'T'));
let lut = _mm256_loadu_si256([ENCODE_LUT, ENCODE_LUT].as_ptr().cast());
let one_mask = _mm256_set1_epi8(0x01);
let v1 = _mm256_loadu_si256(block.as_ptr().cast());
let v2 = _mm256_loadu_si256(block.as_ptr().add(32).cast());
let nl1 = movemask_u32(_mm256_movemask_epi8(_mm256_cmpeq_epi8(v1, newline)));
let nl2 = movemask_u32(_mm256_movemask_epi8(_mm256_cmpeq_epi8(v2, newline)));
let newlines = u64::from(nl1) | (u64::from(nl2) << 32);
let upper1 = _mm256_and_si256(v1, case_mask);
let upper2 = _mm256_and_si256(v2, case_mask);
let acgt1 = _mm256_or_si256(
_mm256_or_si256(_mm256_cmpeq_epi8(upper1, a_vec), _mm256_cmpeq_epi8(upper1, c_vec)),
_mm256_or_si256(_mm256_cmpeq_epi8(upper1, g_vec), _mm256_cmpeq_epi8(upper1, t_vec)),
);
let acgt2 = _mm256_or_si256(
_mm256_or_si256(_mm256_cmpeq_epi8(upper2, a_vec), _mm256_cmpeq_epi8(upper2, c_vec)),
_mm256_or_si256(_mm256_cmpeq_epi8(upper2, g_vec), _mm256_cmpeq_epi8(upper2, t_vec)),
);
let acgt_mask1 = movemask_u32(_mm256_movemask_epi8(acgt1));
let acgt_mask2 = movemask_u32(_mm256_movemask_epi8(acgt2));
let is_acgt = u64::from(acgt_mask1) | (u64::from(acgt_mask2) << 32);
let tb1 = avx2_two_bits(v1, lut, one_mask);
let tb2 = avx2_two_bits(v2, lut, one_mask);
let two_bits = u128::from(tb1) | (u128::from(tb2) << 64);
FastqBitmask { newlines, is_acgt, two_bits }
}
}
#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
#[inline]
const fn interleave_bits_n<const N: usize>(lo: u64, hi: u64) -> u64 {
let mut result: u64 = 0;
let mut i = 0;
while i < N {
result |= ((lo >> i) & 1) << (i * 2);
result |= ((hi >> i) & 1) << (i * 2 + 1);
i += 1;
}
result
}
#[cfg(target_arch = "aarch64")]
#[expect(clippy::cast_possible_truncation, reason = "16 interleaved bits fit in u32")]
#[inline]
const fn interleave_bits(lo: u16, hi: u16) -> u32 {
interleave_bits_n::<16>(lo as u64, hi as u64) as u32
}
#[cfg(target_arch = "x86_64")]
#[inline]
const fn interleave_bits_32(lo: u32, hi: u32) -> u64 {
interleave_bits_n::<32>(lo as u64, hi as u64)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_block(data: &[u8]) -> [u8; 64] {
let mut block = [0u8; 64];
let len = data.len().min(64);
block[..len].copy_from_slice(&data[..len]);
block
}
#[test]
fn test_scalar_no_newlines() {
let block = make_block(b"ACGTACGTACGTACGTACGTACGTACGTACGT");
let bm = lex_block_scalar(&block);
assert_eq!(bm.newlines, 0);
}
#[test]
fn test_scalar_all_newlines() {
let block = [b'\n'; 64];
let bm = lex_block_scalar(&block);
assert_eq!(bm.newlines, u64::MAX);
}
#[test]
fn test_scalar_known_positions() {
let mut block = [b'A'; 64];
block[0] = b'\n';
block[3] = b'\n';
block[63] = b'\n';
let bm = lex_block_scalar(&block);
assert_eq!(bm.newlines, (1 << 0) | (1 << 3) | (1 << 63));
}
#[test]
fn test_scalar_fastq_record() {
let block = make_block(b"@r1\nACGT\n+\nIIII\n");
let bm = lex_block_scalar(&block);
assert_eq!(bm.newlines, (1 << 3) | (1 << 8) | (1 << 10) | (1 << 15));
}
#[test]
fn test_scalar_acgt_classification() {
let block = make_block(b"ACGTacgt");
let bm = lex_block_scalar(&block);
assert_eq!(bm.is_acgt & 0xFF, 0xFF);
assert_eq!(bm.is_acgt >> 8, 0);
}
#[test]
fn test_scalar_non_acgt() {
let block = make_block(b"@+N\nIIII");
let bm = lex_block_scalar(&block);
assert_eq!(bm.is_acgt & 0xFF, 0);
}
#[test]
fn test_scalar_mixed_acgt() {
let block = make_block(b"@ACNGT\n");
let bm = lex_block_scalar(&block);
assert_eq!(bm.is_acgt & 0x7F, 0b011_0110);
}
#[test]
fn test_scalar_two_bit_encoding() {
let block = make_block(b"ACGT");
let bm = lex_block_scalar(&block);
assert_eq!(bm.two_bits & 0x3, 0);
assert_eq!((bm.two_bits >> 2) & 0x3, 1);
assert_eq!((bm.two_bits >> 4) & 0x3, 2);
assert_eq!((bm.two_bits >> 6) & 0x3, 3);
}
#[test]
fn test_scalar_two_bit_case_insensitive() {
let upper = make_block(b"ACGT");
let lower = make_block(b"acgt");
let bm_upper = lex_block_scalar(&upper);
let bm_lower = lex_block_scalar(&lower);
assert_eq!(bm_upper.two_bits & 0xFF, bm_lower.two_bits & 0xFF);
}
#[test]
fn test_simd_matches_scalar_newlines() {
let mut block = [b'A'; 64];
for (i, byte) in block.iter_mut().enumerate() {
if i % 7 == 0 {
*byte = b'\n';
}
}
let simd_result = lex_block(&block);
let scalar_result = lex_block_scalar(&block).newlines;
assert_eq!(simd_result, scalar_result, "Newline bitmask mismatch");
}
#[test]
fn test_simd_newlines_every_position() {
for pos in 0..64 {
let mut block = [b'X'; 64];
block[pos] = b'\n';
let simd_result = lex_block(&block);
let scalar_result = lex_block_scalar(&block).newlines;
assert_eq!(simd_result, scalar_result, "Mismatch at position {pos}");
}
}
#[test]
fn test_simd_full_matches_scalar_all_fields() {
let block = make_block(b"@read1\nACGTACGTacgtNNNN\n+\nIIIIIIIIIIIIIIII\n");
let simd_result = lex_block_full(&block);
let scalar_result = lex_block_scalar(&block);
assert_eq!(simd_result.newlines, scalar_result.newlines, "newlines mismatch");
assert_eq!(simd_result.is_acgt, scalar_result.is_acgt, "is_acgt mismatch");
assert_eq!(simd_result.two_bits, scalar_result.two_bits, "two_bits mismatch");
}
#[test]
fn test_simd_full_matches_scalar_all_acgt() {
let block = {
let mut b = [0u8; 64];
for (i, byte) in b.iter_mut().enumerate() {
*byte = b"ACGTacgt"[i % 8];
}
b
};
let simd_result = lex_block_full(&block);
let scalar_result = lex_block_scalar(&block);
assert_eq!(simd_result.is_acgt, scalar_result.is_acgt, "is_acgt mismatch");
assert_eq!(simd_result.two_bits, scalar_result.two_bits, "two_bits mismatch");
}
#[test]
fn test_simd_full_matches_scalar_every_position() {
for base in [b'A', b'C', b'G', b'T', b'a', b'c', b'g', b't'] {
for pos in 0..64 {
let mut block = [b'N'; 64];
block[pos] = base;
let simd = lex_block_full(&block);
let scalar = lex_block_scalar(&block);
assert_eq!(
simd.is_acgt, scalar.is_acgt,
"is_acgt mismatch for {base} at pos {pos}"
);
let mask = 0x3u128 << (pos * 2);
assert_eq!(
simd.two_bits & mask,
scalar.two_bits & mask,
"two_bits mismatch for {} at pos {pos}",
base as char,
);
}
}
}
}