rabia_counter_example/
lib.rs1use async_trait::async_trait;
31use rabia_core::smr::StateMachine;
32use serde::{Deserialize, Serialize};
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum CounterCommand {
37 Increment(i64),
39 Decrement(i64),
41 Set(i64),
43 Get,
45 Reset,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51pub struct CounterResponse {
52 pub value: i64,
54 pub success: bool,
56 pub message: Option<String>,
58}
59
60impl CounterResponse {
61 pub fn success(value: i64) -> Self {
62 Self {
63 value,
64 success: true,
65 message: None,
66 }
67 }
68
69 pub fn error(value: i64, message: String) -> Self {
70 Self {
71 value,
72 success: false,
73 message: Some(message),
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
80pub struct CounterState {
81 pub value: i64,
83 pub operation_count: u64,
85}
86
87#[derive(Debug, Clone)]
89pub struct CounterSMR {
90 state: CounterState,
91}
92
93impl CounterSMR {
94 pub fn new() -> Self {
96 Self {
97 state: CounterState::default(),
98 }
99 }
100
101 pub fn with_value(initial_value: i64) -> Self {
103 Self {
104 state: CounterState {
105 value: initial_value,
106 operation_count: 0,
107 },
108 }
109 }
110
111 pub fn value(&self) -> i64 {
113 self.state.value
114 }
115
116 pub fn operation_count(&self) -> u64 {
118 self.state.operation_count
119 }
120}
121
122impl Default for CounterSMR {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128#[async_trait]
129impl StateMachine for CounterSMR {
130 type Command = CounterCommand;
131 type Response = CounterResponse;
132 type State = CounterState;
133
134 async fn apply_command(&mut self, command: Self::Command) -> Self::Response {
135 self.state.operation_count += 1;
136
137 match command {
138 CounterCommand::Increment(value) => {
139 match self.state.value.checked_add(value) {
141 Some(new_value) => {
142 self.state.value = new_value;
143 CounterResponse::success(self.state.value)
144 }
145 None => CounterResponse::error(
146 self.state.value,
147 "Overflow: cannot increment counter".to_string(),
148 ),
149 }
150 }
151 CounterCommand::Decrement(value) => {
152 match self.state.value.checked_sub(value) {
154 Some(new_value) => {
155 self.state.value = new_value;
156 CounterResponse::success(self.state.value)
157 }
158 None => CounterResponse::error(
159 self.state.value,
160 "Underflow: cannot decrement counter".to_string(),
161 ),
162 }
163 }
164 CounterCommand::Set(value) => {
165 self.state.value = value;
166 CounterResponse::success(self.state.value)
167 }
168 CounterCommand::Get => {
169 CounterResponse::success(self.state.value)
171 }
172 CounterCommand::Reset => {
173 self.state.value = 0;
174 CounterResponse::success(self.state.value)
175 }
176 }
177 }
178
179 fn get_state(&self) -> Self::State {
180 self.state.clone()
181 }
182
183 fn set_state(&mut self, state: Self::State) {
184 self.state = state;
185 }
186
187 fn serialize_state(&self) -> Vec<u8> {
188 bincode::serialize(&self.state).unwrap_or_default()
189 }
190
191 fn deserialize_state(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
192 self.state = bincode::deserialize(data)?;
193 Ok(())
194 }
195
196 async fn apply_commands(&mut self, commands: Vec<Self::Command>) -> Vec<Self::Response> {
197 let mut responses = Vec::with_capacity(commands.len());
198 for command in commands {
199 responses.push(self.apply_command(command).await);
200 }
201 responses
202 }
203
204 fn is_deterministic(&self) -> bool {
205 true
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[tokio::test]
214 async fn test_counter_basic_operations() {
215 let mut counter = CounterSMR::new();
216
217 let response = counter.apply_command(CounterCommand::Increment(5)).await;
219 assert!(response.success);
220 assert_eq!(response.value, 5);
221 assert_eq!(counter.value(), 5);
222
223 let response = counter.apply_command(CounterCommand::Decrement(2)).await;
225 assert!(response.success);
226 assert_eq!(response.value, 3);
227 assert_eq!(counter.value(), 3);
228
229 let response = counter.apply_command(CounterCommand::Set(10)).await;
231 assert!(response.success);
232 assert_eq!(response.value, 10);
233 assert_eq!(counter.value(), 10);
234
235 let response = counter.apply_command(CounterCommand::Get).await;
237 assert!(response.success);
238 assert_eq!(response.value, 10);
239
240 let response = counter.apply_command(CounterCommand::Reset).await;
242 assert!(response.success);
243 assert_eq!(response.value, 0);
244 assert_eq!(counter.value(), 0);
245 }
246
247 #[tokio::test]
248 async fn test_counter_overflow_underflow() {
249 let mut counter = CounterSMR::with_value(i64::MAX);
250
251 let response = counter.apply_command(CounterCommand::Increment(1)).await;
253 assert!(!response.success);
254 assert_eq!(response.value, i64::MAX);
255 assert!(response.message.as_ref().unwrap().contains("Overflow"));
256
257 counter = CounterSMR::with_value(i64::MIN);
259
260 let response = counter.apply_command(CounterCommand::Decrement(1)).await;
262 assert!(!response.success);
263 assert_eq!(response.value, i64::MIN);
264 assert!(response.message.as_ref().unwrap().contains("Underflow"));
265 }
266
267 #[tokio::test]
268 async fn test_counter_state_serialization() {
269 let mut counter = CounterSMR::new();
270
271 counter.apply_command(CounterCommand::Increment(42)).await;
273 counter.apply_command(CounterCommand::Decrement(10)).await;
274
275 let serialized = counter.serialize_state();
277 assert!(!serialized.is_empty());
278
279 let mut new_counter = CounterSMR::new();
281 new_counter.deserialize_state(&serialized).unwrap();
282
283 assert_eq!(new_counter.value(), 32);
285 assert_eq!(new_counter.operation_count(), 2);
286 assert_eq!(new_counter.get_state(), counter.get_state());
287 }
288
289 #[tokio::test]
290 async fn test_counter_multiple_commands() {
291 let mut counter = CounterSMR::new();
292
293 let commands = vec![
294 CounterCommand::Increment(10),
295 CounterCommand::Increment(5),
296 CounterCommand::Decrement(3),
297 CounterCommand::Set(100),
298 CounterCommand::Get,
299 ];
300
301 let responses = counter.apply_commands(commands).await;
302 assert_eq!(responses.len(), 5);
303
304 assert!(responses.iter().all(|r| r.success));
306
307 assert_eq!(counter.value(), 100);
309 assert_eq!(counter.operation_count(), 5);
310
311 assert_eq!(responses[0].value, 10); assert_eq!(responses[1].value, 15); assert_eq!(responses[2].value, 12); assert_eq!(responses[3].value, 100); assert_eq!(responses[4].value, 100); }
318
319 #[test]
320 fn test_counter_deterministic() {
321 let counter = CounterSMR::new();
322 assert!(counter.is_deterministic());
323 }
324}