miden_crypto/rand/
rpo.rs

1use alloc::{string::ToString, vec::Vec};
2
3use p3_field::{ExtensionField, PrimeField64};
4use rand_core::impls;
5
6use super::{Felt, FeltRng, RngCore};
7use crate::{
8    Word, ZERO,
9    hash::rpo::Rpo256,
10    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13// CONSTANTS
14// ================================================================================================
15
16const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
17const RATE_START: usize = Rpo256::RATE_RANGE.start;
18const RATE_END: usize = Rpo256::RATE_RANGE.end;
19const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
20
21// RPO RANDOM COIN
22// ================================================================================================
23/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
24/// described in <https://eprint.iacr.org/2011/499.pdf>.
25///
26/// The simplification is related to the following facts:
27/// 1. A call to the reseed method implies one and only one call to the permutation function. This
28///    is possible because in our case we never reseed with more than 4 field elements.
29/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
30///    material.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct RpoRandomCoin {
33    state: [Felt; STATE_WIDTH],
34    current: usize,
35}
36
37impl RpoRandomCoin {
38    /// Returns a new [RpoRandomCoin] initialize with the specified seed.
39    pub fn new(seed: Word) -> Self {
40        let mut state = [ZERO; STATE_WIDTH];
41
42        for i in 0..HALF_RATE_WIDTH {
43            state[RATE_START + i] += seed[i];
44        }
45
46        // Absorb
47        Rpo256::apply_permutation(&mut state);
48
49        RpoRandomCoin { state, current: RATE_START }
50    }
51
52    /// Returns an [RpoRandomCoin] instantiated from the provided components.
53    ///
54    /// # Panics
55    /// Panics if `current` is smaller than 4 or greater than or equal to 12.
56    pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
57        assert!(
58            (RATE_START..RATE_END).contains(&current),
59            "current value outside of valid range"
60        );
61        Self { state, current }
62    }
63
64    /// Returns components of this random coin.
65    pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
66        (self.state, self.current)
67    }
68
69    /// Fills `dest` with random data.
70    pub fn fill_bytes(&mut self, dest: &mut [u8]) {
71        <Self as RngCore>::fill_bytes(self, dest)
72    }
73
74    /// Draws a random base field element from the random coin.
75    ///
76    /// This method applies the Rpo256 permutation when the rate portion of the state is exhausted,
77    /// then returns the next element from the rate portion.
78    pub fn draw_basefield(&mut self) -> Felt {
79        if self.current == RATE_END {
80            Rpo256::apply_permutation(&mut self.state);
81            self.current = RATE_START;
82        }
83
84        self.current += 1;
85        self.state[self.current - 1]
86    }
87
88    /// Draws a random field element.
89    ///
90    /// This is an alias for [Self::draw_basefield].
91    pub fn draw(&mut self) -> Felt {
92        self.draw_basefield()
93    }
94
95    /// Draws a random extension field element.
96    ///
97    /// The extension field element is constructed by drawing `E::DIMENSION` base field elements
98    /// and interpreting them as basis coefficients.
99    pub fn draw_ext_field<E: ExtensionField<Felt>>(&mut self) -> E {
100        let ext_degree = E::DIMENSION;
101        let mut result = vec![ZERO; ext_degree];
102        for r in result.iter_mut().take(ext_degree) {
103            *r = self.draw_basefield();
104        }
105        E::from_basis_coefficients_slice(&result).expect("failed to draw extension field element")
106    }
107
108    /// Reseeds the random coin with additional entropy.
109    ///
110    /// The provided `data` is added to the first half of the rate portion of the state,
111    /// then the Rpo256 permutation is applied. The buffer pointer is reset to the start
112    /// of the rate portion.
113    pub fn reseed(&mut self, data: Word) {
114        // Reset buffer
115        self.current = RATE_START;
116
117        // Add the new seed material to the first half of the rate portion of the RPO state
118        self.state[RATE_START] += data[0];
119        self.state[RATE_START + 1] += data[1];
120        self.state[RATE_START + 2] += data[2];
121        self.state[RATE_START + 3] += data[3];
122
123        // Absorb
124        Rpo256::apply_permutation(&mut self.state);
125    }
126
127    /// Checks how many leading zeros a value would produce when hashed with the current state.
128    ///
129    /// This method creates a temporary copy of the state, adds the provided `value` to the first
130    /// rate element, applies the Rpo256 permutation, and returns the number of trailing zeros
131    /// in the resulting first rate element. This is useful for proof-of-work style computations.
132    pub fn check_leading_zeros(&self, value: u64) -> u32 {
133        let value = Felt::new(value);
134        let mut state_tmp = self.state;
135
136        state_tmp[RATE_START] += value;
137
138        Rpo256::apply_permutation(&mut state_tmp);
139
140        let first_rate_element = state_tmp[RATE_START].as_canonical_u64();
141        first_rate_element.trailing_zeros()
142    }
143
144    /// Draws a specified number of unique random integers from a domain of a given size.
145    ///
146    /// # Arguments
147    /// * `num_values` - The number of unique integers to draw (must be less than `domain_size`)
148    /// * `domain_size` - The size of the domain (must be a power of two)
149    /// * `nonce` - A nonce value that is absorbed into the state before drawing
150    ///
151    /// # Returns
152    /// A vector of `num_values` unique integers in the range `[0, domain_size)`
153    ///
154    /// # Panics
155    /// Panics if `domain_size` is not a power of two or if `num_values >= domain_size`.
156    pub fn draw_integers(
157        &mut self,
158        num_values: usize,
159        domain_size: usize,
160        nonce: u64,
161    ) -> Vec<usize> {
162        assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
163        assert!(num_values < domain_size, "number of values must be smaller than domain size");
164
165        // absorb the nonce
166        let nonce = Felt::new(nonce);
167        self.state[RATE_START] += nonce;
168        Rpo256::apply_permutation(&mut self.state);
169
170        // reset the buffer and move the next random element pointer to the second rate element.
171        // this is done as the first rate element will be "biased" via the provided `nonce` to
172        // contain some number of leading zeros.
173        self.current = RATE_START + 1;
174
175        // determine how many bits are needed to represent valid values in the domain
176        let v_mask = (domain_size - 1) as u64;
177
178        // draw values from PRNG until we get as many unique values as specified by num_queries
179        let mut values = Vec::new();
180        for _ in 0..1000 {
181            // get the next pseudo-random field element
182            let value = self.draw_basefield().as_canonical_u64();
183
184            // use the mask to get a value within the range
185            let value = (value & v_mask) as usize;
186
187            values.push(value);
188            if values.len() == num_values {
189                break;
190            }
191        }
192
193        assert_eq!(
194            values.len(),
195            num_values,
196            "failed to draw {} integers after 1000 iterations (got {})",
197            num_values,
198            values.len()
199        );
200
201        values
202    }
203}
204
205// FELT RNG IMPLEMENTATION
206// ------------------------------------------------------------------------------------------------
207
208impl FeltRng for RpoRandomCoin {
209    fn draw_element(&mut self) -> Felt {
210        self.draw_basefield()
211    }
212
213    fn draw_word(&mut self) -> Word {
214        let mut output = [ZERO; 4];
215        for o in output.iter_mut() {
216            *o = self.draw_basefield();
217        }
218        Word::new(output)
219    }
220}
221
222// RNGCORE IMPLEMENTATION
223// ------------------------------------------------------------------------------------------------
224
225impl RngCore for RpoRandomCoin {
226    fn next_u32(&mut self) -> u32 {
227        self.draw_basefield().as_canonical_u64() as u32
228    }
229
230    fn next_u64(&mut self) -> u64 {
231        impls::next_u64_via_u32(self)
232    }
233
234    fn fill_bytes(&mut self, dest: &mut [u8]) {
235        impls::fill_bytes_via_next(self, dest)
236    }
237}
238
239// SERIALIZATION
240// ------------------------------------------------------------------------------------------------
241
242impl Serializable for RpoRandomCoin {
243    fn write_into<W: ByteWriter>(&self, target: &mut W) {
244        self.state.iter().for_each(|v| v.write_into(target));
245        // casting to u8 is OK because `current` is always between 4 and 12.
246        target.write_u8(self.current as u8);
247    }
248}
249
250impl Deserializable for RpoRandomCoin {
251    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
252        let state = [
253            Felt::read_from(source)?,
254            Felt::read_from(source)?,
255            Felt::read_from(source)?,
256            Felt::read_from(source)?,
257            Felt::read_from(source)?,
258            Felt::read_from(source)?,
259            Felt::read_from(source)?,
260            Felt::read_from(source)?,
261            Felt::read_from(source)?,
262            Felt::read_from(source)?,
263            Felt::read_from(source)?,
264            Felt::read_from(source)?,
265        ];
266        let current = source.read_u8()? as usize;
267        if !(RATE_START..RATE_END).contains(&current) {
268            return Err(DeserializationError::InvalidValue(
269                "current value outside of valid range".to_string(),
270            ));
271        }
272        Ok(Self { state, current })
273    }
274}
275
276// TESTS
277// ================================================================================================
278
279#[cfg(test)]
280mod tests {
281    use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
282    use crate::{ONE, Word};
283
284    #[test]
285    fn test_feltrng_felt() {
286        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
287        let output = rpocoin.draw_element();
288
289        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
290        let expected = rpocoin.draw_basefield();
291
292        assert_eq!(output, expected);
293    }
294
295    #[test]
296    fn test_feltrng_word() {
297        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
298        let output = rpocoin.draw_word();
299
300        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
301        let mut expected = [ZERO; 4];
302        for o in expected.iter_mut() {
303            *o = rpocoin.draw_basefield();
304        }
305        let expected = Word::new(expected);
306
307        assert_eq!(output, expected);
308    }
309
310    #[test]
311    fn test_feltrng_serialization() {
312        let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
313
314        let bytes = coin1.to_bytes();
315        let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
316        assert_eq!(coin1, coin2);
317    }
318}