#[inline]
pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) {
if mask == [0; 4] {
return;
}
#[cfg(target_arch = "x86_64")]
{
#[cfg(target_feature = "avx2")]
{
unsafe { apply_mask_avx2(buf, mask) };
return;
}
#[cfg(not(target_feature = "avx2"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { apply_mask_avx2(buf, mask) };
return;
}
}
unsafe { apply_mask_sse2(buf, mask) };
return;
}
#[allow(unreachable_code)]
apply_mask_scalar(buf, mask);
}
fn apply_mask_scalar(buf: &mut [u8], mask: [u8; 4]) {
let mask_u32 = u32::from_ne_bytes(mask);
let mask_u64 = u64::from(mask_u32) | (u64::from(mask_u32) << 32);
let (prefix, middle, suffix) = unsafe { buf.align_to_mut::<u64>() };
for (i, byte) in prefix.iter_mut().enumerate() {
*byte ^= mask[i % 4];
}
let offset = prefix.len() % 4;
let aligned_mask = if offset == 0 {
mask_u64
} else {
let rotated: [u8; 8] = [
mask[offset % 4],
mask[(offset + 1) % 4],
mask[(offset + 2) % 4],
mask[(offset + 3) % 4],
mask[offset % 4],
mask[(offset + 1) % 4],
mask[(offset + 2) % 4],
mask[(offset + 3) % 4],
];
u64::from_ne_bytes(rotated)
};
for word in middle.iter_mut() {
*word ^= aligned_mask;
}
let suffix_offset = (prefix.len() + middle.len() * 8) % 4;
for (i, byte) in suffix.iter_mut().enumerate() {
*byte ^= mask[(suffix_offset + i) % 4];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn apply_mask_sse2(buf: &mut [u8], mask: [u8; 4]) {
use std::arch::x86_64::*;
let len = buf.len();
if len < 16 {
apply_mask_scalar(buf, mask);
return;
}
let mask_u32 = u32::from_ne_bytes(mask);
let mask_vec = _mm_set1_epi32(mask_u32 as i32);
let ptr = buf.as_mut_ptr();
let mut i = 0usize;
while i + 16 <= len {
unsafe {
let data = _mm_loadu_si128(ptr.add(i) as *const __m128i);
let masked = _mm_xor_si128(data, mask_vec);
_mm_storeu_si128(ptr.add(i) as *mut __m128i, masked);
}
i += 16;
}
while i < len {
unsafe { *buf.get_unchecked_mut(i) ^= mask[i % 4] };
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn apply_mask_avx2(buf: &mut [u8], mask: [u8; 4]) {
use std::arch::x86_64::*;
let len = buf.len();
if len < 32 {
unsafe { apply_mask_sse2(buf, mask) };
return;
}
let mask_u32 = u32::from_ne_bytes(mask);
let mask_vec = _mm256_set1_epi32(mask_u32 as i32);
let ptr = buf.as_mut_ptr();
let mut i = 0usize;
while i + 32 <= len {
unsafe {
let data = _mm256_loadu_si256(ptr.add(i) as *const __m256i);
let masked = _mm256_xor_si256(data, mask_vec);
_mm256_storeu_si256(ptr.add(i) as *mut __m256i, masked);
}
i += 32;
}
if i + 16 <= len {
unsafe {
let mask_128 = _mm_set1_epi32(mask_u32 as i32);
let data = _mm_loadu_si128(ptr.add(i) as *const __m128i);
let masked = _mm_xor_si128(data, mask_128);
_mm_storeu_si128(ptr.add(i) as *mut __m128i, masked);
}
i += 16;
}
while i < len {
unsafe { *buf.get_unchecked_mut(i) ^= mask[i % 4] };
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip() {
let original = b"Hello, WebSocket!".to_vec();
let mask = [0x37, 0xFA, 0x21, 0x3D];
let mut buf = original.clone();
apply_mask(&mut buf, mask);
assert_ne!(&buf, &original);
apply_mask(&mut buf, mask);
assert_eq!(&buf, &original);
}
#[test]
fn known_answer() {
let mask = [0x37, 0xFA, 0x21, 0x3D];
let mut buf = vec![0x48, 0x65, 0x6C, 0x6C, 0x6F];
apply_mask(&mut buf, mask);
assert_eq!(buf, vec![0x7F, 0x9F, 0x4D, 0x51, 0x58]);
}
#[test]
fn empty_payload() {
let mask = [0x37, 0xFA, 0x21, 0x3D];
let mut buf = vec![];
apply_mask(&mut buf, mask);
assert!(buf.is_empty());
}
#[test]
fn one_byte() {
let mask = [0xAA, 0xBB, 0xCC, 0xDD];
let mut buf = vec![0x55];
apply_mask(&mut buf, mask);
assert_eq!(buf, vec![0x55 ^ 0xAA]);
}
#[test]
fn two_bytes() {
let mask = [0xAA, 0xBB, 0xCC, 0xDD];
let mut buf = vec![0x11, 0x22];
apply_mask(&mut buf, mask);
assert_eq!(buf, vec![0x11 ^ 0xAA, 0x22 ^ 0xBB]);
}
#[test]
fn three_bytes() {
let mask = [0xAA, 0xBB, 0xCC, 0xDD];
let mut buf = vec![0x11, 0x22, 0x33];
apply_mask(&mut buf, mask);
assert_eq!(buf, vec![0x11 ^ 0xAA, 0x22 ^ 0xBB, 0x33 ^ 0xCC]);
}
#[test]
fn exactly_four_bytes() {
let mask = [0xAA, 0xBB, 0xCC, 0xDD];
let original = vec![0x11, 0x22, 0x33, 0x44];
let mut buf = original.clone();
apply_mask(&mut buf, mask);
assert_eq!(
buf,
vec![0x11 ^ 0xAA, 0x22 ^ 0xBB, 0x33 ^ 0xCC, 0x44 ^ 0xDD]
);
apply_mask(&mut buf, mask);
assert_eq!(buf, original);
}
#[test]
fn large_payload_round_trip() {
let mask = [0xDE, 0xAD, 0xBE, 0xEF];
let original: Vec<u8> = (0..4096).map(|i| (i & 0xFF) as u8).collect();
let mut buf = original.clone();
apply_mask(&mut buf, mask);
assert_ne!(&buf, &original);
apply_mask(&mut buf, mask);
assert_eq!(&buf, &original);
}
#[test]
fn zero_mask_is_noop() {
let original = vec![0x48, 0x65, 0x6C, 0x6C, 0x6F];
let mut buf = original.clone();
apply_mask(&mut buf, [0, 0, 0, 0]);
assert_eq!(buf, original);
}
#[test]
fn simd_matches_scalar() {
let mask = [0x12, 0x34, 0x56, 0x78];
let original: Vec<u8> = (0..257).map(|i| (i & 0xFF) as u8).collect();
let mut scalar = original.clone();
apply_mask_scalar(&mut scalar, mask);
let mut dispatch = original;
apply_mask(&mut dispatch, mask);
assert_eq!(scalar, dispatch, "SIMD path must match scalar");
}
}