1use message::AcceptedData;
4use message::Message;
5use message::Messenger;
6use std::collections::hash_map::HashMap;
7use std::collections::hash_set::HashSet;
8use std::hash::Hash;
9use std::sync::Arc;
10
11pub struct Learner<T> {
16 pub id: u64,
18 pub messenger: Option<Box<Messenger<T>>>,
20 pub last_accepted_n: u64,
22 pub accepted_received: HashMap<u64, HashSet<AcceptedData<T>>>,
24 pub value: Option<Arc<T>>,
26 pub quorum: u8,
28}
29
30impl<T> Learner<T>
31where
32 T: Hash + Eq,
33{
34 pub fn new(id: u64, quorum: u8) -> Self {
35 Self {
36 id,
37 messenger: None,
38 last_accepted_n: 0,
39 accepted_received: HashMap::new(),
40 value: None,
41 quorum,
42 }
43 }
44
45 pub fn receive_accepted(&mut self, msg: Message<T>) {
47 if let Message::Accepted(data) = msg {
48 let id = data.id;
49 if id == self.last_accepted_n {
50 if let Some(ref val) = self.value {
51 if *val != data.value {
52 panic!("Value mismatch for proposal {}", id);
53 }
54 }
55 }
56
57 self.accepted_received.entry(id).or_insert(HashSet::new());
58
59 self.accepted_received.get_mut(&id).unwrap().insert(data);
60
61 if self.accepted_received.get(&id).unwrap().len() == self.quorum as usize {
62 self.value = Some(
63 self.accepted_received
64 .get(&id)
65 .unwrap()
66 .iter()
67 .next()
68 .unwrap()
69 .value
70 .clone(),
71 );
72 self.last_accepted_n = id;
73 if let Some(ref mut messenger) = self.messenger {
74 messenger.on_resolution(id, self.value.clone().unwrap());
75 }
76 }
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn learner_new() {
87 let l: Learner<u64> = Learner::new(1, 7);
88
89 assert_eq!(l.id, 1);
90 assert!(l.messenger.is_none());
91 assert_eq!(l.last_accepted_n, 0);
92 assert!(l.value.is_none());
93 assert_eq!(l.accepted_received, HashMap::new());
94 }
95
96 #[test]
97 fn learner_receive_accepted() {
98 let mut l: Learner<u64> = Learner::new(1, 7);
99
100 let id = 1;
101 let msg = Message::Accepted(AcceptedData {
102 id,
103 value: Arc::new(10),
104 from: 0,
105 });
106
107 l.receive_accepted(msg);
108
109 assert_eq!(l.value, None);
110 assert_eq!(l.accepted_received.get(&id).unwrap().len(), 1);
111
112 for i in 1..l.quorum {
113 let msg = Message::Accepted(AcceptedData {
114 id: 1,
115 value: Arc::new(10),
116 from: i as u64,
117 });
118 l.receive_accepted(msg);
119 }
120
121 assert_eq!(l.last_accepted_n, 1);
122 assert_eq!(l.value, Some(Arc::new(10)));
123 }
124
125 #[test(should_panic)]
126 fn learner_receive_accepted_mismatch() {
127 let mut l: Learner<u64> = Learner::new(1, 7);
128
129 let id = 1;
130 let msg = Message::Accepted(AcceptedData {
131 id,
132 value: Arc::new(10),
133 from: 0,
134 });
135
136 l.receive_accepted(msg);
137
138 let msg = Message::Accepted(AcceptedData {
139 id: 1,
140 value: Arc::new(8), from: 1 as u64,
142 });
143 l.receive_accepted(msg);
144 }
145}