Skip to main content

rabia_counter_example/
lib.rs

1//! # Counter SMR Example
2//!
3//! A simple counter implementation that demonstrates how to create a State Machine
4//! Replication (SMR) application using the Rabia consensus protocol.
5//!
6//! This example shows the minimal implementation needed to create an SMR application:
7//! - Command types for operations
8//! - Response types for results
9//! - State type for the machine state
10//! - StateMachine trait implementation
11//!
12//! ## Example Usage
13//!
14//! ```rust
15//! use rabia_counter_example::{CounterSMR, CounterCommand};
16//! use rabia_core::smr::StateMachine;
17//!
18//! #[tokio::main]
19//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//!     let mut counter = CounterSMR::new();
21//!     
22//!     let command = CounterCommand::Increment(5);
23//!     let response = counter.apply_command(command).await;
24//!     println!("Counter value: {}", response.value);
25//!     
26//!     Ok(())
27//! }
28//! ```
29
30use async_trait::async_trait;
31use rabia_core::smr::StateMachine;
32use serde::{Deserialize, Serialize};
33
34/// Commands that can be applied to the counter
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum CounterCommand {
37    /// Increment the counter by the given value
38    Increment(i64),
39    /// Decrement the counter by the given value
40    Decrement(i64),
41    /// Set the counter to a specific value
42    Set(i64),
43    /// Get the current value (read-only operation)
44    Get,
45    /// Reset the counter to zero
46    Reset,
47}
48
49/// Response from counter operations
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51pub struct CounterResponse {
52    /// The current value after the operation
53    pub value: i64,
54    /// Whether the operation was successful
55    pub success: bool,
56    /// Optional message (e.g., for errors)
57    pub message: Option<String>,
58}
59
60impl CounterResponse {
61    pub fn success(value: i64) -> Self {
62        Self {
63            value,
64            success: true,
65            message: None,
66        }
67    }
68
69    pub fn error(value: i64, message: String) -> Self {
70        Self {
71            value,
72            success: false,
73            message: Some(message),
74        }
75    }
76}
77
78/// State of the counter state machine
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
80pub struct CounterState {
81    /// The current counter value
82    pub value: i64,
83    /// Total number of operations performed
84    pub operation_count: u64,
85}
86
87/// Simple counter state machine implementation
88#[derive(Debug, Clone)]
89pub struct CounterSMR {
90    state: CounterState,
91}
92
93impl CounterSMR {
94    /// Create a new counter state machine
95    pub fn new() -> Self {
96        Self {
97            state: CounterState::default(),
98        }
99    }
100
101    /// Create a new counter with an initial value
102    pub fn with_value(initial_value: i64) -> Self {
103        Self {
104            state: CounterState {
105                value: initial_value,
106                operation_count: 0,
107            },
108        }
109    }
110
111    /// Get the current counter value
112    pub fn value(&self) -> i64 {
113        self.state.value
114    }
115
116    /// Get the total number of operations performed
117    pub fn operation_count(&self) -> u64 {
118        self.state.operation_count
119    }
120}
121
122impl Default for CounterSMR {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128#[async_trait]
129impl StateMachine for CounterSMR {
130    type Command = CounterCommand;
131    type Response = CounterResponse;
132    type State = CounterState;
133
134    async fn apply_command(&mut self, command: Self::Command) -> Self::Response {
135        self.state.operation_count += 1;
136
137        match command {
138            CounterCommand::Increment(value) => {
139                // Check for overflow
140                match self.state.value.checked_add(value) {
141                    Some(new_value) => {
142                        self.state.value = new_value;
143                        CounterResponse::success(self.state.value)
144                    }
145                    None => CounterResponse::error(
146                        self.state.value,
147                        "Overflow: cannot increment counter".to_string(),
148                    ),
149                }
150            }
151            CounterCommand::Decrement(value) => {
152                // Check for underflow
153                match self.state.value.checked_sub(value) {
154                    Some(new_value) => {
155                        self.state.value = new_value;
156                        CounterResponse::success(self.state.value)
157                    }
158                    None => CounterResponse::error(
159                        self.state.value,
160                        "Underflow: cannot decrement counter".to_string(),
161                    ),
162                }
163            }
164            CounterCommand::Set(value) => {
165                self.state.value = value;
166                CounterResponse::success(self.state.value)
167            }
168            CounterCommand::Get => {
169                // Read-only operation, don't change state
170                CounterResponse::success(self.state.value)
171            }
172            CounterCommand::Reset => {
173                self.state.value = 0;
174                CounterResponse::success(self.state.value)
175            }
176        }
177    }
178
179    fn get_state(&self) -> Self::State {
180        self.state.clone()
181    }
182
183    fn set_state(&mut self, state: Self::State) {
184        self.state = state;
185    }
186
187    fn serialize_state(&self) -> Vec<u8> {
188        bincode::serialize(&self.state).unwrap_or_default()
189    }
190
191    fn deserialize_state(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
192        self.state = bincode::deserialize(data)?;
193        Ok(())
194    }
195
196    async fn apply_commands(&mut self, commands: Vec<Self::Command>) -> Vec<Self::Response> {
197        let mut responses = Vec::with_capacity(commands.len());
198        for command in commands {
199            responses.push(self.apply_command(command).await);
200        }
201        responses
202    }
203
204    fn is_deterministic(&self) -> bool {
205        true
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[tokio::test]
214    async fn test_counter_basic_operations() {
215        let mut counter = CounterSMR::new();
216
217        // Test increment
218        let response = counter.apply_command(CounterCommand::Increment(5)).await;
219        assert!(response.success);
220        assert_eq!(response.value, 5);
221        assert_eq!(counter.value(), 5);
222
223        // Test decrement
224        let response = counter.apply_command(CounterCommand::Decrement(2)).await;
225        assert!(response.success);
226        assert_eq!(response.value, 3);
227        assert_eq!(counter.value(), 3);
228
229        // Test set
230        let response = counter.apply_command(CounterCommand::Set(10)).await;
231        assert!(response.success);
232        assert_eq!(response.value, 10);
233        assert_eq!(counter.value(), 10);
234
235        // Test get
236        let response = counter.apply_command(CounterCommand::Get).await;
237        assert!(response.success);
238        assert_eq!(response.value, 10);
239
240        // Test reset
241        let response = counter.apply_command(CounterCommand::Reset).await;
242        assert!(response.success);
243        assert_eq!(response.value, 0);
244        assert_eq!(counter.value(), 0);
245    }
246
247    #[tokio::test]
248    async fn test_counter_overflow_underflow() {
249        let mut counter = CounterSMR::with_value(i64::MAX);
250
251        // Test overflow
252        let response = counter.apply_command(CounterCommand::Increment(1)).await;
253        assert!(!response.success);
254        assert_eq!(response.value, i64::MAX);
255        assert!(response.message.as_ref().unwrap().contains("Overflow"));
256
257        // Reset to minimum value
258        counter = CounterSMR::with_value(i64::MIN);
259
260        // Test underflow
261        let response = counter.apply_command(CounterCommand::Decrement(1)).await;
262        assert!(!response.success);
263        assert_eq!(response.value, i64::MIN);
264        assert!(response.message.as_ref().unwrap().contains("Underflow"));
265    }
266
267    #[tokio::test]
268    async fn test_counter_state_serialization() {
269        let mut counter = CounterSMR::new();
270
271        // Apply some operations
272        counter.apply_command(CounterCommand::Increment(42)).await;
273        counter.apply_command(CounterCommand::Decrement(10)).await;
274
275        // Serialize state
276        let serialized = counter.serialize_state();
277        assert!(!serialized.is_empty());
278
279        // Create new counter and deserialize
280        let mut new_counter = CounterSMR::new();
281        new_counter.deserialize_state(&serialized).unwrap();
282
283        // Verify state was restored
284        assert_eq!(new_counter.value(), 32);
285        assert_eq!(new_counter.operation_count(), 2);
286        assert_eq!(new_counter.get_state(), counter.get_state());
287    }
288
289    #[tokio::test]
290    async fn test_counter_multiple_commands() {
291        let mut counter = CounterSMR::new();
292
293        let commands = vec![
294            CounterCommand::Increment(10),
295            CounterCommand::Increment(5),
296            CounterCommand::Decrement(3),
297            CounterCommand::Set(100),
298            CounterCommand::Get,
299        ];
300
301        let responses = counter.apply_commands(commands).await;
302        assert_eq!(responses.len(), 5);
303
304        // All operations should succeed
305        assert!(responses.iter().all(|r| r.success));
306
307        // Check final value
308        assert_eq!(counter.value(), 100);
309        assert_eq!(counter.operation_count(), 5);
310
311        // Check individual responses
312        assert_eq!(responses[0].value, 10); // Increment 10
313        assert_eq!(responses[1].value, 15); // Increment 5
314        assert_eq!(responses[2].value, 12); // Decrement 3
315        assert_eq!(responses[3].value, 100); // Set 100
316        assert_eq!(responses[4].value, 100); // Get
317    }
318
319    #[test]
320    fn test_counter_deterministic() {
321        let counter = CounterSMR::new();
322        assert!(counter.is_deterministic());
323    }
324}