Skip to main content

rabia_core/
messages.rs

1use crate::state_machine::Snapshot;
2use crate::{BatchId, CommandBatch, NodeId, PhaseId, StateValue};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ProtocolMessage {
8    pub id: uuid::Uuid,
9    pub from: NodeId,
10    pub to: Option<NodeId>, // None for broadcast
11    pub timestamp: u64,
12    pub message_type: MessageType,
13}
14
15impl ProtocolMessage {
16    pub fn new(from: NodeId, to: Option<NodeId>, message_type: MessageType) -> Self {
17        Self {
18            id: uuid::Uuid::new_v4(),
19            from,
20            to,
21            timestamp: std::time::SystemTime::now()
22                .duration_since(std::time::UNIX_EPOCH)
23                .unwrap()
24                .as_millis() as u64,
25            message_type,
26        }
27    }
28
29    pub fn propose(from: NodeId, proposal: ProposeMessage) -> Self {
30        Self::new(from, None, MessageType::Propose(proposal))
31    }
32
33    pub fn vote_round1(from: NodeId, to: NodeId, vote: VoteRound1Message) -> Self {
34        Self::new(from, Some(to), MessageType::VoteRound1(vote))
35    }
36
37    pub fn vote_round2(from: NodeId, to: NodeId, vote: VoteRound2Message) -> Self {
38        Self::new(from, Some(to), MessageType::VoteRound2(vote))
39    }
40
41    pub fn decision(from: NodeId, decision: DecisionMessage) -> Self {
42        Self::new(from, None, MessageType::Decision(decision))
43    }
44
45    pub fn sync_request(from: NodeId, to: NodeId, request: SyncRequestMessage) -> Self {
46        Self::new(from, Some(to), MessageType::SyncRequest(request))
47    }
48
49    pub fn sync_response(from: NodeId, to: NodeId, response: SyncResponseMessage) -> Self {
50        Self::new(from, Some(to), MessageType::SyncResponse(response))
51    }
52
53    pub fn new_batch(from: NodeId, batch: NewBatchMessage) -> Self {
54        Self::new(from, None, MessageType::NewBatch(batch))
55    }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum MessageType {
60    Propose(ProposeMessage),
61    VoteRound1(VoteRound1Message),
62    VoteRound2(VoteRound2Message),
63    Decision(DecisionMessage),
64    SyncRequest(SyncRequestMessage),
65    SyncResponse(SyncResponseMessage),
66    NewBatch(NewBatchMessage),
67    HeartBeat(HeartBeatMessage),
68    QuorumNotification(QuorumNotificationMessage),
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ProposeMessage {
73    pub phase_id: PhaseId,
74    pub batch_id: BatchId,
75    pub value: StateValue,
76    pub batch: Option<CommandBatch>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct VoteRound1Message {
81    pub phase_id: PhaseId,
82    pub batch_id: BatchId,
83    pub vote: StateValue,
84    pub voter_id: NodeId,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct VoteRound2Message {
89    pub phase_id: PhaseId,
90    pub batch_id: BatchId,
91    pub vote: StateValue,
92    pub voter_id: NodeId,
93    pub round1_votes: HashMap<NodeId, StateValue>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct DecisionMessage {
98    pub phase_id: PhaseId,
99    pub batch_id: BatchId,
100    pub decision: StateValue,
101    pub batch: Option<CommandBatch>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SyncRequestMessage {
106    pub requester_phase: PhaseId,
107    pub requester_state_version: u64,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct SyncResponseMessage {
112    pub responder_phase: PhaseId,
113    pub responder_state_version: u64,
114    pub state_snapshot: Option<Snapshot>,
115    pub pending_batches: Vec<(BatchId, CommandBatch)>,
116    pub committed_phases: Vec<(PhaseId, BatchId, StateValue)>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NewBatchMessage {
121    pub batch: CommandBatch,
122    pub originator: NodeId,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct HeartBeatMessage {
127    pub current_phase: PhaseId,
128    pub last_committed_phase: PhaseId,
129    pub active: bool,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct QuorumNotificationMessage {
134    pub has_quorum: bool,
135    pub active_nodes: Vec<NodeId>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct PhaseData {
140    pub phase_id: PhaseId,
141    pub batch_id: Option<BatchId>,
142    pub proposed_value: Option<StateValue>,
143    pub round1_votes: HashMap<NodeId, StateValue>,
144    pub round2_votes: HashMap<NodeId, StateValue>,
145    pub decision: Option<StateValue>,
146    pub batch: Option<CommandBatch>,
147    pub timestamp: u64,
148    pub is_committed: bool,
149}
150
151impl PhaseData {
152    pub fn new(phase_id: PhaseId) -> Self {
153        Self {
154            phase_id,
155            batch_id: None,
156            proposed_value: None,
157            round1_votes: HashMap::new(),
158            round2_votes: HashMap::new(),
159            decision: None,
160            batch: None,
161            timestamp: std::time::SystemTime::now()
162                .duration_since(std::time::UNIX_EPOCH)
163                .unwrap()
164                .as_millis() as u64,
165            is_committed: false,
166        }
167    }
168
169    pub fn add_round1_vote(&mut self, voter: NodeId, vote: StateValue) {
170        self.round1_votes.insert(voter, vote);
171    }
172
173    pub fn add_round2_vote(&mut self, voter: NodeId, vote: StateValue) {
174        self.round2_votes.insert(voter, vote);
175    }
176
177    pub fn has_round1_majority(&self, quorum_size: usize) -> Option<StateValue> {
178        self.count_votes(&self.round1_votes, quorum_size)
179    }
180
181    pub fn has_round2_majority(&self, quorum_size: usize) -> Option<StateValue> {
182        self.count_votes(&self.round2_votes, quorum_size)
183    }
184
185    fn count_votes(
186        &self,
187        votes: &HashMap<NodeId, StateValue>,
188        quorum_size: usize,
189    ) -> Option<StateValue> {
190        let mut v0_count = 0;
191        let mut v1_count = 0;
192        let mut vq_count = 0;
193
194        for vote in votes.values() {
195            match vote {
196                StateValue::V0 => v0_count += 1,
197                StateValue::V1 => v1_count += 1,
198                StateValue::VQuestion => vq_count += 1,
199            }
200        }
201
202        if v0_count >= quorum_size {
203            Some(StateValue::V0)
204        } else if v1_count >= quorum_size {
205            Some(StateValue::V1)
206        } else if vq_count >= quorum_size {
207            Some(StateValue::VQuestion)
208        } else {
209            None
210        }
211    }
212
213    pub fn total_votes(&self) -> usize {
214        self.round1_votes.len().max(self.round2_votes.len())
215    }
216
217    pub fn set_decision(&mut self, decision: StateValue) {
218        self.decision = Some(decision);
219        if decision != StateValue::VQuestion {
220            self.is_committed = true;
221        }
222    }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct PendingBatch {
227    pub batch: CommandBatch,
228    pub originator: NodeId,
229    pub received_timestamp: u64,
230    pub retry_count: usize,
231}
232
233impl PendingBatch {
234    pub fn new(batch: CommandBatch, originator: NodeId) -> Self {
235        Self {
236            batch,
237            originator,
238            received_timestamp: std::time::SystemTime::now()
239                .duration_since(std::time::UNIX_EPOCH)
240                .unwrap()
241                .as_millis() as u64,
242            retry_count: 0,
243        }
244    }
245
246    pub fn increment_retry(&mut self) {
247        self.retry_count += 1;
248    }
249
250    pub fn age_millis(&self) -> u64 {
251        let now = std::time::SystemTime::now()
252            .duration_since(std::time::UNIX_EPOCH)
253            .unwrap()
254            .as_millis() as u64;
255        now.saturating_sub(self.received_timestamp)
256    }
257}