agentic_payments/consensus/
voting.rs

1//! Vote Collection and Aggregation
2//!
3//! Implements parallel vote collection with Byzantine fault detection
4//! and weighted vote aggregation.
5
6use super::{AuthorityId, RoundId, Vote, VoteValue};
7use crate::error::{Error, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::time::{Duration, Instant};
11
12/// Voting configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct VotingConfig {
15    /// Maximum time to collect votes
16    pub vote_timeout: Duration,
17    /// Allow duplicate votes (last one wins)
18    pub allow_duplicate_votes: bool,
19    /// Require vote signatures
20    pub require_signatures: bool,
21    /// Maximum votes per authority per round
22    pub max_votes_per_authority: usize,
23}
24
25impl Default for VotingConfig {
26    fn default() -> Self {
27        VotingConfig {
28            vote_timeout: Duration::from_secs(30),
29            allow_duplicate_votes: false,
30            require_signatures: true,
31            max_votes_per_authority: 1,
32        }
33    }
34}
35
36/// Vote aggregation result
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct VoteAggregation {
39    pub value: VoteValue,
40    pub total_weight: u64,
41    pub vote_count: usize,
42    pub authorities: HashSet<AuthorityId>,
43}
44
45/// Vote collector for a single consensus round
46pub struct VoteCollector {
47    round_id: RoundId,
48    config: VotingConfig,
49    votes: HashMap<AuthorityId, Vec<Vote>>,
50    vote_by_value: HashMap<VoteValue, VoteAggregation>,
51    start_time: Instant,
52    total_votes: usize,
53}
54
55impl VoteCollector {
56    pub fn new(round_id: RoundId, config: VotingConfig) -> Self {
57        VoteCollector {
58            round_id,
59            config,
60            votes: HashMap::new(),
61            vote_by_value: HashMap::new(),
62            start_time: Instant::now(),
63            total_votes: 0,
64        }
65    }
66
67    /// Add a vote to the collection
68    pub fn add_vote(&mut self, vote: Vote) -> Result<()> {
69        // Validate round ID
70        if vote.round_id != self.round_id {
71            return Err(Error::InvalidVote {
72                authority: vote.authority.0.clone(),
73                reason: format!(
74                    "Wrong round: expected {}, got {}",
75                    self.round_id, vote.round_id
76                ),
77            });
78        }
79
80        // Check timeout
81        if self.start_time.elapsed() > self.config.vote_timeout {
82            return Err(Error::Timeout {
83                operation: "Vote collection".to_string(),
84                duration: self.config.vote_timeout,
85            });
86        }
87
88        // Validate signature if required
89        if self.config.require_signatures && vote.signature.is_empty() {
90            return Err(Error::InvalidVote {
91                authority: vote.authority.0.clone(),
92                reason: "Missing signature".to_string(),
93            });
94        }
95
96        // Check for duplicate votes
97        if let Some(existing_votes) = self.votes.get(&vote.authority) {
98            if !self.config.allow_duplicate_votes && !existing_votes.is_empty() {
99                return Err(Error::DuplicateVote {
100                    authority: vote.authority.0.clone(),
101                });
102            }
103
104            if existing_votes.len() >= self.config.max_votes_per_authority {
105                return Err(Error::InvalidVote {
106                    authority: vote.authority.0.clone(),
107                    reason: format!(
108                        "Exceeded max votes per authority: {}",
109                        self.config.max_votes_per_authority
110                    ),
111                });
112            }
113
114            // Check for Byzantine behavior (voting for different values)
115            if let Some(first_vote) = existing_votes.first() {
116                if first_vote.value != vote.value {
117                    return Err(Error::ByzantineFault {
118                        message: format!(
119                            "Authority {} voted for multiple values",
120                            vote.authority
121                        ),
122                    });
123                }
124            }
125        }
126
127        // Update aggregation
128        let aggregation = self
129            .vote_by_value
130            .entry(vote.value.clone())
131            .or_insert_with(|| VoteAggregation {
132                value: vote.value.clone(),
133                total_weight: 0,
134                vote_count: 0,
135                authorities: HashSet::new(),
136            });
137
138        aggregation.total_weight += vote.weight;
139        aggregation.vote_count += 1;
140        aggregation.authorities.insert(vote.authority.clone());
141
142        // Store vote
143        self.votes
144            .entry(vote.authority.clone())
145            .or_insert_with(Vec::new)
146            .push(vote);
147
148        self.total_votes += 1;
149
150        Ok(())
151    }
152
153    /// Get vote aggregation for a specific value
154    pub fn get_aggregation(&self, value: &VoteValue) -> Option<&VoteAggregation> {
155        self.vote_by_value.get(value)
156    }
157
158    /// Get all vote aggregations
159    pub fn get_all_aggregations(&self) -> Vec<&VoteAggregation> {
160        self.vote_by_value.values().collect()
161    }
162
163    /// Get leading vote value by weight
164    pub fn get_leading_value(&self) -> Option<&VoteAggregation> {
165        self.vote_by_value
166            .values()
167            .max_by_key(|agg| agg.total_weight)
168    }
169
170    /// Get votes from a specific authority
171    pub fn get_authority_votes(&self, authority: &AuthorityId) -> Vec<&Vote> {
172        self.votes
173            .get(authority)
174            .map(|votes| votes.iter().collect())
175            .unwrap_or_default()
176    }
177
178    /// Get all authorities that voted
179    pub fn get_voting_authorities(&self) -> HashSet<AuthorityId> {
180        self.votes.keys().cloned().collect()
181    }
182
183    /// Get total vote weight
184    pub fn get_total_weight(&self) -> u64 {
185        self.vote_by_value
186            .values()
187            .map(|agg| agg.total_weight)
188            .sum()
189    }
190
191    /// Get total number of votes
192    pub fn get_vote_count(&self) -> usize {
193        self.total_votes
194    }
195
196    /// Check if voting has timed out
197    pub fn has_timed_out(&self) -> bool {
198        self.start_time.elapsed() > self.config.vote_timeout
199    }
200
201    /// Get remaining time for voting
202    pub fn remaining_time(&self) -> Duration {
203        self.config
204            .vote_timeout
205            .saturating_sub(self.start_time.elapsed())
206    }
207
208    /// Detect potential Byzantine authorities
209    pub fn detect_byzantine_authorities(&self) -> Vec<AuthorityId> {
210        let mut byzantine = Vec::new();
211
212        for (authority, votes) in &self.votes {
213            // Check for multiple different votes
214            if votes.len() > 1 {
215                let unique_values: HashSet<_> = votes.iter().map(|v| &v.value).collect();
216                if unique_values.len() > 1 {
217                    byzantine.push(authority.clone());
218                }
219            }
220        }
221
222        byzantine
223    }
224
225    /// Get vote statistics
226    pub fn get_statistics(&self) -> VoteStatistics {
227        let unique_values = self.vote_by_value.len();
228        let participating_authorities = self.votes.len();
229        let leading = self.get_leading_value();
230
231        VoteStatistics {
232            round_id: self.round_id,
233            total_votes: self.total_votes,
234            unique_values,
235            participating_authorities,
236            total_weight: self.get_total_weight(),
237            leading_value_weight: leading.map(|agg| agg.total_weight),
238            elapsed_time: self.start_time.elapsed(),
239            timed_out: self.has_timed_out(),
240        }
241    }
242}
243
244/// Vote statistics
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct VoteStatistics {
247    pub round_id: RoundId,
248    pub total_votes: usize,
249    pub unique_values: usize,
250    pub participating_authorities: usize,
251    pub total_weight: u64,
252    pub leading_value_weight: Option<u64>,
253    pub elapsed_time: Duration,
254    pub timed_out: bool,
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    fn create_vote(
262        round: u64,
263        authority: &str,
264        value: &str,
265        weight: u64,
266    ) -> Vote {
267        Vote::new(
268            RoundId(round),
269            AuthorityId::from(authority),
270            VoteValue::from_string(value),
271            weight,
272        )
273        .with_signature(vec![1, 2, 3]) // Dummy signature
274    }
275
276    #[test]
277    fn test_add_vote() {
278        let config = VotingConfig::default();
279        let mut collector = VoteCollector::new(RoundId(1), config);
280
281        let vote = create_vote(1, "auth-1", "value-a", 100);
282        assert!(collector.add_vote(vote).is_ok());
283        assert_eq!(collector.get_vote_count(), 1);
284    }
285
286    #[test]
287    fn test_wrong_round() {
288        let config = VotingConfig::default();
289        let mut collector = VoteCollector::new(RoundId(1), config);
290
291        let vote = create_vote(2, "auth-1", "value-a", 100);
292        let result = collector.add_vote(vote);
293
294        assert!(result.is_err());
295        assert!(matches!(result.unwrap_err(), Error::InvalidVote { .. }));
296    }
297
298    #[test]
299    fn test_duplicate_vote_not_allowed() {
300        let config = VotingConfig {
301            allow_duplicate_votes: false,
302            ..Default::default()
303        };
304        let mut collector = VoteCollector::new(RoundId(1), config);
305
306        let vote1 = create_vote(1, "auth-1", "value-a", 100);
307        let vote2 = create_vote(1, "auth-1", "value-a", 100);
308
309        assert!(collector.add_vote(vote1).is_ok());
310        let result = collector.add_vote(vote2);
311
312        assert!(result.is_err());
313        assert!(matches!(result.unwrap_err(), Error::DuplicateVote { .. }));
314    }
315
316    #[test]
317    fn test_byzantine_detection() {
318        let config = VotingConfig {
319            allow_duplicate_votes: true,
320            max_votes_per_authority: 2,
321            ..Default::default()
322        };
323        let mut collector = VoteCollector::new(RoundId(1), config);
324
325        let vote1 = create_vote(1, "auth-1", "value-a", 100);
326        let vote2 = create_vote(1, "auth-1", "value-b", 100);
327
328        assert!(collector.add_vote(vote1).is_ok());
329        let result = collector.add_vote(vote2);
330
331        assert!(result.is_err());
332        assert!(matches!(result.unwrap_err(), Error::ByzantineFault { .. }));
333    }
334
335    #[test]
336    fn test_vote_aggregation() {
337        let config = VotingConfig::default();
338        let mut collector = VoteCollector::new(RoundId(1), config);
339
340        collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
341        collector.add_vote(create_vote(1, "auth-2", "value-a", 150)).unwrap();
342        collector.add_vote(create_vote(1, "auth-3", "value-b", 200)).unwrap();
343
344        let agg_a = collector
345            .get_aggregation(&VoteValue::from_string("value-a"))
346            .unwrap();
347        assert_eq!(agg_a.total_weight, 250);
348        assert_eq!(agg_a.vote_count, 2);
349
350        let agg_b = collector
351            .get_aggregation(&VoteValue::from_string("value-b"))
352            .unwrap();
353        assert_eq!(agg_b.total_weight, 200);
354        assert_eq!(agg_b.vote_count, 1);
355    }
356
357    #[test]
358    fn test_leading_value() {
359        let config = VotingConfig::default();
360        let mut collector = VoteCollector::new(RoundId(1), config);
361
362        collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
363        collector.add_vote(create_vote(1, "auth-2", "value-b", 200)).unwrap();
364        collector.add_vote(create_vote(1, "auth-3", "value-b", 150)).unwrap();
365
366        let leading = collector.get_leading_value().unwrap();
367        assert_eq!(leading.value, VoteValue::from_string("value-b"));
368        assert_eq!(leading.total_weight, 350);
369    }
370
371    #[test]
372    fn test_voting_authorities() {
373        let config = VotingConfig::default();
374        let mut collector = VoteCollector::new(RoundId(1), config);
375
376        collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
377        collector.add_vote(create_vote(1, "auth-2", "value-a", 100)).unwrap();
378
379        let authorities = collector.get_voting_authorities();
380        assert_eq!(authorities.len(), 2);
381        assert!(authorities.contains(&AuthorityId::from("auth-1")));
382        assert!(authorities.contains(&AuthorityId::from("auth-2")));
383    }
384
385    #[test]
386    fn test_statistics() {
387        let config = VotingConfig::default();
388        let mut collector = VoteCollector::new(RoundId(1), config);
389
390        collector.add_vote(create_vote(1, "auth-1", "value-a", 100)).unwrap();
391        collector.add_vote(create_vote(1, "auth-2", "value-b", 200)).unwrap();
392
393        let stats = collector.get_statistics();
394        assert_eq!(stats.round_id, RoundId(1));
395        assert_eq!(stats.total_votes, 2);
396        assert_eq!(stats.unique_values, 2);
397        assert_eq!(stats.participating_authorities, 2);
398        assert_eq!(stats.total_weight, 300);
399        assert_eq!(stats.leading_value_weight, Some(200));
400    }
401
402    #[test]
403    fn test_missing_signature() {
404        let config = VotingConfig {
405            require_signatures: true,
406            ..Default::default()
407        };
408        let mut collector = VoteCollector::new(RoundId(1), config);
409
410        let vote = Vote::new(
411            RoundId(1),
412            AuthorityId::from("auth-1"),
413            VoteValue::from_string("value-a"),
414            100,
415        );
416
417        let result = collector.add_vote(vote);
418        assert!(result.is_err());
419    }
420
421    #[test]
422    fn test_get_authority_votes() {
423        let config = VotingConfig::default();
424        let mut collector = VoteCollector::new(RoundId(1), config);
425
426        let vote = create_vote(1, "auth-1", "value-a", 100);
427        collector.add_vote(vote).unwrap();
428
429        let votes = collector.get_authority_votes(&AuthorityId::from("auth-1"));
430        assert_eq!(votes.len(), 1);
431        assert_eq!(votes[0].weight, 100);
432    }
433}