hashgraph_like_consensus/
session.rs1use std::{collections::HashMap, time::Duration};
8
9use crate::{
10 error::ConsensusError,
11 protos::consensus::v1::{Proposal, Vote},
12 scope_config::{NetworkType, ScopeConfig},
13 types::SessionTransition,
14 utils::{
15 calculate_consensus_result, calculate_max_rounds, current_timestamp, validate_proposal,
16 validate_proposal_timestamp, validate_vote, validate_vote_chain,
17 },
18};
19
20#[derive(Debug, Clone)]
26pub struct ConsensusConfig {
27 consensus_threshold: f64,
29 consensus_timeout: Duration,
31 max_rounds: u32,
36 use_gossipsub_rounds: bool,
41 liveness_criteria: bool,
43}
44
45impl From<NetworkType> for ConsensusConfig {
46 fn from(network_type: NetworkType) -> Self {
47 ConsensusConfig::from(ScopeConfig::from(network_type))
48 }
49}
50
51impl From<ScopeConfig> for ConsensusConfig {
52 fn from(config: ScopeConfig) -> Self {
53 let (max_rounds, use_gossipsub_rounds) = match config.network_type {
54 NetworkType::Gossipsub => (config.max_rounds_override.unwrap_or(2), true),
55 NetworkType::P2P => (config.max_rounds_override.unwrap_or(0), false),
57 };
58
59 ConsensusConfig::new(
60 config.default_consensus_threshold,
61 config.default_timeout,
62 max_rounds,
63 use_gossipsub_rounds,
64 config.default_liveness_criteria_yes,
65 )
66 }
67}
68
69impl ConsensusConfig {
70 pub fn p2p() -> Self {
73 ConsensusConfig::from(NetworkType::P2P)
74 }
75
76 pub fn gossipsub() -> Self {
78 ConsensusConfig::from(NetworkType::Gossipsub)
79 }
80
81 pub fn with_timeout(mut self, consensus_timeout: Duration) -> Result<Self, ConsensusError> {
83 crate::utils::validate_timeout(consensus_timeout)?;
84 self.consensus_timeout = consensus_timeout;
85 Ok(self)
86 }
87
88 pub fn with_threshold(mut self, consensus_threshold: f64) -> Result<Self, ConsensusError> {
90 crate::utils::validate_threshold(consensus_threshold)?;
91 self.consensus_threshold = consensus_threshold;
92 Ok(self)
93 }
94
95 pub fn with_liveness_criteria(mut self, liveness_criteria: bool) -> Self {
97 self.liveness_criteria = liveness_criteria;
98 self
99 }
100
101 pub(crate) fn new(
104 consensus_threshold: f64,
105 consensus_timeout: Duration,
106 max_rounds: u32,
107 use_gossipsub_rounds: bool,
108 liveness_criteria: bool,
109 ) -> Self {
110 Self {
111 consensus_threshold,
112 consensus_timeout,
113 max_rounds,
114 use_gossipsub_rounds,
115 liveness_criteria,
116 }
117 }
118
119 fn max_round_limit(&self, expected_voters_count: u32) -> u32 {
120 if self.use_gossipsub_rounds {
121 self.max_rounds
122 } else if self.max_rounds == 0 {
123 calculate_max_rounds(expected_voters_count, self.consensus_threshold)
124 } else {
125 self.max_rounds
126 }
127 }
128
129 pub fn consensus_timeout(&self) -> Duration {
131 self.consensus_timeout
132 }
133
134 pub fn consensus_threshold(&self) -> f64 {
136 self.consensus_threshold
137 }
138
139 pub fn liveness_criteria(&self) -> bool {
141 self.liveness_criteria
142 }
143
144 pub fn max_rounds(&self) -> u32 {
146 self.max_rounds
147 }
148
149 pub fn use_gossipsub_rounds(&self) -> bool {
151 self.use_gossipsub_rounds
152 }
153}
154
155#[derive(Debug, Clone)]
156pub enum ConsensusState {
157 Active,
159 ConsensusReached(bool),
161 Failed,
163}
164
165#[derive(Debug, Clone)]
166pub struct ConsensusSession {
167 pub proposal: Proposal,
169 pub state: ConsensusState,
171 pub votes: HashMap<Vec<u8>, Vote>, pub created_at: u64,
175 pub config: ConsensusConfig,
177}
178
179impl ConsensusSession {
180 fn new(proposal: Proposal, config: ConsensusConfig) -> Self {
183 let now = current_timestamp().unwrap_or(0);
184 Self {
185 proposal,
186 state: ConsensusState::Active,
187 votes: HashMap::new(),
188 created_at: now,
189 config,
190 }
191 }
192
193 pub fn from_proposal(
197 proposal: Proposal,
198 config: ConsensusConfig,
199 ) -> Result<(Self, SessionTransition), ConsensusError> {
200 validate_proposal(&proposal)?;
201
202 let existing_votes = proposal.votes.clone();
204 let mut clean_proposal = proposal.clone();
205 clean_proposal.votes.clear();
206 clean_proposal.round = 1;
208
209 let mut session = Self::new(clean_proposal, config);
210 let transition = session.initialize_with_votes(
211 existing_votes,
212 proposal.expiration_timestamp,
213 proposal.timestamp,
214 )?;
215
216 Ok((session, transition))
217 }
218
219 pub(crate) fn add_vote(&mut self, vote: Vote) -> Result<SessionTransition, ConsensusError> {
221 match self.state {
222 ConsensusState::Active => {
223 validate_proposal_timestamp(self.proposal.expiration_timestamp)?;
224
225 self.check_round_limit(1)?;
227
228 if self.votes.contains_key(&vote.vote_owner) {
229 return Err(ConsensusError::DuplicateVote);
230 }
231 self.votes.insert(vote.vote_owner.clone(), vote.clone());
232 self.proposal.votes.push(vote.clone());
233
234 self.update_round(1);
235 Ok(self.check_consensus())
236 }
237 ConsensusState::ConsensusReached(res) => Ok(SessionTransition::ConsensusReached(res)),
238 _ => Err(ConsensusError::SessionNotActive),
239 }
240 }
241
242 pub(crate) fn initialize_with_votes(
245 &mut self,
246 votes: Vec<Vote>,
247 expiration_timestamp: u64,
248 creation_time: u64,
249 ) -> Result<SessionTransition, ConsensusError> {
250 if !matches!(self.state, ConsensusState::Active) {
251 return Err(ConsensusError::SessionNotActive);
252 }
253
254 validate_proposal_timestamp(expiration_timestamp)?;
255
256 if votes.is_empty() {
257 return Ok(SessionTransition::StillActive);
258 }
259
260 let mut seen_owners = std::collections::HashSet::new();
261 for vote in &votes {
262 if !seen_owners.insert(&vote.vote_owner) {
263 return Err(ConsensusError::DuplicateVote);
264 }
265 }
266
267 validate_vote_chain(&votes)?;
268 for vote in &votes {
269 validate_vote(vote, expiration_timestamp, creation_time)?;
270 }
271
272 self.check_round_limit(votes.len())?;
273 self.update_round(votes.len());
274
275 for vote in votes {
276 self.votes.insert(vote.vote_owner.clone(), vote.clone());
277 self.proposal.votes.push(vote);
278 }
279
280 Ok(self.check_consensus())
281 }
282
283 fn check_round_limit(&mut self, vote_count: usize) -> Result<(), ConsensusError> {
290 let projected_value = if self.config.use_gossipsub_rounds {
292 if self.proposal.round == 2 || (self.proposal.round == 1 && vote_count > 0) {
297 2
298 } else {
299 self.proposal.round }
301 } else {
302 let current_votes = self.proposal.round.saturating_sub(1);
307 current_votes.saturating_add(vote_count as u32)
308 };
309
310 if projected_value
311 > self
312 .config
313 .max_round_limit(self.proposal.expected_voters_count)
314 {
315 self.state = ConsensusState::Failed;
316 return Err(ConsensusError::MaxRoundsExceeded);
317 }
318
319 Ok(())
320 }
321
322 fn update_round(&mut self, vote_count: usize) {
328 if self.config.use_gossipsub_rounds {
329 if self.proposal.round == 1 && vote_count > 0 {
334 self.proposal.round = 2;
335 }
336 } else {
337 self.proposal.round = self.proposal.round.saturating_add(vote_count as u32);
340 }
341 }
342
343 fn check_consensus(&mut self) -> SessionTransition {
348 let expected_voters = self.proposal.expected_voters_count;
349 let threshold = self.config.consensus_threshold;
350 let liveness = self.proposal.liveness_criteria_yes;
351
352 match calculate_consensus_result(&self.votes, expected_voters, threshold, liveness) {
353 Some(result) => {
354 self.state = ConsensusState::ConsensusReached(result);
355 SessionTransition::ConsensusReached(result)
356 }
357 None => {
358 self.state = ConsensusState::Active;
359 SessionTransition::StillActive
360 }
361 }
362 }
363
364 pub fn is_active(&self) -> bool {
366 matches!(self.state, ConsensusState::Active)
367 }
368
369 pub fn get_consensus_result(&self) -> Result<bool, ConsensusError> {
374 if let ConsensusState::ConsensusReached(result) = self.state {
375 Ok(result)
376 } else {
377 Err(ConsensusError::ConsensusNotReached)
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use alloy::signers::local::PrivateKeySigner;
385
386 use crate::{
387 error::ConsensusError,
388 session::{ConsensusConfig, ConsensusSession},
389 types::CreateProposalRequest,
390 utils::build_vote,
391 };
392
393 #[tokio::test]
394 async fn enforce_max_rounds_gossipsub() {
395 let signer1 = PrivateKeySigner::random();
398 let signer2 = PrivateKeySigner::random();
399 let signer3 = PrivateKeySigner::random();
400 let signer4 = PrivateKeySigner::random();
401
402 let request = CreateProposalRequest::new(
403 "Test".into(),
404 "".into(),
405 signer1.address().as_slice().to_vec(),
406 4, 60,
408 false,
409 )
410 .unwrap();
411
412 let proposal = request.into_proposal().unwrap();
413 let config = ConsensusConfig::gossipsub();
414 let mut session = ConsensusSession::new(proposal, config);
415
416 let vote1 = build_vote(&session.proposal, true, signer1).await.unwrap();
418 session.add_vote(vote1).unwrap();
419 assert_eq!(session.proposal.round, 2);
420
421 let vote2 = build_vote(&session.proposal, false, signer2).await.unwrap();
423 session.add_vote(vote2).unwrap();
424 assert_eq!(session.proposal.round, 2);
425
426 let vote3 = build_vote(&session.proposal, true, signer3).await.unwrap();
428 session.add_vote(vote3).unwrap();
429 assert_eq!(session.proposal.round, 2);
430
431 let vote4 = build_vote(&session.proposal, true, signer4).await.unwrap();
433 session.add_vote(vote4).unwrap();
434 assert_eq!(session.proposal.round, 2);
435 assert_eq!(session.votes.len(), 4);
436 }
437
438 #[tokio::test]
439 async fn enforce_max_rounds_p2p() {
440 let signer1 = PrivateKeySigner::random();
444 let signer2 = PrivateKeySigner::random();
445 let signer3 = PrivateKeySigner::random();
446 let signer4 = PrivateKeySigner::random();
447 let signer5 = PrivateKeySigner::random();
448
449 let request = CreateProposalRequest::new(
450 "Test".into(),
451 "".into(),
452 signer1.address().as_slice().to_vec(),
453 5,
454 60,
455 false,
456 )
457 .unwrap();
458
459 let proposal = request.into_proposal().unwrap();
460 let config = ConsensusConfig::p2p();
461 let mut session = ConsensusSession::new(proposal, config);
462
463 let vote1 = build_vote(&session.proposal, true, signer1).await.unwrap();
465 session.add_vote(vote1).unwrap();
466 assert_eq!(session.proposal.round, 2);
467 assert_eq!(session.votes.len(), 1);
468
469 let vote2 = build_vote(&session.proposal, false, signer2).await.unwrap();
471 session.add_vote(vote2).unwrap();
472 assert_eq!(session.proposal.round, 3);
473 assert_eq!(session.votes.len(), 2);
474
475 let vote3 = build_vote(&session.proposal, true, signer3).await.unwrap();
477 session.add_vote(vote3).unwrap();
478 assert_eq!(session.proposal.round, 4);
479 assert_eq!(session.votes.len(), 3);
480
481 let vote4 = build_vote(&session.proposal, true, signer4).await.unwrap();
483 session.add_vote(vote4).unwrap();
484 assert_eq!(session.proposal.round, 5);
485 assert_eq!(session.votes.len(), 4);
486
487 let vote5 = build_vote(&session.proposal, true, signer5).await.unwrap();
489 let err = session.add_vote(vote5).unwrap_err();
490 assert!(matches!(err, ConsensusError::MaxRoundsExceeded));
491 }
492}