1use std::cmp::Ordering;
8
9use serde::{Deserialize, Serialize};
10use serde_with::serde_as;
11use tracing::trace;
12
13use crate::peer_message::usig_message::checkpoint::CheckpointHash;
14use crate::{
15 peer_message::usig_message::checkpoint::{Checkpoint, CheckpointCertificate},
16 Config,
17};
18
19use super::CollectorMessages;
20
21pub(crate) type CollectorCheckpoints<Sig> = CollectorMessages<KeyCheckpoints, Checkpoint<Sig>>;
30
31#[serde_as]
34#[derive(Debug, Clone, Hash, PartialEq, Serialize, Deserialize, Eq)]
35pub(crate) struct KeyCheckpoints {
36 #[serde_as(as = "serde_with::Bytes")]
37 state_hash: CheckpointHash,
38 total_amount_accepted_batches: u64,
39}
40
41impl PartialOrd for KeyCheckpoints {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 self.total_amount_accepted_batches
46 .partial_cmp(&other.total_amount_accepted_batches)
47 }
48}
49
50impl Ord for KeyCheckpoints {
51 fn cmp(&self, other: &Self) -> Ordering {
53 self.total_amount_accepted_batches
54 .cmp(&other.total_amount_accepted_batches)
55 }
56}
57
58impl<Sig: Clone> CollectorCheckpoints<Sig> {
59 pub(crate) fn collect_checkpoint(&mut self, msg: Checkpoint<Sig>) -> u64 {
61 trace!("Collecting Checkpoint (origin: {:?}, counter latest accepted Prepare: {:?}, amount accepted batches: {:?}) ...", msg.origin, msg.counter_latest_prep, msg.total_amount_accepted_batches);
62 let key = KeyCheckpoints {
63 state_hash: msg.state_hash,
64 total_amount_accepted_batches: msg.total_amount_accepted_batches,
65 };
66 let amount_collected = self.collect(msg.clone(), msg.origin, key);
67 trace!("Successfully collected Checkpoint (origin: {:?}, counter of latest accepted Prepare: {:?}, amount accepted batches: {:?}).", msg.origin, msg.counter_latest_prep, msg.total_amount_accepted_batches);
68 amount_collected
69 }
70
71 pub(crate) fn retrieve_collected_checkpoints(
100 &mut self,
101 msg: &Checkpoint<Sig>,
102 config: &Config,
103 ) -> Option<CheckpointCertificate<Sig>> {
104 trace!(
105 "Retrieving Checkpoints (amount accepted batches: {:?}) from collector ...",
106 msg.total_amount_accepted_batches
107 );
108 let key = KeyCheckpoints {
109 state_hash: msg.state_hash,
110 total_amount_accepted_batches: msg.total_amount_accepted_batches,
111 };
112 let retrieved = self.retrieve(key, config)?;
113
114 let cert = CheckpointCertificate {
115 my_checkpoint: retrieved.0,
116 other_checkpoints: retrieved.1,
117 };
118 Some(cert)
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use rstest::rstest;
125
126 use super::CollectorCheckpoints;
127 use std::num::NonZeroU64;
128
129 use rand::Rng;
130 use usig::{Count, ReplicaId};
131
132 use crate::peer_message_processor::collector::collector_checkpoints::KeyCheckpoints;
133 use crate::tests::{
134 create_default_configs_for_replicas, get_random_included_index,
135 get_shuffled_remaining_replicas,
136 };
137 use crate::{
138 peer_message::usig_message::checkpoint::test::{
139 create_checkpoint, create_rand_state_hash_diff,
140 },
141 tests::{
142 create_attested_usigs_for_replicas, create_rand_number_diff, create_random_state_hash,
143 get_random_replica_id,
144 },
145 };
146
147 #[rstest]
153 fn collect_checkpoint_single(#[values(3, 4, 5, 6, 7, 8, 9, 10)] n: u64) {
154 let n_parsed = NonZeroU64::new(n).unwrap();
155
156 let mut rng = rand::thread_rng();
157 let origin = get_random_replica_id(n_parsed, &mut rng);
158 let counter_latest_prep = Count(rng.gen());
159 let total_amount_accepted_batches: u64 = rng.gen();
160 let state_hash = create_random_state_hash();
161
162 let mut usigs = create_attested_usigs_for_replicas(n_parsed, Vec::new());
163
164 let usig_origin = usigs.get_mut(&origin).unwrap();
165
166 let checkpoint = create_checkpoint(
167 origin,
168 state_hash,
169 counter_latest_prep,
170 total_amount_accepted_batches,
171 usig_origin,
172 );
173
174 let mut collector = CollectorCheckpoints::new();
175 collector.collect_checkpoint(checkpoint.clone());
176
177 assert_eq!(collector.0.len(), 1);
178
179 let key = KeyCheckpoints {
180 state_hash,
181 total_amount_accepted_batches,
182 };
183
184 assert!(collector.0.get(&key).is_some());
185 let collected_checkpoints = collector.0.get(&key).unwrap();
186 assert!(collected_checkpoints.get(&checkpoint.origin).is_some());
187 let collected_checkpoint = collected_checkpoints.get(&checkpoint.origin).unwrap();
188 assert_eq!(collected_checkpoint.origin, checkpoint.origin);
189 assert_eq!(collected_checkpoint.state_hash, checkpoint.state_hash);
190 assert_eq!(
191 collected_checkpoint.counter_latest_prep,
192 checkpoint.counter_latest_prep
193 );
194 assert_eq!(
195 collected_checkpoint.total_amount_accepted_batches,
196 checkpoint.total_amount_accepted_batches
197 );
198 }
199
200 #[rstest]
208 fn retrieve_checkpoint(#[values(3, 4, 5, 6, 7, 8, 9, 10)] n: u64) {
209 let n_parsed = NonZeroU64::new(n).unwrap();
210 let t = n / 2;
211
212 let mut rng = rand::thread_rng();
213 let counter_latest_prep = Count(rng.gen());
214 let total_amount_accepted_batches: u64 = rng.gen();
215 let state_hash = create_random_state_hash();
216
217 let configs = create_default_configs_for_replicas(n_parsed, t);
218 let mut usigs = create_attested_usigs_for_replicas(n_parsed, Vec::new());
219
220 let shuffled_replicas = get_shuffled_remaining_replicas(n_parsed, None, &mut rng);
221 let shuffled_iter = shuffled_replicas.iter().take((t + 1).try_into().unwrap());
222 let shuffled_set: Vec<ReplicaId> = shuffled_iter.clone().cloned().collect();
223
224 let origin_index = get_random_included_index(shuffled_iter.len(), None, &mut rng);
225 let origin = shuffled_set[origin_index];
226 let config_origin = configs.get(&origin).unwrap();
227
228 let mut collector = CollectorCheckpoints::new();
229
230 let mut last_collected_checkpoint = None;
231
232 let mut counter_collected = 0;
233 for rep_id in shuffled_iter {
234 let usig_rep_id = usigs.get_mut(rep_id).unwrap();
235
236 let checkpoint = create_checkpoint(
237 *rep_id,
238 state_hash,
239 counter_latest_prep,
240 total_amount_accepted_batches,
241 usig_rep_id,
242 );
243
244 collector.collect_checkpoint(checkpoint.clone());
245 counter_collected += 1;
246 last_collected_checkpoint = Some(checkpoint.clone());
247
248 if counter_collected <= t.try_into().unwrap() {
249 let cp_cert = collector.retrieve_collected_checkpoints(
250 &last_collected_checkpoint.clone().unwrap(),
251 config_origin,
252 );
253 assert!(cp_cert.is_none());
254 }
255 }
256
257 assert!(last_collected_checkpoint.is_some());
258
259 let cp_cert = collector
260 .retrieve_collected_checkpoints(&last_collected_checkpoint.unwrap(), config_origin);
261 assert!(cp_cert.is_some());
262 let cp_cert = cp_cert.unwrap();
263
264 assert_eq!(cp_cert.my_checkpoint.origin, origin);
265 assert_eq!(cp_cert.my_checkpoint.state_hash, state_hash);
266 assert_eq!(
267 cp_cert.my_checkpoint.counter_latest_prep,
268 counter_latest_prep
269 );
270 assert_eq!(
271 cp_cert.my_checkpoint.total_amount_accepted_batches,
272 total_amount_accepted_batches
273 );
274 }
275
276 #[rstest]
286 fn collect_diff_checkpoints_state_hash(#[values(3, 4, 5, 6, 7, 8, 9, 10)] n: u64) {
287 let n_parsed = NonZeroU64::new(n).unwrap();
288
289 let mut rng = rand::thread_rng();
290 let origin = get_random_replica_id(n_parsed, &mut rng);
291 let counter_latest_prep = Count(rng.gen());
292 let total_amount_accepted_batches: u64 = rng.gen();
293 let state_hash = create_random_state_hash();
294
295 let mut usigs = create_attested_usigs_for_replicas(n_parsed, Vec::new());
296 let usig_origin = usigs.get_mut(&origin).unwrap();
297
298 let mut collector = CollectorCheckpoints::new();
299
300 let checkpoint = create_checkpoint(
301 origin,
302 state_hash,
303 counter_latest_prep,
304 total_amount_accepted_batches,
305 usig_origin,
306 );
307 collector.collect_checkpoint(checkpoint.clone());
308
309 let state_hash_diff = create_rand_state_hash_diff(state_hash, &mut rng);
310
311 let checkpoint_diff = create_checkpoint(
312 origin,
313 state_hash_diff,
314 counter_latest_prep,
315 total_amount_accepted_batches,
316 usig_origin,
317 );
318 collector.collect_checkpoint(checkpoint_diff.clone());
319
320 assert_eq!(collector.0.len(), 2);
321
322 let key = KeyCheckpoints {
324 state_hash,
325 total_amount_accepted_batches,
326 };
327 assert!(collector.0.get(&key).is_some());
328 let collected_checkpoints = collector.0.get(&key).unwrap();
329 assert_eq!(collected_checkpoints.len(), 1);
330 assert!(collected_checkpoints.get(&checkpoint.origin).is_some());
331 let collected_checkpoint = collected_checkpoints.get(&checkpoint.origin).unwrap();
332 assert_eq!(collected_checkpoint.origin, checkpoint.origin);
333 assert_eq!(collected_checkpoint.state_hash, checkpoint.state_hash);
334 assert_eq!(
335 collected_checkpoint.counter_latest_prep,
336 checkpoint.counter_latest_prep
337 );
338 assert_eq!(
339 collected_checkpoint.total_amount_accepted_batches,
340 checkpoint.total_amount_accepted_batches
341 );
342
343 let key_diff = KeyCheckpoints {
345 state_hash: state_hash_diff,
346 total_amount_accepted_batches,
347 };
348 assert!(collector.0.get(&key_diff).is_some());
349 let collected_checkpoints = collector.0.get(&key_diff).unwrap();
350 assert_eq!(collected_checkpoints.len(), 1);
351 assert!(collected_checkpoints.get(&checkpoint_diff.origin).is_some());
352 let collected_checkpoint = collected_checkpoints.get(&checkpoint_diff.origin).unwrap();
353 assert_eq!(collected_checkpoint.origin, checkpoint_diff.origin);
354 assert_eq!(collected_checkpoint.state_hash, checkpoint_diff.state_hash);
355 assert_eq!(
356 collected_checkpoint.counter_latest_prep,
357 checkpoint_diff.counter_latest_prep
358 );
359 assert_eq!(
360 collected_checkpoint.total_amount_accepted_batches,
361 checkpoint_diff.total_amount_accepted_batches
362 );
363 }
364
365 #[rstest]
375 fn collect_diff_checkpoints_amount_accepted_batches(#[values(3, 4, 5, 6, 7, 8, 9, 10)] n: u64) {
376 let n_parsed = NonZeroU64::new(n).unwrap();
377
378 let mut rng = rand::thread_rng();
379 let origin = get_random_replica_id(n_parsed, &mut rng);
380 let counter_latest_prep = Count(rng.gen());
381 let total_amount_accepted_batches: u64 = rng.gen();
382 let state_hash = create_random_state_hash();
383
384 let mut usigs = create_attested_usigs_for_replicas(n_parsed, Vec::new());
385 let usig_origin = usigs.get_mut(&origin).unwrap();
386
387 let mut collector = CollectorCheckpoints::new();
388
389 let checkpoint = create_checkpoint(
390 origin,
391 state_hash,
392 counter_latest_prep,
393 total_amount_accepted_batches,
394 usig_origin,
395 );
396 collector.collect_checkpoint(checkpoint.clone());
397
398 let total_amount_accepted_batches_diff =
399 create_rand_number_diff(total_amount_accepted_batches, &mut rng);
400
401 let checkpoint_diff = create_checkpoint(
402 origin,
403 state_hash,
404 counter_latest_prep,
405 total_amount_accepted_batches_diff,
406 usig_origin,
407 );
408 collector.collect_checkpoint(checkpoint_diff.clone());
409
410 assert_eq!(collector.0.len(), 2);
411
412 let key = KeyCheckpoints {
414 state_hash,
415 total_amount_accepted_batches,
416 };
417 assert!(collector.0.get(&key).is_some());
418 let collected_checkpoints = collector.0.get(&key).unwrap();
419 assert_eq!(collected_checkpoints.len(), 1);
420 assert!(collected_checkpoints.get(&checkpoint.origin).is_some());
421 let collected_checkpoint = collected_checkpoints.get(&checkpoint.origin).unwrap();
422 assert_eq!(collected_checkpoint.origin, checkpoint.origin);
423 assert_eq!(collected_checkpoint.state_hash, checkpoint.state_hash);
424 assert_eq!(
425 collected_checkpoint.counter_latest_prep,
426 checkpoint.counter_latest_prep
427 );
428 assert_eq!(
429 collected_checkpoint.total_amount_accepted_batches,
430 checkpoint.total_amount_accepted_batches
431 );
432
433 let key_diff = KeyCheckpoints {
435 state_hash,
436 total_amount_accepted_batches: total_amount_accepted_batches_diff,
437 };
438 assert!(collector.0.get(&key_diff).is_some());
439 let collected_checkpoints = collector.0.get(&key_diff).unwrap();
440 assert_eq!(collected_checkpoints.len(), 1);
441 assert!(collected_checkpoints.get(&checkpoint_diff.origin).is_some());
442 let collected_checkpoint = collected_checkpoints.get(&checkpoint_diff.origin).unwrap();
443 assert_eq!(collected_checkpoint.origin, checkpoint_diff.origin);
444 assert_eq!(collected_checkpoint.state_hash, checkpoint_diff.state_hash);
445 assert_eq!(
446 collected_checkpoint.counter_latest_prep,
447 checkpoint_diff.counter_latest_prep
448 );
449 assert_eq!(
450 collected_checkpoint.total_amount_accepted_batches,
451 checkpoint_diff.total_amount_accepted_batches
452 );
453 }
454
455 #[rstest]
466 fn collect_diff_checkpoints_all_state(#[values(3, 4, 5, 6, 7, 8, 9, 10)] n: u64) {
467 let n_parsed = NonZeroU64::new(n).unwrap();
468
469 let mut rng = rand::thread_rng();
470 let origin = get_random_replica_id(n_parsed, &mut rng);
471 let counter_latest_prep = Count(rng.gen());
472 let total_amount_accepted_batches: u64 = rng.gen();
473 let state_hash = create_random_state_hash();
474
475 let mut usigs = create_attested_usigs_for_replicas(n_parsed, Vec::new());
476 let usig_origin = usigs.get_mut(&origin).unwrap();
477
478 let mut collector = CollectorCheckpoints::new();
479
480 let checkpoint = create_checkpoint(
481 origin,
482 state_hash,
483 counter_latest_prep,
484 total_amount_accepted_batches,
485 usig_origin,
486 );
487 collector.collect_checkpoint(checkpoint.clone());
488
489 let state_hash_diff = create_rand_state_hash_diff(state_hash, &mut rng);
490 let total_amount_accepted_batches_diff =
491 create_rand_number_diff(total_amount_accepted_batches, &mut rng);
492
493 let checkpoint_diff = create_checkpoint(
494 origin,
495 state_hash_diff,
496 counter_latest_prep,
497 total_amount_accepted_batches_diff,
498 usig_origin,
499 );
500 collector.collect_checkpoint(checkpoint_diff.clone());
501
502 assert_eq!(collector.0.len(), 2);
503
504 let key = KeyCheckpoints {
506 state_hash,
507 total_amount_accepted_batches,
508 };
509 assert!(collector.0.get(&key).is_some());
510 let collected_checkpoints = collector.0.get(&key).unwrap();
511 assert_eq!(collected_checkpoints.len(), 1);
512 assert!(collected_checkpoints.get(&checkpoint.origin).is_some());
513 let collected_checkpoint = collected_checkpoints.get(&checkpoint.origin).unwrap();
514 assert_eq!(collected_checkpoint.origin, checkpoint.origin);
515 assert_eq!(collected_checkpoint.state_hash, checkpoint.state_hash);
516 assert_eq!(
517 collected_checkpoint.counter_latest_prep,
518 checkpoint.counter_latest_prep
519 );
520 assert_eq!(
521 collected_checkpoint.total_amount_accepted_batches,
522 checkpoint.total_amount_accepted_batches
523 );
524
525 let key_diff = KeyCheckpoints {
527 state_hash: state_hash_diff,
528 total_amount_accepted_batches: total_amount_accepted_batches_diff,
529 };
530 assert!(collector.0.get(&key_diff).is_some());
531 let collected_checkpoints = collector.0.get(&key_diff).unwrap();
532 assert_eq!(collected_checkpoints.len(), 1);
533 assert!(collected_checkpoints.get(&checkpoint_diff.origin).is_some());
534 let collected_checkpoint = collected_checkpoints.get(&checkpoint_diff.origin).unwrap();
535 assert_eq!(collected_checkpoint.origin, checkpoint_diff.origin);
536 assert_eq!(collected_checkpoint.state_hash, checkpoint_diff.state_hash);
537 assert_eq!(
538 collected_checkpoint.counter_latest_prep,
539 checkpoint_diff.counter_latest_prep
540 );
541 assert_eq!(
542 collected_checkpoint.total_amount_accepted_batches,
543 checkpoint_diff.total_amount_accepted_batches
544 );
545 }
546}