use bytemuck::cast;
use wide::{u8x16, u16x8};
pub(crate) const NIBBLE_TO_BASE: [u8; 16] = *b"=ACMGRSVTWYHKDBN";
pub fn decode_packed_sequence_into(packed: &[u8], base_count: usize, dst: &mut Vec<u8>) {
dst.clear();
if base_count == 0 {
return;
}
debug_assert!(
packed.len() >= base_count.div_ceil(2),
"packed buffer too short for {base_count} bases: have {}",
packed.len()
);
dst.reserve(base_count);
let table = u8x16::new(NIBBLE_TO_BASE);
let mask_low_nibble = u8x16::splat(0x0F);
let mask_both_low_nibbles_u16 = u16x8::splat(0x0F0F);
let chunks = packed.chunks_exact(16);
let tail = chunks.remainder();
let mut written: usize = 0;
for chunk in chunks {
let packed_v = u8x16::new(chunk.try_into().unwrap());
let low_nibbles = packed_v & mask_low_nibble;
let as_u16: u16x8 = cast(packed_v);
let shifted: u16x8 = as_u16 >> 4;
let masked = shifted & mask_both_low_nibbles_u16;
let high_nibbles: u8x16 = cast(masked);
let decoded_hi = table.swizzle_relaxed(high_nibbles);
let decoded_lo = table.swizzle_relaxed(low_nibbles);
let first_16 = u8x16::unpack_low(decoded_hi, decoded_lo);
let second_16 = u8x16::unpack_high(decoded_hi, decoded_lo);
let remaining = base_count - written;
let first_16_arr: [u8; 16] = cast(first_16);
let second_16_arr: [u8; 16] = cast(second_16);
if remaining >= 32 {
dst.extend_from_slice(&first_16_arr);
dst.extend_from_slice(&second_16_arr);
written += 32;
} else {
let take_first = remaining.min(16);
dst.extend_from_slice(&first_16_arr[..take_first]);
let take_second = remaining - take_first;
dst.extend_from_slice(&second_16_arr[..take_second]);
return;
}
}
scalar_tail(tail, base_count - written, dst);
}
#[inline]
fn scalar_tail(packed_tail: &[u8], mut remaining: usize, dst: &mut Vec<u8>) {
for &byte in packed_tail {
if remaining == 0 {
break;
}
let hi = NIBBLE_TO_BASE[(byte >> 4) as usize];
dst.push(hi);
remaining -= 1;
if remaining == 0 {
break;
}
let lo = NIBBLE_TO_BASE[(byte & 0x0F) as usize];
dst.push(lo);
remaining -= 1;
}
}
#[cfg(test)]
pub(crate) fn decode_packed_sequence_into_scalar(
packed: &[u8],
base_count: usize,
dst: &mut Vec<u8>,
) {
dst.clear();
dst.reserve(base_count);
scalar_tail(packed, base_count, dst);
}
#[cfg(test)]
mod tests {
use super::*;
fn pack_bases(ascii: &[u8]) -> Vec<u8> {
let mut packed = Vec::with_capacity(ascii.len().div_ceil(2));
for chunk in ascii.chunks(2) {
let hi = base_to_nibble(chunk[0]);
let lo = if chunk.len() == 2 { base_to_nibble(chunk[1]) } else { 0 };
packed.push((hi << 4) | lo);
}
packed
}
fn base_to_nibble(b: u8) -> u8 {
u8::try_from(NIBBLE_TO_BASE.iter().position(|&x| x == b).unwrap())
.expect("nibble index fits in u8 — table is 16 entries")
}
#[test]
fn scalar_and_simd_agree_on_aligned_sequence() {
let ascii: Vec<u8> = (0..256_usize).map(|i| NIBBLE_TO_BASE[i % 16]).collect();
let packed = pack_bases(&ascii);
let mut simd_out = Vec::new();
decode_packed_sequence_into(&packed, ascii.len(), &mut simd_out);
let mut scalar_out = Vec::new();
decode_packed_sequence_into_scalar(&packed, ascii.len(), &mut scalar_out);
assert_eq!(simd_out, ascii);
assert_eq!(scalar_out, ascii);
}
#[test]
fn scalar_and_simd_agree_on_odd_length() {
let ascii: Vec<u8> = b"ACGTACGTNNNNACGTACGTNNNNACGTACGTA".to_vec();
assert_eq!(ascii.len(), 33);
let packed = pack_bases(&ascii);
let mut simd_out = Vec::new();
decode_packed_sequence_into(&packed, ascii.len(), &mut simd_out);
assert_eq!(simd_out, ascii);
}
#[test]
fn short_sequences_under_simd_width() {
for len in [1usize, 2, 15, 16, 17, 31, 32, 33] {
let ascii: Vec<u8> = (0..len).map(|i| NIBBLE_TO_BASE[i % 16]).collect();
let packed = pack_bases(&ascii);
let mut out = Vec::new();
decode_packed_sequence_into(&packed, len, &mut out);
assert_eq!(out, ascii, "mismatch at len={len}");
}
}
#[test]
fn empty_sequence_produces_empty_output() {
let mut out = vec![1, 2, 3];
decode_packed_sequence_into(&[], 0, &mut out);
assert!(out.is_empty(), "empty input should clear the destination");
}
#[test]
fn existing_destination_capacity_is_reused() {
let ascii: Vec<u8> = (0..150).map(|i| NIBBLE_TO_BASE[i % 16]).collect();
let packed = pack_bases(&ascii);
let mut out: Vec<u8> = Vec::with_capacity(200);
let cap_before = out.capacity();
decode_packed_sequence_into(&packed, ascii.len(), &mut out);
let cap_after = out.capacity();
assert_eq!(out, ascii);
assert_eq!(cap_after, cap_before);
}
#[test]
fn realistic_150bp_read() {
let ascii: Vec<u8> = (0..150).map(|i| b"ACGT"[i % 4]).collect();
let packed = pack_bases(&ascii);
let mut out = Vec::new();
decode_packed_sequence_into(&packed, ascii.len(), &mut out);
assert_eq!(out, ascii);
}
#[test]
#[ignore = "perf instrumentation; run with --release --ignored"]
fn bench_simd_decode() {
use std::time::Instant;
let ascii: Vec<u8> = (0..150).map(|i| b"ACGT"[i % 4]).collect();
let packed = pack_bases(&ascii);
let iters: u32 = 5_000_000;
let mut out = Vec::with_capacity(ascii.len());
let start = Instant::now();
for _ in 0..iters {
decode_packed_sequence_into(&packed, ascii.len(), &mut out);
std::hint::black_box(&out);
}
let simd = start.elapsed();
let start = Instant::now();
for _ in 0..iters {
decode_packed_sequence_into_scalar(&packed, ascii.len(), &mut out);
std::hint::black_box(&out);
}
let scalar = start.elapsed();
let simd_ns_per = simd.as_nanos() as f64 / f64::from(iters);
let scalar_ns_per = scalar.as_nanos() as f64 / f64::from(iters);
eprintln!(
"150bp decode × {iters}: simd {simd:?} ({simd_ns_per:.1} ns/rec) scalar {scalar:?} ({scalar_ns_per:.1} ns/rec) speedup {:.2}x",
scalar_ns_per / simd_ns_per,
);
}
#[test]
fn high_nibble_mask_prevents_cross_byte_leak() {
let packed = vec![0xF8, 0x18, 0xF0, 0x01]; let expected: Vec<u8> = b"NTATN==A".to_vec();
assert_eq!(expected[5], b'=');
let mut out = Vec::new();
decode_packed_sequence_into(&packed, 8, &mut out);
assert_eq!(out, expected);
}
}