1use alloc::{string::ToString, vec::Vec};
2
3use rand_core::impls;
4
5use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
6use crate::{
7 hash::rpx::{Rpx256, RpxDigest},
8 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
9};
10
11const STATE_WIDTH: usize = Rpx256::STATE_WIDTH;
15const RATE_START: usize = Rpx256::RATE_RANGE.start;
16const RATE_END: usize = Rpx256::RATE_RANGE.end;
17const HALF_RATE_WIDTH: usize = (Rpx256::RATE_RANGE.end - Rpx256::RATE_RANGE.start) / 2;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct RpxRandomCoin {
31 state: [Felt; STATE_WIDTH],
32 current: usize,
33}
34
35impl RpxRandomCoin {
36 pub fn new(seed: Word) -> Self {
38 let mut state = [ZERO; STATE_WIDTH];
39
40 for i in 0..HALF_RATE_WIDTH {
41 state[RATE_START + i] += seed[i];
42 }
43
44 Rpx256::apply_permutation(&mut state);
46
47 RpxRandomCoin { state, current: RATE_START }
48 }
49
50 pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
55 assert!(
56 (RATE_START..RATE_END).contains(¤t),
57 "current value outside of valid range"
58 );
59 Self { state, current }
60 }
61
62 pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
64 (self.state, self.current)
65 }
66
67 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
69 <Self as RngCore>::fill_bytes(self, dest)
70 }
71
72 fn draw_basefield(&mut self) -> Felt {
73 if self.current == RATE_END {
74 Rpx256::apply_permutation(&mut self.state);
75 self.current = RATE_START;
76 }
77
78 self.current += 1;
79 self.state[self.current - 1]
80 }
81}
82
83impl RandomCoin for RpxRandomCoin {
87 type BaseField = Felt;
88 type Hasher = Rpx256;
89
90 fn new(seed: &[Self::BaseField]) -> Self {
91 let digest: Word = Rpx256::hash_elements(seed).into();
92 Self::new(digest)
93 }
94
95 fn reseed(&mut self, data: RpxDigest) {
96 self.current = RATE_START;
98
99 let data: Word = data.into();
101
102 self.state[RATE_START] += data[0];
103 self.state[RATE_START + 1] += data[1];
104 self.state[RATE_START + 2] += data[2];
105 self.state[RATE_START + 3] += data[3];
106
107 Rpx256::apply_permutation(&mut self.state);
109 }
110
111 fn check_leading_zeros(&self, value: u64) -> u32 {
112 let value = Felt::new(value);
113 let mut state_tmp = self.state;
114
115 state_tmp[RATE_START] += value;
116
117 Rpx256::apply_permutation(&mut state_tmp);
118
119 let first_rate_element = state_tmp[RATE_START].as_int();
120 first_rate_element.trailing_zeros()
121 }
122
123 fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
124 let ext_degree = E::EXTENSION_DEGREE;
125 let mut result = vec![ZERO; ext_degree];
126 for r in result.iter_mut().take(ext_degree) {
127 *r = self.draw_basefield();
128 }
129
130 let result = E::slice_from_base_elements(&result);
131 Ok(result[0])
132 }
133
134 fn draw_integers(
135 &mut self,
136 num_values: usize,
137 domain_size: usize,
138 nonce: u64,
139 ) -> Result<Vec<usize>, RandomCoinError> {
140 assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
141 assert!(num_values < domain_size, "number of values must be smaller than domain size");
142
143 let nonce = Felt::new(nonce);
145 self.state[RATE_START] += nonce;
146 Rpx256::apply_permutation(&mut self.state);
147
148 self.current = RATE_START;
150
151 let v_mask = (domain_size - 1) as u64;
153
154 let mut values = Vec::new();
156 for _ in 0..1000 {
157 let value = self.draw_basefield().as_int();
159
160 let value = (value & v_mask) as usize;
162
163 values.push(value);
164 if values.len() == num_values {
165 break;
166 }
167 }
168
169 if values.len() < num_values {
170 return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
171 }
172
173 Ok(values)
174 }
175}
176
177impl FeltRng for RpxRandomCoin {
181 fn draw_element(&mut self) -> Felt {
182 self.draw_basefield()
183 }
184
185 fn draw_word(&mut self) -> Word {
186 let mut output = [ZERO; 4];
187 for o in output.iter_mut() {
188 *o = self.draw_basefield();
189 }
190 output
191 }
192}
193
194impl RngCore for RpxRandomCoin {
198 fn next_u32(&mut self) -> u32 {
199 self.draw_basefield().as_int() as u32
200 }
201
202 fn next_u64(&mut self) -> u64 {
203 impls::next_u64_via_u32(self)
204 }
205
206 fn fill_bytes(&mut self, dest: &mut [u8]) {
207 impls::fill_bytes_via_next(self, dest)
208 }
209
210 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
211 self.fill_bytes(dest);
212 Ok(())
213 }
214}
215
216impl Serializable for RpxRandomCoin {
220 fn write_into<W: ByteWriter>(&self, target: &mut W) {
221 self.state.iter().for_each(|v| v.write_into(target));
222 target.write_u8(self.current as u8);
224 }
225}
226
227impl Deserializable for RpxRandomCoin {
228 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
229 let state = [
230 Felt::read_from(source)?,
231 Felt::read_from(source)?,
232 Felt::read_from(source)?,
233 Felt::read_from(source)?,
234 Felt::read_from(source)?,
235 Felt::read_from(source)?,
236 Felt::read_from(source)?,
237 Felt::read_from(source)?,
238 Felt::read_from(source)?,
239 Felt::read_from(source)?,
240 Felt::read_from(source)?,
241 Felt::read_from(source)?,
242 ];
243 let current = source.read_u8()? as usize;
244 if !(RATE_START..RATE_END).contains(¤t) {
245 return Err(DeserializationError::InvalidValue(
246 "current value outside of valid range".to_string(),
247 ));
248 }
249 Ok(Self { state, current })
250 }
251}
252
253#[cfg(test)]
257mod tests {
258 use super::{Deserializable, FeltRng, RpxRandomCoin, Serializable, ZERO};
259 use crate::ONE;
260
261 #[test]
262 fn test_feltrng_felt() {
263 let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
264 let output = rpxcoin.draw_element();
265
266 let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
267 let expected = rpxcoin.draw_basefield();
268
269 assert_eq!(output, expected);
270 }
271
272 #[test]
273 fn test_feltrng_word() {
274 let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
275 let output = rpxcoin.draw_word();
276
277 let mut rpocoin = RpxRandomCoin::new([ZERO; 4]);
278 let mut expected = [ZERO; 4];
279 for o in expected.iter_mut() {
280 *o = rpocoin.draw_basefield();
281 }
282
283 assert_eq!(output, expected);
284 }
285
286 #[test]
287 fn test_feltrng_serialization() {
288 let coin1 = RpxRandomCoin::from_parts([ONE; 12], 5);
289
290 let bytes = coin1.to_bytes();
291 let coin2 = RpxRandomCoin::read_from_bytes(&bytes).unwrap();
292 assert_eq!(coin1, coin2);
293 }
294}