1use bytes::Bytes;
2use dashmap::DashMap;
3use parking_lot::RwLock;
4use rabia_core::{
5 messages::{PendingBatch, PhaseData, SyncResponseMessage},
6 BatchId, CommandBatch, NodeId, PhaseId, RabiaError, Result,
7};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use tokio::sync::{mpsc, oneshot};
12
13#[derive(Debug)]
14pub struct EngineState {
15 pub current_phase: Arc<AtomicU64>,
16 pub last_committed_phase: Arc<AtomicU64>,
17 pub is_active: Arc<AtomicBool>,
18 pub has_quorum: Arc<AtomicBool>,
19
20 pub pending_batches: Arc<DashMap<BatchId, PendingBatch>>,
21 pub phases: Arc<DashMap<PhaseId, PhaseData>>,
22 pub sync_responses: Arc<DashMap<NodeId, SyncResponseMessage>>,
23
24 pub active_nodes: Arc<RwLock<std::collections::HashSet<NodeId>>>,
25 pub quorum_size: usize,
26
27 pub state_version: Arc<AtomicU64>,
28 pub last_cleanup: Arc<AtomicU64>,
29}
30
31impl EngineState {
32 pub fn new(quorum_size: usize) -> Self {
33 Self {
34 current_phase: Arc::new(AtomicU64::new(0)),
35 last_committed_phase: Arc::new(AtomicU64::new(0)),
36 is_active: Arc::new(AtomicBool::new(true)),
37 has_quorum: Arc::new(AtomicBool::new(true)),
38
39 pending_batches: Arc::new(DashMap::new()),
40 phases: Arc::new(DashMap::new()),
41 sync_responses: Arc::new(DashMap::new()),
42
43 active_nodes: Arc::new(RwLock::new(std::collections::HashSet::new())),
44 quorum_size,
45
46 state_version: Arc::new(AtomicU64::new(1)),
47 last_cleanup: Arc::new(AtomicU64::new(0)),
48 }
49 }
50
51 pub fn current_phase(&self) -> PhaseId {
52 PhaseId::new(self.current_phase.load(Ordering::Acquire))
53 }
54
55 pub fn last_committed_phase(&self) -> PhaseId {
56 PhaseId::new(self.last_committed_phase.load(Ordering::Acquire))
57 }
58
59 pub fn advance_phase(&self) -> PhaseId {
60 let new_phase = self.current_phase.fetch_add(1, Ordering::AcqRel) + 1;
61 self.increment_version();
62 PhaseId::new(new_phase)
63 }
64
65 pub fn commit_phase(&self, phase_id: PhaseId) -> Result<bool> {
66 let phase_value = phase_id.value();
67 let current_phase_value = self.current_phase.load(Ordering::Acquire);
68
69 if phase_value > current_phase_value {
71 return Err(RabiaError::InvalidStateTransition {
72 from: format!("current_phase={}", current_phase_value),
73 to: format!("commit_phase={}", phase_value),
74 });
75 }
76
77 let mut current = self.last_committed_phase.load(Ordering::Acquire);
78
79 while current < phase_value {
81 match self.last_committed_phase.compare_exchange_weak(
82 current,
83 phase_value,
84 Ordering::AcqRel,
85 Ordering::Acquire,
86 ) {
87 Ok(_) => {
88 self.increment_version();
89 return Ok(true);
90 }
91 Err(actual) => {
92 current = actual;
93 if current >= phase_value {
95 return Ok(false);
96 }
97 }
98 }
99 }
100
101 Ok(false)
103 }
104
105 pub fn is_active(&self) -> bool {
106 self.is_active.load(Ordering::Acquire)
107 }
108
109 pub fn set_active(&self, active: bool) {
110 if self.is_active.swap(active, Ordering::AcqRel) != active {
111 self.increment_version();
112 }
113 }
114
115 pub fn has_quorum(&self) -> bool {
116 self.has_quorum.load(Ordering::Acquire)
117 }
118
119 pub fn set_quorum(&self, has_quorum: bool) {
120 if self.has_quorum.swap(has_quorum, Ordering::AcqRel) != has_quorum {
121 self.increment_version();
122 }
123 }
124
125 pub fn get_active_nodes(&self) -> std::collections::HashSet<NodeId> {
126 self.active_nodes.read().clone()
127 }
128
129 pub fn update_active_nodes(&self, nodes: std::collections::HashSet<NodeId>) {
130 let has_quorum = nodes.len() >= self.quorum_size;
131
132 {
133 let mut active_nodes = self.active_nodes.write();
134 if *active_nodes != nodes {
135 *active_nodes = nodes;
136 self.increment_version();
137 }
138 }
139
140 self.set_quorum(has_quorum);
141 self.set_active(has_quorum);
142 }
143
144 pub fn add_pending_batch(&self, batch: CommandBatch, originator: NodeId) -> BatchId {
145 let pending = PendingBatch::new(batch, originator);
146 let batch_id = pending.batch.id;
147 self.pending_batches.insert(batch_id, pending);
148 self.increment_version();
149 batch_id
150 }
151
152 pub fn remove_pending_batch(&self, batch_id: &BatchId) -> Option<PendingBatch> {
153 let result = self.pending_batches.remove(batch_id).map(|(_, v)| v);
154 if result.is_some() {
155 self.increment_version();
156 }
157 result
158 }
159
160 pub fn get_pending_batch(&self, batch_id: &BatchId) -> Option<PendingBatch> {
161 self.pending_batches
162 .get(batch_id)
163 .map(|entry| entry.value().clone())
164 }
165
166 pub fn get_or_create_phase(&self, phase_id: PhaseId) -> PhaseData {
167 self.phases
168 .entry(phase_id)
169 .or_insert_with(|| {
170 self.increment_version();
171 PhaseData::new(phase_id)
172 })
173 .clone()
174 }
175
176 pub fn update_phase<F>(&self, phase_id: PhaseId, update_fn: F) -> Result<()>
177 where
178 F: FnOnce(&mut PhaseData),
179 {
180 if let Some(mut entry) = self.phases.get_mut(&phase_id) {
181 update_fn(&mut entry);
182 self.increment_version();
183 }
184 Ok(())
185 }
186
187 pub fn get_phase(&self, phase_id: &PhaseId) -> Option<PhaseData> {
188 self.phases.get(phase_id).map(|entry| entry.value().clone())
189 }
190
191 pub fn cleanup_old_phases(&self, max_phase_history: usize) -> usize {
192 let current_phase = self.current_phase();
193 let cutoff_phase = if current_phase.value() > max_phase_history as u64 {
194 PhaseId::new(current_phase.value() - max_phase_history as u64)
195 } else {
196 PhaseId::new(0)
197 };
198
199 let mut removed_count = 0;
200 self.phases.retain(|&phase_id, _| {
201 let should_keep = phase_id >= cutoff_phase;
202 if !should_keep {
203 removed_count += 1;
204 }
205 should_keep
206 });
207
208 if removed_count > 0 {
209 self.increment_version();
210 self.last_cleanup.store(
211 std::time::SystemTime::now()
212 .duration_since(std::time::UNIX_EPOCH)
213 .unwrap()
214 .as_secs(),
215 Ordering::Release,
216 );
217 }
218
219 removed_count
220 }
221
222 pub fn cleanup_old_pending_batches(&self, max_age_secs: u64) -> usize {
223 let now = std::time::SystemTime::now()
224 .duration_since(std::time::UNIX_EPOCH)
225 .unwrap()
226 .as_millis() as u64;
227 let cutoff = now.saturating_sub(max_age_secs * 1000);
228
229 let mut removed_count = 0;
230 self.pending_batches.retain(|_, pending| {
231 let should_keep = pending.received_timestamp >= cutoff;
232 if !should_keep {
233 removed_count += 1;
234 }
235 should_keep
236 });
237
238 if removed_count > 0 {
239 self.increment_version();
240 }
241
242 removed_count
243 }
244
245 pub fn get_state_version(&self) -> u64 {
246 self.state_version.load(Ordering::Acquire)
247 }
248
249 fn increment_version(&self) {
250 self.state_version.fetch_add(1, Ordering::AcqRel);
251 }
252
253 pub fn add_sync_response(&self, node_id: NodeId, response: SyncResponseMessage) {
254 self.sync_responses.insert(node_id, response);
255 }
256
257 pub fn get_sync_responses(&self) -> HashMap<NodeId, SyncResponseMessage> {
258 self.sync_responses
259 .iter()
260 .map(|entry| (*entry.key(), entry.value().clone()))
261 .collect()
262 }
263
264 pub fn clear_sync_responses(&self) {
265 self.sync_responses.clear();
266 }
267
268 pub fn get_statistics(&self) -> EngineStatistics {
269 EngineStatistics {
270 current_phase: self.current_phase(),
271 last_committed_phase: self.last_committed_phase(),
272 pending_batches_count: self.pending_batches.len(),
273 phases_count: self.phases.len(),
274 active_nodes_count: self.active_nodes.read().len(),
275 has_quorum: self.has_quorum(),
276 is_active: self.is_active(),
277 state_version: self.get_state_version(),
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
283pub struct EngineStatistics {
284 pub current_phase: PhaseId,
285 pub last_committed_phase: PhaseId,
286 pub pending_batches_count: usize,
287 pub phases_count: usize,
288 pub active_nodes_count: usize,
289 pub has_quorum: bool,
290 pub is_active: bool,
291 pub state_version: u64,
292}
293
294#[derive(Debug)]
295pub struct CommandRequest {
296 pub batch: CommandBatch,
297 pub response_tx: oneshot::Sender<Result<Vec<Bytes>>>,
298}
299
300#[derive(Debug)]
301pub enum EngineCommand {
302 ProcessBatch(CommandRequest),
303 Shutdown,
304 ForcePhaseAdvance,
305 TriggerSync,
306 GetStatistics(oneshot::Sender<EngineStatistics>),
307}
308
309pub type EngineCommandSender = mpsc::UnboundedSender<EngineCommand>;
310pub type EngineCommandReceiver = mpsc::UnboundedReceiver<EngineCommand>;