1use core::ops::Div;
16use core::ptr;
17
18use byteorder::{ByteOrder, LittleEndian};
19use typenum::consts::{U8, U72, U104, U136, U144, U224, U256, U384, U512};
20use typenum::uint::Unsigned;
21use static_buffer::{FixedBuf, FixedBuffer, StandardPadding};
22
23use Digest;
24
25#[derive(Copy)]
26struct State {
27    hash: [u64; 25],
28    rest: usize,
29    block_size: usize,
30}
31
32impl Clone for State {
33    fn clone(&self) -> Self {
34        State {
35            hash: self.hash,
36            rest: self.rest,
37            block_size: self.block_size,
38        }
39    }
40}
41
42const ROUND_CONSTS: [u64; 24] = [0x0000000000000001,
43                                 0x0000000000008082,
44                                 0x800000000000808a,
45                                 0x8000000080008000,
46                                 0x000000000000808b,
47                                 0x0000000080000001,
48                                 0x8000000080008081,
49                                 0x8000000000008009,
50                                 0x000000000000008a,
51                                 0x0000000000000088,
52                                 0x0000000080008009,
53                                 0x000000008000000a,
54                                 0x000000008000808b,
55                                 0x800000000000008b,
56                                 0x8000000000008089,
57                                 0x8000000000008003,
58                                 0x8000000000008002,
59                                 0x8000000000000080,
60                                 0x000000000000800a,
61                                 0x800000008000000a,
62                                 0x8000000080008081,
63                                 0x8000000000008080,
64                                 0x0000000080000001,
65                                 0x8000000080008008];
66
67impl State {
68    fn init(bits: usize) -> Self {
69        let rate = 1600 - bits * 2;
70        assert!(rate <= 1600 && (rate % 64) == 0);
71        State {
72            hash: [0; 25],
73            rest: 0,
74            block_size: rate / 8,
75        }
76    }
77
78    fn permutation(&mut self) {
79        let mut a: [u64; 25] = self.hash;
80        let mut c: [u64; 5] = [a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20],
81                               a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21],
82                               a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22],
83                               a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23],
84                               a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24]];
85        for i in 0..12 {
86            self.round(i * 2, &mut a, &mut c);
87            self.round(i * 2 + 1, &mut a, &mut c);
88        }
89        self.hash = a;
90    }
91
92    #[inline(always)]
93    fn round(&self, i: usize, a: &mut [u64; 25], c: &mut [u64; 5]) {
94        let d0 = c[4] ^ c[1].rotate_left(1);
95        let d1 = c[0] ^ c[2].rotate_left(1);
96        let d2 = c[1] ^ c[3].rotate_left(1);
97        let d3 = c[2] ^ c[4].rotate_left(1);
98        let d4 = c[3] ^ c[0].rotate_left(1);
99
100        let b0 = a[0] ^ d0;
101        let b10 = (a[1] ^ d1).rotate_left(1);
102        let b20 = (a[2] ^ d2).rotate_left(62);
103        let b5 = (a[3] ^ d3).rotate_left(28);
104        let b15 = (a[4] ^ d4).rotate_left(27);
105        let b16 = (a[5] ^ d0).rotate_left(36);
106        let b1 = (a[6] ^ d1).rotate_left(44);
107        let b11 = (a[7] ^ d2).rotate_left(6);
108        let b21 = (a[8] ^ d3).rotate_left(55);
109        let b6 = (a[9] ^ d4).rotate_left(20);
110        let b7 = (a[10] ^ d0).rotate_left(3);
111        let b17 = (a[11] ^ d1).rotate_left(10);
112        let b2 = (a[12] ^ d2).rotate_left(43);
113        let b12 = (a[13] ^ d3).rotate_left(25);
114        let b22 = (a[14] ^ d4).rotate_left(39);
115        let b23 = (a[15] ^ d0).rotate_left(41);
116        let b8 = (a[16] ^ d1).rotate_left(45);
117        let b18 = (a[17] ^ d2).rotate_left(15);
118        let b3 = (a[18] ^ d3).rotate_left(21);
119        let b13 = (a[19] ^ d4).rotate_left(8);
120        let b14 = (a[20] ^ d0).rotate_left(18);
121        let b24 = (a[21] ^ d1).rotate_left(2);
122        let b9 = (a[22] ^ d2).rotate_left(61);
123        let b19 = (a[23] ^ d3).rotate_left(56);
124        let b4 = (a[24] ^ d4).rotate_left(14);
125
126        a[0] = (b0 ^ ((!b1) & b2)) ^ ROUND_CONSTS[i];
127        c[0] = a[0];
128        a[1] = b1 ^ ((!b2) & b3);
129        c[1] = a[1];
130        a[2] = b2 ^ ((!b3) & b4);
131        c[2] = a[2];
132        a[3] = b3 ^ ((!b4) & b0);
133        c[3] = a[3];
134        a[4] = b4 ^ ((!b0) & b1);
135        c[4] = a[4];
136
137        a[5] = b5 ^ ((!b6) & b7);
138        c[0] ^= a[5];
139        a[6] = b6 ^ ((!b7) & b8);
140        c[1] ^= a[6];
141        a[7] = b7 ^ ((!b8) & b9);
142        c[2] ^= a[7];
143        a[8] = b8 ^ ((!b9) & b5);
144        c[3] ^= a[8];
145        a[9] = b9 ^ ((!b5) & b6);
146        c[4] ^= a[9];
147
148        a[10] = b10 ^ ((!b11) & b12);
149        c[0] ^= a[10];
150        a[11] = b11 ^ ((!b12) & b13);
151        c[1] ^= a[11];
152        a[12] = b12 ^ ((!b13) & b14);
153        c[2] ^= a[12];
154        a[13] = b13 ^ ((!b14) & b10);
155        c[3] ^= a[13];
156        a[14] = b14 ^ ((!b10) & b11);
157        c[4] ^= a[14];
158
159        a[15] = b15 ^ ((!b16) & b17);
160        c[0] ^= a[15];
161        a[16] = b16 ^ ((!b17) & b18);
162        c[1] ^= a[16];
163        a[17] = b17 ^ ((!b18) & b19);
164        c[2] ^= a[17];
165        a[18] = b18 ^ ((!b19) & b15);
166        c[3] ^= a[18];
167        a[19] = b19 ^ ((!b15) & b16);
168        c[4] ^= a[19];
169
170        a[20] = b20 ^ ((!b21) & b22);
171        c[0] ^= a[20];
172        a[21] = b21 ^ ((!b22) & b23);
173        c[1] ^= a[21];
174        a[22] = b22 ^ ((!b23) & b24);
175        c[2] ^= a[22];
176        a[23] = b23 ^ ((!b24) & b20);
177        c[3] ^= a[23];
178        a[24] = b24 ^ ((!b20) & b21);
179        c[4] ^= a[24];
180    }
181
182    fn process(&mut self, data: &[u8]) {
183        let max = self.block_size / 8;
184        for (h, c) in self.hash[0..max].iter_mut().zip(data.chunks(8)) {
185            *h ^= LittleEndian::read_u64(c)
186        }
187
188        self.permutation();
189    }
190}
191
192macro_rules! sha3_impl {
193    ($(#[$attr:meta])* struct $name:ident -> $size:ty, $bsize:ty) => {
194        #[derive(Clone)]
195        pub struct $name {
196            state: State,
197            buffer: FixedBuffer<$bsize>,
198        }
199
200        impl Default for $name {
201            fn default() -> Self {
202                $name {
203                    state: State::init(<$size as Unsigned>::to_usize()),
204                    buffer: FixedBuffer::new(),
205                }
206            }
207        }
208
209        impl Digest for $name {
210            type OutputBits = $size;
211            type OutputBytes = <$size as Div<U8>>::Output;
212
213            type BlockSize = $bsize;
214
215            fn update<T>(&mut self, data: T) where T: AsRef<[u8]> {
216                let state = &mut self.state;
217                self.buffer.input(data.as_ref(), |d| state.process(d));
218            }
219
220            fn result<T>(mut self, mut out: T) where T: AsMut<[u8]> {
221                let mut ret = out.as_mut();
222                assert!(ret.len() >= Self::output_bytes());
223                let state = &mut self.state;
224
225                self.buffer.pad(0b00000110, 0, |d| state.process(d));
226                let buf = self.buffer.full_buffer();
227                let last = buf.len() - 1;
228                buf[last] |= 0b10000000;
229                state.process(buf);
230
231                unsafe {
232                    ptr::copy_nonoverlapping(
233                        state.hash.as_ptr() as *const u8,
234                        ret.as_mut_ptr(),
235                        Self::output_bytes())
236                };
237            }
238        }
239    }
240}
241
242sha3_impl!(
243    struct Sha224 -> U224, U144);
247sha3_impl!(
248    struct Sha256 -> U256, U136);
252sha3_impl!(
253    struct Sha384 -> U384, U104);
257sha3_impl!(
258    struct Sha512 -> U512, U72);