sapling_crypto_ce/circuit/
sha256.rs

1use super::uint32::UInt32;
2use super::multieq::MultiEq;
3use super::boolean::Boolean;
4use bellman::{ConstraintSystem, SynthesisError};
5use bellman::pairing::Engine;
6
7const ROUND_CONSTANTS: [u32; 64] = [
8    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
9    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
10    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
11    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
12    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
13    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
14    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
15    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
16];
17
18const IV: [u32; 8] = [
19    0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
20    0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19
21];
22
23pub fn sha256_block_no_padding<E, CS>(
24    mut cs: CS,
25    input: &[Boolean]
26) -> Result<Vec<Boolean>, SynthesisError>
27    where E: Engine, CS: ConstraintSystem<E>
28{
29    assert_eq!(input.len(), 512);
30
31    Ok(sha256_compression_function(
32        &mut cs,
33        &input,
34        &get_sha256_iv()
35    )?
36    .into_iter()
37    .flat_map(|e| e.into_bits_be())
38    .collect())
39}
40
41pub fn sha256<E, CS>(
42    mut cs: CS,
43    input: &[Boolean]
44) -> Result<Vec<Boolean>, SynthesisError>
45    where E: Engine, CS: ConstraintSystem<E>
46{
47    assert!(input.len() % 8 == 0);
48
49    let mut padded = input.to_vec();
50    let plen = padded.len() as u64;
51    // append a single '1' bit
52    padded.push(Boolean::constant(true));
53    // append K '0' bits, where K is the minimum number >= 0 such that L + 1 + K + 64 is a multiple of 512
54    while (padded.len() + 64) % 512 != 0 {
55        padded.push(Boolean::constant(false));
56    }
57    // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits
58    for b in (0..64).rev().map(|i| (plen >> i) & 1 == 1) {
59        padded.push(Boolean::constant(b));
60    }
61    assert!(padded.len() % 512 == 0);
62
63    let mut cur = get_sha256_iv();
64    for (i, block) in padded.chunks(512).enumerate() {
65        cur = sha256_compression_function(
66            cs.namespace(|| format!("block {}", i)),
67            block,
68            &cur
69        )?;
70    }
71
72    Ok(cur.into_iter()
73    .flat_map(|e| e.into_bits_be())
74    .collect())
75}
76
77pub fn get_sha256_iv() -> Vec<UInt32> {
78    IV.iter().map(|&v| UInt32::constant(v)).collect()
79}
80
81pub fn sha256_compression_function<E, CS>(
82    cs: CS,
83    input: &[Boolean],
84    current_hash_value: &[UInt32]
85) -> Result<Vec<UInt32>, SynthesisError>
86    where E: Engine, CS: ConstraintSystem<E>
87{
88    assert_eq!(input.len(), 512);
89    assert_eq!(current_hash_value.len(), 8);
90
91    let mut w = input.chunks(32)
92                     .map(|e| UInt32::from_bits_be(e))
93                     .collect::<Vec<_>>();
94
95    // We can save some constraints by combining some of
96    // the constraints in different u32 additions
97    let mut cs = MultiEq::new(cs);
98
99    for i in 16..64 {
100        let cs = &mut cs.namespace(|| format!("w extension {}", i));
101
102        // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3)
103        let mut s0 = w[i-15].rotr(7);
104        s0 = s0.xor(
105            cs.namespace(|| "first xor for s0"),
106            &w[i-15].rotr(18)
107        )?;
108        s0 = s0.xor(
109            cs.namespace(|| "second xor for s0"),
110            &w[i-15].shr(3)
111        )?;
112
113        // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10)
114        let mut s1 = w[i-2].rotr(17);
115        s1 = s1.xor(
116            cs.namespace(|| "first xor for s1"),
117            &w[i-2].rotr(19)
118        )?;
119        s1 = s1.xor(
120            cs.namespace(|| "second xor for s1"),
121            &w[i-2].shr(10)
122        )?;
123
124        let tmp = UInt32::addmany(
125            cs.namespace(|| "computation of w[i]"),
126            &[w[i-16].clone(), s0, w[i-7].clone(), s1]
127        )?;
128
129        // w[i] := w[i-16] + s0 + w[i-7] + s1
130        w.push(tmp);
131    }
132
133    assert_eq!(w.len(), 64);
134
135    enum Maybe {
136        Deferred(Vec<UInt32>),
137        Concrete(UInt32)
138    }
139
140    impl Maybe {
141        fn compute<E, CS, M>(
142            self,
143            cs: M,
144            others: &[UInt32]
145        ) -> Result<UInt32, SynthesisError>
146            where E: Engine,
147                  CS: ConstraintSystem<E>,
148                  M: ConstraintSystem<E, Root=MultiEq<E, CS>>
149        {
150            Ok(match self {
151                Maybe::Concrete(ref v) => {
152                    return Ok(v.clone())
153                },
154                Maybe::Deferred(mut v) => {
155                    v.extend(others.into_iter().cloned());
156                    UInt32::addmany(
157                        cs,
158                        &v
159                    )?
160                }
161            })
162        }
163    }
164
165    let mut a = Maybe::Concrete(current_hash_value[0].clone());
166    let mut b = current_hash_value[1].clone();
167    let mut c = current_hash_value[2].clone();
168    let mut d = current_hash_value[3].clone();
169    let mut e = Maybe::Concrete(current_hash_value[4].clone());
170    let mut f = current_hash_value[5].clone();
171    let mut g = current_hash_value[6].clone();
172    let mut h = current_hash_value[7].clone();
173
174    for i in 0..64 {
175        let cs = &mut cs.namespace(|| format!("compression round {}", i));
176
177        // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)
178        let new_e = e.compute(cs.namespace(|| "deferred e computation"), &[])?;
179        let mut s1 = new_e.rotr(6);
180        s1 = s1.xor(
181            cs.namespace(|| "first xor for s1"),
182            &new_e.rotr(11)
183        )?;
184        s1 = s1.xor(
185            cs.namespace(|| "second xor for s1"),
186            &new_e.rotr(25)
187        )?;
188
189        // ch := (e and f) xor ((not e) and g)
190        let ch = UInt32::sha256_ch(
191            cs.namespace(|| "ch"),
192            &new_e,
193            &f,
194            &g
195        )?;
196
197        // temp1 := h + S1 + ch + k[i] + w[i]
198        let temp1 = vec![
199            h.clone(),
200            s1,
201            ch,
202            UInt32::constant(ROUND_CONSTANTS[i]),
203            w[i].clone()
204        ];
205
206        // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)
207        let new_a = a.compute(cs.namespace(|| "deferred a computation"), &[])?;
208        let mut s0 = new_a.rotr(2);
209        s0 = s0.xor(
210            cs.namespace(|| "first xor for s0"),
211            &new_a.rotr(13)
212        )?;
213        s0 = s0.xor(
214            cs.namespace(|| "second xor for s0"),
215            &new_a.rotr(22)
216        )?;
217
218        // maj := (a and b) xor (a and c) xor (b and c)
219        let maj = UInt32::sha256_maj(
220            cs.namespace(|| "maj"),
221            &new_a,
222            &b,
223            &c
224        )?;
225
226        // temp2 := S0 + maj
227        let temp2 = vec![s0, maj];
228
229        /*
230        h := g
231        g := f
232        f := e
233        e := d + temp1
234        d := c
235        c := b
236        b := a
237        a := temp1 + temp2
238        */
239
240        h = g;
241        g = f;
242        f = new_e;
243        e = Maybe::Deferred(temp1.iter().cloned().chain(Some(d)).collect::<Vec<_>>());
244        d = c;
245        c = b;
246        b = new_a;
247        a = Maybe::Deferred(temp1.iter().cloned().chain(temp2.iter().cloned()).collect::<Vec<_>>());
248    }
249
250    /*
251        Add the compressed chunk to the current hash value:
252        h0 := h0 + a
253        h1 := h1 + b
254        h2 := h2 + c
255        h3 := h3 + d
256        h4 := h4 + e
257        h5 := h5 + f
258        h6 := h6 + g
259        h7 := h7 + h
260    */
261
262    let h0 = a.compute(
263        cs.namespace(|| "deferred h0 computation"),
264        &[current_hash_value[0].clone()]
265    )?;
266
267    let h1 = UInt32::addmany(
268        cs.namespace(|| "new h1"),
269        &[current_hash_value[1].clone(), b]
270    )?;
271
272    let h2 = UInt32::addmany(
273        cs.namespace(|| "new h2"),
274        &[current_hash_value[2].clone(), c]
275    )?;
276
277    let h3 = UInt32::addmany(
278        cs.namespace(|| "new h3"),
279        &[current_hash_value[3].clone(), d]
280    )?;
281
282    let h4 = e.compute(
283        cs.namespace(|| "deferred h4 computation"),
284        &[current_hash_value[4].clone()]
285    )?;
286
287    let h5 = UInt32::addmany(
288        cs.namespace(|| "new h5"),
289        &[current_hash_value[5].clone(), f]
290    )?;
291
292    let h6 = UInt32::addmany(
293        cs.namespace(|| "new h6"),
294        &[current_hash_value[6].clone(), g]
295    )?;
296
297    let h7 = UInt32::addmany(
298        cs.namespace(|| "new h7"),
299        &[current_hash_value[7].clone(), h]
300    )?;
301
302    Ok(vec![h0, h1, h2, h3, h4, h5, h6, h7])
303}
304
305#[cfg(test)]
306mod test {
307    use super::*;
308    use circuit::boolean::AllocatedBit;
309    use bellman::pairing::bls12_381::Bls12;
310    use circuit::test::TestConstraintSystem;
311    use rand::{XorShiftRng, SeedableRng, Rng};
312
313    #[test]
314    fn test_blank_hash() {
315        let iv = get_sha256_iv();
316
317        let mut cs = TestConstraintSystem::<Bls12>::new();
318        let mut input_bits: Vec<_> = (0..512).map(|_| Boolean::Constant(false)).collect();
319        input_bits[0] = Boolean::Constant(true);
320        let out = sha256_compression_function(
321            &mut cs,
322            &input_bits,
323            &iv
324        ).unwrap();
325        let out_bits: Vec<_> = out.into_iter().flat_map(|e| e.into_bits_be()).collect();
326
327        assert!(cs.is_satisfied());
328        assert_eq!(cs.num_constraints(), 0);
329
330        let expected = hex!("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
331
332        let mut out = out_bits.into_iter();
333        for b in expected.into_iter() {
334            for i in (0..8).rev() {
335                let c = out.next().unwrap().get_value().unwrap();
336
337                assert_eq!(c, (b >> i) & 1u8 == 1u8);
338            }
339        }
340    }
341
342    #[test]
343    fn test_full_block() {
344        let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
345
346        let iv = get_sha256_iv();
347
348        let mut cs = TestConstraintSystem::<Bls12>::new();
349        let input_bits: Vec<_> = (0..512).map(|i| {
350            Boolean::from(
351                AllocatedBit::alloc(
352                    cs.namespace(|| format!("input bit {}", i)),
353                    Some(rng.gen())
354                ).unwrap()
355            )
356        }).collect();
357
358        sha256_compression_function(
359            cs.namespace(|| "sha256"),
360            &input_bits,
361            &iv
362        ).unwrap();
363
364        assert!(cs.is_satisfied());
365        assert_eq!(cs.num_constraints() - 512, 25840);
366    }
367
368    #[test]
369    fn test_against_vectors() {
370        use sha2::{Sha256, Digest};
371
372        let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
373
374        for input_len in (0..32).chain((32..256).filter(|a| a % 8 == 0))
375        {
376            let mut h = Sha256::new();
377            let data: Vec<u8> = (0..input_len).map(|_| rng.gen()).collect();
378            h.input(&data);
379            let result = h.result();
380            let hash_result = result.as_slice();
381
382            let mut cs = TestConstraintSystem::<Bls12>::new();
383            let mut input_bits = vec![];
384
385            for (byte_i, input_byte) in data.into_iter().enumerate() {
386                for bit_i in (0..8).rev() {
387                    let cs = cs.namespace(|| format!("input bit {} {}", byte_i, bit_i));
388
389                    input_bits.push(AllocatedBit::alloc(cs, Some((input_byte >> bit_i) & 1u8 == 1u8)).unwrap().into());
390                }
391            }
392
393            let r = sha256(&mut cs, &input_bits).unwrap();
394
395            assert!(cs.is_satisfied());
396
397            let mut s = hash_result.as_ref().iter()
398                                            .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1u8 == 1u8));
399
400            for b in r {
401                match b {
402                    Boolean::Is(b) => {
403                        assert!(s.next().unwrap() == b.get_value().unwrap());
404                    },
405                    Boolean::Not(b) => {
406                        assert!(s.next().unwrap() != b.get_value().unwrap());
407                    },
408                    Boolean::Constant(b) => {
409                        assert!(input_len == 0);
410                        assert!(s.next().unwrap() == b);
411                    }
412                }
413            }
414        }
415    }
416}