1use super::{AuthorityId, RoundId, Vote, VoteValue};
7use crate::error::{Error, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct VotingConfig {
15 pub vote_timeout: Duration,
17 pub allow_duplicate_votes: bool,
19 pub require_signatures: bool,
21 pub max_votes_per_authority: usize,
23}
24
25impl Default for VotingConfig {
26 fn default() -> Self {
27 VotingConfig {
28 vote_timeout: Duration::from_secs(30),
29 allow_duplicate_votes: false,
30 require_signatures: true,
31 max_votes_per_authority: 1,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct VoteAggregation {
39 pub value: VoteValue,
40 pub total_weight: u64,
41 pub vote_count: usize,
42 pub authorities: HashSet<AuthorityId>,
43}
44
45pub struct VoteCollector {
47 round_id: RoundId,
48 config: VotingConfig,
49 votes: HashMap<AuthorityId, Vec<Vote>>,
50 vote_by_value: HashMap<VoteValue, VoteAggregation>,
51 start_time: Instant,
52 total_votes: usize,
53}
54
55impl VoteCollector {
56 pub fn new(round_id: RoundId, config: VotingConfig) -> Self {
57 VoteCollector {
58 round_id,
59 config,
60 votes: HashMap::new(),
61 vote_by_value: HashMap::new(),
62 start_time: Instant::now(),
63 total_votes: 0,
64 }
65 }
66
67 pub fn add_vote(&mut self, vote: Vote) -> Result<()> {
69 if vote.round_id != self.round_id {
71 return Err(Error::InvalidVote {
72 authority: vote.authority.0.clone(),
73 reason: format!(
74 "Wrong round: expected {}, got {}",
75 self.round_id, vote.round_id
76 ),
77 });
78 }
79
80 if self.start_time.elapsed() > self.config.vote_timeout {
82 return Err(Error::Timeout {
83 operation: "Vote collection".to_string(),
84 duration: self.config.vote_timeout,
85 });
86 }
87
88 if self.config.require_signatures && vote.signature.is_empty() {
90 return Err(Error::InvalidVote {
91 authority: vote.authority.0.clone(),
92 reason: "Missing signature".to_string(),
93 });
94 }
95
96 if let Some(existing_votes) = self.votes.get(&vote.authority) {
98 if !self.config.allow_duplicate_votes && !existing_votes.is_empty() {
99 return Err(Error::DuplicateVote {
100 authority: vote.authority.0.clone(),
101 });
102 }
103
104 if existing_votes.len() >= self.config.max_votes_per_authority {
105 return Err(Error::InvalidVote {
106 authority: vote.authority.0.clone(),
107 reason: format!(
108 "Exceeded max votes per authority: {}",
109 self.config.max_votes_per_authority
110 ),
111 });
112 }
113
114 if let Some(first_vote) = existing_votes.first() {
116 if first_vote.value != vote.value {
117 return Err(Error::ByzantineFault {
118 message: format!(
119 "Authority {} voted for multiple values",
120 vote.authority
121 ),
122 });
123 }
124 }
125 }
126
127 let aggregation = self
129 .vote_by_value
130 .entry(vote.value.clone())
131 .or_insert_with(|| VoteAggregation {
132 value: vote.value.clone(),
133 total_weight: 0,
134 vote_count: 0,
135 authorities: HashSet::new(),
136 });
137
138 aggregation.total_weight += vote.weight;
139 aggregation.vote_count += 1;
140 aggregation.authorities.insert(vote.authority.clone());
141
142 self.votes
144 .entry(vote.authority.clone())
145 .or_insert_with(Vec::new)
146 .push(vote);
147
148 self.total_votes += 1;
149
150 Ok(())
151 }
152
153 pub fn get_aggregation(&self, value: &VoteValue) -> Option<&VoteAggregation> {
155 self.vote_by_value.get(value)
156 }
157
158 pub fn get_all_aggregations(&self) -> Vec<&VoteAggregation> {
160 self.vote_by_value.values().collect()
161 }
162
163 pub fn get_leading_value(&self) -> Option<&VoteAggregation> {
165 self.vote_by_value
166 .values()
167 .max_by_key(|agg| agg.total_weight)
168 }
169
170 pub fn get_authority_votes(&self, authority: &AuthorityId) -> Vec<&Vote> {
172 self.votes
173 .get(authority)
174 .map(|votes| votes.iter().collect())
175 .unwrap_or_default()
176 }
177
178 pub fn get_voting_authorities(&self) -> HashSet<AuthorityId> {
180 self.votes.keys().cloned().collect()
181 }
182
183 pub fn get_total_weight(&self) -> u64 {
185 self.vote_by_value
186 .values()
187 .map(|agg| agg.total_weight)
188 .sum()
189 }
190
191 pub fn get_vote_count(&self) -> usize {
193 self.total_votes
194 }
195
196 pub fn has_timed_out(&self) -> bool {
198 self.start_time.elapsed() > self.config.vote_timeout
199 }
200
201 pub fn remaining_time(&self) -> Duration {
203 self.config
204 .vote_timeout
205 .saturating_sub(self.start_time.elapsed())
206 }
207
208 pub fn detect_byzantine_authorities(&self) -> Vec<AuthorityId> {
210 let mut byzantine = Vec::new();
211
212 for (authority, votes) in &self.votes {
213 if votes.len() > 1 {
215 let unique_values: HashSet<_> = votes.iter().map(|v| &v.value).collect();
216 if unique_values.len() > 1 {
217 byzantine.push(authority.clone());
218 }
219 }
220 }
221
222 byzantine
223 }
224
225 pub fn get_statistics(&self) -> VoteStatistics {
227 let unique_values = self.vote_by_value.len();
228 let participating_authorities = self.votes.len();
229 let leading = self.get_leading_value();
230
231 VoteStatistics {
232 round_id: self.round_id,
233 total_votes: self.total_votes,
234 unique_values,
235 participating_authorities,
236 total_weight: self.get_total_weight(),
237 leading_value_weight: leading.map(|agg| agg.total_weight),
238 elapsed_time: self.start_time.elapsed(),
239 timed_out: self.has_timed_out(),
240 }
241 }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct VoteStatistics {
247 pub round_id: RoundId,
248 pub total_votes: usize,
249 pub unique_values: usize,
250 pub participating_authorities: usize,
251 pub total_weight: u64,
252 pub leading_value_weight: Option<u64>,
253 pub elapsed_time: Duration,
254 pub timed_out: bool,
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 fn create_vote(
262 round: u64,
263 authority: &str,
264 value: &str,
265 weight: u64,
266 ) -> Vote {
267 Vote::new(
268 RoundId(round),
269 AuthorityId::from(authority),
270 VoteValue::from_string(value),
271 weight,
272 )
273 .with_signature(vec![1, 2, 3]) }
275
276 #[test]
277 fn test_add_vote() {
278 let config = VotingConfig::default();
279 let mut collector = VoteCollector::new(RoundId(1), config);
280
281 let vote = create_vote(1, "auth-1", "value-a", 100);
282 assert!(collector.add_vote(vote).is_ok());
283 assert_eq!(collector.get_vote_count(), 1);
284 }
285
286 #[test]
287 fn test_wrong_round() {
288 let config = VotingConfig::default();
289 let mut collector = VoteCollector::new(RoundId(1), config);
290
291 let vote = create_vote(2, "auth-1", "value-a", 100);
292 let result = collector.add_vote(vote);
293
294 assert!(result.is_err());
295 assert!(matches!(result.unwrap_err(), Error::InvalidVote { .. }));
296 }
297
298 #[test]
299 fn test_duplicate_vote_not_allowed() {
300 let config = VotingConfig {
301 allow_duplicate_votes: false,
302 ..Default::default()
303 };
304 let mut collector = VoteCollector::new(RoundId(1), config);
305
306 let vote1 = create_vote(1, "auth-1", "value-a", 100);
307 let vote2 = create_vote(1, "auth-1", "value-a", 100);
308
309 assert!(collector.add_vote(vote1).is_ok());
310 let result = collector.add_vote(vote2);
311
312 assert!(result.is_err());
313 assert!(matches!(result.unwrap_err(), Error::DuplicateVote { .. }));
314 }
315
316 #[test]
317 fn test_byzantine_detection() {
318 let config = VotingConfig {
319 allow_duplicate_votes: true,
320 max_votes_per_authority: 2,
321 ..Default::default()
322 };
323 let mut collector = VoteCollector::new(RoundId(1), config);
324
325 let vote1 = create_vote(1, "auth-1", "value-a", 100);
326 let vote2 = create_vote(1, "auth-1", "value-b", 100);
327
328 assert!(collector.add_vote(vote1).is_ok());
329 let result = collector.add_vote(vote2);
330
331 assert!(result.is_err());
332 assert!(matches!(result.unwrap_err(), Error::ByzantineFault { .. }));
333 }
334
335 #[test]
336 fn test_vote_aggregation() {
337 let config = VotingConfig::default();
338 let mut collector = VoteCollector::new(RoundId(1), config);
339
340 collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
341 collector.add_vote(create_vote(1, "auth-2", "value-a", 150)).unwrap();
342 collector.add_vote(create_vote(1, "auth-3", "value-b", 200)).unwrap();
343
344 let agg_a = collector
345 .get_aggregation(&VoteValue::from_string("value-a"))
346 .unwrap();
347 assert_eq!(agg_a.total_weight, 250);
348 assert_eq!(agg_a.vote_count, 2);
349
350 let agg_b = collector
351 .get_aggregation(&VoteValue::from_string("value-b"))
352 .unwrap();
353 assert_eq!(agg_b.total_weight, 200);
354 assert_eq!(agg_b.vote_count, 1);
355 }
356
357 #[test]
358 fn test_leading_value() {
359 let config = VotingConfig::default();
360 let mut collector = VoteCollector::new(RoundId(1), config);
361
362 collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
363 collector.add_vote(create_vote(1, "auth-2", "value-b", 200)).unwrap();
364 collector.add_vote(create_vote(1, "auth-3", "value-b", 150)).unwrap();
365
366 let leading = collector.get_leading_value().unwrap();
367 assert_eq!(leading.value, VoteValue::from_string("value-b"));
368 assert_eq!(leading.total_weight, 350);
369 }
370
371 #[test]
372 fn test_voting_authorities() {
373 let config = VotingConfig::default();
374 let mut collector = VoteCollector::new(RoundId(1), config);
375
376 collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
377 collector.add_vote(create_vote(1, "auth-2", "value-a", 100)).unwrap();
378
379 let authorities = collector.get_voting_authorities();
380 assert_eq!(authorities.len(), 2);
381 assert!(authorities.contains(&AuthorityId::from("auth-1")));
382 assert!(authorities.contains(&AuthorityId::from("auth-2")));
383 }
384
385 #[test]
386 fn test_statistics() {
387 let config = VotingConfig::default();
388 let mut collector = VoteCollector::new(RoundId(1), config);
389
390 collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
391 collector.add_vote(create_vote(1, "auth-2", "value-b", 200)).unwrap();
392
393 let stats = collector.get_statistics();
394 assert_eq!(stats.round_id, RoundId(1));
395 assert_eq!(stats.total_votes, 2);
396 assert_eq!(stats.unique_values, 2);
397 assert_eq!(stats.participating_authorities, 2);
398 assert_eq!(stats.total_weight, 300);
399 assert_eq!(stats.leading_value_weight, Some(200));
400 }
401
402 #[test]
403 fn test_missing_signature() {
404 let config = VotingConfig {
405 require_signatures: true,
406 ..Default::default()
407 };
408 let mut collector = VoteCollector::new(RoundId(1), config);
409
410 let vote = Vote::new(
411 RoundId(1),
412 AuthorityId::from("auth-1"),
413 VoteValue::from_string("value-a"),
414 100,
415 );
416
417 let result = collector.add_vote(vote);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn test_get_authority_votes() {
423 let config = VotingConfig::default();
424 let mut collector = VoteCollector::new(RoundId(1), config);
425
426 let vote = create_vote(1, "auth-1", "value-a", 100);
427 collector.add_vote(vote).unwrap();
428
429 let votes = collector.get_authority_votes(&AuthorityId::from("auth-1"));
430 assert_eq!(votes.len(), 1);
431 assert_eq!(votes[0].weight, 100);
432 }
433}