miden_crypto/rand/
rpo.rs

1use alloc::{string::ToString, vec::Vec};
2
3use rand_core::impls;
4
5use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, ZERO};
6use crate::{
7    Word,
8    hash::rpo::Rpo256,
9    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
10};
11
12// CONSTANTS
13// ================================================================================================
14
15const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
16const RATE_START: usize = Rpo256::RATE_RANGE.start;
17const RATE_END: usize = Rpo256::RATE_RANGE.end;
18const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
19
20// RPO RANDOM COIN
21// ================================================================================================
22/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
23/// described in <https://eprint.iacr.org/2011/499.pdf>.
24///
25/// The simplification is related to the following facts:
26/// 1. A call to the reseed method implies one and only one call to the permutation function. This
27///    is possible because in our case we never reseed with more than 4 field elements.
28/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
29///    material.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct RpoRandomCoin {
32    state: [Felt; STATE_WIDTH],
33    current: usize,
34}
35
36impl RpoRandomCoin {
37    /// Returns a new [RpoRandomCoin] initialize with the specified seed.
38    pub fn new(seed: Word) -> Self {
39        let mut state = [ZERO; STATE_WIDTH];
40
41        for i in 0..HALF_RATE_WIDTH {
42            state[RATE_START + i] += seed[i];
43        }
44
45        // Absorb
46        Rpo256::apply_permutation(&mut state);
47
48        RpoRandomCoin { state, current: RATE_START }
49    }
50
51    /// Returns an [RpoRandomCoin] instantiated from the provided components.
52    ///
53    /// # Panics
54    /// Panics if `current` is smaller than 4 or greater than or equal to 12.
55    pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
56        assert!(
57            (RATE_START..RATE_END).contains(&current),
58            "current value outside of valid range"
59        );
60        Self { state, current }
61    }
62
63    /// Returns components of this random coin.
64    pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
65        (self.state, self.current)
66    }
67
68    /// Fills `dest` with random data.
69    pub fn fill_bytes(&mut self, dest: &mut [u8]) {
70        <Self as RngCore>::fill_bytes(self, dest)
71    }
72
73    fn draw_basefield(&mut self) -> Felt {
74        if self.current == RATE_END {
75            Rpo256::apply_permutation(&mut self.state);
76            self.current = RATE_START;
77        }
78
79        self.current += 1;
80        self.state[self.current - 1]
81    }
82}
83
84// RANDOM COIN IMPLEMENTATION
85// ------------------------------------------------------------------------------------------------
86
87impl RandomCoin for RpoRandomCoin {
88    type BaseField = Felt;
89    type Hasher = Rpo256;
90
91    fn new(seed: &[Self::BaseField]) -> Self {
92        let digest: Word = Rpo256::hash_elements(seed);
93        Self::new(digest)
94    }
95
96    fn reseed(&mut self, data: Word) {
97        // Reset buffer
98        self.current = RATE_START;
99
100        // Add the new seed material to the first half of the rate portion of the RPO state
101        self.state[RATE_START] += data[0];
102        self.state[RATE_START + 1] += data[1];
103        self.state[RATE_START + 2] += data[2];
104        self.state[RATE_START + 3] += data[3];
105
106        // Absorb
107        Rpo256::apply_permutation(&mut self.state);
108    }
109
110    fn check_leading_zeros(&self, value: u64) -> u32 {
111        let value = Felt::new(value);
112        let mut state_tmp = self.state;
113
114        state_tmp[RATE_START] += value;
115
116        Rpo256::apply_permutation(&mut state_tmp);
117
118        let first_rate_element = state_tmp[RATE_START].as_int();
119        first_rate_element.trailing_zeros()
120    }
121
122    fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
123        let ext_degree = E::EXTENSION_DEGREE;
124        let mut result = vec![ZERO; ext_degree];
125        for r in result.iter_mut().take(ext_degree) {
126            *r = self.draw_basefield();
127        }
128
129        let result = E::slice_from_base_elements(&result);
130        Ok(result[0])
131    }
132
133    fn draw_integers(
134        &mut self,
135        num_values: usize,
136        domain_size: usize,
137        nonce: u64,
138    ) -> Result<Vec<usize>, RandomCoinError> {
139        assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
140        assert!(num_values < domain_size, "number of values must be smaller than domain size");
141
142        // absorb the nonce
143        let nonce = Felt::new(nonce);
144        self.state[RATE_START] += nonce;
145        Rpo256::apply_permutation(&mut self.state);
146
147        // reset the buffer and move the next random element pointer to the second rate element.
148        // this is done as the first rate element will be "biased" via the provided `nonce` to
149        // contain some number of leading zeros.
150        self.current = RATE_START + 1;
151
152        // determine how many bits are needed to represent valid values in the domain
153        let v_mask = (domain_size - 1) as u64;
154
155        // draw values from PRNG until we get as many unique values as specified by num_queries
156        let mut values = Vec::new();
157        for _ in 0..1000 {
158            // get the next pseudo-random field element
159            let value = self.draw_basefield().as_int();
160
161            // use the mask to get a value within the range
162            let value = (value & v_mask) as usize;
163
164            values.push(value);
165            if values.len() == num_values {
166                break;
167            }
168        }
169
170        if values.len() < num_values {
171            return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
172        }
173
174        Ok(values)
175    }
176}
177
178// FELT RNG IMPLEMENTATION
179// ------------------------------------------------------------------------------------------------
180
181impl FeltRng for RpoRandomCoin {
182    fn draw_element(&mut self) -> Felt {
183        self.draw_basefield()
184    }
185
186    fn draw_word(&mut self) -> Word {
187        let mut output = [ZERO; 4];
188        for o in output.iter_mut() {
189            *o = self.draw_basefield();
190        }
191        Word::new(output)
192    }
193}
194
195// RNGCORE IMPLEMENTATION
196// ------------------------------------------------------------------------------------------------
197
198impl RngCore for RpoRandomCoin {
199    fn next_u32(&mut self) -> u32 {
200        self.draw_basefield().as_int() as u32
201    }
202
203    fn next_u64(&mut self) -> u64 {
204        impls::next_u64_via_u32(self)
205    }
206
207    fn fill_bytes(&mut self, dest: &mut [u8]) {
208        impls::fill_bytes_via_next(self, dest)
209    }
210}
211
212// SERIALIZATION
213// ------------------------------------------------------------------------------------------------
214
215impl Serializable for RpoRandomCoin {
216    fn write_into<W: ByteWriter>(&self, target: &mut W) {
217        self.state.iter().for_each(|v| v.write_into(target));
218        // casting to u8 is OK because `current` is always between 4 and 12.
219        target.write_u8(self.current as u8);
220    }
221}
222
223impl Deserializable for RpoRandomCoin {
224    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
225        let state = [
226            Felt::read_from(source)?,
227            Felt::read_from(source)?,
228            Felt::read_from(source)?,
229            Felt::read_from(source)?,
230            Felt::read_from(source)?,
231            Felt::read_from(source)?,
232            Felt::read_from(source)?,
233            Felt::read_from(source)?,
234            Felt::read_from(source)?,
235            Felt::read_from(source)?,
236            Felt::read_from(source)?,
237            Felt::read_from(source)?,
238        ];
239        let current = source.read_u8()? as usize;
240        if !(RATE_START..RATE_END).contains(&current) {
241            return Err(DeserializationError::InvalidValue(
242                "current value outside of valid range".to_string(),
243            ));
244        }
245        Ok(Self { state, current })
246    }
247}
248
249// TESTS
250// ================================================================================================
251
252#[cfg(test)]
253mod tests {
254    use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
255    use crate::{ONE, Word};
256
257    #[test]
258    fn test_feltrng_felt() {
259        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
260        let output = rpocoin.draw_element();
261
262        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
263        let expected = rpocoin.draw_basefield();
264
265        assert_eq!(output, expected);
266    }
267
268    #[test]
269    fn test_feltrng_word() {
270        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
271        let output = rpocoin.draw_word();
272
273        let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
274        let mut expected = [ZERO; 4];
275        for o in expected.iter_mut() {
276            *o = rpocoin.draw_basefield();
277        }
278        let expected = Word::new(expected);
279
280        assert_eq!(output, expected);
281    }
282
283    #[test]
284    fn test_feltrng_serialization() {
285        let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
286
287        let bytes = coin1.to_bytes();
288        let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
289        assert_eq!(coin1, coin2);
290    }
291}