fastwebsockets/mask.rs
1// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[inline]
16fn unmask_easy(payload: &mut [u8], mask: [u8; 4]) {
17 payload.iter_mut().enumerate().for_each(|(i, v)| {
18 *v ^= mask[i & 3];
19 });
20}
21
22// TODO(@littledivy): Compiler does a good job at auto-vectorizing `unmask_fallback` with
23// -C target-cpu=native. Below is a manual implementation.
24//
25// #[cfg(all(target_arch = "x86_64", feature = "simd"))]
26// #[inline]
27// fn unmask_x86_64(payload: &mut [u8], mask: [u8; 4]) {
28// #[inline]
29// fn sse2(payload: &mut [u8], mask: [u8; 4]) {
30// const ALIGNMENT: usize = 16;
31// unsafe {
32// use std::arch::x86_64::*;
33//
34// let len = payload.len();
35// if len < ALIGNMENT {
36// return unmask_fallback(payload, mask);
37// }
38//
39// let start = len - len % ALIGNMENT;
40//
41// let mut aligned_mask = [0; ALIGNMENT];
42//
43// for j in (0..ALIGNMENT).step_by(4) {
44// aligned_mask[j] = mask[j % 4];
45// aligned_mask[j + 1] = mask[(j % 4) + 1];
46// aligned_mask[j + 2] = mask[(j % 4) + 2];
47// aligned_mask[j + 3] = mask[(j % 4) + 3];
48// }
49//
50// let mask_m = _mm_loadu_si128(aligned_mask.as_ptr() as *const _);
51//
52// for index in (0..start).step_by(ALIGNMENT) {
53// let ptr = payload.as_mut_ptr().add(index);
54// let mut v = _mm_loadu_si128(ptr as *const _);
55// v = _mm_xor_si128(v, mask_m);
56// _mm_storeu_si128(ptr as *mut _, v);
57// }
58//
59// if len != start {
60// unmask_fallback(&mut payload[start..], mask);
61// }
62// }
63// }
64// #[cfg(target_feature = "sse2")]
65// {
66// return sse2(payload, mask);
67// }
68//
69// #[cfg(not(target_feature = "sse2"))]
70// {
71// use core::mem;
72// use std::sync::atomic::AtomicPtr;
73// use std::sync::atomic::Ordering;
74//
75// type FnRaw = *mut ();
76// type FnImpl = unsafe fn(&mut [u8], [u8; 4]);
77//
78// unsafe fn get_impl(input: &mut [u8], mask: [u8; 4]) {
79// let fun = if std::is_x86_feature_detected!("sse2") {
80// sse2
81// } else {
82// unmask_fallback
83// };
84// FN.store(fun as FnRaw, Ordering::Relaxed);
85// (fun)(input, mask);
86// }
87//
88// static FN: AtomicPtr<()> = AtomicPtr::new(get_impl as FnRaw);
89//
90// if payload.len() < 16 {
91// return unmask_fallback(payload, mask);
92// }
93//
94// let fun = FN.load(Ordering::Relaxed);
95// unsafe { mem::transmute::<FnRaw, FnImpl>(fun)(payload, mask) }
96// }
97// }
98
99// Faster version of `unmask_easy()` which operates on 4-byte blocks.
100// https://github.com/snapview/tungstenite-rs/blob/e5efe537b87a6705467043fe44bb220ddf7c1ce8/src/protocol/frame/mask.rs#L23
101//
102// https://godbolt.org/z/EPTYo5jK8
103#[inline]
104fn unmask_fallback(buf: &mut [u8], mask: [u8; 4]) {
105 let mask_u32 = u32::from_ne_bytes(mask);
106
107 let (prefix, words, suffix) = unsafe { buf.align_to_mut::<u32>() };
108 unmask_easy(prefix, mask);
109 let head = prefix.len() & 3;
110 let mask_u32 = if head > 0 {
111 if cfg!(target_endian = "big") {
112 mask_u32.rotate_left(8 * head as u32)
113 } else {
114 mask_u32.rotate_right(8 * head as u32)
115 }
116 } else {
117 mask_u32
118 };
119 for word in words.iter_mut() {
120 *word ^= mask_u32;
121 }
122 unmask_easy(suffix, mask_u32.to_ne_bytes());
123}
124
125/// Unmask a payload using the given 4-byte mask.
126#[inline]
127pub fn unmask(payload: &mut [u8], mask: [u8; 4]) {
128 unmask_fallback(payload, mask)
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_unmask() {
137 let mut payload = [0u8; 33];
138 let mask = [1, 2, 3, 4];
139 unmask(&mut payload, mask);
140 assert_eq!(
141 &payload,
142 &[
143 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
144 1, 2, 3, 4, 1, 2, 3, 4, 1
145 ]
146 );
147 }
148
149 #[test]
150 fn length_variation_unmask() {
151 for len in &[0, 2, 3, 8, 16, 18, 31, 32, 40] {
152 let mut payload = vec![0u8; *len];
153 let mask = [1, 2, 3, 4];
154 unmask(&mut payload, mask);
155
156 let expected = (0..*len).map(|i| (i & 3) as u8 + 1).collect::<Vec<_>>();
157 assert_eq!(payload, expected);
158 }
159 }
160
161 #[test]
162 fn length_variation_unmask_2() {
163 for len in &[0, 2, 3, 8, 16, 18, 31, 32, 40] {
164 let mut payload = vec![0u8; *len];
165 let mask = rand::random::<[u8; 4]>();
166 unmask(&mut payload, mask);
167
168 let expected = (0..*len).map(|i| mask[i & 3]).collect::<Vec<_>>();
169 assert_eq!(payload, expected);
170 }
171 }
172}