miden_crypto/rand/
coin.rs1use alloc::string::ToString;
2
3use rand_core::impls;
4
5use super::{Felt, FeltRng, RngCore};
6use crate::{
7 Word, ZERO,
8 field::ExtensionField,
9 hash::poseidon2::Poseidon2,
10 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13const STATE_WIDTH: usize = Poseidon2::STATE_WIDTH;
17const RATE_START: usize = Poseidon2::RATE_RANGE.start;
18const RATE_END: usize = Poseidon2::RATE_RANGE.end;
19const HALF_RATE_WIDTH: usize = (Poseidon2::RATE_RANGE.end - Poseidon2::RATE_RANGE.start) / 2;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct RandomCoin {
33 state: [Felt; STATE_WIDTH],
34 current: usize,
35}
36
37impl RandomCoin {
38 pub fn new(seed: Word) -> Self {
40 let mut state = [ZERO; STATE_WIDTH];
41
42 for i in 0..HALF_RATE_WIDTH {
43 state[RATE_START + i] += seed[i];
44 }
45
46 Poseidon2::apply_permutation(&mut state);
48
49 RandomCoin { state, current: RATE_START }
50 }
51
52 pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
57 assert!(
58 (RATE_START..RATE_END).contains(¤t),
59 "current value outside of valid range"
60 );
61 Self { state, current }
62 }
63
64 pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
66 (self.state, self.current)
67 }
68
69 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
71 <Self as RngCore>::fill_bytes(self, dest)
72 }
73
74 pub fn draw_basefield(&mut self) -> Felt {
79 if self.current == RATE_END {
80 Poseidon2::apply_permutation(&mut self.state);
81 self.current = RATE_START;
82 }
83
84 self.current += 1;
85 self.state[self.current - 1]
86 }
87
88 pub fn draw(&mut self) -> Felt {
92 self.draw_basefield()
93 }
94
95 pub fn draw_ext_field<E: ExtensionField<Felt>>(&mut self) -> E {
100 let ext_degree = E::DIMENSION;
101 let mut result = vec![ZERO; ext_degree];
102 for r in result.iter_mut().take(ext_degree) {
103 *r = self.draw_basefield();
104 }
105 E::from_basis_coefficients_slice(&result).expect("failed to draw extension field element")
106 }
107
108 pub fn reseed(&mut self, data: Word) {
114 self.current = RATE_START;
116
117 self.state[RATE_START] += data[0];
119 self.state[RATE_START + 1] += data[1];
120 self.state[RATE_START + 2] += data[2];
121 self.state[RATE_START + 3] += data[3];
122
123 Poseidon2::apply_permutation(&mut self.state);
125 }
126}
127
128impl FeltRng for RandomCoin {
132 fn draw_element(&mut self) -> Felt {
133 self.draw_basefield()
134 }
135
136 fn draw_word(&mut self) -> Word {
137 let mut output = [ZERO; 4];
138 for o in output.iter_mut() {
139 *o = self.draw_basefield();
140 }
141 Word::new(output)
142 }
143}
144
145impl RngCore for RandomCoin {
149 fn next_u32(&mut self) -> u32 {
150 self.draw_basefield().as_canonical_u64() as u32
151 }
152
153 fn next_u64(&mut self) -> u64 {
154 impls::next_u64_via_u32(self)
155 }
156
157 fn fill_bytes(&mut self, dest: &mut [u8]) {
158 impls::fill_bytes_via_next(self, dest)
159 }
160}
161
162impl Serializable for RandomCoin {
166 fn write_into<W: ByteWriter>(&self, target: &mut W) {
167 self.state.iter().for_each(|v| v.write_into(target));
168 target.write_u8(self.current as u8);
170 }
171}
172
173impl Deserializable for RandomCoin {
174 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
175 let state = [
176 Felt::read_from(source)?,
177 Felt::read_from(source)?,
178 Felt::read_from(source)?,
179 Felt::read_from(source)?,
180 Felt::read_from(source)?,
181 Felt::read_from(source)?,
182 Felt::read_from(source)?,
183 Felt::read_from(source)?,
184 Felt::read_from(source)?,
185 Felt::read_from(source)?,
186 Felt::read_from(source)?,
187 Felt::read_from(source)?,
188 ];
189 let current = source.read_u8()? as usize;
190 if !(RATE_START..RATE_END).contains(¤t) {
191 return Err(DeserializationError::InvalidValue(
192 "current value outside of valid range".to_string(),
193 ));
194 }
195 Ok(Self { state, current })
196 }
197}
198
199#[cfg(test)]
203mod tests {
204 use super::{Deserializable, FeltRng, RandomCoin, Serializable, ZERO};
205 use crate::{ONE, Word};
206
207 #[test]
208 fn test_feltrng_felt() {
209 let mut coin = RandomCoin::new([ZERO; 4].into());
210 let output = coin.draw_element();
211
212 let mut coin = RandomCoin::new([ZERO; 4].into());
213 let expected = coin.draw_basefield();
214
215 assert_eq!(output, expected);
216 }
217
218 #[test]
219 fn test_feltrng_word() {
220 let mut coin = RandomCoin::new([ZERO; 4].into());
221 let output = coin.draw_word();
222
223 let mut coin = RandomCoin::new([ZERO; 4].into());
224 let mut expected = [ZERO; 4];
225 for o in expected.iter_mut() {
226 *o = coin.draw_basefield();
227 }
228 let expected = Word::new(expected);
229
230 assert_eq!(output, expected);
231 }
232
233 #[test]
234 fn test_feltrng_serialization() {
235 let coin1 = RandomCoin::from_parts([ONE; 12], 5);
236
237 let bytes = coin1.to_bytes();
238 let coin2 = RandomCoin::read_from_bytes(&bytes).unwrap();
239 assert_eq!(coin1, coin2);
240 }
241}