miden_crypto/rand/
coin.rs1use alloc::string::ToString;
2
3use rand::{
4 Rng,
5 rand_core::{Infallible, TryRng, utils},
6};
7
8use super::{Felt, FeltRng};
9use crate::{
10 Word, ZERO,
11 field::ExtensionField,
12 hash::poseidon2::Poseidon2,
13 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
14};
15
16const STATE_WIDTH: usize = Poseidon2::STATE_WIDTH;
20const RATE_START: usize = Poseidon2::RATE_RANGE.start;
21const RATE_END: usize = Poseidon2::RATE_RANGE.end;
22const HALF_RATE_WIDTH: usize = (Poseidon2::RATE_RANGE.end - Poseidon2::RATE_RANGE.start) / 2;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub struct RandomCoin {
36 state: [Felt; STATE_WIDTH],
37 current: usize,
38}
39
40impl RandomCoin {
41 pub fn new(seed: Word) -> Self {
43 let mut state = [ZERO; STATE_WIDTH];
44
45 for i in 0..HALF_RATE_WIDTH {
46 state[RATE_START + i] += seed[i];
47 }
48
49 Poseidon2::apply_permutation(&mut state);
51
52 RandomCoin { state, current: RATE_START }
53 }
54
55 pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
60 assert!(
61 (RATE_START..RATE_END).contains(¤t),
62 "current value outside of valid range"
63 );
64 Self { state, current }
65 }
66
67 pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
69 (self.state, self.current)
70 }
71
72 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
74 <Self as Rng>::fill_bytes(self, dest)
75 }
76
77 pub fn draw_basefield(&mut self) -> Felt {
82 if self.current == RATE_END {
83 Poseidon2::apply_permutation(&mut self.state);
84 self.current = RATE_START;
85 }
86
87 self.current += 1;
88 self.state[self.current - 1]
89 }
90
91 pub fn draw(&mut self) -> Felt {
95 self.draw_basefield()
96 }
97
98 pub fn draw_ext_field<E: ExtensionField<Felt>>(&mut self) -> E {
103 let ext_degree = E::DIMENSION;
104 let mut result = vec![ZERO; ext_degree];
105 for r in result.iter_mut().take(ext_degree) {
106 *r = self.draw_basefield();
107 }
108 E::from_basis_coefficients_slice(&result).expect("failed to draw extension field element")
109 }
110
111 pub fn reseed(&mut self, data: Word) {
117 self.current = RATE_START;
119
120 self.state[RATE_START] += data[0];
122 self.state[RATE_START + 1] += data[1];
123 self.state[RATE_START + 2] += data[2];
124 self.state[RATE_START + 3] += data[3];
125
126 Poseidon2::apply_permutation(&mut self.state);
128 }
129}
130
131impl FeltRng for RandomCoin {
135 fn draw_element(&mut self) -> Felt {
136 self.draw_basefield()
137 }
138
139 fn draw_word(&mut self) -> Word {
140 let mut output = [ZERO; 4];
141 for o in output.iter_mut() {
142 *o = self.draw_basefield();
143 }
144 Word::new(output)
145 }
146}
147
148impl TryRng for RandomCoin {
152 type Error = Infallible;
153
154 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
155 Ok(self.draw_basefield().as_canonical_u64() as u32)
156 }
157
158 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
159 utils::next_u64_via_u32(self)
160 }
161
162 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
163 utils::fill_bytes_via_next_word(dest, || self.try_next_u32())
164 }
165}
166
167impl Serializable for RandomCoin {
171 fn write_into<W: ByteWriter>(&self, target: &mut W) {
172 self.state.iter().for_each(|v| v.write_into(target));
173 target.write_u8(self.current as u8);
175 }
176}
177
178impl Deserializable for RandomCoin {
179 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
180 let state = [
181 Felt::read_from(source)?,
182 Felt::read_from(source)?,
183 Felt::read_from(source)?,
184 Felt::read_from(source)?,
185 Felt::read_from(source)?,
186 Felt::read_from(source)?,
187 Felt::read_from(source)?,
188 Felt::read_from(source)?,
189 Felt::read_from(source)?,
190 Felt::read_from(source)?,
191 Felt::read_from(source)?,
192 Felt::read_from(source)?,
193 ];
194 let current = source.read_u8()? as usize;
195 if !(RATE_START..RATE_END).contains(¤t) {
196 return Err(DeserializationError::InvalidValue(
197 "current value outside of valid range".to_string(),
198 ));
199 }
200 Ok(Self { state, current })
201 }
202}
203
204#[cfg(test)]
208mod tests {
209 use super::{Deserializable, FeltRng, RandomCoin, Serializable, ZERO};
210 use crate::{ONE, Word};
211
212 #[test]
213 fn test_feltrng_felt() {
214 let mut coin = RandomCoin::new([ZERO; 4].into());
215 let output = coin.draw_element();
216
217 let mut coin = RandomCoin::new([ZERO; 4].into());
218 let expected = coin.draw_basefield();
219
220 assert_eq!(output, expected);
221 }
222
223 #[test]
224 fn test_feltrng_word() {
225 let mut coin = RandomCoin::new([ZERO; 4].into());
226 let output = coin.draw_word();
227
228 let mut coin = RandomCoin::new([ZERO; 4].into());
229 let mut expected = [ZERO; 4];
230 for o in expected.iter_mut() {
231 *o = coin.draw_basefield();
232 }
233 let expected = Word::new(expected);
234
235 assert_eq!(output, expected);
236 }
237
238 #[test]
239 fn test_feltrng_serialization() {
240 let coin1 = RandomCoin::from_parts([ONE; 12], 5);
241
242 let bytes = coin1.to_bytes();
243 let coin2 = RandomCoin::read_from_bytes(&bytes).unwrap();
244 assert_eq!(coin1, coin2);
245 }
246}