fastwebsockets/
mask.rs

1// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[inline]
16fn unmask_easy(payload: &mut [u8], mask: [u8; 4]) {
17  payload.iter_mut().enumerate().for_each(|(i, v)| {
18    *v ^= mask[i & 3];
19  });
20}
21
22// TODO(@littledivy): Compiler does a good job at auto-vectorizing `unmask_fallback` with
23// -C target-cpu=native. Below is a manual implementation.
24//
25// #[cfg(all(target_arch = "x86_64", feature = "simd"))]
26// #[inline]
27// fn unmask_x86_64(payload: &mut [u8], mask: [u8; 4]) {
28//   #[inline]
29//   fn sse2(payload: &mut [u8], mask: [u8; 4]) {
30//     const ALIGNMENT: usize = 16;
31//     unsafe {
32//       use std::arch::x86_64::*;
33//
34//       let len = payload.len();
35//       if len < ALIGNMENT {
36//         return unmask_fallback(payload, mask);
37//       }
38//
39//       let start = len - len % ALIGNMENT;
40//
41//       let mut aligned_mask = [0; ALIGNMENT];
42//
43//       for j in (0..ALIGNMENT).step_by(4) {
44//         aligned_mask[j] = mask[j % 4];
45//         aligned_mask[j + 1] = mask[(j % 4) + 1];
46//         aligned_mask[j + 2] = mask[(j % 4) + 2];
47//         aligned_mask[j + 3] = mask[(j % 4) + 3];
48//       }
49//
50//       let mask_m = _mm_loadu_si128(aligned_mask.as_ptr() as *const _);
51//
52//       for index in (0..start).step_by(ALIGNMENT) {
53//         let ptr = payload.as_mut_ptr().add(index);
54//         let mut v = _mm_loadu_si128(ptr as *const _);
55//         v = _mm_xor_si128(v, mask_m);
56//         _mm_storeu_si128(ptr as *mut _, v);
57//       }
58//
59//       if len != start {
60//         unmask_fallback(&mut payload[start..], mask);
61//       }
62//     }
63//   }
64//   #[cfg(target_feature = "sse2")]
65//   {
66//     return sse2(payload, mask);
67//   }
68//
69//   #[cfg(not(target_feature = "sse2"))]
70//   {
71//     use core::mem;
72//     use std::sync::atomic::AtomicPtr;
73//     use std::sync::atomic::Ordering;
74//
75//     type FnRaw = *mut ();
76//     type FnImpl = unsafe fn(&mut [u8], [u8; 4]);
77//
78//     unsafe fn get_impl(input: &mut [u8], mask: [u8; 4]) {
79//       let fun = if std::is_x86_feature_detected!("sse2") {
80//         sse2
81//       } else {
82//         unmask_fallback
83//       };
84//       FN.store(fun as FnRaw, Ordering::Relaxed);
85//       (fun)(input, mask);
86//     }
87//
88//     static FN: AtomicPtr<()> = AtomicPtr::new(get_impl as FnRaw);
89//
90//     if payload.len() < 16 {
91//       return unmask_fallback(payload, mask);
92//     }
93//
94//     let fun = FN.load(Ordering::Relaxed);
95//     unsafe { mem::transmute::<FnRaw, FnImpl>(fun)(payload, mask) }
96//   }
97// }
98
99// Faster version of `unmask_easy()` which operates on 4-byte blocks.
100// https://github.com/snapview/tungstenite-rs/blob/e5efe537b87a6705467043fe44bb220ddf7c1ce8/src/protocol/frame/mask.rs#L23
101//
102// https://godbolt.org/z/EPTYo5jK8
103#[inline]
104fn unmask_fallback(buf: &mut [u8], mask: [u8; 4]) {
105  let mask_u32 = u32::from_ne_bytes(mask);
106
107  let (prefix, words, suffix) = unsafe { buf.align_to_mut::<u32>() };
108  unmask_easy(prefix, mask);
109  let head = prefix.len() & 3;
110  let mask_u32 = if head > 0 {
111    if cfg!(target_endian = "big") {
112      mask_u32.rotate_left(8 * head as u32)
113    } else {
114      mask_u32.rotate_right(8 * head as u32)
115    }
116  } else {
117    mask_u32
118  };
119  for word in words.iter_mut() {
120    *word ^= mask_u32;
121  }
122  unmask_easy(suffix, mask_u32.to_ne_bytes());
123}
124
125/// Unmask a payload using the given 4-byte mask.
126#[inline]
127pub fn unmask(payload: &mut [u8], mask: [u8; 4]) {
128  unmask_fallback(payload, mask)
129}
130
131#[cfg(test)]
132mod tests {
133  use super::*;
134
135  #[test]
136  fn test_unmask() {
137    let mut payload = [0u8; 33];
138    let mask = [1, 2, 3, 4];
139    unmask(&mut payload, mask);
140    assert_eq!(
141      &payload,
142      &[
143        1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
144        1, 2, 3, 4, 1, 2, 3, 4, 1
145      ]
146    );
147  }
148
149  #[test]
150  fn length_variation_unmask() {
151    for len in &[0, 2, 3, 8, 16, 18, 31, 32, 40] {
152      let mut payload = vec![0u8; *len];
153      let mask = [1, 2, 3, 4];
154      unmask(&mut payload, mask);
155
156      let expected = (0..*len).map(|i| (i & 3) as u8 + 1).collect::<Vec<_>>();
157      assert_eq!(payload, expected);
158    }
159  }
160
161  #[test]
162  fn length_variation_unmask_2() {
163    for len in &[0, 2, 3, 8, 16, 18, 31, 32, 40] {
164      let mut payload = vec![0u8; *len];
165      let mask = rand::random::<[u8; 4]>();
166      unmask(&mut payload, mask);
167
168      let expected = (0..*len).map(|i| mask[i & 3]).collect::<Vec<_>>();
169      assert_eq!(payload, expected);
170    }
171  }
172}