1#![allow(dead_code)]
2#![allow(unused_imports)]
3
4pub use crate::poseidon::{Arity, Poseidon};
5use crate::round_constants::generate_constants;
6use crate::round_numbers::{round_numbers_base, round_numbers_strengthened};
7#[cfg(feature = "abomonation")]
8use abomonation_derive::Abomonation;
9#[cfg(test)]
10use blstrs::Scalar as Fr;
11pub use error::Error;
12use ff::PrimeField;
13use generic_array::GenericArray;
14use serde::{Deserialize, Serialize};
15use std::fmt;
16use trait_set::trait_set;
17
18#[cfg(all(
19 any(feature = "cuda", feature = "opencl"),
20 not(any(
21 feature = "arity2",
22 feature = "arity4",
23 feature = "arity8",
24 feature = "arity11",
25 feature = "arity16",
26 feature = "arity24",
27 feature = "arity36"
28 ))
29))]
30compile_error!("The `cuda` and `opencl` features need at least one arity feature to be set");
31
32#[cfg(all(
33 feature = "strengthened",
34 not(any(feature = "cuda", feature = "opencl"))
35))]
36compile_error!("The `strengthened` feature needs the `cuda` and/or `opencl` feature to be set");
37
38#[cfg(all(
39 any(feature = "cuda", feature = "opencl"),
40 not(any(feature = "bls", feature = "pasta",))
41))]
42compile_error!("The `cuda` and `opencl` features need the `bls` and/or `pasta` feature to be set");
43
44pub mod circuit;
46pub mod circuit2;
47pub mod circuit2_witness;
48pub mod error;
49mod matrix;
50mod mds;
51
52pub mod poseidon;
54mod poseidon_alt;
55mod preprocessing;
56mod round_constants;
57mod round_numbers;
58
59pub mod sponge;
61
62pub mod hash_type;
64
65#[cfg(any(feature = "cuda", feature = "opencl"))]
67pub mod tree_builder;
68
69#[cfg(any(feature = "cuda", feature = "opencl"))]
71pub mod column_tree_builder;
72
73#[cfg(any(feature = "cuda", feature = "opencl"))]
75pub mod batch_hasher;
76
77#[cfg(any(feature = "cuda", feature = "opencl"))]
78pub mod proteus;
79
80#[cfg(not(any(feature = "cuda", feature = "opencl")))]
81trait_set! {
82 pub trait NeptuneField = PrimeField;
87}
88
89#[cfg(any(feature = "cuda", feature = "opencl"))]
90trait_set! {
91 pub trait NeptuneField = PrimeField + ec_gpu::GpuName;
95}
96
97mod serde_impl;
98
99pub(crate) const TEST_SEED: [u8; 16] = [
100 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5,
101];
102
103#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
104#[cfg_attr(feature = "abomonation", derive(Abomonation))]
105pub enum Strength {
106 Standard,
107 Strengthened,
108}
109
110impl fmt::Display for Strength {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 match self {
113 Self::Standard => write!(f, "standard"),
114 Self::Strengthened => write!(f, "strengthened"),
115 }
116 }
117}
118
119pub(crate) const DEFAULT_STRENGTH: Strength = Strength::Standard;
120
121pub trait BatchHasher<F, A>
122where
123 F: PrimeField,
124 A: Arity<F>,
125{
126 fn hash(&mut self, preimages: &[GenericArray<F, A>]) -> Result<Vec<F>, Error>;
129
130 fn hash_into_slice(
131 &mut self,
132 target_slice: &mut [F],
133 preimages: &[GenericArray<F, A>],
134 ) -> Result<(), Error> {
135 assert_eq!(target_slice.len(), preimages.len());
136 target_slice.copy_from_slice(self.hash(preimages)?.as_slice());
139 Ok(())
140 }
141
142 fn max_batch_size(&self) -> usize {
147 700000
148 }
149}
150
151pub fn round_numbers(arity: usize, strength: &Strength) -> (usize, usize) {
152 match strength {
153 Strength::Standard => round_numbers_base(arity),
154 Strength::Strengthened => round_numbers_strengthened(arity),
155 }
156}
157
158#[cfg(test)]
159pub(crate) fn scalar_from_u64s(parts: [u64; 4]) -> Fr {
160 let mut le_bytes = [0u8; 32];
161 le_bytes[0..8].copy_from_slice(&parts[0].to_le_bytes());
162 le_bytes[8..16].copy_from_slice(&parts[1].to_le_bytes());
163 le_bytes[16..24].copy_from_slice(&parts[2].to_le_bytes());
164 le_bytes[24..32].copy_from_slice(&parts[3].to_le_bytes());
165 let mut repr = <Fr as PrimeField>::Repr::default();
166 repr.as_mut().copy_from_slice(&le_bytes[..]);
167 Fr::from_repr_vartime(repr).expect("u64s exceed BLS12-381 scalar field modulus")
168}
169
170const SBOX: u8 = 1; const FIELD: u8 = 1; fn round_constants<F: PrimeField>(arity: usize, strength: &Strength) -> Vec<F> {
174 let t = arity + 1;
175
176 let (full_rounds, partial_rounds) = round_numbers(arity, strength);
177
178 let r_f = full_rounds as u16;
179 let r_p = partial_rounds as u16;
180
181 let fr_num_bits = F::NUM_BITS;
182 let field_size = {
183 assert!(fr_num_bits <= u32::from(std::u16::MAX));
184 fr_num_bits as u16
186 };
187
188 generate_constants::<F>(FIELD, SBOX, field_size, t as u16, r_f, r_p)
189}
190
191pub(crate) fn quintic_s_box<F: PrimeField>(l: &mut F, pre_add: Option<&F>, post_add: Option<&F>) {
193 if let Some(x) = pre_add {
194 l.add_assign(x);
195 }
196 let mut tmp = *l;
197 tmp = tmp.square(); tmp = tmp.square(); l.mul_assign(&tmp); if let Some(x) = post_add {
201 l.add_assign(x);
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_strengthened_round_numbers() {
211 let cases = [
212 (1, 69),
213 (2, 69),
214 (3, 70),
215 (4, 70),
216 (5, 70),
217 (6, 70),
218 (7, 72),
219 (8, 72),
220 (9, 72),
221 (10, 72),
222 (11, 72),
223 (16, 74),
224 (24, 74),
225 (36, 75),
226 (64, 77),
227 ];
228
229 cases.iter().for_each(|(arity, expected_rounds)| {
230 let (full_rounds, actual_rounds) = round_numbers_strengthened(*arity);
231 assert_eq!(8, full_rounds);
232 assert_eq!(
233 *expected_rounds, actual_rounds,
234 "wrong number of partial rounds for arity {}",
235 *arity
236 );
237 })
238 }
239}