1use alloc::{string::ToString, vec::Vec};
2
3use rand_core::impls;
4
5use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, ZERO};
6use crate::{
7 Word,
8 hash::rpo::Rpo256,
9 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
10};
11
12const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
16const RATE_START: usize = Rpo256::RATE_RANGE.start;
17const RATE_END: usize = Rpo256::RATE_RANGE.end;
18const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct RpoRandomCoin {
32 state: [Felt; STATE_WIDTH],
33 current: usize,
34}
35
36impl RpoRandomCoin {
37 pub fn new(seed: Word) -> Self {
39 let mut state = [ZERO; STATE_WIDTH];
40
41 for i in 0..HALF_RATE_WIDTH {
42 state[RATE_START + i] += seed[i];
43 }
44
45 Rpo256::apply_permutation(&mut state);
47
48 RpoRandomCoin { state, current: RATE_START }
49 }
50
51 pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
56 assert!(
57 (RATE_START..RATE_END).contains(¤t),
58 "current value outside of valid range"
59 );
60 Self { state, current }
61 }
62
63 pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
65 (self.state, self.current)
66 }
67
68 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
70 <Self as RngCore>::fill_bytes(self, dest)
71 }
72
73 fn draw_basefield(&mut self) -> Felt {
74 if self.current == RATE_END {
75 Rpo256::apply_permutation(&mut self.state);
76 self.current = RATE_START;
77 }
78
79 self.current += 1;
80 self.state[self.current - 1]
81 }
82}
83
84impl RandomCoin for RpoRandomCoin {
88 type BaseField = Felt;
89 type Hasher = Rpo256;
90
91 fn new(seed: &[Self::BaseField]) -> Self {
92 let digest: Word = Rpo256::hash_elements(seed);
93 Self::new(digest)
94 }
95
96 fn reseed(&mut self, data: Word) {
97 self.current = RATE_START;
99
100 self.state[RATE_START] += data[0];
102 self.state[RATE_START + 1] += data[1];
103 self.state[RATE_START + 2] += data[2];
104 self.state[RATE_START + 3] += data[3];
105
106 Rpo256::apply_permutation(&mut self.state);
108 }
109
110 fn check_leading_zeros(&self, value: u64) -> u32 {
111 let value = Felt::new(value);
112 let mut state_tmp = self.state;
113
114 state_tmp[RATE_START] += value;
115
116 Rpo256::apply_permutation(&mut state_tmp);
117
118 let first_rate_element = state_tmp[RATE_START].as_int();
119 first_rate_element.trailing_zeros()
120 }
121
122 fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
123 let ext_degree = E::EXTENSION_DEGREE;
124 let mut result = vec![ZERO; ext_degree];
125 for r in result.iter_mut().take(ext_degree) {
126 *r = self.draw_basefield();
127 }
128
129 let result = E::slice_from_base_elements(&result);
130 Ok(result[0])
131 }
132
133 fn draw_integers(
134 &mut self,
135 num_values: usize,
136 domain_size: usize,
137 nonce: u64,
138 ) -> Result<Vec<usize>, RandomCoinError> {
139 assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
140 assert!(num_values < domain_size, "number of values must be smaller than domain size");
141
142 let nonce = Felt::new(nonce);
144 self.state[RATE_START] += nonce;
145 Rpo256::apply_permutation(&mut self.state);
146
147 self.current = RATE_START + 1;
151
152 let v_mask = (domain_size - 1) as u64;
154
155 let mut values = Vec::new();
157 for _ in 0..1000 {
158 let value = self.draw_basefield().as_int();
160
161 let value = (value & v_mask) as usize;
163
164 values.push(value);
165 if values.len() == num_values {
166 break;
167 }
168 }
169
170 if values.len() < num_values {
171 return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
172 }
173
174 Ok(values)
175 }
176}
177
178impl FeltRng for RpoRandomCoin {
182 fn draw_element(&mut self) -> Felt {
183 self.draw_basefield()
184 }
185
186 fn draw_word(&mut self) -> Word {
187 let mut output = [ZERO; 4];
188 for o in output.iter_mut() {
189 *o = self.draw_basefield();
190 }
191 Word::new(output)
192 }
193}
194
195impl RngCore for RpoRandomCoin {
199 fn next_u32(&mut self) -> u32 {
200 self.draw_basefield().as_int() as u32
201 }
202
203 fn next_u64(&mut self) -> u64 {
204 impls::next_u64_via_u32(self)
205 }
206
207 fn fill_bytes(&mut self, dest: &mut [u8]) {
208 impls::fill_bytes_via_next(self, dest)
209 }
210}
211
212impl Serializable for RpoRandomCoin {
216 fn write_into<W: ByteWriter>(&self, target: &mut W) {
217 self.state.iter().for_each(|v| v.write_into(target));
218 target.write_u8(self.current as u8);
220 }
221}
222
223impl Deserializable for RpoRandomCoin {
224 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
225 let state = [
226 Felt::read_from(source)?,
227 Felt::read_from(source)?,
228 Felt::read_from(source)?,
229 Felt::read_from(source)?,
230 Felt::read_from(source)?,
231 Felt::read_from(source)?,
232 Felt::read_from(source)?,
233 Felt::read_from(source)?,
234 Felt::read_from(source)?,
235 Felt::read_from(source)?,
236 Felt::read_from(source)?,
237 Felt::read_from(source)?,
238 ];
239 let current = source.read_u8()? as usize;
240 if !(RATE_START..RATE_END).contains(¤t) {
241 return Err(DeserializationError::InvalidValue(
242 "current value outside of valid range".to_string(),
243 ));
244 }
245 Ok(Self { state, current })
246 }
247}
248
249#[cfg(test)]
253mod tests {
254 use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
255 use crate::{ONE, Word};
256
257 #[test]
258 fn test_feltrng_felt() {
259 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
260 let output = rpocoin.draw_element();
261
262 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
263 let expected = rpocoin.draw_basefield();
264
265 assert_eq!(output, expected);
266 }
267
268 #[test]
269 fn test_feltrng_word() {
270 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
271 let output = rpocoin.draw_word();
272
273 let mut rpocoin = RpoRandomCoin::new([ZERO; 4].into());
274 let mut expected = [ZERO; 4];
275 for o in expected.iter_mut() {
276 *o = rpocoin.draw_basefield();
277 }
278 let expected = Word::new(expected);
279
280 assert_eq!(output, expected);
281 }
282
283 #[test]
284 fn test_feltrng_serialization() {
285 let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
286
287 let bytes = coin1.to_bytes();
288 let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
289 assert_eq!(coin1, coin2);
290 }
291}