1use std::collections::{BTreeMap, BTreeSet};
2
3use blsttc::{PublicKeySet, SignatureShare};
4use core::fmt::Debug;
5use serde::{Deserialize, Serialize};
6
7use crate::sn_membership::Generation;
8use crate::{Candidate, Error, Fault, NodeId, Result, VoteCount};
9
10pub trait Proposition: Ord + Clone + Debug + Serialize {}
11impl<T: Ord + Clone + Debug + Serialize> Proposition for T {}
12
13#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
14pub enum Ballot<T: Proposition> {
15 Propose(T),
16 Merge(BTreeSet<SignedVote<T>>),
17 SuperMajority {
18 votes: BTreeSet<SignedVote<T>>,
19 proposals: BTreeMap<T, (NodeId, SignatureShare)>,
20 },
21}
22
23impl<T: Proposition> Debug for Ballot<T> {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Ballot::Propose(r) => write!(f, "P({r:?})"),
27 Ballot::Merge(votes) => write!(f, "M{votes:?}"),
28 Ballot::SuperMajority { votes, proposals } => write!(
29 f,
30 "SM{:?}-{:?}",
31 votes,
32 BTreeSet::from_iter(proposals.keys())
33 ),
34 }
35 }
36}
37
38pub fn simplify_votes<T: Proposition>(
39 signed_votes: &BTreeSet<SignedVote<T>>,
40) -> BTreeSet<SignedVote<T>> {
41 let mut simpler_votes = BTreeSet::new();
42 for v in signed_votes.iter() {
43 let this_vote_is_superseded = signed_votes
44 .iter()
45 .filter(|other_v| other_v != &v)
46 .any(|other_v| other_v.supersedes(v));
47
48 if !this_vote_is_superseded {
49 simpler_votes.insert(v.clone());
50 }
51 }
52 simpler_votes
53}
54
55pub fn proposals<T: Proposition>(
56 votes: &BTreeSet<SignedVote<T>>,
57 known_faulty: &BTreeSet<NodeId>,
58) -> BTreeSet<T> {
59 BTreeSet::from_iter(
60 votes
61 .iter()
62 .flat_map(SignedVote::unpack_votes)
63 .filter(|v| !known_faulty.contains(&v.voter))
64 .filter_map(|v| v.vote.ballot.as_proposal())
65 .cloned(),
66 )
67}
68
69impl<T: Proposition> Ballot<T> {
70 pub fn as_proposal(&self) -> Option<&T> {
71 match &self {
72 Ballot::Propose(p) => Some(p),
73 _ => None,
74 }
75 }
76
77 #[must_use]
78 pub fn simplify(&self) -> Self {
79 match &self {
80 Ballot::Propose(_) => self.clone(), Ballot::Merge(votes) => Ballot::Merge(simplify_votes(votes)),
82 Ballot::SuperMajority { votes, proposals } => Ballot::SuperMajority {
83 votes: simplify_votes(votes),
84 proposals: proposals.clone(),
85 },
86 }
87 }
88}
89
90#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
91pub struct Vote<T: Proposition> {
92 pub gen: Generation,
93 pub ballot: Ballot<T>,
94 pub faults: BTreeSet<Fault<T>>,
95}
96
97impl<T: Proposition> Debug for Vote<T> {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(f, "G{}-{:?}", self.gen, self.ballot)?;
100
101 if !self.faults.is_empty() {
102 write!(f, "-F{:?}", self.faults)?;
103 }
104 Ok(())
105 }
106}
107
108impl<T: Proposition> Vote<T> {
109 pub fn validate(
110 &self,
111 voters: &PublicKeySet,
112 valid_votes_memo: &BTreeSet<SignatureShare>,
113 ) -> Result<()> {
114 let validate_child_votes = |child_votes: &BTreeSet<SignedVote<T>>| {
115 for child_vote in child_votes {
116 let child_gen = child_vote.vote.gen;
117 let merge_gen = self.gen;
118 if child_gen != merge_gen {
119 return Err(Error::ParentAndChildWithDiffGen {
120 child_gen,
121 merge_gen,
122 });
123 }
124
125 if !valid_votes_memo.contains(&child_vote.sig) {
126 child_vote.validate(voters, valid_votes_memo)?;
127 };
128 }
129 Ok(())
130 };
131
132 match &self.ballot {
133 Ballot::Propose(_) => Ok(()),
134 Ballot::Merge(votes) => validate_child_votes(votes),
135 Ballot::SuperMajority { votes, proposals } => {
136 let vote_count = VoteCount::count(votes, &self.faulty_ids());
137
138 let candidate_proposals = vote_count
139 .candidate_with_most_votes()
140 .map(|(c, _)| c.proposals.clone())
141 .unwrap_or_default();
142
143 if !vote_count.do_we_have_supermajority(voters) {
144 Err(Error::SuperMajorityBallotIsNotSuperMajority)
146 } else if !candidate_proposals.iter().eq(proposals.keys()) {
147 Err(Error::SuperMajorityProposalsDoesNotMatchVoteProposals)
149 } else if proposals
150 .iter()
151 .try_for_each(|(p, (id, sig))| crate::verify_sig_share(&p, sig, *id, voters))
152 .is_err()
153 {
154 Err(Error::InvalidElderSignature)
155 } else {
156 validate_child_votes(votes)
157 }
158 }
159 }
160 }
161
162 pub fn is_super_majority_ballot(&self) -> bool {
163 matches!(self.ballot, Ballot::SuperMajority { .. })
164 }
165
166 pub fn to_bytes(&self) -> Result<Vec<u8>> {
167 Ok(bincode::serialize(&self)?)
168 }
169
170 pub fn faulty_ids(&self) -> BTreeSet<NodeId> {
171 BTreeSet::from_iter(self.faults.iter().map(Fault::voter_at_fault))
172 }
173
174 pub fn proposals(&self) -> BTreeSet<T> {
175 self.proposals_with_known_faults(&self.faulty_ids())
176 }
177
178 pub fn proposals_with_known_faults(&self, known_faulty: &BTreeSet<NodeId>) -> BTreeSet<T> {
179 match &self.ballot {
180 Ballot::Propose(proposal) => BTreeSet::from_iter([proposal.clone()]),
181 Ballot::Merge(votes) | Ballot::SuperMajority { votes, .. } => {
182 proposals(votes, known_faulty)
184 }
185 }
186 }
187}
188
189#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
190pub struct SignedVote<T: Proposition> {
191 pub vote: Vote<T>,
192 pub voter: NodeId,
193 pub sig: SignatureShare,
194}
195
196impl<T: Proposition> Debug for SignedVote<T> {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 write!(f, "{:?}@{}", self.vote, self.voter)
199 }
200}
201
202impl<T: Proposition> SignedVote<T> {
203 pub fn candidate(&self) -> Candidate<T> {
204 match &self.vote.ballot {
205 Ballot::SuperMajority { votes, .. } => VoteCount::count(votes, &self.vote.faulty_ids())
206 .candidate_with_most_votes()
207 .map(|(candidate, _)| candidate.clone())
208 .unwrap_or_default(),
209 _ => Candidate {
210 proposals: self.proposals(),
211 faulty: self.vote.faulty_ids(),
212 },
213 }
214 }
215
216 pub fn validate_signature(&self, voters: &PublicKeySet) -> Result<()> {
217 crate::verify_sig_share(&self.vote, &self.sig, self.voter, voters)
218 }
219
220 pub fn validate(
223 &self,
224 voters: &PublicKeySet,
225 valid_votes_cache: &BTreeSet<SignatureShare>,
226 ) -> Result<()> {
227 self.validate_signature(voters)?;
228 self.vote.validate(voters, valid_votes_cache)?;
229
230 Ok(())
231 }
232
233 pub fn detect_byzantine_faults(
234 &self,
235 voters: &PublicKeySet,
236 existing_votes: &BTreeMap<NodeId, SignedVote<T>>,
237 valid_votes_cache: &BTreeSet<SignatureShare>,
238 ) -> std::result::Result<(), BTreeMap<NodeId, Fault<T>>> {
239 let mut faults = BTreeMap::new();
240 for vote in self.unpack_votes() {
241 if valid_votes_cache.contains(&vote.sig) {
242 continue;
243 }
244
245 if let Some(existing_vote) = existing_votes.get(&vote.voter) {
246 let fault = Fault::ChangedVote {
247 a: existing_vote.clone(),
248 b: vote.clone(),
249 };
250
251 if let Ok(()) = fault.validate(voters) {
252 faults.insert(vote.voter, fault);
253 }
254 }
255
256 {
257 let fault = Fault::InvalidFault {
258 signed_vote: vote.clone(),
259 };
260 if let Ok(()) = fault.validate(voters) {
261 faults.insert(vote.voter, fault);
262 }
263 }
264 }
265
266 if faults.is_empty() {
267 Ok(())
268 } else {
269 Err(faults)
270 }
271 }
272
273 pub fn unpack_votes(&self) -> Box<dyn Iterator<Item = &Self> + '_> {
274 match &self.vote.ballot {
275 Ballot::Propose(_) => Box::new(std::iter::once(self)),
276 Ballot::Merge(votes) | Ballot::SuperMajority { votes, .. } => {
277 Box::new(std::iter::once(self).chain(votes.iter().flat_map(Self::unpack_votes)))
278 }
279 }
280 }
281
282 pub fn proposals(&self) -> BTreeSet<T> {
283 self.vote.proposals()
284 }
285
286 pub fn supersedes(&self, other: &Self) -> bool {
287 let our_faulty = self.vote.faulty_ids();
288 let other_faulty = other.vote.faulty_ids();
289
290 if (&self.voter, self.vote.gen, &self.vote.ballot)
291 == (&other.voter, other.vote.gen, &other.vote.ballot)
292 && our_faulty.is_superset(&other_faulty)
293 {
294 true
295 } else {
296 match &self.vote.ballot {
297 Ballot::Propose(_) => false, Ballot::Merge(votes) | Ballot::SuperMajority { votes, .. } => {
299 votes.iter().any(|v| v.supersedes(other))
300 }
301 }
302 }
303 }
304
305 pub fn vote_count(&self) -> VoteCount<T> {
306 VoteCount::count([self], &self.vote.faulty_ids())
307 }
308}