miden_crypto/rand/
rpo.rs

1use alloc::{string::ToString, vec::Vec};
2
3use rand_core::impls;
4
5use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
6use crate::{
7    hash::rpo::{Rpo256, RpoDigest},
8    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
9};
10
11// CONSTANTS
12// ================================================================================================
13
14const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
15const RATE_START: usize = Rpo256::RATE_RANGE.start;
16const RATE_END: usize = Rpo256::RATE_RANGE.end;
17const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
18
19// RPO RANDOM COIN
20// ================================================================================================
21/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
22/// described in <https://eprint.iacr.org/2011/499.pdf>.
23///
24/// The simplification is related to the following facts:
25/// 1. A call to the reseed method implies one and only one call to the permutation function. This
26///    is possible because in our case we never reseed with more than 4 field elements.
27/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
28///    material.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct RpoRandomCoin {
31    state: [Felt; STATE_WIDTH],
32    current: usize,
33}
34
35impl RpoRandomCoin {
36    /// Returns a new [RpoRandomCoin] initialize with the specified seed.
37    pub fn new(seed: Word) -> Self {
38        let mut state = [ZERO; STATE_WIDTH];
39
40        for i in 0..HALF_RATE_WIDTH {
41            state[RATE_START + i] += seed[i];
42        }
43
44        // Absorb
45        Rpo256::apply_permutation(&mut state);
46
47        RpoRandomCoin { state, current: RATE_START }
48    }
49
50    /// Returns an [RpoRandomCoin] instantiated from the provided components.
51    ///
52    /// # Panics
53    /// Panics if `current` is smaller than 4 or greater than or equal to 12.
54    pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
55        assert!(
56            (RATE_START..RATE_END).contains(&current),
57            "current value outside of valid range"
58        );
59        Self { state, current }
60    }
61
62    /// Returns components of this random coin.
63    pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
64        (self.state, self.current)
65    }
66
67    /// Fills `dest` with random data.
68    pub fn fill_bytes(&mut self, dest: &mut [u8]) {
69        <Self as RngCore>::fill_bytes(self, dest)
70    }
71
72    fn draw_basefield(&mut self) -> Felt {
73        if self.current == RATE_END {
74            Rpo256::apply_permutation(&mut self.state);
75            self.current = RATE_START;
76        }
77
78        self.current += 1;
79        self.state[self.current - 1]
80    }
81}
82
83// RANDOM COIN IMPLEMENTATION
84// ------------------------------------------------------------------------------------------------
85
86impl RandomCoin for RpoRandomCoin {
87    type BaseField = Felt;
88    type Hasher = Rpo256;
89
90    fn new(seed: &[Self::BaseField]) -> Self {
91        let digest: Word = Rpo256::hash_elements(seed).into();
92        Self::new(digest)
93    }
94
95    fn reseed(&mut self, data: RpoDigest) {
96        // Reset buffer
97        self.current = RATE_START;
98
99        // Add the new seed material to the first half of the rate portion of the RPO state
100        let data: Word = data.into();
101
102        self.state[RATE_START] += data[0];
103        self.state[RATE_START + 1] += data[1];
104        self.state[RATE_START + 2] += data[2];
105        self.state[RATE_START + 3] += data[3];
106
107        // Absorb
108        Rpo256::apply_permutation(&mut self.state);
109    }
110
111    fn check_leading_zeros(&self, value: u64) -> u32 {
112        let value = Felt::new(value);
113        let mut state_tmp = self.state;
114
115        state_tmp[RATE_START] += value;
116
117        Rpo256::apply_permutation(&mut state_tmp);
118
119        let first_rate_element = state_tmp[RATE_START].as_int();
120        first_rate_element.trailing_zeros()
121    }
122
123    fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
124        let ext_degree = E::EXTENSION_DEGREE;
125        let mut result = vec![ZERO; ext_degree];
126        for r in result.iter_mut().take(ext_degree) {
127            *r = self.draw_basefield();
128        }
129
130        let result = E::slice_from_base_elements(&result);
131        Ok(result[0])
132    }
133
134    fn draw_integers(
135        &mut self,
136        num_values: usize,
137        domain_size: usize,
138        nonce: u64,
139    ) -> Result<Vec<usize>, RandomCoinError> {
140        assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
141        assert!(num_values < domain_size, "number of values must be smaller than domain size");
142
143        // absorb the nonce
144        let nonce = Felt::new(nonce);
145        self.state[RATE_START] += nonce;
146        Rpo256::apply_permutation(&mut self.state);
147
148        // reset the buffer and move the next random element pointer to the second rate element.
149        // this is done as the first rate element will be "biased" via the provided `nonce` to
150        // contain some number of leading zeros.
151        self.current = RATE_START + 1;
152
153        // determine how many bits are needed to represent valid values in the domain
154        let v_mask = (domain_size - 1) as u64;
155
156        // draw values from PRNG until we get as many unique values as specified by num_queries
157        let mut values = Vec::new();
158        for _ in 0..1000 {
159            // get the next pseudo-random field element
160            let value = self.draw_basefield().as_int();
161
162            // use the mask to get a value within the range
163            let value = (value & v_mask) as usize;
164
165            values.push(value);
166            if values.len() == num_values {
167                break;
168            }
169        }
170
171        if values.len() < num_values {
172            return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
173        }
174
175        Ok(values)
176    }
177}
178
179// FELT RNG IMPLEMENTATION
180// ------------------------------------------------------------------------------------------------
181
182impl FeltRng for RpoRandomCoin {
183    fn draw_element(&mut self) -> Felt {
184        self.draw_basefield()
185    }
186
187    fn draw_word(&mut self) -> Word {
188        let mut output = [ZERO; 4];
189        for o in output.iter_mut() {
190            *o = self.draw_basefield();
191        }
192        output
193    }
194}
195
196// RNGCORE IMPLEMENTATION
197// ------------------------------------------------------------------------------------------------
198
199impl RngCore for RpoRandomCoin {
200    fn next_u32(&mut self) -> u32 {
201        self.draw_basefield().as_int() as u32
202    }
203
204    fn next_u64(&mut self) -> u64 {
205        impls::next_u64_via_u32(self)
206    }
207
208    fn fill_bytes(&mut self, dest: &mut [u8]) {
209        impls::fill_bytes_via_next(self, dest)
210    }
211
212    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
213        self.fill_bytes(dest);
214        Ok(())
215    }
216}
217
218// SERIALIZATION
219// ------------------------------------------------------------------------------------------------
220
221impl Serializable for RpoRandomCoin {
222    fn write_into<W: ByteWriter>(&self, target: &mut W) {
223        self.state.iter().for_each(|v| v.write_into(target));
224        // casting to u8 is OK because `current` is always between 4 and 12.
225        target.write_u8(self.current as u8);
226    }
227}
228
229impl Deserializable for RpoRandomCoin {
230    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
231        let state = [
232            Felt::read_from(source)?,
233            Felt::read_from(source)?,
234            Felt::read_from(source)?,
235            Felt::read_from(source)?,
236            Felt::read_from(source)?,
237            Felt::read_from(source)?,
238            Felt::read_from(source)?,
239            Felt::read_from(source)?,
240            Felt::read_from(source)?,
241            Felt::read_from(source)?,
242            Felt::read_from(source)?,
243            Felt::read_from(source)?,
244        ];
245        let current = source.read_u8()? as usize;
246        if !(RATE_START..RATE_END).contains(&current) {
247            return Err(DeserializationError::InvalidValue(
248                "current value outside of valid range".to_string(),
249            ));
250        }
251        Ok(Self { state, current })
252    }
253}
254
255// TESTS
256// ================================================================================================
257
258#[cfg(test)]
259mod tests {
260    use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
261    use crate::ONE;
262
263    #[test]
264    fn test_feltrng_felt() {
265        let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
266        let output = rpocoin.draw_element();
267
268        let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
269        let expected = rpocoin.draw_basefield();
270
271        assert_eq!(output, expected);
272    }
273
274    #[test]
275    fn test_feltrng_word() {
276        let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
277        let output = rpocoin.draw_word();
278
279        let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
280        let mut expected = [ZERO; 4];
281        for o in expected.iter_mut() {
282            *o = rpocoin.draw_basefield();
283        }
284
285        assert_eq!(output, expected);
286    }
287
288    #[test]
289    fn test_feltrng_serialization() {
290        let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
291
292        let bytes = coin1.to_bytes();
293        let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
294        assert_eq!(coin1, coin2);
295    }
296}