bellpepper/gadgets/
sha256.rs

1//! Circuits for the [SHA-256] hash function and its internal compression
2//! function.
3//!
4//! [SHA-256]: https://tools.ietf.org/html/rfc6234
5
6#![allow(clippy::many_single_char_names)]
7
8use ff::PrimeField;
9
10use super::boolean::Boolean;
11use super::multieq::MultiEq;
12use super::uint32::UInt32;
13use bellpepper_core::{ConstraintSystem, SynthesisError};
14
15#[allow(clippy::unreadable_literal)]
16const ROUND_CONSTANTS: [u32; 64] = [
17    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
18    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
19    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
20    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
21    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
22    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
23    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
24    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
25];
26
27#[allow(clippy::unreadable_literal)]
28const IV: [u32; 8] = [
29    0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
30];
31
32pub fn sha256_block_no_padding<Scalar, CS>(
33    mut cs: CS,
34    input: &[Boolean],
35) -> Result<Vec<Boolean>, SynthesisError>
36where
37    Scalar: PrimeField,
38    CS: ConstraintSystem<Scalar>,
39{
40    assert_eq!(input.len(), 512);
41
42    Ok(
43        sha256_compression_function(&mut cs, input, &get_sha256_iv())?
44            .into_iter()
45            .flat_map(|e| e.into_bits_be())
46            .collect(),
47    )
48}
49
50pub fn sha256<Scalar, CS>(mut cs: CS, input: &[Boolean]) -> Result<Vec<Boolean>, SynthesisError>
51where
52    Scalar: PrimeField,
53    CS: ConstraintSystem<Scalar>,
54{
55    assert!(input.len() % 8 == 0);
56
57    let mut padded = input.to_vec();
58    let plen = padded.len() as u64;
59    // append a single '1' bit
60    padded.push(Boolean::constant(true));
61    // append K '0' bits, where K is the minimum number >= 0 such that L + 1 + K + 64 is a multiple of 512
62    while (padded.len() + 64) % 512 != 0 {
63        padded.push(Boolean::constant(false));
64    }
65    // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits
66    for b in (0..64).rev().map(|i| (plen >> i) & 1 == 1) {
67        padded.push(Boolean::constant(b));
68    }
69    assert!(padded.len() % 512 == 0);
70
71    let mut cur = get_sha256_iv();
72    for (i, block) in padded.chunks(512).enumerate() {
73        cur = sha256_compression_function(cs.namespace(|| format!("block {}", i)), block, &cur)?;
74    }
75
76    Ok(cur.into_iter().flat_map(|e| e.into_bits_be()).collect())
77}
78
79fn get_sha256_iv() -> Vec<UInt32> {
80    IV.iter().map(|&v| UInt32::constant(v)).collect()
81}
82
83pub fn sha256_compression_function<Scalar, CS>(
84    cs: CS,
85    input: &[Boolean],
86    current_hash_value: &[UInt32],
87) -> Result<Vec<UInt32>, SynthesisError>
88where
89    Scalar: PrimeField,
90    CS: ConstraintSystem<Scalar>,
91{
92    assert_eq!(input.len(), 512);
93    assert_eq!(current_hash_value.len(), 8);
94
95    let mut w = input
96        .chunks(32)
97        .map(UInt32::from_bits_be)
98        .collect::<Vec<_>>();
99
100    // We can save some constraints by combining some of
101    // the constraints in different u32 additions
102    let mut cs = MultiEq::new(cs);
103
104    for i in 16..64 {
105        let cs = &mut cs.namespace(|| format!("w extension {}", i));
106
107        // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3)
108        let mut s0 = w[i - 15].rotr(7);
109        s0 = s0.xor(cs.namespace(|| "first xor for s0"), &w[i - 15].rotr(18))?;
110        s0 = s0.xor(cs.namespace(|| "second xor for s0"), &w[i - 15].shr(3))?;
111
112        // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10)
113        let mut s1 = w[i - 2].rotr(17);
114        s1 = s1.xor(cs.namespace(|| "first xor for s1"), &w[i - 2].rotr(19))?;
115        s1 = s1.xor(cs.namespace(|| "second xor for s1"), &w[i - 2].shr(10))?;
116
117        let tmp = UInt32::addmany(
118            cs.namespace(|| "computation of w[i]"),
119            &[w[i - 16].clone(), s0, w[i - 7].clone(), s1],
120        )?;
121
122        // w[i] := w[i-16] + s0 + w[i-7] + s1
123        w.push(tmp);
124    }
125
126    assert_eq!(w.len(), 64);
127
128    enum Maybe {
129        Deferred(Vec<UInt32>),
130        Concrete(UInt32),
131    }
132
133    impl Maybe {
134        fn compute<Scalar, CS, M>(self, cs: M, others: &[UInt32]) -> Result<UInt32, SynthesisError>
135        where
136            Scalar: PrimeField,
137            CS: ConstraintSystem<Scalar>,
138            M: ConstraintSystem<Scalar, Root = MultiEq<Scalar, CS>>,
139        {
140            Ok(match self {
141                Maybe::Concrete(ref v) => return Ok(v.clone()),
142                Maybe::Deferred(mut v) => {
143                    v.extend(others.iter().cloned());
144                    UInt32::addmany(cs, &v)?
145                }
146            })
147        }
148    }
149
150    let mut a = Maybe::Concrete(current_hash_value[0].clone());
151    let mut b = current_hash_value[1].clone();
152    let mut c = current_hash_value[2].clone();
153    let mut d = current_hash_value[3].clone();
154    let mut e = Maybe::Concrete(current_hash_value[4].clone());
155    let mut f = current_hash_value[5].clone();
156    let mut g = current_hash_value[6].clone();
157    let mut h = current_hash_value[7].clone();
158
159    for i in 0..64 {
160        let cs = &mut cs.namespace(|| format!("compression round {}", i));
161
162        // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)
163        let new_e = e.compute(cs.namespace(|| "deferred e computation"), &[])?;
164        let mut s1 = new_e.rotr(6);
165        s1 = s1.xor(cs.namespace(|| "first xor for s1"), &new_e.rotr(11))?;
166        s1 = s1.xor(cs.namespace(|| "second xor for s1"), &new_e.rotr(25))?;
167
168        // ch := (e and f) xor ((not e) and g)
169        let ch = UInt32::sha256_ch(cs.namespace(|| "ch"), &new_e, &f, &g)?;
170
171        // temp1 := h + S1 + ch + k[i] + w[i]
172        let temp1 = vec![
173            h.clone(),
174            s1,
175            ch,
176            UInt32::constant(ROUND_CONSTANTS[i]),
177            w[i].clone(),
178        ];
179
180        // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)
181        let new_a = a.compute(cs.namespace(|| "deferred a computation"), &[])?;
182        let mut s0 = new_a.rotr(2);
183        s0 = s0.xor(cs.namespace(|| "first xor for s0"), &new_a.rotr(13))?;
184        s0 = s0.xor(cs.namespace(|| "second xor for s0"), &new_a.rotr(22))?;
185
186        // maj := (a and b) xor (a and c) xor (b and c)
187        let maj = UInt32::sha256_maj(cs.namespace(|| "maj"), &new_a, &b, &c)?;
188
189        // temp2 := S0 + maj
190        let temp2 = vec![s0, maj];
191
192        /*
193        h := g
194        g := f
195        f := e
196        e := d + temp1
197        d := c
198        c := b
199        b := a
200        a := temp1 + temp2
201        */
202
203        h = g;
204        g = f;
205        f = new_e;
206        e = Maybe::Deferred(temp1.iter().cloned().chain(Some(d)).collect::<Vec<_>>());
207        d = c;
208        c = b;
209        b = new_a;
210        a = Maybe::Deferred(
211            temp1
212                .iter()
213                .cloned()
214                .chain(temp2.iter().cloned())
215                .collect::<Vec<_>>(),
216        );
217    }
218
219    /*
220        Add the compressed chunk to the current hash value:
221        h0 := h0 + a
222        h1 := h1 + b
223        h2 := h2 + c
224        h3 := h3 + d
225        h4 := h4 + e
226        h5 := h5 + f
227        h6 := h6 + g
228        h7 := h7 + h
229    */
230
231    let h0 = a.compute(
232        cs.namespace(|| "deferred h0 computation"),
233        &[current_hash_value[0].clone()],
234    )?;
235
236    let h1 = UInt32::addmany(
237        cs.namespace(|| "new h1"),
238        &[current_hash_value[1].clone(), b],
239    )?;
240
241    let h2 = UInt32::addmany(
242        cs.namespace(|| "new h2"),
243        &[current_hash_value[2].clone(), c],
244    )?;
245
246    let h3 = UInt32::addmany(
247        cs.namespace(|| "new h3"),
248        &[current_hash_value[3].clone(), d],
249    )?;
250
251    let h4 = e.compute(
252        cs.namespace(|| "deferred h4 computation"),
253        &[current_hash_value[4].clone()],
254    )?;
255
256    let h5 = UInt32::addmany(
257        cs.namespace(|| "new h5"),
258        &[current_hash_value[5].clone(), f],
259    )?;
260
261    let h6 = UInt32::addmany(
262        cs.namespace(|| "new h6"),
263        &[current_hash_value[6].clone(), g],
264    )?;
265
266    let h7 = UInt32::addmany(
267        cs.namespace(|| "new h7"),
268        &[current_hash_value[7].clone(), h],
269    )?;
270
271    Ok(vec![h0, h1, h2, h3, h4, h5, h6, h7])
272}
273
274#[cfg(test)]
275mod test {
276    use super::*;
277    use crate::gadgets::boolean::AllocatedBit;
278    use bellpepper_core::test_cs::*;
279    use blstrs::Scalar as Fr;
280    use hex_literal::hex;
281    use rand_core::{RngCore, SeedableRng};
282    use rand_xorshift::XorShiftRng;
283
284    #[test]
285    #[allow(clippy::needless_collect)]
286    fn test_blank_hash() {
287        let iv = get_sha256_iv();
288
289        let mut cs = TestConstraintSystem::<Fr>::new();
290        let mut input_bits: Vec<_> = (0..512).map(|_| Boolean::Constant(false)).collect();
291        input_bits[0] = Boolean::Constant(true);
292        let out = sha256_compression_function(&mut cs, &input_bits, &iv).unwrap();
293        let out_bits = out.into_iter().flat_map(|e| e.into_bits_be());
294
295        assert!(cs.is_satisfied());
296        assert_eq!(cs.num_constraints(), 0);
297
298        let expected = hex!("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
299
300        let mut out = out_bits;
301        for b in expected.iter() {
302            for i in (0..8).rev() {
303                let c = out.next().unwrap().get_value().unwrap();
304
305                assert_eq!(c, (b >> i) & 1u8 == 1u8);
306            }
307        }
308    }
309
310    #[test]
311    fn test_full_block() {
312        let mut rng = XorShiftRng::from_seed([
313            0x59, 0x62, 0xbe, 0x3d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
314            0xbc, 0xe5,
315        ]);
316
317        let iv = get_sha256_iv();
318
319        let mut cs = TestConstraintSystem::<Fr>::new();
320        let input_bits: Vec<_> = (0..512)
321            .map(|i| {
322                Boolean::from(
323                    AllocatedBit::alloc(
324                        cs.namespace(|| format!("input bit {}", i)),
325                        Some(rng.next_u32() % 2 != 0),
326                    )
327                    .unwrap(),
328                )
329            })
330            .collect();
331
332        sha256_compression_function(cs.namespace(|| "sha256"), &input_bits, &iv).unwrap();
333
334        assert!(cs.is_satisfied());
335        assert_eq!(cs.num_constraints() - 512, 25840);
336    }
337
338    #[test]
339    fn test_full_hash() {
340        let mut rng = XorShiftRng::from_seed([
341            0x59, 0x62, 0xbe, 0x3d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
342            0xbc, 0xe5,
343        ]);
344
345        let mut cs = TestConstraintSystem::<Fr>::new();
346        let input_bits: Vec<_> = (0..512)
347            .map(|i| {
348                Boolean::from(
349                    AllocatedBit::alloc(
350                        cs.namespace(|| format!("input bit {}", i)),
351                        Some(rng.next_u32() % 2 != 0),
352                    )
353                    .unwrap(),
354                )
355            })
356            .collect();
357
358        sha256(cs.namespace(|| "sha256"), &input_bits).unwrap();
359
360        assert!(cs.is_satisfied());
361        assert_eq!(cs.num_constraints() - 512, 44874);
362    }
363
364    #[test]
365    fn test_against_vectors() {
366        use sha2::{Digest, Sha256};
367
368        let mut rng = XorShiftRng::from_seed([
369            0x59, 0x62, 0xbe, 0x3d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
370            0xbc, 0xe5,
371        ]);
372
373        for input_len in (0..32).chain((32..256).filter(|a| a % 8 == 0)) {
374            let mut h = Sha256::new();
375            let data: Vec<u8> = (0..input_len).map(|_| rng.next_u32() as u8).collect();
376            h.update(&data);
377            let hash_result = h.finalize();
378
379            let mut cs = TestConstraintSystem::<Fr>::new();
380            let mut input_bits = vec![];
381
382            for (byte_i, input_byte) in data.into_iter().enumerate() {
383                for bit_i in (0..8).rev() {
384                    let cs = cs.namespace(|| format!("input bit {} {}", byte_i, bit_i));
385
386                    input_bits.push(
387                        AllocatedBit::alloc(cs, Some((input_byte >> bit_i) & 1u8 == 1u8))
388                            .unwrap()
389                            .into(),
390                    );
391                }
392            }
393
394            let r = sha256(&mut cs, &input_bits).unwrap();
395
396            assert!(cs.is_satisfied());
397
398            let mut s = hash_result
399                .iter()
400                .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1u8 == 1u8));
401
402            for b in r {
403                match b {
404                    Boolean::Is(b) => {
405                        assert!(s.next().unwrap() == b.get_value().unwrap());
406                    }
407                    Boolean::Not(b) => {
408                        assert!(s.next().unwrap() != b.get_value().unwrap());
409                    }
410                    Boolean::Constant(b) => {
411                        assert!(input_len == 0);
412                        assert!(s.next().unwrap() == b);
413                    }
414                }
415            }
416        }
417    }
418}