1use alloc::{string::ToString, vec::Vec};
2
3use p3_field::{ExtensionField, PrimeField64};
4use rand_core::impls;
5
6use super::{Felt, FeltRng, RngCore, Word};
7use crate::{
8 ZERO,
9 hash::rpx::Rpx256,
10 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13const STATE_WIDTH: usize = Rpx256::STATE_WIDTH;
17const RATE_START: usize = Rpx256::RATE_RANGE.start;
18const RATE_END: usize = Rpx256::RATE_RANGE.end;
19const HALF_RATE_WIDTH: usize = (Rpx256::RATE_RANGE.end - Rpx256::RATE_RANGE.start) / 2;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct RpxRandomCoin {
33 state: [Felt; STATE_WIDTH],
34 current: usize,
35}
36
37impl RpxRandomCoin {
38 pub fn new(seed: Word) -> Self {
40 let mut state = [ZERO; STATE_WIDTH];
41
42 for i in 0..HALF_RATE_WIDTH {
43 state[RATE_START + i] += seed[i];
44 }
45
46 Rpx256::apply_permutation(&mut state);
48
49 RpxRandomCoin { state, current: RATE_START }
50 }
51
52 pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
57 assert!(
58 (RATE_START..RATE_END).contains(¤t),
59 "current value outside of valid range"
60 );
61 Self { state, current }
62 }
63
64 pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
66 (self.state, self.current)
67 }
68
69 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
71 <Self as RngCore>::fill_bytes(self, dest)
72 }
73
74 pub fn draw_basefield(&mut self) -> Felt {
75 if self.current == RATE_END {
76 Rpx256::apply_permutation(&mut self.state);
77 self.current = RATE_START;
78 }
79
80 self.current += 1;
81 self.state[self.current - 1]
82 }
83
84 pub fn draw(&mut self) -> Felt {
88 self.draw_basefield()
89 }
90
91 pub fn draw_ext_field<E: ExtensionField<Felt>>(&mut self) -> E {
92 let ext_degree = E::DIMENSION;
93 let mut result = vec![ZERO; ext_degree];
94 for r in result.iter_mut().take(ext_degree) {
95 *r = self.draw_basefield();
96 }
97 E::from_basis_coefficients_slice(&result).expect("failed to draw extension field element")
98 }
99
100 pub fn reseed(&mut self, data: Word) {
101 self.current = RATE_START;
103
104 let data: Word = (*data).into();
106
107 self.state[RATE_START] += data[0];
108 self.state[RATE_START + 1] += data[1];
109 self.state[RATE_START + 2] += data[2];
110 self.state[RATE_START + 3] += data[3];
111
112 Rpx256::apply_permutation(&mut self.state);
114 }
115
116 pub fn check_leading_zeros(&self, value: u64) -> u32 {
117 let value = Felt::new(value);
118 let mut state_tmp = self.state;
119
120 state_tmp[RATE_START] += value;
121
122 Rpx256::apply_permutation(&mut state_tmp);
123
124 let first_rate_element = state_tmp[RATE_START].as_canonical_u64();
125 first_rate_element.trailing_zeros()
126 }
127
128 pub fn draw_integers(
129 &mut self,
130 num_values: usize,
131 domain_size: usize,
132 nonce: u64,
133 ) -> Vec<usize> {
134 assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
135 assert!(num_values < domain_size, "number of values must be smaller than domain size");
136
137 let nonce = Felt::new(nonce);
139 self.state[RATE_START] += nonce;
140 Rpx256::apply_permutation(&mut self.state);
141
142 self.current = RATE_START;
144
145 let v_mask = (domain_size - 1) as u64;
147
148 let mut values = Vec::new();
150 for _ in 0..1000 {
151 let value = self.draw_basefield().as_canonical_u64();
153
154 let value = (value & v_mask) as usize;
156
157 values.push(value);
158 if values.len() == num_values {
159 break;
160 }
161 }
162
163 assert_eq!(
164 values.len(),
165 num_values,
166 "failed to draw {} integers after 1000 iterations (got {})",
167 num_values,
168 values.len()
169 );
170
171 values
172 }
173}
174
175impl FeltRng for RpxRandomCoin {
179 fn draw_element(&mut self) -> Felt {
180 self.draw_basefield()
181 }
182
183 fn draw_word(&mut self) -> Word {
184 let mut output = [ZERO; 4];
185 for o in output.iter_mut() {
186 *o = self.draw_basefield();
187 }
188 Word::new(output)
189 }
190}
191
192impl RngCore for RpxRandomCoin {
196 fn next_u32(&mut self) -> u32 {
197 self.draw_basefield().as_canonical_u64() as u32
198 }
199
200 fn next_u64(&mut self) -> u64 {
201 impls::next_u64_via_u32(self)
202 }
203
204 fn fill_bytes(&mut self, dest: &mut [u8]) {
205 impls::fill_bytes_via_next(self, dest)
206 }
207}
208
209impl Serializable for RpxRandomCoin {
213 fn write_into<W: ByteWriter>(&self, target: &mut W) {
214 self.state.iter().for_each(|v| v.write_into(target));
215 target.write_u8(self.current as u8);
217 }
218}
219
220impl Deserializable for RpxRandomCoin {
221 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
222 let state = [
223 Felt::read_from(source)?,
224 Felt::read_from(source)?,
225 Felt::read_from(source)?,
226 Felt::read_from(source)?,
227 Felt::read_from(source)?,
228 Felt::read_from(source)?,
229 Felt::read_from(source)?,
230 Felt::read_from(source)?,
231 Felt::read_from(source)?,
232 Felt::read_from(source)?,
233 Felt::read_from(source)?,
234 Felt::read_from(source)?,
235 ];
236 let current = source.read_u8()? as usize;
237 if !(RATE_START..RATE_END).contains(¤t) {
238 return Err(DeserializationError::InvalidValue(
239 "current value outside of valid range".to_string(),
240 ));
241 }
242 Ok(Self { state, current })
243 }
244}
245
246#[cfg(test)]
250mod tests {
251 use super::{Deserializable, FeltRng, RpxRandomCoin, Serializable, ZERO};
252 use crate::{ONE, Word};
253
254 #[test]
255 fn test_feltrng_felt() {
256 let mut rpxcoin = RpxRandomCoin::new([ZERO; 4].into());
257 let output = rpxcoin.draw_element();
258
259 let mut rpxcoin = RpxRandomCoin::new([ZERO; 4].into());
260 let expected = rpxcoin.draw_basefield();
261
262 assert_eq!(output, expected);
263 }
264
265 #[test]
266 fn test_feltrng_word() {
267 let mut rpxcoin = RpxRandomCoin::new([ZERO; 4].into());
268 let output = rpxcoin.draw_word();
269
270 let mut rpocoin = RpxRandomCoin::new([ZERO; 4].into());
271 let mut expected = [ZERO; 4];
272 for o in expected.iter_mut() {
273 *o = rpocoin.draw_basefield();
274 }
275 let expected = Word::new(expected);
276
277 assert_eq!(output, expected);
278 }
279
280 #[test]
281 fn test_feltrng_serialization() {
282 let coin1 = RpxRandomCoin::from_parts([ONE; 12], 5);
283
284 let bytes = coin1.to_bytes();
285 let coin2 = RpxRandomCoin::read_from_bytes(&bytes).unwrap();
286 assert_eq!(coin1, coin2);
287 }
288}