1use std::cmp::Ordering;
22use std::collections::{HashMap, HashSet};
23use std::fmt::Debug;
24use std::hash::Hash;
25use std::sync::{Arc, LockResult, PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard};
26
27use crate::estimator::Estimator;
28use crate::justification::LatestMessages;
29use crate::message::Message;
30use crate::util::id::Id;
31use crate::util::weight::{WeightUnit, Zero};
32
33pub trait ValidatorName: Hash + Clone + Ord + Eq + Send + Sync + Debug + serde::Serialize {}
65
66impl ValidatorName for u8 {}
68impl ValidatorName for u32 {}
69impl ValidatorName for u64 {}
70impl ValidatorName for i8 {}
71impl ValidatorName for i32 {}
72impl ValidatorName for i64 {}
73
74#[derive(Debug, Clone)]
118pub struct State<E, U>
119where
120 E: Estimator,
121 U: WeightUnit,
122{
123 pub(crate) state_fault_weight: U,
125 pub(crate) thr: U,
127 pub(crate) validators_weights: Weights<E::ValidatorName, U>,
129 pub(crate) latest_messages: LatestMessages<E>,
130 pub(crate) equivocators: HashSet<E::ValidatorName>,
131}
132
133pub enum Error<'rwlock, T> {
141 WriteLockError(PoisonError<RwLockWriteGuard<'rwlock, T>>),
142 ReadLockError(PoisonError<RwLockReadGuard<'rwlock, T>>),
143 NotFound,
144}
145
146impl<'rwlock, T> std::error::Error for Error<'rwlock, T> {}
147
148impl<'rwlock, T> std::fmt::Display for Error<'rwlock, T> {
149 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
150 match self {
151 Error::NotFound => writeln!(f, "Validator weight not found"),
152 Error::WriteLockError(p_err) => std::fmt::Display::fmt(p_err, f),
153 Error::ReadLockError(p_err) => std::fmt::Display::fmt(p_err, f),
154 }
155 }
156}
157
158impl<'rwlock, T> std::fmt::Debug for Error<'rwlock, T> {
159 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160 match self {
161 Error::NotFound => writeln!(f, "Validator weight not found"),
162 Error::WriteLockError(p_err) => std::fmt::Display::fmt(p_err, f),
163 Error::ReadLockError(p_err) => std::fmt::Display::fmt(p_err, f),
164 }
165 }
166}
167
168impl<E, U> State<E, U>
169where
170 E: Estimator,
171 U: WeightUnit,
172{
173 pub fn new(
174 validators_weights: Weights<E::ValidatorName, U>,
175 state_fault_weight: U,
176 latest_messages: LatestMessages<E>,
177 thr: U,
178 equivocators: HashSet<E::ValidatorName>,
179 ) -> Self {
180 State {
181 validators_weights,
182 equivocators,
183 state_fault_weight,
184 thr,
185 latest_messages,
186 }
187 }
188
189 pub fn new_with_default_state(
190 default_state: Self,
191 validators_weights: Option<Weights<E::ValidatorName, U>>,
192 state_fault_weight: Option<U>,
193 latest_messages: Option<LatestMessages<E>>,
194 thr: Option<U>,
195 equivocators: Option<HashSet<E::ValidatorName>>,
196 ) -> Self {
197 State {
198 validators_weights: validators_weights.unwrap_or(default_state.validators_weights),
199 state_fault_weight: state_fault_weight.unwrap_or(default_state.state_fault_weight),
200 latest_messages: latest_messages.unwrap_or(default_state.latest_messages),
201 thr: thr.unwrap_or(default_state.thr),
202 equivocators: equivocators.unwrap_or(default_state.equivocators),
203 }
204 }
205
206 pub fn update(&mut self, messages: &[&Message<E>]) -> bool {
211 messages.iter().fold(true, |acc, message| {
212 let sender = message.sender();
213 let weight = self
214 .validators_weights
215 .weight(sender)
216 .unwrap_or(U::INFINITY);
217
218 let update_success = self.latest_messages.update(message);
219
220 if self.latest_messages.equivocate(message)
221 && weight + self.state_fault_weight <= self.thr
222 && self.equivocators.insert(sender.clone())
223 {
224 self.state_fault_weight += weight;
225 }
226
227 acc && update_success
228 })
229 }
230
231 pub fn equivocators(&self) -> &HashSet<E::ValidatorName> {
232 &self.equivocators
233 }
234
235 pub fn validators_weights(&self) -> &Weights<E::ValidatorName, U> {
236 &self.validators_weights
237 }
238
239 pub fn latests_messages(&self) -> &LatestMessages<E> {
240 &self.latest_messages
241 }
242
243 pub fn latests_messages_as_mut(&mut self) -> &mut LatestMessages<E> {
244 &mut self.latest_messages
245 }
246
247 pub fn fault_weight(&self) -> U {
248 self.state_fault_weight
249 }
250
251 pub fn sort_by_faultweight<'z>(
255 &self,
256 messages: &HashSet<&'z Message<E>>,
257 ) -> Vec<&'z Message<E>> {
258 let mut messages_sorted_by_faultw: Vec<_> = messages
259 .iter()
260 .filter_map(|&message| {
261 let sender = message.sender();
263 if !self.equivocators.contains(sender) && self.latest_messages.equivocate(message) {
264 self.validators_weights
265 .weight(sender)
266 .map(|weight| (message, weight))
267 .ok()
268 } else {
269 Some((message, <U as Zero<U>>::ZERO))
270 }
271 })
272 .collect();
273
274 messages_sorted_by_faultw.sort_unstable_by(|(m0, w0), (m1, w1)| match w0.partial_cmp(w1) {
275 None | Some(Ordering::Equal) => m0.id().cmp(&m1.id()),
276 Some(ord) => ord,
277 });
278
279 messages_sorted_by_faultw
280 .iter()
281 .map(|(message, _)| message)
282 .cloned()
283 .collect()
284 }
285}
286
287#[derive(Clone, Debug)]
305pub struct Weights<V: self::ValidatorName, U: WeightUnit>(Arc<RwLock<HashMap<V, U>>>);
306
307impl<V: self::ValidatorName, U: WeightUnit> Weights<V, U> {
308 pub fn new(weights: HashMap<V, U>) -> Self {
313 Weights(Arc::new(RwLock::new(weights)))
314 }
315
316 fn read(&self) -> LockResult<RwLockReadGuard<HashMap<V, U>>> {
318 self.0.read()
319 }
320
321 fn write(&self) -> LockResult<RwLockWriteGuard<HashMap<V, U>>> {
323 self.0.write()
324 }
325
326 pub fn insert(&mut self, validator: V, weight: U) -> Result<bool, Error<HashMap<V, U>>> {
328 self.write()
329 .map_err(Error::WriteLockError)
330 .map(|mut hash_map| {
331 hash_map.insert(validator, weight);
332 true
333 })
334 }
335
336 pub fn validators(&self) -> Result<HashSet<V>, Error<HashMap<V, U>>> {
339 self.read().map_err(Error::ReadLockError).map(|hash_map| {
340 hash_map
341 .iter()
342 .filter_map(|(validator, &weight)| {
343 if weight > <U as Zero<U>>::ZERO {
344 Some(validator.clone())
345 } else {
346 None
347 }
348 })
349 .collect()
350 })
351 }
352
353 pub fn weight(&self, validator: &V) -> Result<U, Error<HashMap<V, U>>> {
356 self.read()
357 .map_err(Error::ReadLockError)
358 .and_then(|hash_map| {
359 hash_map
360 .get(validator)
361 .map(Clone::clone)
362 .ok_or(Error::NotFound)
363 })
364 }
365
366 pub fn sum_weight_validators(&self, validators: &HashSet<V>) -> U {
368 validators
369 .iter()
370 .fold(<U as Zero<U>>::ZERO, |acc, validator| {
371 acc + self.weight(validator).unwrap_or(U::NAN)
372 })
373 }
374
375 pub fn sum_all_weights(&self) -> U {
377 if let Ok(validators) = self.validators() {
378 self.sum_weight_validators(&validators)
379 } else {
380 U::NAN
381 }
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 use crate::VoteCount;
390
391 use std::iter::FromIterator;
392
393 #[test]
394 fn weights_validators_include_positive_weight() {
395 let weights = Weights::new(vec![(0, 1.0), (1, 1.0), (2, 1.0)].into_iter().collect());
396 assert_eq!(
397 weights.validators().unwrap(),
398 vec![0, 1, 2].into_iter().collect(),
399 "should include validators with valid, positive weight"
400 );
401 }
402
403 #[test]
404 fn weights_validators_exclude_zero_weighted_validators() {
405 let weights = Weights::new(vec![(0, 0.0), (1, 1.0), (2, 1.0)].into_iter().collect());
406 assert_eq!(
407 weights.validators().unwrap(),
408 vec![1, 2].into_iter().collect(),
409 "should exclude validators with 0 weight"
410 );
411 }
412
413 #[test]
414 fn weights_validators_exclude_negative_weights() {
415 let weights = Weights::new(vec![(0, 1.0), (1, -1.0), (2, 1.0)].into_iter().collect());
416 assert_eq!(
417 weights.validators().unwrap(),
418 vec![0, 2].into_iter().collect(),
419 "should exclude validators with negative weight"
420 );
421 }
422
423 #[test]
424 fn weights_validators_exclude_nan_weights() {
425 let weights = Weights::new(
426 vec![(0, f32::NAN), (1, 1.0), (2, 1.0)]
427 .into_iter()
428 .collect(),
429 );
430 assert_eq!(
431 weights.validators().unwrap(),
432 vec![1, 2].into_iter().collect(),
433 "should exclude validators with NAN weight"
434 );
435 }
436
437 #[test]
438 fn weights_validators_include_infinity_weighted_validators() {
439 let weights = Weights::new(
440 vec![(0, f32::INFINITY), (1, 1.0), (2, 1.0)]
441 .into_iter()
442 .collect(),
443 );
444 assert_eq!(
445 weights.validators().unwrap(),
446 vec![0, 1, 2].into_iter().collect(),
447 "should include validators with INFINITY weight"
448 );
449 }
450
451 #[test]
452 fn weights_weight() {
453 let weights = Weights::new(
454 vec![(0, 1.0), (1, -1.0), (2, f32::INFINITY)]
455 .into_iter()
456 .collect(),
457 );
458 float_eq!(weights.weight(&0).unwrap(), 1.0);
460 float_eq!(weights.weight(&1).unwrap(), -1.0);
461 assert!(weights.weight(&2).unwrap().is_infinite());
462 }
463
464 #[test]
465 fn weights_weight_not_found() {
466 let weights = Weights::<u32, f32>::new(vec![].into_iter().collect());
467 match weights.weight(&0) {
468 Err(Error::NotFound) => (),
469 _ => panic!("Expected Error::NotFound"),
470 };
471 }
472
473 #[test]
474 fn weights_sum_weight_validators() {
475 let weights = Weights::new(
476 vec![(0, 1.0), (1, -1.0), (2, 3.3), (3, f32::INFINITY)]
477 .into_iter()
478 .collect(),
479 );
480 assert!(weights
481 .sum_weight_validators(&HashSet::from_iter(vec![0, 1, 3]))
482 .is_infinite());
483 float_eq!(
484 weights.sum_weight_validators(&HashSet::from_iter(vec![0, 1])),
485 0.0
486 );
487 float_eq!(
488 weights.sum_weight_validators(&HashSet::from_iter(vec![0, 2])),
489 4.3
490 );
491 assert!(weights
492 .sum_weight_validators(&HashSet::from_iter(vec![4]))
493 .is_nan());
494 }
495
496 #[test]
497 fn weights_sum_all_weights() {
498 let weights = Weights::new(vec![(0, 2.0), (1, -1.0), (2, 3.3)].into_iter().collect());
499 float_eq!(
501 weights.sum_all_weights(),
502 weights.sum_weight_validators(&HashSet::from_iter(vec![0, 2]))
503 );
504 }
505
506 #[test]
507 fn validator_state_update() {
508 let mut validator_state = State::new(
509 Weights::new(vec![(0, 1.0), (1, 1.0)].into_iter().collect()),
510 0.0,
511 LatestMessages::empty(),
512 2.0,
513 HashSet::new(),
514 );
515
516 let v0 = VoteCount::create_vote_message(0, false);
517 let v1 = VoteCount::create_vote_message(1, true);
518
519 let all_valid = validator_state.update(&[&v0, &v1]);
520
521 let hs0 = validator_state
522 .latests_messages()
523 .get(&0)
524 .expect("state should contain validator 0");
525 let hs1 = validator_state
526 .latests_messages()
527 .get(&1)
528 .expect("state should contain validator 1");
529
530 assert!(all_valid, "messages should not be all valid messages");
531 assert_eq!(
532 hs0.len(),
533 1,
534 "validator_state should have only 1 message for validator 0",
535 );
536 assert_eq!(
537 hs1.len(),
538 1,
539 "validator_state should have only 1 message for validator 1",
540 );
541 assert!(hs0.contains(&v0), "validator_state should contain v0");
542 assert!(hs1.contains(&v1), "validator_state should contain v1");
543 float_eq!(
544 validator_state.fault_weight(),
545 0.0,
546 "fault weight should be 0"
547 );
548 assert!(
549 validator_state.equivocators().is_empty(),
550 "no equivocators should exist",
551 );
552 }
553
554 #[test]
555 fn validator_state_update_equivocate_under_threshold() {
556 let mut validator_state = State::new(
557 Weights::new(vec![(0, 1.0), (1, 1.0)].into_iter().collect()),
558 0.0,
559 LatestMessages::empty(),
560 2.0,
561 HashSet::new(),
562 );
563
564 let v0 = VoteCount::create_vote_message(0, false);
565 let v0_prime = VoteCount::create_vote_message(0, true);
566 let v1 = VoteCount::create_vote_message(1, true);
567
568 let all_valid = validator_state.update(&[&v0, &v0_prime, &v1]);
569
570 let hs0 = validator_state
571 .latests_messages()
572 .get(&0)
573 .expect("state should contain validator 0");
574 let hs1 = validator_state
575 .latests_messages()
576 .get(&1)
577 .expect("state should contain validator 1");
578
579 assert!(all_valid, "messages should not be all valid messages");
580 assert_eq!(
581 hs0.len(),
582 2,
583 "validator_state should have 2 messages for validator 0",
584 );
585 assert_eq!(
586 hs1.len(),
587 1,
588 "validator_state should have only 1 message for validator 1",
589 );
590 assert!(hs0.contains(&v0), "validator_state should contain v0");
591 assert!(
592 hs0.contains(&v0_prime),
593 "validator_state should contain v0_prime",
594 );
595 assert!(hs1.contains(&v1), "validator_state should contain v1");
596 float_eq!(
597 validator_state.fault_weight(),
598 1.0,
599 "fault weight should be 1"
600 );
601 assert!(
602 validator_state.equivocators().contains(&0),
603 "validator 0 should be in equivocators",
604 );
605 }
606
607 #[test]
608 fn validator_state_update_equivocate_at_threshold() {
609 let mut validator_state = State::new(
610 Weights::new(vec![(0, 1.0), (1, 1.0)].into_iter().collect()),
611 0.0,
612 LatestMessages::empty(),
613 0.0,
614 HashSet::new(),
615 );
616
617 let v0 = VoteCount::create_vote_message(0, false);
618 let v0_prime = VoteCount::create_vote_message(0, true);
619 let v1 = VoteCount::create_vote_message(1, true);
620
621 let all_valid = validator_state.update(&[&v0, &v0_prime, &v1]);
622
623 let hs0 = validator_state
624 .latests_messages()
625 .get(&0)
626 .expect("state should contain validator 0");
627 let hs1 = validator_state
628 .latests_messages()
629 .get(&1)
630 .expect("state should contain validator 1");
631
632 assert!(all_valid, "messages should not be all valid messages");
633 assert_eq!(
634 hs0.len(),
635 2,
636 "validator_state should have 2 messages for validator 0",
637 );
638 assert_eq!(
639 hs1.len(),
640 1,
641 "validator_state should have only 1 message for validator 1",
642 );
643 assert!(hs0.contains(&v0), "validator_state should contain v0");
644 assert!(
645 hs0.contains(&v0_prime),
646 "validator_state should contain v0_prime",
647 );
648 assert!(hs1.contains(&v1), "validator_state should contain v1");
649 float_eq!(
650 validator_state.fault_weight(),
651 0.0,
652 "fault weight should be 0"
653 );
654 assert!(
655 validator_state.equivocators().is_empty(),
656 "validator 0 should not be in equivocators"
657 );
658 }
659
660 #[test]
661 fn state_sort_by_faultweight_unknown_equivocators() {
662 let v0_prime = VoteCount::create_vote_message(0, false);
663 let v1_prime = VoteCount::create_vote_message(1, true);
664 let v2_prime = VoteCount::create_vote_message(2, true);
665
666 let get_sorted_vec_with_weights = |weights: Vec<(u32, f32)>| {
667 let mut state = State::new(
668 Weights::new(weights.into_iter().collect()),
669 0.0,
670 LatestMessages::empty(),
671 10.0,
672 HashSet::new(),
673 );
674 let v0 = VoteCount::create_vote_message(0, true);
675 let v1 = VoteCount::create_vote_message(1, false);
676 let v2 = VoteCount::create_vote_message(2, false);
677 state.update(&[&v0, &v1, &v2]);
678 state.sort_by_faultweight(&HashSet::from_iter(vec![&v0_prime, &v1_prime, &v2_prime]))
679 };
680
681 assert_eq!(
684 get_sorted_vec_with_weights(vec![(0, 1.0), (1, 2.0), (2, 3.0)]),
685 [&v0_prime, &v1_prime, &v2_prime]
686 );
687 assert_eq!(
688 get_sorted_vec_with_weights(vec![(0, 2.0), (1, 1.0), (2, 3.0)]),
689 [&v1_prime, &v0_prime, &v2_prime]
690 );
691 }
692
693 #[test]
694 fn state_sort_by_faultweight_known_equivocators() {
695 fn test_with_weights(weights: Vec<(u32, f32)>) {
696 let mut state = State::new(
697 Weights::new(weights.into_iter().collect()),
698 0.0,
699 LatestMessages::empty(),
700 10.0,
701 HashSet::new(),
702 );
703 let v0 = VoteCount::create_vote_message(0, true);
704 let v0_prime = VoteCount::create_vote_message(0, false);
705 let v1 = VoteCount::create_vote_message(1, false);
706 let v1_prime = VoteCount::create_vote_message(1, true);
707 let v2 = VoteCount::create_vote_message(2, false);
708 let v2_prime = VoteCount::create_vote_message(2, true);
709 state.update(&[&v0, &v0_prime, &v1, &v1_prime, &v2, &v2_prime]);
710
711 assert_eq!(
715 state.sort_by_faultweight(&HashSet::from_iter(vec![&v0, &v1, &v2])),
716 [&v2, &v0, &v1]
717 );
718
719 assert!(v2.id() < v0.id());
721 assert!(v0.id() < v1.id());
722 }
723
724 test_with_weights(vec![(0, 2.0), (1, 1.0), (2, 3.0)]);
726 test_with_weights(vec![(0, 2.0), (1, 4.0), (2, 3.0)]);
727 }
728
729 #[test]
730 fn state_sort_by_faultweight_no_fault() {
731 fn test_with_weights(weights: Vec<(u32, f32)>) {
732 let mut state = State::new(
733 Weights::new(weights.into_iter().collect()),
734 0.0,
735 LatestMessages::empty(),
736 1.0,
737 HashSet::new(),
738 );
739 let v0 = VoteCount::create_vote_message(0, true);
740 let v1 = VoteCount::create_vote_message(1, false);
741 let v2 = VoteCount::create_vote_message(2, false);
742 state.update(&[&v0, &v1, &v2]);
743
744 assert_eq!(
747 state.sort_by_faultweight(&HashSet::from_iter(vec![&v0, &v1, &v2])),
748 [&v2, &v0, &v1]
749 );
750
751 assert!(v2.id() < v0.id());
753 assert!(v0.id() < v1.id());
754 }
755
756 test_with_weights(vec![(0, 2.0), (1, 1.0), (2, 3.0)]);
758 test_with_weights(vec![(0, 2.0), (1, 4.0), (2, 3.0)]);
759 }
760}