Skip to main content

miden_crypto/rand/
coin.rs

1use alloc::string::ToString;
2
3use rand::{
4    Rng,
5    rand_core::{Infallible, TryRng, utils},
6};
7
8use super::{Felt, FeltRng};
9use crate::{
10    Word, ZERO,
11    field::ExtensionField,
12    hash::poseidon2::Poseidon2,
13    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
14};
15
16// CONSTANTS
17// ================================================================================================
18
19const STATE_WIDTH: usize = Poseidon2::STATE_WIDTH;
20const RATE_START: usize = Poseidon2::RATE_RANGE.start;
21const RATE_END: usize = Poseidon2::RATE_RANGE.end;
22const HALF_RATE_WIDTH: usize = (Poseidon2::RATE_RANGE.end - Poseidon2::RATE_RANGE.start) / 2;
23
24// POSEIDON2 RANDOM COIN
25// ================================================================================================
26/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
27/// described in <https://eprint.iacr.org/2011/499.pdf>.
28///
29/// The simplification is related to the following facts:
30/// 1. A call to the reseed method implies one and only one call to the permutation function. This
31///    is possible because in our case we never reseed with more than 4 field elements.
32/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
33///    material.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub struct RandomCoin {
36    state: [Felt; STATE_WIDTH],
37    current: usize,
38}
39
40impl RandomCoin {
41    /// Returns a new [RandomCoin] initialized with the specified seed.
42    pub fn new(seed: Word) -> Self {
43        let mut state = [ZERO; STATE_WIDTH];
44
45        for i in 0..HALF_RATE_WIDTH {
46            state[RATE_START + i] += seed[i];
47        }
48
49        // Absorb
50        Poseidon2::apply_permutation(&mut state);
51
52        RandomCoin { state, current: RATE_START }
53    }
54
55    /// Returns a [RandomCoin] instantiated from the provided components.
56    ///
57    /// # Panics
58    /// Panics if `current` is outside of the rate range.
59    pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
60        assert!(
61            (RATE_START..RATE_END).contains(&current),
62            "current value outside of valid range"
63        );
64        Self { state, current }
65    }
66
67    /// Returns components of this random coin.
68    pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
69        (self.state, self.current)
70    }
71
72    /// Fills `dest` with random data.
73    pub fn fill_bytes(&mut self, dest: &mut [u8]) {
74        <Self as Rng>::fill_bytes(self, dest)
75    }
76
77    /// Draws a random base field element from the random coin.
78    ///
79    /// This method applies the Poseidon2 permutation when the rate portion of the state is
80    /// exhausted, then returns the next element from the rate portion.
81    pub fn draw_basefield(&mut self) -> Felt {
82        if self.current == RATE_END {
83            Poseidon2::apply_permutation(&mut self.state);
84            self.current = RATE_START;
85        }
86
87        self.current += 1;
88        self.state[self.current - 1]
89    }
90
91    /// Draws a random field element.
92    ///
93    /// This is an alias for [Self::draw_basefield].
94    pub fn draw(&mut self) -> Felt {
95        self.draw_basefield()
96    }
97
98    /// Draws a random extension field element.
99    ///
100    /// The extension field element is constructed by drawing `E::DIMENSION` base field elements
101    /// and interpreting them as basis coefficients.
102    pub fn draw_ext_field<E: ExtensionField<Felt>>(&mut self) -> E {
103        let ext_degree = E::DIMENSION;
104        let mut result = vec![ZERO; ext_degree];
105        for r in result.iter_mut().take(ext_degree) {
106            *r = self.draw_basefield();
107        }
108        E::from_basis_coefficients_slice(&result).expect("failed to draw extension field element")
109    }
110
111    /// Reseeds the random coin with additional entropy.
112    ///
113    /// The provided `data` is added to the first half of the rate portion of the state,
114    /// then the Poseidon2 permutation is applied. The buffer pointer is reset to the start
115    /// of the rate portion.
116    pub fn reseed(&mut self, data: Word) {
117        // Reset buffer
118        self.current = RATE_START;
119
120        // Add the new seed material to the first half of the rate portion of the Poseidon2 state
121        self.state[RATE_START] += data[0];
122        self.state[RATE_START + 1] += data[1];
123        self.state[RATE_START + 2] += data[2];
124        self.state[RATE_START + 3] += data[3];
125
126        // Absorb
127        Poseidon2::apply_permutation(&mut self.state);
128    }
129}
130
131// FELT RNG IMPLEMENTATION
132// ------------------------------------------------------------------------------------------------
133
134impl FeltRng for RandomCoin {
135    fn draw_element(&mut self) -> Felt {
136        self.draw_basefield()
137    }
138
139    fn draw_word(&mut self) -> Word {
140        let mut output = [ZERO; 4];
141        for o in output.iter_mut() {
142            *o = self.draw_basefield();
143        }
144        Word::new(output)
145    }
146}
147
148// RNG IMPLEMENTATION
149// ------------------------------------------------------------------------------------------------
150
151impl TryRng for RandomCoin {
152    type Error = Infallible;
153
154    fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
155        Ok(self.draw_basefield().as_canonical_u64() as u32)
156    }
157
158    fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
159        utils::next_u64_via_u32(self)
160    }
161
162    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
163        utils::fill_bytes_via_next_word(dest, || self.try_next_u32())
164    }
165}
166
167// SERIALIZATION
168// ------------------------------------------------------------------------------------------------
169
170impl Serializable for RandomCoin {
171    fn write_into<W: ByteWriter>(&self, target: &mut W) {
172        self.state.iter().for_each(|v| v.write_into(target));
173        // casting to u8 is OK because `current` is always within the rate range.
174        target.write_u8(self.current as u8);
175    }
176}
177
178impl Deserializable for RandomCoin {
179    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
180        let state = [
181            Felt::read_from(source)?,
182            Felt::read_from(source)?,
183            Felt::read_from(source)?,
184            Felt::read_from(source)?,
185            Felt::read_from(source)?,
186            Felt::read_from(source)?,
187            Felt::read_from(source)?,
188            Felt::read_from(source)?,
189            Felt::read_from(source)?,
190            Felt::read_from(source)?,
191            Felt::read_from(source)?,
192            Felt::read_from(source)?,
193        ];
194        let current = source.read_u8()? as usize;
195        if !(RATE_START..RATE_END).contains(&current) {
196            return Err(DeserializationError::InvalidValue(
197                "current value outside of valid range".to_string(),
198            ));
199        }
200        Ok(Self { state, current })
201    }
202}
203
204// TESTS
205// ================================================================================================
206
207#[cfg(test)]
208mod tests {
209    use super::{Deserializable, FeltRng, RandomCoin, Serializable, ZERO};
210    use crate::{ONE, Word};
211
212    #[test]
213    fn test_feltrng_felt() {
214        let mut coin = RandomCoin::new([ZERO; 4].into());
215        let output = coin.draw_element();
216
217        let mut coin = RandomCoin::new([ZERO; 4].into());
218        let expected = coin.draw_basefield();
219
220        assert_eq!(output, expected);
221    }
222
223    #[test]
224    fn test_feltrng_word() {
225        let mut coin = RandomCoin::new([ZERO; 4].into());
226        let output = coin.draw_word();
227
228        let mut coin = RandomCoin::new([ZERO; 4].into());
229        let mut expected = [ZERO; 4];
230        for o in expected.iter_mut() {
231            *o = coin.draw_basefield();
232        }
233        let expected = Word::new(expected);
234
235        assert_eq!(output, expected);
236    }
237
238    #[test]
239    fn test_feltrng_serialization() {
240        let coin1 = RandomCoin::from_parts([ONE; 12], 5);
241
242        let bytes = coin1.to_bytes();
243        let coin2 = RandomCoin::read_from_bytes(&bytes).unwrap();
244        assert_eq!(coin1, coin2);
245    }
246}