1use alloc::{string::ToString, vec::Vec};
2
3use p3_field::{ExtensionField, PrimeField64};
4use rand_core::impls;
5
6use super::{Felt, FeltRng, RngCore};
7use crate::{
8 Word, ZERO,
9 hash::rpo::Rpo256,
10 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
17const RATE_START: usize = Rpo256::RATE_RANGE.start;
18const RATE_END: usize = Rpo256::RATE_RANGE.end;
19const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct RpoRandomCoin {
33 state: [Felt; STATE_WIDTH],
34 current: usize,
35}
36
37impl RpoRandomCoin {
38 pub fn new(seed: Word) -> Self {
40 let mut state = [ZERO; STATE_WIDTH];
41
42 for i in 0..HALF_RATE_WIDTH {
43 state[RATE_START + i] += seed[i];
44 }
45
46 Rpo256::apply_permutation(&mut state);
48
49 RpoRandomCoin { state, current: RATE_START }
50 }
51
52 pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
57 assert!(
58 (RATE_START..RATE_END).contains(¤t),
59 "current value outside of valid range"
60 );
61 Self { state, current }
62 }
63
64 pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
66 (self.state, self.current)
67 }
68
69 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
71 <Self as RngCore>::fill_bytes(self, dest)
72 }
73
74 pub fn draw_basefield(&mut self) -> Felt {
79 if self.current == RATE_END {
80 Rpo256::apply_permutation(&mut self.state);
81 self.current = RATE_START;
82 }
83
84 self.current += 1;
85 self.state[self.current - 1]
86 }
87
88 pub fn draw(&mut self) -> Felt {
92 self.draw_basefield()
93 }
94
95 pub fn draw_ext_field<E: ExtensionField<Felt>>(&mut self) -> E {
100 let ext_degree = E::DIMENSION;
101 let mut result = vec![ZERO; ext_degree];
102 for r in result.iter_mut().take(ext_degree) {
103 *r = self.draw_basefield();
104 }
105 E::from_basis_coefficients_slice(&result).expect("failed to draw extension field element")
106 }
107
108 pub fn reseed(&mut self, data: Word) {
114 self.current = RATE_START;
116
117 self.state[RATE_START] += data[0];
119 self.state[RATE_START + 1] += data[1];
120 self.state[RATE_START + 2] += data[2];
121 self.state[RATE_START + 3] += data[3];
122
123 Rpo256::apply_permutation(&mut self.state);
125 }
126
127 pub fn check_leading_zeros(&self, value: u64) -> u32 {
133 let value = Felt::new(value);
134 let mut state_tmp = self.state;
135
136 state_tmp[RATE_START] += value;
137
138 Rpo256::apply_permutation(&mut state_tmp);
139
140 let first_rate_element = state_tmp[RATE_START].as_canonical_u64();
141 first_rate_element.trailing_zeros()
142 }
143
144 pub fn draw_integers(
157 &mut self,
158 num_values: usize,
159 domain_size: usize,
160 nonce: u64,
161 ) -> Vec<usize> {
162 assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
163 assert!(num_values < domain_size, "number of values must be smaller than domain size");
164
165 let nonce = Felt::new(nonce);
167 self.state[RATE_START] += nonce;
168 Rpo256::apply_permutation(&mut self.state);
169
170 self.current = RATE_START + 1;
174
175 let v_mask = (domain_size - 1) as u64;
177
178 let mut values = Vec::new();
180 for _ in 0..1000 {
181 let value = self.draw_basefield().as_canonical_u64();
183
184 let value = (value & v_mask) as usize;
186
187 values.push(value);
188 if values.len() == num_values {
189 break;
190 }
191 }
192
193 assert_eq!(
194 values.len(),
195 num_values,
196 "failed to draw {} integers after 1000 iterations (got {})",
197 num_values,
198 values.len()
199 );
200
201 values
202 }
203}
204
205impl FeltRng for RpoRandomCoin {
209 fn draw_element(&mut self) -> Felt {
210 self.draw_basefield()
211 }
212
213 fn draw_word(&mut self) -> Word {
214 let mut output = [ZERO; 4];
215 for o in output.iter_mut() {
216 *o = self.draw_basefield();
217 }
218 Word::new(output)
219 }
220}
221
222impl RngCore for RpoRandomCoin {
226 fn next_u32(&mut self) -> u32 {
227 self.draw_basefield().as_canonical_u64() as u32
228 }
229
230 fn next_u64(&mut self) -> u64 {
231 impls::next_u64_via_u32(self)
232 }
233
234 fn fill_bytes(&mut self, dest: &mut [u8]) {
235 impls::fill_bytes_via_next(self, dest)
236 }
237}
238
239impl Serializable for RpoRandomCoin {
243 fn write_into<W: ByteWriter>(&self, target: &mut W) {
244 self.state.iter().for_each(|v| v.write_into(target));
245 target.write_u8(self.current as u8);
247 }
248}
249
250impl Deserializable for RpoRandomCoin {
251 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
252 let state = [
253 Felt::read_from(source)?,
254 Felt::read_from(source)?,
255 Felt::read_from(source)?,
256 Felt::read_from(source)?,
257 Felt::read_from(source)?,
258 Felt::read_from(source)?,
259 Felt::read_from(source)?,
260 Felt::read_from(source)?,
261 Felt::read_from(source)?,
262 Felt::read_from(source)?,
263 Felt::read_from(source)?,
264 Felt::read_from(source)?,
265 ];
266 let current = source.read_u8()? as usize;
267 if !(RATE_START..RATE_END).contains(¤t) {
268 return Err(DeserializationError::InvalidValue(
269 "current value outside of valid range".to_string(),
270 ));
271 }
272 Ok(Self { state, current })
273 }
274}
275
276#[cfg(test)]
280mod tests {
281 use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
282 use crate::{ONE, Word};
283
284 #[test]
285 fn test_feltrng_felt() {
286 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
287 let output = rpocoin.draw_element();
288
289 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
290 let expected = rpocoin.draw_basefield();
291
292 assert_eq!(output, expected);
293 }
294
295 #[test]
296 fn test_feltrng_word() {
297 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
298 let output = rpocoin.draw_word();
299
300 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
301 let mut expected = [ZERO; 4];
302 for o in expected.iter_mut() {
303 *o = rpocoin.draw_basefield();
304 }
305 let expected = Word::new(expected);
306
307 assert_eq!(output, expected);
308 }
309
310 #[test]
311 fn test_feltrng_serialization() {
312 let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
313
314 let bytes = coin1.to_bytes();
315 let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
316 assert_eq!(coin1, coin2);
317 }
318}