1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//!  Mask flag and key.

use crate::error::FrameError;

/// Payload mask with a 32-bit key.
///
/// `Mask::Skip` is used by server side to skip unmask
/// if mask key equals 0.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mask {
    Key([u8; 4]),
    Skip,
    None,
}

impl Mask {
    /// Read the flag which indicates whether mask is used.
    #[inline]
    pub const fn from_flag(b: u8) -> Result<Self, FrameError> {
        let mask = match b & 0x80 {
            0x80 => Mask::Skip,
            0x00 => Mask::None,
            _ => return Err(FrameError::IllegalMask),
        };
        Ok(mask)
    }

    /// Get the flag byte.
    #[inline]
    pub const fn to_flag(&self) -> u8 {
        use Mask::*;
        match self {
            Key(_) | Skip => 0x80,
            None => 0x00,
        }
    }

    /// Get inner mask key.
    #[inline]
    pub const fn to_key(&self) -> [u8; 4] {
        use Mask::*;
        match self {
            Key(k) => *k,
            Skip => [0u8; 4],
            None => unreachable!(),
        }
    }
}

/// Generate a new random mask key.
#[inline]
pub fn new_mask_key() -> [u8; 4] { rand::random::<[u8; 4]>() }

/// Mask the buffer, byte by byte.
#[inline]
pub fn apply_mask(key: [u8; 4], buf: &mut [u8]) {
    for (i, b) in buf.iter_mut().enumerate() {
        *b ^= key[i & 0x03];
    }
}

/// Mask the buffer, 4 bytes at a time.
#[inline]
pub fn apply_mask4(key: [u8; 4], buf: &mut [u8]) {
    let key4 = u32::from_ne_bytes(key);

    let (prefix, middle, suffix) = unsafe { buf.align_to_mut::<u32>() };

    apply_mask(key, prefix);

    let head = prefix.len() & 3;
    let key4 = if head > 0 {
        if cfg!(target_endian = "big") {
            key4.rotate_left(8 * head as u32)
        } else {
            key4.rotate_right(8 * head as u32)
        }
    } else {
        key4
    };
    for b4 in middle.iter_mut() {
        *b4 ^= key4;
    }

    apply_mask(key4.to_ne_bytes(), suffix);
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn mask_store() {
        for v in [0x00, 0x80] {
            assert_eq!(Mask::from_flag(v).unwrap().to_flag(), v);
        }
    }

    #[test]
    fn mask_byte() {
        let key: [u8; 4] = rand::random();
        let buf: Vec<u8> = std::iter::repeat(rand::random::<u8>()).take(1024).collect();

        assert_eq!(buf.len(), 1024);

        let mut buf2 = buf.clone();
        apply_mask(key, &mut buf2);
        apply_mask(key, &mut buf2);

        assert_eq!(buf, buf2);
    }

    #[test]
    fn mask_byte4() {
        for i in 0..4096 {
            let key: [u8; 4] = rand::random();
            let buf: Vec<u8> = std::iter::repeat(rand::random::<u8>()).take(i).collect();

            assert_eq!(buf.len(), i);

            let mut buf2 = buf.clone();
            apply_mask4(key, &mut buf2);
            apply_mask4(key, &mut buf2);

            assert_eq!(buf, buf2);
        }
    }
}