rabia_core/
state_machine.rs1use crate::{Command, Result};
2use async_trait::async_trait;
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Snapshot {
8 pub version: u64,
9 pub data: Bytes,
10 pub checksum: u32,
11}
12
13impl Snapshot {
14 pub fn new(version: u64, data: impl Into<Bytes>) -> Self {
15 let data = data.into();
16 let checksum = crc32fast::hash(&data);
17 Self {
18 version,
19 data,
20 checksum,
21 }
22 }
23
24 pub fn verify_checksum(&self) -> bool {
25 crc32fast::hash(&self.data) == self.checksum
26 }
27}
28
29#[async_trait]
30pub trait StateMachine: Send + Sync {
31 type State: Clone + Send + Sync;
32
33 async fn apply_command(&mut self, command: &Command) -> Result<Bytes>;
34
35 async fn apply_commands(&mut self, commands: &[Command]) -> Result<Vec<Bytes>> {
36 let mut results = Vec::with_capacity(commands.len());
37 for command in commands {
38 results.push(self.apply_command(command).await?);
39 }
40 Ok(results)
41 }
42
43 async fn create_snapshot(&self) -> Result<Snapshot>;
44
45 async fn restore_snapshot(&mut self, snapshot: &Snapshot) -> Result<()>;
46
47 async fn get_state(&self) -> Self::State;
48
49 fn is_deterministic(&self) -> bool {
50 true
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct InMemoryStateMachine {
56 pub state: std::collections::HashMap<String, Bytes>,
57 pub version: u64,
58}
59
60impl InMemoryStateMachine {
61 pub fn new() -> Self {
62 Self {
63 state: std::collections::HashMap::new(),
64 version: 0,
65 }
66 }
67}
68
69impl Default for InMemoryStateMachine {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75#[async_trait]
76impl StateMachine for InMemoryStateMachine {
77 type State = std::collections::HashMap<String, Bytes>;
78
79 async fn apply_command(&mut self, command: &Command) -> Result<Bytes> {
80 let command_str = String::from_utf8_lossy(&command.data);
81 let parts: Vec<String> = command_str
82 .split_whitespace()
83 .map(|s| s.to_string())
84 .collect();
85
86 if parts.is_empty() {
87 return Ok(Bytes::from("ERROR: Empty command"));
88 }
89
90 match parts[0].as_str() {
91 "SET" if parts.len() == 3 => {
92 let key = parts[1].clone();
93 let value = Bytes::from(parts[2].clone());
94 self.state.insert(key, value.clone());
95 self.version += 1;
96 Ok(Bytes::from("OK"))
97 }
98 "GET" if parts.len() == 2 => {
99 let key = &parts[1];
100 match self.state.get(key) {
101 Some(value) => Ok(value.clone()),
102 None => Ok(Bytes::from("NOT_FOUND")),
103 }
104 }
105 "DEL" if parts.len() == 2 => {
106 let key = &parts[1];
107 match self.state.remove(key) {
108 Some(_) => {
109 self.version += 1;
110 Ok(Bytes::from("OK"))
111 }
112 None => Ok(Bytes::from("NOT_FOUND")),
113 }
114 }
115 _ => Ok(Bytes::from("ERROR: Invalid command")),
116 }
117 }
118
119 async fn create_snapshot(&self) -> Result<Snapshot> {
120 let serialized = serde_json::to_vec(&self.state)?;
121 Ok(Snapshot::new(self.version, serialized))
122 }
123
124 async fn restore_snapshot(&mut self, snapshot: &Snapshot) -> Result<()> {
125 if !snapshot.verify_checksum() {
126 return Err(crate::RabiaError::ChecksumMismatch {
127 expected: snapshot.checksum,
128 actual: crc32fast::hash(&snapshot.data),
129 });
130 }
131
132 self.state = serde_json::from_slice(&snapshot.data)?;
133 self.version = snapshot.version;
134 Ok(())
135 }
136
137 async fn get_state(&self) -> Self::State {
138 self.state.clone()
139 }
140}