Skip to main content

rabia_engine/
state.rs

1use bytes::Bytes;
2use dashmap::DashMap;
3use parking_lot::RwLock;
4use rabia_core::{
5    messages::{PendingBatch, PhaseData, SyncResponseMessage},
6    BatchId, CommandBatch, NodeId, PhaseId, RabiaError, Result,
7};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use tokio::sync::{mpsc, oneshot};
12
13#[derive(Debug)]
14pub struct EngineState {
15    pub current_phase: Arc<AtomicU64>,
16    pub last_committed_phase: Arc<AtomicU64>,
17    pub is_active: Arc<AtomicBool>,
18    pub has_quorum: Arc<AtomicBool>,
19
20    pub pending_batches: Arc<DashMap<BatchId, PendingBatch>>,
21    pub phases: Arc<DashMap<PhaseId, PhaseData>>,
22    pub sync_responses: Arc<DashMap<NodeId, SyncResponseMessage>>,
23
24    pub active_nodes: Arc<RwLock<std::collections::HashSet<NodeId>>>,
25    pub quorum_size: usize,
26
27    pub state_version: Arc<AtomicU64>,
28    pub last_cleanup: Arc<AtomicU64>,
29}
30
31impl EngineState {
32    pub fn new(quorum_size: usize) -> Self {
33        Self {
34            current_phase: Arc::new(AtomicU64::new(0)),
35            last_committed_phase: Arc::new(AtomicU64::new(0)),
36            is_active: Arc::new(AtomicBool::new(true)),
37            has_quorum: Arc::new(AtomicBool::new(true)),
38
39            pending_batches: Arc::new(DashMap::new()),
40            phases: Arc::new(DashMap::new()),
41            sync_responses: Arc::new(DashMap::new()),
42
43            active_nodes: Arc::new(RwLock::new(std::collections::HashSet::new())),
44            quorum_size,
45
46            state_version: Arc::new(AtomicU64::new(1)),
47            last_cleanup: Arc::new(AtomicU64::new(0)),
48        }
49    }
50
51    pub fn current_phase(&self) -> PhaseId {
52        PhaseId::new(self.current_phase.load(Ordering::Acquire))
53    }
54
55    pub fn last_committed_phase(&self) -> PhaseId {
56        PhaseId::new(self.last_committed_phase.load(Ordering::Acquire))
57    }
58
59    pub fn advance_phase(&self) -> PhaseId {
60        let new_phase = self.current_phase.fetch_add(1, Ordering::AcqRel) + 1;
61        self.increment_version();
62        PhaseId::new(new_phase)
63    }
64
65    pub fn commit_phase(&self, phase_id: PhaseId) -> Result<bool> {
66        let phase_value = phase_id.value();
67        let current_phase_value = self.current_phase.load(Ordering::Acquire);
68
69        // Validate phase ordering - can only commit phases <= current phase
70        if phase_value > current_phase_value {
71            return Err(RabiaError::InvalidStateTransition {
72                from: format!("current_phase={}", current_phase_value),
73                to: format!("commit_phase={}", phase_value),
74            });
75        }
76
77        let mut current = self.last_committed_phase.load(Ordering::Acquire);
78
79        // Only allow monotonic increases in committed phase
80        while current < phase_value {
81            match self.last_committed_phase.compare_exchange_weak(
82                current,
83                phase_value,
84                Ordering::AcqRel,
85                Ordering::Acquire,
86            ) {
87                Ok(_) => {
88                    self.increment_version();
89                    return Ok(true);
90                }
91                Err(actual) => {
92                    current = actual;
93                    // If someone else committed a higher phase, we're done
94                    if current >= phase_value {
95                        return Ok(false);
96                    }
97                }
98            }
99        }
100
101        // Phase was already committed or higher phase was committed
102        Ok(false)
103    }
104
105    pub fn is_active(&self) -> bool {
106        self.is_active.load(Ordering::Acquire)
107    }
108
109    pub fn set_active(&self, active: bool) {
110        if self.is_active.swap(active, Ordering::AcqRel) != active {
111            self.increment_version();
112        }
113    }
114
115    pub fn has_quorum(&self) -> bool {
116        self.has_quorum.load(Ordering::Acquire)
117    }
118
119    pub fn set_quorum(&self, has_quorum: bool) {
120        if self.has_quorum.swap(has_quorum, Ordering::AcqRel) != has_quorum {
121            self.increment_version();
122        }
123    }
124
125    pub fn get_active_nodes(&self) -> std::collections::HashSet<NodeId> {
126        self.active_nodes.read().clone()
127    }
128
129    pub fn update_active_nodes(&self, nodes: std::collections::HashSet<NodeId>) {
130        let has_quorum = nodes.len() >= self.quorum_size;
131
132        {
133            let mut active_nodes = self.active_nodes.write();
134            if *active_nodes != nodes {
135                *active_nodes = nodes;
136                self.increment_version();
137            }
138        }
139
140        self.set_quorum(has_quorum);
141        self.set_active(has_quorum);
142    }
143
144    pub fn add_pending_batch(&self, batch: CommandBatch, originator: NodeId) -> BatchId {
145        let pending = PendingBatch::new(batch, originator);
146        let batch_id = pending.batch.id;
147        self.pending_batches.insert(batch_id, pending);
148        self.increment_version();
149        batch_id
150    }
151
152    pub fn remove_pending_batch(&self, batch_id: &BatchId) -> Option<PendingBatch> {
153        let result = self.pending_batches.remove(batch_id).map(|(_, v)| v);
154        if result.is_some() {
155            self.increment_version();
156        }
157        result
158    }
159
160    pub fn get_pending_batch(&self, batch_id: &BatchId) -> Option<PendingBatch> {
161        self.pending_batches
162            .get(batch_id)
163            .map(|entry| entry.value().clone())
164    }
165
166    pub fn get_or_create_phase(&self, phase_id: PhaseId) -> PhaseData {
167        self.phases
168            .entry(phase_id)
169            .or_insert_with(|| {
170                self.increment_version();
171                PhaseData::new(phase_id)
172            })
173            .clone()
174    }
175
176    pub fn update_phase<F>(&self, phase_id: PhaseId, update_fn: F) -> Result<()>
177    where
178        F: FnOnce(&mut PhaseData),
179    {
180        if let Some(mut entry) = self.phases.get_mut(&phase_id) {
181            update_fn(&mut entry);
182            self.increment_version();
183        }
184        Ok(())
185    }
186
187    pub fn get_phase(&self, phase_id: &PhaseId) -> Option<PhaseData> {
188        self.phases.get(phase_id).map(|entry| entry.value().clone())
189    }
190
191    pub fn cleanup_old_phases(&self, max_phase_history: usize) -> usize {
192        let current_phase = self.current_phase();
193        let cutoff_phase = if current_phase.value() > max_phase_history as u64 {
194            PhaseId::new(current_phase.value() - max_phase_history as u64)
195        } else {
196            PhaseId::new(0)
197        };
198
199        let mut removed_count = 0;
200        self.phases.retain(|&phase_id, _| {
201            let should_keep = phase_id >= cutoff_phase;
202            if !should_keep {
203                removed_count += 1;
204            }
205            should_keep
206        });
207
208        if removed_count > 0 {
209            self.increment_version();
210            self.last_cleanup.store(
211                std::time::SystemTime::now()
212                    .duration_since(std::time::UNIX_EPOCH)
213                    .unwrap()
214                    .as_secs(),
215                Ordering::Release,
216            );
217        }
218
219        removed_count
220    }
221
222    pub fn cleanup_old_pending_batches(&self, max_age_secs: u64) -> usize {
223        let now = std::time::SystemTime::now()
224            .duration_since(std::time::UNIX_EPOCH)
225            .unwrap()
226            .as_millis() as u64;
227        let cutoff = now.saturating_sub(max_age_secs * 1000);
228
229        let mut removed_count = 0;
230        self.pending_batches.retain(|_, pending| {
231            let should_keep = pending.received_timestamp >= cutoff;
232            if !should_keep {
233                removed_count += 1;
234            }
235            should_keep
236        });
237
238        if removed_count > 0 {
239            self.increment_version();
240        }
241
242        removed_count
243    }
244
245    pub fn get_state_version(&self) -> u64 {
246        self.state_version.load(Ordering::Acquire)
247    }
248
249    fn increment_version(&self) {
250        self.state_version.fetch_add(1, Ordering::AcqRel);
251    }
252
253    pub fn add_sync_response(&self, node_id: NodeId, response: SyncResponseMessage) {
254        self.sync_responses.insert(node_id, response);
255    }
256
257    pub fn get_sync_responses(&self) -> HashMap<NodeId, SyncResponseMessage> {
258        self.sync_responses
259            .iter()
260            .map(|entry| (*entry.key(), entry.value().clone()))
261            .collect()
262    }
263
264    pub fn clear_sync_responses(&self) {
265        self.sync_responses.clear();
266    }
267
268    pub fn get_statistics(&self) -> EngineStatistics {
269        EngineStatistics {
270            current_phase: self.current_phase(),
271            last_committed_phase: self.last_committed_phase(),
272            pending_batches_count: self.pending_batches.len(),
273            phases_count: self.phases.len(),
274            active_nodes_count: self.active_nodes.read().len(),
275            has_quorum: self.has_quorum(),
276            is_active: self.is_active(),
277            state_version: self.get_state_version(),
278        }
279    }
280}
281
282#[derive(Debug, Clone)]
283pub struct EngineStatistics {
284    pub current_phase: PhaseId,
285    pub last_committed_phase: PhaseId,
286    pub pending_batches_count: usize,
287    pub phases_count: usize,
288    pub active_nodes_count: usize,
289    pub has_quorum: bool,
290    pub is_active: bool,
291    pub state_version: u64,
292}
293
294#[derive(Debug)]
295pub struct CommandRequest {
296    pub batch: CommandBatch,
297    pub response_tx: oneshot::Sender<Result<Vec<Bytes>>>,
298}
299
300#[derive(Debug)]
301pub enum EngineCommand {
302    ProcessBatch(CommandRequest),
303    Shutdown,
304    ForcePhaseAdvance,
305    TriggerSync,
306    GetStatistics(oneshot::Sender<EngineStatistics>),
307}
308
309pub type EngineCommandSender = mpsc::UnboundedSender<EngineCommand>;
310pub type EngineCommandReceiver = mpsc::UnboundedReceiver<EngineCommand>;