use std::sync::{Arc, RwLock};
use hashgraph_like_consensus::protos::consensus::v1::Proposal;
use prost::Message;
use tracing::{error, info};
use crate::{
app::{
ConversationState, LockExt, SessionRunner, UserError,
session::{
consensus::build_vote_banner_event,
consensus_bridge::{forward_incoming_proposal, forward_incoming_vote},
runner::send_packet,
},
},
core::{
ConsensusPlugin, ConversationPluginsFactory, PeerScoringPlugin, ProcessResult,
ProposalKind, ScoreSnapshot, SessionEvent, StewardList, StewardListConfig,
StewardListPlugin, member_set,
},
mls_crypto::MlsService,
protos::de_mls::messages::v1::{
AppMessage, ConversationMessage, ConversationSync, ConversationUpdateRequest, TimingConfig,
conversation_update_request,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DispatchOutcome {
Done,
LeaveRequested,
}
impl<P: ConsensusPlugin, CP: ConversationPluginsFactory> SessionRunner<P, CP> {
pub(crate) async fn dispatch_inbound_result(
arc: &Arc<RwLock<Self>>,
result: ProcessResult,
) -> Result<DispatchOutcome, UserError> {
match result {
ProcessResult::AppMessage(msg) => {
arc.read_or_err("session")?
.emit_event(SessionEvent::AppMessage(*msg));
Ok(DispatchOutcome::Done)
}
ProcessResult::Proposal(proposal) => {
Self::on_incoming_proposal(arc, *proposal).await?;
Ok(DispatchOutcome::Done)
}
ProcessResult::Vote(vote) => {
let proposal_id = vote.proposal_id;
let (consensus, conversation_name, outcome_applied) = {
let s = arc.read_or_err("session")?;
(
s.consensus.clone(),
s.conversation_name.clone(),
s.handle
.conversation
.is_consensus_outcome_applied(proposal_id),
)
};
forward_incoming_vote::<P>(&conversation_name, *vote, &consensus, outcome_applied)
.await?;
Ok(DispatchOutcome::Done)
}
ProcessResult::MembershipChangeReceived(request) => {
Self::handle_incoming_update_request(arc, *request).await?;
Ok(DispatchOutcome::Done)
}
ProcessResult::JoinedConversation(_name) => {
Self::on_joined_conversation(arc).await?;
Ok(DispatchOutcome::Done)
}
ProcessResult::ConversationUpdated => {
Self::on_conversation_updated(arc).await?;
Ok(DispatchOutcome::Done)
}
ProcessResult::LeaveConversation => {
Self::prepare_self_leave(arc)?;
Ok(DispatchOutcome::LeaveRequested)
}
ProcessResult::CommitCandidateReceived { steward } => {
Self::on_commit_candidate_received(arc, &steward).await?;
Ok(DispatchOutcome::Done)
}
ProcessResult::ConversationSyncReceived(sync) => {
Self::on_conversation_sync(arc, *sync)?;
Ok(DispatchOutcome::Done)
}
ProcessResult::Noop(reason) => {
let conv_name = arc.read_or_err("session")?.conversation_name.clone();
tracing::debug!(
conversation = %conv_name,
?reason,
"inbound dispatched as noop"
);
Ok(DispatchOutcome::Done)
}
}
}
async fn on_incoming_proposal(
arc: &Arc<RwLock<Self>>,
proposal: Proposal,
) -> Result<(), UserError> {
let decoded = ConversationUpdateRequest::decode(proposal.payload.as_slice()).ok();
if let Some(req) = decoded.as_ref() {
let mut s = arc.write_or_err("session")?;
let current_epoch = match s.handle.mls() {
Some(mls) => mls.current_epoch()?,
None => 0,
};
match &req.payload {
Some(conversation_update_request::Payload::EmergencyCriteria(_)) => {
s.handle
.conversation
.observe_emergency(proposal.proposal_id);
}
Some(conversation_update_request::Payload::InviteMember(_))
| Some(conversation_update_request::Payload::RemoveMember(_)) => {
s.handle
.conversation
.buffer_pending_update(req.clone(), current_epoch);
}
_ => {}
}
}
let proposal_id = proposal.proposal_id;
let expected_voters = proposal.expected_voters_count;
let payload = proposal.payload.clone();
let kind = decoded
.as_ref()
.map(ProposalKind::of)
.unwrap_or(ProposalKind::Commit);
let (consensus, conversation_name) = {
let s = arc.read_or_err("session")?;
(s.consensus.clone(), s.conversation_name.clone())
};
forward_incoming_proposal::<P>(&conversation_name, proposal, &consensus).await?;
if expected_voters > 1 {
let banner = build_vote_banner_event(&conversation_name, proposal_id, payload);
arc.read_or_err("session")?
.emit_event(SessionEvent::AppMessage(banner));
let (delay, vote) = {
let s = arc.read_or_err("session")?;
(
s.handle.config.voting_delay_for(kind),
s.handle.config.liveness_criteria_yes,
)
};
arc.write_or_err("session")?
.register_auto_vote(proposal_id, delay, vote);
}
Ok(())
}
async fn on_joined_conversation(arc: &Arc<RwLock<Self>>) -> Result<(), UserError> {
arc.write_or_err("session")?
.prune_pending_updates_after_commit()?;
let (packet, mls_members, conversation_name) = {
let mut s = arc.write_or_err("session")?;
let msg: AppMessage = ConversationMessage {
message: format!("User {} joined the conversation", s.identity_display)
.into_bytes(),
sender: "SYSTEM".to_string(),
conversation_name: s.conversation_name.clone(),
}
.into();
let app_id = Arc::clone(&s.app_id);
let conversation_name = s.conversation_name.clone();
let mls = s.handle.expect_mls_mut()?;
let members = mls.members().unwrap_or_default();
let packet = mls.build_message(&msg, &app_id)?;
(packet, members, conversation_name)
};
let transport = Arc::clone(arc.read_or_err("session")?.transport());
send_packet(&transport, packet)?;
arc.read_or_err("session")?.emit_event(SessionEvent::Joined);
arc.write_or_err("session")?
.sync_scoring_members(&mls_members);
let event = arc.write_or_err("session")?.start_working();
arc.read_or_err("session")?
.emit_event(SessionEvent::PhaseChange(event));
info!(conversation = %conversation_name, "joined conversation");
Ok(())
}
async fn on_conversation_updated(arc: &Arc<RwLock<Self>>) -> Result<(), UserError> {
let mls_members = {
let s = arc.read_or_err("session")?;
match s.handle.mls() {
Some(mls) => mls.members().unwrap_or_default(),
None => Vec::new(),
}
};
arc.write_or_err("session")?
.sync_scoring_members(&mls_members);
arc.write_or_err("session")?
.prune_pending_updates_after_commit()?;
let working_event = {
let mut s = arc.write_or_err("session")?;
s.handle.steward_list.reset_retry();
let state = s.handle.current_state();
if matches!(
state,
ConversationState::Working
| ConversationState::Freezing
| ConversationState::Selection
| ConversationState::Reelection
) {
Some(s.start_working())
} else {
None
}
};
Self::steward_list_housekeeping(arc).await?;
Self::process_buffered_updates(arc).await?;
Self::maybe_close_recovery_window(arc).await;
if let Some(event) = working_event {
arc.read_or_err("session")?
.emit_event(SessionEvent::PhaseChange(event));
}
Ok(())
}
async fn maybe_close_recovery_window(arc: &Arc<RwLock<Self>>) {
let in_recovery_mode = match arc.read_or_err("session") {
Ok(s) => s.handle.is_in_recovery_mode(),
Err(e) => {
tracing::warn!(error = %e, "recovery window check skipped: session lock poisoned");
return;
}
};
if !in_recovery_mode {
return;
}
if let Err(e) = Self::try_initiate_steward_election(arc, true, None).await {
let conv_name = arc
.read_or_err("session")
.map(|s| s.conversation_name.clone())
.unwrap_or_else(|_| "<poisoned>".to_string());
info!(
conversation = %conv_name,
error = %e,
"post-recovery election deferred"
);
}
}
fn prepare_self_leave(arc: &Arc<RwLock<Self>>) -> Result<(), UserError> {
arc.read_or_err("session")?
.emit_event(SessionEvent::Leaving);
let taken_mls = arc.write_or_err("session")?.handle.take_mls();
if let Some(mut mls) = taken_mls {
mls.delete()?;
}
Ok(())
}
async fn on_commit_candidate_received(
arc: &Arc<RwLock<Self>>,
steward: &[u8],
) -> Result<(), UserError> {
{
let conv_name = arc.read_or_err("session")?.conversation_name.clone();
tracing::debug!(
conversation = %conv_name,
steward = ?steward,
"candidate received from peer steward"
);
}
let (event, outbound) = {
let mut s = arc.write_or_err("session")?;
if s.handle.current_state() != ConversationState::Working {
return Ok(());
}
let event = s.start_freezing();
let epoch = s.handle.expect_mls()?.current_epoch()?;
s.handle.conversation.ensure_freeze_round(epoch);
let self_identity = Arc::clone(&s.self_identity);
let app_id = Arc::clone(&s.app_id);
let outbound = if s.handle.steward_list.is_steward(&self_identity) {
match s.handle.create_commit_candidate(&self_identity, &app_id) {
Ok(packets) => packets,
Err(e) => {
error!(
conversation = %s.conversation_name,
error = %e,
"own commit candidate build failed"
);
None
}
}
} else {
None
};
(event, outbound)
};
arc.read_or_err("session")?
.emit_event(SessionEvent::PhaseChange(event));
if let Some(message) = outbound {
let transport = Arc::clone(arc.read_or_err("session")?.transport());
send_packet(&transport, message)?;
}
Ok(())
}
fn on_conversation_sync(
arc: &Arc<RwLock<Self>>,
sync: ConversationSync,
) -> Result<(), UserError> {
let (members, current_epoch, local_default_peer_score, conversation_name) = {
let s = arc.read_or_err("session")?;
if s.handle.steward_list.current_list().is_some() {
return Ok(());
}
let mls = s.handle.expect_mls()?;
(
mls.members()?,
mls.current_epoch()?,
s.handle.scoring.default_score(),
s.conversation_name.clone(),
)
};
if !validate_conversation_sync(
&conversation_name,
&sync,
current_epoch,
&members,
local_default_peer_score,
)? {
return Ok(());
}
let sn = sync.steward_members.len();
arc.write_or_err("session")?
.apply_conversation_sync_to_entry(&sync)?;
info!(
conversation = %conversation_name,
election_epoch = sync.election_epoch,
stewards = sn,
scores = sync.peer_scores.len(),
timing = sync.timing.is_some(),
"conversation sync applied"
);
Ok(())
}
fn apply_conversation_sync_to_entry(
&mut self,
sync: &ConversationSync,
) -> Result<(), UserError> {
let mut protocol_config =
StewardListConfig::new(sync.sn_min as usize, sync.sn_max as usize)?;
protocol_config.allow_subset_candidates = sync.allow_subset_candidates;
let sn = sync.steward_members.len();
self.handle.steward_list.set_config(protocol_config);
let _events = self.handle.steward_list.install_list(
sync.election_epoch,
&sync.steward_members,
sn,
sync.retry_round,
)?;
self.handle
.steward_list
.set_max_retries(sync.max_reelection_attempts);
self.handle.scoring.set_threshold(sync.threshold_peer_score);
let snapshot = ScoreSnapshot {
diverged: sync
.peer_scores
.iter()
.map(|ps| (ps.member_id.clone(), ps.score))
.collect(),
};
let _events = self.handle.scoring.apply_snapshot(&snapshot);
self.handle.config.liveness_criteria_yes = sync.liveness_criteria_yes;
self.handle.config.pending_update_max_epochs = sync.pending_update_max_epochs;
if let Some(timing) = &sync.timing {
self.handle.config.apply_timing(timing);
}
Ok(())
}
}
fn validate_conversation_sync(
conversation_name: &str,
sync: &ConversationSync,
current_epoch: u64,
members: &[Vec<u8>],
local_default_peer_score: i64,
) -> Result<bool, UserError> {
if sync.election_epoch > current_epoch {
info!(
conversation = conversation_name,
election_epoch = sync.election_epoch,
current_epoch,
"conversation sync rejected: election_epoch > current_epoch"
);
return Ok(false);
}
let members_set = member_set(members);
let any_present = sync
.steward_members
.iter()
.any(|s| members_set.contains(s.as_slice()));
let ordering_valid = StewardList::validate(
&sync.steward_members,
sync.election_epoch,
conversation_name.as_bytes(),
&sync.steward_members,
&StewardListConfig::new(sync.sn_min as usize, sync.sn_max as usize)?,
sync.retry_round,
)?;
if !(any_present && ordering_valid) {
info!(
conversation = conversation_name,
any_present,
ordering = ordering_valid,
"conversation sync rejected: invalid"
);
return Ok(false);
}
if let Some(timing) = &sync.timing
&& let Some(zero_field) = first_zero_timing_field(timing)
{
info!(
conversation = conversation_name,
field = zero_field,
"conversation sync rejected: zero-valued timing field"
);
return Ok(false);
}
if local_default_peer_score <= sync.threshold_peer_score {
info!(
conversation = conversation_name,
local_default_peer_score,
threshold_peer_score = sync.threshold_peer_score,
"conversation sync rejected: default_peer_score at or below threshold would mark new members removable on add"
);
return Ok(false);
}
Ok(true)
}
fn first_zero_timing_field(timing: &TimingConfig) -> Option<&'static str> {
if timing.commit_inactivity_duration_ms == 0 {
Some("commit_inactivity_duration_ms")
} else if timing.freeze_duration_ms == 0 {
Some("freeze_duration_ms")
} else if timing.proposal_expiration_ms == 0 {
Some("proposal_expiration_ms")
} else if timing.consensus_timeout_ms == 0 {
Some("consensus_timeout_ms")
} else if timing.recovery_inactivity_duration_ms == 0 {
Some("recovery_inactivity_duration_ms")
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protos::de_mls::messages::v1::TimingConfig;
fn nonzero_timing() -> TimingConfig {
TimingConfig {
commit_inactivity_duration_ms: 60_000,
freeze_duration_ms: 30_000,
proposal_expiration_ms: 3_600_000,
consensus_timeout_ms: 30_000,
recovery_inactivity_duration_ms: 5_000,
}
}
#[test]
fn nonzero_timing_passes() {
assert!(first_zero_timing_field(&nonzero_timing()).is_none());
}
fn valid_sync_with(threshold: i64) -> ConversationSync {
ConversationSync {
steward_members: vec![b"alice".to_vec()],
election_epoch: 0,
sn_min: 1,
sn_max: 5,
allow_subset_candidates: false,
peer_scores: vec![],
timing: Some(nonzero_timing()),
retry_round: 0,
max_reelection_attempts: 1,
liveness_criteria_yes: true,
threshold_peer_score: threshold,
pending_update_max_epochs: 3,
}
}
#[test]
fn validate_accepts_default_above_threshold() {
let sync = valid_sync_with(0);
assert!(validate_conversation_sync("g", &sync, 0, &[b"alice".to_vec()], 100).unwrap());
}
#[test]
fn validate_rejects_default_equal_to_threshold() {
let sync = valid_sync_with(50);
assert!(!validate_conversation_sync("g", &sync, 0, &[b"alice".to_vec()], 50).unwrap());
}
#[test]
fn validate_rejects_default_below_threshold() {
let sync = valid_sync_with(100);
assert!(!validate_conversation_sync("g", &sync, 0, &[b"alice".to_vec()], 50).unwrap());
}
#[test]
fn each_zero_field_is_detected() {
let cases = [
(
"commit_inactivity_duration_ms",
TimingConfig {
commit_inactivity_duration_ms: 0,
..nonzero_timing()
},
),
(
"freeze_duration_ms",
TimingConfig {
freeze_duration_ms: 0,
..nonzero_timing()
},
),
(
"proposal_expiration_ms",
TimingConfig {
proposal_expiration_ms: 0,
..nonzero_timing()
},
),
(
"consensus_timeout_ms",
TimingConfig {
consensus_timeout_ms: 0,
..nonzero_timing()
},
),
(
"recovery_inactivity_duration_ms",
TimingConfig {
recovery_inactivity_duration_ms: 0,
..nonzero_timing()
},
),
];
for (name, timing) in cases {
assert_eq!(
first_zero_timing_field(&timing),
Some(name),
"expected field {name} to be detected as zero"
);
}
}
}