neptune/
lib.rs

1#![allow(dead_code)]
2#![allow(unused_imports)]
3
4pub use crate::poseidon::{Arity, Poseidon};
5use crate::round_constants::generate_constants;
6use crate::round_numbers::{round_numbers_base, round_numbers_strengthened};
7#[cfg(feature = "abomonation")]
8use abomonation_derive::Abomonation;
9#[cfg(test)]
10use blstrs::Scalar as Fr;
11pub use error::Error;
12use ff::PrimeField;
13use generic_array::GenericArray;
14use serde::{Deserialize, Serialize};
15use std::fmt;
16use trait_set::trait_set;
17
18#[cfg(all(
19    any(feature = "cuda", feature = "opencl"),
20    not(any(
21        feature = "arity2",
22        feature = "arity4",
23        feature = "arity8",
24        feature = "arity11",
25        feature = "arity16",
26        feature = "arity24",
27        feature = "arity36"
28    ))
29))]
30compile_error!("The `cuda` and `opencl` features need at least one arity feature to be set");
31
32#[cfg(all(
33    feature = "strengthened",
34    not(any(feature = "cuda", feature = "opencl"))
35))]
36compile_error!("The `strengthened` feature needs the `cuda` and/or `opencl` feature to be set");
37
38#[cfg(all(
39    any(feature = "cuda", feature = "opencl"),
40    not(any(feature = "bls", feature = "pasta",))
41))]
42compile_error!("The `cuda` and `opencl` features need the `bls` and/or `pasta` feature to be set");
43
44/// Poseidon circuit
45pub mod circuit;
46pub mod circuit2;
47pub mod circuit2_witness;
48pub mod error;
49mod matrix;
50mod mds;
51
52/// Poseidon hash
53pub mod poseidon;
54mod poseidon_alt;
55mod preprocessing;
56mod round_constants;
57mod round_numbers;
58
59/// Sponge
60pub mod sponge;
61
62/// Hash types and domain separation tags.
63pub mod hash_type;
64
65/// Tree Builder
66#[cfg(any(feature = "cuda", feature = "opencl"))]
67pub mod tree_builder;
68
69/// Column Tree Builder
70#[cfg(any(feature = "cuda", feature = "opencl"))]
71pub mod column_tree_builder;
72
73/// Batch Hasher
74#[cfg(any(feature = "cuda", feature = "opencl"))]
75pub mod batch_hasher;
76
77#[cfg(any(feature = "cuda", feature = "opencl"))]
78pub mod proteus;
79
80#[cfg(not(any(feature = "cuda", feature = "opencl")))]
81trait_set! {
82   /// Use a trait alias, so that we can have different traits depending on the features.
83   ///
84   /// When `cuda` and/or `opencl` is enabled, then the field also needs to implement `GpuName`.
85   //pub trait NeptuneField = PrimeField + ec_gpu::GpuName;
86   pub trait NeptuneField = PrimeField;
87}
88
89#[cfg(any(feature = "cuda", feature = "opencl"))]
90trait_set! {
91   /// Use a trait alias, so that we can have different traits depending on the features.
92   ///
93   /// When `cuda` and/or `opencl` is enabled, then the field also needs to implement `GpuName`.
94   pub trait NeptuneField = PrimeField + ec_gpu::GpuName;
95}
96
97mod serde_impl;
98
99pub(crate) const TEST_SEED: [u8; 16] = [
100    0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5,
101];
102
103#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
104#[cfg_attr(feature = "abomonation", derive(Abomonation))]
105pub enum Strength {
106    Standard,
107    Strengthened,
108}
109
110impl fmt::Display for Strength {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self {
113            Self::Standard => write!(f, "standard"),
114            Self::Strengthened => write!(f, "strengthened"),
115        }
116    }
117}
118
119pub(crate) const DEFAULT_STRENGTH: Strength = Strength::Standard;
120
121pub trait BatchHasher<F, A>
122where
123    F: PrimeField,
124    A: Arity<F>,
125{
126    // type State;
127
128    fn hash(&mut self, preimages: &[GenericArray<F, A>]) -> Result<Vec<F>, Error>;
129
130    fn hash_into_slice(
131        &mut self,
132        target_slice: &mut [F],
133        preimages: &[GenericArray<F, A>],
134    ) -> Result<(), Error> {
135        assert_eq!(target_slice.len(), preimages.len());
136        // FIXME: Account for max batch size.
137
138        target_slice.copy_from_slice(self.hash(preimages)?.as_slice());
139        Ok(())
140    }
141
142    /// `max_batch_size` is advisory. Implenters of `BatchHasher` should ensure that up to the returned max hashes can
143    /// be safely performed on the target GPU (currently 2080Ti). The max returned should represent a safe batch size
144    /// optimized for performance.
145    /// `BatchHasher` users are responsible for not attempting to hash batches larger than the advised maximum.
146    fn max_batch_size(&self) -> usize {
147        700000
148    }
149}
150
151pub fn round_numbers(arity: usize, strength: &Strength) -> (usize, usize) {
152    match strength {
153        Strength::Standard => round_numbers_base(arity),
154        Strength::Strengthened => round_numbers_strengthened(arity),
155    }
156}
157
158#[cfg(test)]
159pub(crate) fn scalar_from_u64s(parts: [u64; 4]) -> Fr {
160    let mut le_bytes = [0u8; 32];
161    le_bytes[0..8].copy_from_slice(&parts[0].to_le_bytes());
162    le_bytes[8..16].copy_from_slice(&parts[1].to_le_bytes());
163    le_bytes[16..24].copy_from_slice(&parts[2].to_le_bytes());
164    le_bytes[24..32].copy_from_slice(&parts[3].to_le_bytes());
165    let mut repr = <Fr as PrimeField>::Repr::default();
166    repr.as_mut().copy_from_slice(&le_bytes[..]);
167    Fr::from_repr_vartime(repr).expect("u64s exceed BLS12-381 scalar field modulus")
168}
169
170const SBOX: u8 = 1; // x^5
171const FIELD: u8 = 1; // Gf(p)
172
173fn round_constants<F: PrimeField>(arity: usize, strength: &Strength) -> Vec<F> {
174    let t = arity + 1;
175
176    let (full_rounds, partial_rounds) = round_numbers(arity, strength);
177
178    let r_f = full_rounds as u16;
179    let r_p = partial_rounds as u16;
180
181    let fr_num_bits = F::NUM_BITS;
182    let field_size = {
183        assert!(fr_num_bits <= u32::from(std::u16::MAX));
184        // It's safe to convert to u16 for compatibility with other types.
185        fr_num_bits as u16
186    };
187
188    generate_constants::<F>(FIELD, SBOX, field_size, t as u16, r_f, r_p)
189}
190
191/// Apply the quintic S-Box (s^5) to a given item
192pub(crate) fn quintic_s_box<F: PrimeField>(l: &mut F, pre_add: Option<&F>, post_add: Option<&F>) {
193    if let Some(x) = pre_add {
194        l.add_assign(x);
195    }
196    let mut tmp = *l;
197    tmp = tmp.square(); // l^2
198    tmp = tmp.square(); // l^4
199    l.mul_assign(&tmp); // l^5
200    if let Some(x) = post_add {
201        l.add_assign(x);
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_strengthened_round_numbers() {
211        let cases = [
212            (1, 69),
213            (2, 69),
214            (3, 70),
215            (4, 70),
216            (5, 70),
217            (6, 70),
218            (7, 72),
219            (8, 72),
220            (9, 72),
221            (10, 72),
222            (11, 72),
223            (16, 74),
224            (24, 74),
225            (36, 75),
226            (64, 77),
227        ];
228
229        cases.iter().for_each(|(arity, expected_rounds)| {
230            let (full_rounds, actual_rounds) = round_numbers_strengthened(*arity);
231            assert_eq!(8, full_rounds);
232            assert_eq!(
233                *expected_rounds, actual_rounds,
234                "wrong number of partial rounds for arity {}",
235                *arity
236            );
237        })
238    }
239}