rabia_core/
state_machine.rs

1use crate::{Command, Result};
2use async_trait::async_trait;
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Snapshot {
8    pub version: u64,
9    pub data: Bytes,
10    pub checksum: u32,
11}
12
13impl Snapshot {
14    pub fn new(version: u64, data: impl Into<Bytes>) -> Self {
15        let data = data.into();
16        let checksum = crc32fast::hash(&data);
17        Self {
18            version,
19            data,
20            checksum,
21        }
22    }
23
24    pub fn verify_checksum(&self) -> bool {
25        crc32fast::hash(&self.data) == self.checksum
26    }
27}
28
29#[async_trait]
30pub trait StateMachine: Send + Sync {
31    type State: Clone + Send + Sync;
32
33    async fn apply_command(&mut self, command: &Command) -> Result<Bytes>;
34
35    async fn apply_commands(&mut self, commands: &[Command]) -> Result<Vec<Bytes>> {
36        let mut results = Vec::with_capacity(commands.len());
37        for command in commands {
38            results.push(self.apply_command(command).await?);
39        }
40        Ok(results)
41    }
42
43    async fn create_snapshot(&self) -> Result<Snapshot>;
44
45    async fn restore_snapshot(&mut self, snapshot: &Snapshot) -> Result<()>;
46
47    async fn get_state(&self) -> Self::State;
48
49    fn is_deterministic(&self) -> bool {
50        true
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct InMemoryStateMachine {
56    pub state: std::collections::HashMap<String, Bytes>,
57    pub version: u64,
58}
59
60impl InMemoryStateMachine {
61    pub fn new() -> Self {
62        Self {
63            state: std::collections::HashMap::new(),
64            version: 0,
65        }
66    }
67}
68
69impl Default for InMemoryStateMachine {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75#[async_trait]
76impl StateMachine for InMemoryStateMachine {
77    type State = std::collections::HashMap<String, Bytes>;
78
79    async fn apply_command(&mut self, command: &Command) -> Result<Bytes> {
80        let command_str = String::from_utf8_lossy(&command.data);
81        let parts: Vec<String> = command_str
82            .split_whitespace()
83            .map(|s| s.to_string())
84            .collect();
85
86        if parts.is_empty() {
87            return Ok(Bytes::from("ERROR: Empty command"));
88        }
89
90        match parts[0].as_str() {
91            "SET" if parts.len() == 3 => {
92                let key = parts[1].clone();
93                let value = Bytes::from(parts[2].clone());
94                self.state.insert(key, value.clone());
95                self.version += 1;
96                Ok(Bytes::from("OK"))
97            }
98            "GET" if parts.len() == 2 => {
99                let key = &parts[1];
100                match self.state.get(key) {
101                    Some(value) => Ok(value.clone()),
102                    None => Ok(Bytes::from("NOT_FOUND")),
103                }
104            }
105            "DEL" if parts.len() == 2 => {
106                let key = &parts[1];
107                match self.state.remove(key) {
108                    Some(_) => {
109                        self.version += 1;
110                        Ok(Bytes::from("OK"))
111                    }
112                    None => Ok(Bytes::from("NOT_FOUND")),
113                }
114            }
115            _ => Ok(Bytes::from("ERROR: Invalid command")),
116        }
117    }
118
119    async fn create_snapshot(&self) -> Result<Snapshot> {
120        let serialized = serde_json::to_vec(&self.state)?;
121        Ok(Snapshot::new(self.version, serialized))
122    }
123
124    async fn restore_snapshot(&mut self, snapshot: &Snapshot) -> Result<()> {
125        if !snapshot.verify_checksum() {
126            return Err(crate::RabiaError::ChecksumMismatch {
127                expected: snapshot.checksum,
128                actual: crc32fast::hash(&snapshot.data),
129            });
130        }
131
132        self.state = serde_json::from_slice(&snapshot.data)?;
133        self.version = snapshot.version;
134        Ok(())
135    }
136
137    async fn get_state(&self) -> Self::State {
138        self.state.clone()
139    }
140}