#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
#[allow(dead_code)]
pub trait SimdTranslate {
#[cfg(target_arch = "x86_64")]
unsafe fn translate_encode(&self, indices: __m128i) -> __m128i;
#[cfg(target_arch = "aarch64")]
unsafe fn translate_encode(&self, indices: uint8x16_t) -> uint8x16_t;
#[cfg(target_arch = "x86_64")]
unsafe fn translate_decode(&self, chars: __m128i) -> Option<__m128i>;
#[cfg(target_arch = "aarch64")]
unsafe fn translate_decode(&self, chars: uint8x16_t) -> Option<uint8x16_t>;
#[cfg(target_arch = "x86_64")]
unsafe fn validate(&self, chars: __m128i) -> bool;
#[cfg(target_arch = "aarch64")]
unsafe fn validate(&self, chars: uint8x16_t) -> bool;
#[cfg(target_arch = "x86_64")]
unsafe fn translate_encode_256(&self, indices: __m256i) -> __m256i;
#[cfg(target_arch = "x86_64")]
unsafe fn translate_decode_256(&self, chars: __m256i) -> Option<__m256i>;
}
#[derive(Debug, Clone, Copy)]
pub struct SequentialTranslate {
start_codepoint: u32,
bits_per_symbol: u8,
}
impl SequentialTranslate {
#[allow(dead_code)]
pub const fn new(start_codepoint: u32, bits_per_symbol: u8) -> Self {
Self {
start_codepoint,
bits_per_symbol,
}
}
#[allow(dead_code)]
pub const fn start_codepoint(&self) -> u32 {
self.start_codepoint
}
#[allow(dead_code)]
pub const fn bits_per_symbol(&self) -> u8 {
self.bits_per_symbol
}
#[allow(dead_code)]
const fn max_index(&self) -> u8 {
(1u8 << self.bits_per_symbol) - 1
}
}
#[cfg(target_arch = "x86_64")]
impl SimdTranslate for SequentialTranslate {
#[target_feature(enable = "ssse3")]
unsafe fn translate_encode(&self, indices: __m128i) -> __m128i {
let offset = _mm_set1_epi8(self.start_codepoint as i8);
_mm_add_epi8(indices, offset)
}
#[target_feature(enable = "ssse3")]
unsafe fn translate_decode(&self, chars: __m128i) -> Option<__m128i> {
let offset = _mm_set1_epi8(self.start_codepoint as i8);
let indices = _mm_sub_epi8(chars, offset);
let max_valid = _mm_set1_epi8(self.max_index() as i8);
let bias = _mm_set1_epi8(-128_i8);
let indices_biased = _mm_add_epi8(indices, bias);
let max_biased = _mm_add_epi8(max_valid, bias);
let too_large = _mm_cmpgt_epi8(indices_biased, max_biased);
let invalid_mask = _mm_movemask_epi8(too_large);
if invalid_mask == 0 {
Some(indices)
} else {
None
}
}
#[target_feature(enable = "ssse3")]
unsafe fn validate(&self, chars: __m128i) -> bool {
let start = _mm_set1_epi8(self.start_codepoint as i8);
let chars_offset = _mm_sub_epi8(chars, start);
let range_size = _mm_set1_epi8((1 << self.bits_per_symbol) as i8);
let bias = _mm_set1_epi8(-128_i8);
let offset_biased = _mm_add_epi8(chars_offset, bias);
let range_biased = _mm_add_epi8(range_size, bias);
let too_large = _mm_cmpgt_epi8(offset_biased, _mm_sub_epi8(range_biased, _mm_set1_epi8(1)));
let invalid_mask = _mm_movemask_epi8(too_large);
invalid_mask == 0
}
#[target_feature(enable = "avx2")]
unsafe fn translate_encode_256(&self, indices: __m256i) -> __m256i {
let offset = _mm256_set1_epi8(self.start_codepoint as i8);
_mm256_add_epi8(indices, offset)
}
#[target_feature(enable = "avx2")]
unsafe fn translate_decode_256(&self, chars: __m256i) -> Option<__m256i> {
let offset = _mm256_set1_epi8(self.start_codepoint as i8);
let indices = _mm256_sub_epi8(chars, offset);
let max_valid = _mm256_set1_epi8(self.max_index() as i8);
let bias = _mm256_set1_epi8(-128_i8);
let indices_biased = _mm256_add_epi8(indices, bias);
let max_biased = _mm256_add_epi8(max_valid, bias);
let too_large = _mm256_cmpgt_epi8(indices_biased, max_biased);
let invalid_mask = _mm256_movemask_epi8(too_large);
if invalid_mask == 0 {
Some(indices)
} else {
None
}
}
}
#[cfg(target_arch = "aarch64")]
impl SimdTranslate for SequentialTranslate {
#[target_feature(enable = "neon")]
unsafe fn translate_encode(&self, indices: uint8x16_t) -> uint8x16_t {
let offset = vdupq_n_u8(self.start_codepoint as u8);
vaddq_u8(indices, offset)
}
#[target_feature(enable = "neon")]
unsafe fn translate_decode(&self, chars: uint8x16_t) -> Option<uint8x16_t> {
let offset = vdupq_n_u8(self.start_codepoint as u8);
let indices = vsubq_u8(chars, offset);
let max_valid = vdupq_n_u8(self.max_index());
let too_large = vcgtq_u8(indices, max_valid);
let invalid = vmaxvq_u8(too_large);
if invalid == 0 { Some(indices) } else { None }
}
#[target_feature(enable = "neon")]
unsafe fn validate(&self, chars: uint8x16_t) -> bool {
let start = vdupq_n_u8(self.start_codepoint as u8);
let chars_offset = vsubq_u8(chars, start);
let max_valid = vdupq_n_u8(self.max_index());
let too_large = vcgtq_u8(chars_offset, max_valid);
let invalid = vmaxvq_u8(too_large);
invalid == 0
}
}
#[cfg(test)]
mod tests {
#[cfg(target_arch = "x86_64")]
use super::{SequentialTranslate, SimdTranslate};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sequential_encode_at_sign() {
if !crate::simd::has_ssse3() {
eprintln!("SSSE3 not available, skipping test");
return;
}
let translator = SequentialTranslate::new(0x40, 6);
unsafe {
let indices = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let chars = translator.translate_encode(indices);
let mut result = [0u8; 16];
_mm_storeu_si128(result.as_mut_ptr() as *mut __m128i, chars);
let expected = [
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C, 0x4D,
0x4E, 0x4F,
];
assert_eq!(result, expected);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sequential_decode_at_sign() {
if !crate::simd::has_ssse3() {
eprintln!("SSSE3 not available, skipping test");
return;
}
let translator = SequentialTranslate::new(0x40, 6);
unsafe {
let chars = _mm_setr_epi8(
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C, 0x4D,
0x4E, 0x4F,
);
let indices = translator
.translate_decode(chars)
.expect("Valid characters");
let mut result = [0u8; 16];
_mm_storeu_si128(result.as_mut_ptr() as *mut __m128i, indices);
let expected = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
assert_eq!(result, expected);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sequential_decode_invalid() {
if !crate::simd::has_ssse3() {
eprintln!("SSSE3 not available, skipping test");
return;
}
let translator = SequentialTranslate::new(0x40, 6);
unsafe {
let chars = _mm_setr_epi8(
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C, 0x4D,
0x4E, -0x7F_i8, );
let result = translator.translate_decode(chars);
assert!(result.is_none(), "Should reject invalid characters");
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sequential_validate_valid() {
if !crate::simd::has_ssse3() {
eprintln!("SSSE3 not available, skipping test");
return;
}
let translator = SequentialTranslate::new(0x30, 4);
unsafe {
let chars = _mm_setr_epi8(
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D,
0x3E, 0x3F,
);
assert!(translator.validate(chars));
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sequential_validate_invalid() {
if !crate::simd::has_ssse3() {
eprintln!("SSSE3 not available, skipping test");
return;
}
let translator = SequentialTranslate::new(0x30, 4);
unsafe {
let chars = _mm_setr_epi8(
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D,
0x3E, 0x40, );
assert!(!translator.validate(chars));
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sequential_round_trip() {
if !crate::simd::has_ssse3() {
eprintln!("SSSE3 not available, skipping test");
return;
}
let translator = SequentialTranslate::new(0x41, 6);
unsafe {
let original_indices =
_mm_setr_epi8(0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 63, 0, 0);
let chars = translator.translate_encode(original_indices);
let decoded_indices = translator
.translate_decode(chars)
.expect("Valid round trip");
let mut original = [0u8; 16];
let mut decoded = [0u8; 16];
_mm_storeu_si128(original.as_mut_ptr() as *mut __m128i, original_indices);
_mm_storeu_si128(decoded.as_mut_ptr() as *mut __m128i, decoded_indices);
assert_eq!(original, decoded);
}
}
}