Skip to main content

provekit_common/skyscraper/
whir.rs

1// Use the fastest available compress_many for this platform.
2#[cfg(target_arch = "aarch64")]
3use skyscraper::block4::compress_many;
4#[cfg(not(target_arch = "aarch64"))]
5use skyscraper::simple::compress_many;
6use {
7    std::borrow::Cow,
8    whir::{
9        engines::EngineId,
10        hash::{Hash, HashEngine},
11    },
12};
13
14/// Pre-computed `EngineId` for the Skyscraper hash engine.
15///
16/// Derived as `SHA3-256("whir::hash" || "skyscraper")`.
17pub const SKYSCRAPER: EngineId = EngineId::new([
18    0xa5, 0x0d, 0x5e, 0xe2, 0xa3, 0xfc, 0x52, 0xe9, 0x6f, 0x11, 0x10, 0x3c, 0xbb, 0x8a, 0x65, 0xa3,
19    0x77, 0xb5, 0x82, 0xb0, 0xb2, 0xdd, 0x42, 0x1c, 0x66, 0x19, 0x13, 0xe6, 0xa5, 0x63, 0xf8, 0xa1,
20]);
21
22#[derive(Clone, Copy, Debug)]
23pub struct SkyscraperHashEngine;
24
25impl HashEngine for SkyscraperHashEngine {
26    fn name(&self) -> Cow<'_, str> {
27        "skyscraper".into()
28    }
29
30    fn supports_size(&self, size: usize) -> bool {
31        size > 0 && size % 32 == 0
32    }
33
34    fn preferred_batch_size(&self) -> usize {
35        skyscraper::WIDTH_LCM
36    }
37
38    fn hash_many(&self, size: usize, input: &[u8], output: &mut [Hash]) {
39        assert!(
40            self.supports_size(size),
41            "skyscraper: unsupported message size {size} (must be a positive multiple of 32)"
42        );
43
44        let count = output.len();
45        assert_eq!(
46            input.len(),
47            size * count,
48            "skyscraper: input length {} != size {size} * count {count}",
49            input.len()
50        );
51
52        // SAFETY: `output` is `&mut [[u8; 32]]` with `count` elements, so it occupies
53        // exactly `count * 32` contiguous bytes. We reinterpret as a flat `&mut [u8]`
54        // to interface with `compress_many` which operates on byte slices.
55        let out_bytes =
56            unsafe { std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::<u8>(), count * 32) };
57
58        if size == 32 {
59            out_bytes.copy_from_slice(input);
60            return;
61        }
62
63        if size == 64 {
64            compress_many(input, out_bytes);
65            return;
66        }
67
68        // Leaf hashing: left-fold 32-byte chunks, batched across messages
69        // for SIMD throughput. Equivalent to main's SkyscraperCRH::evaluate:
70        //   elements.reduce(compress)
71        // Processes in fixed-size groups to avoid heap allocation.
72        const GROUP: usize = 4 * skyscraper::WIDTH_LCM; // fits in 3 KiB on stack
73        let chunks_per_msg = size / 32;
74        let mut pair_buf = [0u8; GROUP * 64];
75
76        for start in (0..count).step_by(GROUP) {
77            let n = (count - start).min(GROUP);
78            let pairs = &mut pair_buf[..n * 64];
79            let accs = &mut out_bytes[start * 32..(start + n) * 32];
80
81            for i in 0..n {
82                let msg = &input[(start + i) * size..];
83                pairs[i * 64..i * 64 + 32].copy_from_slice(&msg[..32]);
84                pairs[i * 64 + 32..i * 64 + 64].copy_from_slice(&msg[32..64]);
85            }
86            compress_many(pairs, accs);
87
88            for k in 2..chunks_per_msg {
89                for i in 0..n {
90                    let msg = &input[(start + i) * size..];
91                    pairs[i * 64..i * 64 + 32].copy_from_slice(&accs[i * 32..i * 32 + 32]);
92                    pairs[i * 64 + 32..i * 64 + 64].copy_from_slice(&msg[k * 32..k * 32 + 32]);
93                }
94                compress_many(pairs, accs);
95            }
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use {super::*, zerocopy::IntoBytes};
103
104    fn limbs_to_bytes(limbs: [u64; 4]) -> [u8; 32] {
105        let mut out = [0u8; 32];
106        out[0..8].copy_from_slice(&limbs[0].to_le_bytes());
107        out[8..16].copy_from_slice(&limbs[1].to_le_bytes());
108        out[16..24].copy_from_slice(&limbs[2].to_le_bytes());
109        out[24..32].copy_from_slice(&limbs[3].to_le_bytes());
110        out
111    }
112
113    #[test]
114    fn engine_id_matches() {
115        use whir::engines::Engine;
116        assert_eq!(SkyscraperHashEngine.engine_id(), SKYSCRAPER);
117    }
118
119    #[test]
120    fn supports_expected_sizes() {
121        let e = SkyscraperHashEngine;
122        assert!(!e.supports_size(0));
123        assert!(!e.supports_size(1));
124        assert!(!e.supports_size(31));
125        assert!(e.supports_size(32));
126        assert!(e.supports_size(64));
127        assert!(e.supports_size(512));
128        assert!(e.supports_size(1024));
129    }
130
131    #[test]
132    fn two_to_one_matches_simple_compress() {
133        let l: [u64; 4] = [1, 2, 3, 4];
134        let r: [u64; 4] = [5, 6, 7, 8];
135        let expected = skyscraper::simple::compress(l, r);
136
137        let mut input = [0u8; 64];
138        input[0..32].copy_from_slice(&limbs_to_bytes(l));
139        input[32..64].copy_from_slice(&limbs_to_bytes(r));
140
141        let mut output = [Hash::default()];
142        SkyscraperHashEngine.hash_many(64, &input, &mut output);
143
144        assert_eq!(output[0].0, limbs_to_bytes(expected));
145    }
146
147    #[test]
148    fn leaf_hash_matches_fold() {
149        let elems: [[u64; 4]; 4] = [[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]];
150
151        let expected = elems
152            .into_iter()
153            .reduce(skyscraper::simple::compress)
154            .unwrap();
155
156        let mut output = [Hash::default()];
157        SkyscraperHashEngine.hash_many(128, elems.as_bytes(), &mut output);
158
159        assert_eq!(output[0].0, limbs_to_bytes(expected));
160    }
161
162    #[test]
163    fn batch_two_to_one_consistency() {
164        let pairs: [[[u64; 4]; 2]; 3] = [
165            [[1, 2, 3, 4], [5, 6, 7, 8]],
166            [[9, 10, 11, 12], [13, 14, 15, 16]],
167            [[17, 18, 19, 20], [21, 22, 23, 24]],
168        ];
169
170        let mut batch_output = [Hash::default(); 3];
171        SkyscraperHashEngine.hash_many(64, pairs.as_bytes(), &mut batch_output);
172
173        for (i, pair) in pairs.iter().enumerate() {
174            let expected = skyscraper::simple::compress(pair[0], pair[1]);
175            assert_eq!(batch_output[i].0, limbs_to_bytes(expected));
176        }
177    }
178
179    #[test]
180    fn batch_leaf_hash_consistency() {
181        // 3 messages of 16 field elements each (512 bytes per message).
182        // Verify batched result matches per-message scalar reduce(compress).
183        let msgs: [[[u64; 4]; 16]; 3] =
184            std::array::from_fn(|i| std::array::from_fn(|j| [(i * 16 + j + 1) as u64, 0, 0, 0]));
185
186        let mut batch_output = [Hash::default(); 3];
187        SkyscraperHashEngine.hash_many(512, msgs.as_bytes(), &mut batch_output);
188
189        for (i, msg) in msgs.iter().enumerate() {
190            let expected = msg
191                .iter()
192                .copied()
193                .reduce(skyscraper::simple::compress)
194                .unwrap();
195            assert_eq!(batch_output[i].0, limbs_to_bytes(expected));
196        }
197    }
198}