1use std::sync::Arc;
2
3use prost::Message;
4use tempfile::TempDir;
5use tonic::async_trait;
6
7use crate::proto::client::write_command::Delete;
8use crate::proto::client::write_command::Insert;
9use crate::proto::client::write_command::Operation;
10use crate::proto::client::WriteCommand;
11use crate::proto::common::entry_payload::Payload;
12use crate::proto::common::Entry;
13use crate::proto::common::EntryPayload;
14use crate::proto::common::LogId;
15use crate::proto::storage::SnapshotMetadata;
16use crate::storage::StateMachine;
17use crate::Error;
18
19pub struct StateMachineTestSuite;
26
27#[async_trait]
29pub trait StateMachineBuilder: Send + Sync {
30 async fn build(&self) -> Result<Arc<dyn StateMachine>, Error>;
32
33 async fn cleanup(&self) -> Result<(), Error>;
35}
36
37impl StateMachineTestSuite {
38 pub async fn run_all_tests<B: StateMachineBuilder>(builder: B) -> Result<(), Error> {
40 Self::test_start_stop(builder.build().await?).await?;
41 Self::test_basic_kv_operations(builder.build().await?).await?;
42 Self::test_apply_chunk_functionality(builder.build().await?).await?;
43 Self::test_last_applied_detection(builder.build().await?).await?;
44 Self::test_snapshot_operations(builder.build().await?).await?;
45 Self::test_persistence(builder.build().await?).await?;
46 Self::test_reset_operation(builder.build().await?).await?;
47
48 builder.cleanup().await?;
49 Ok(())
50 }
51
52 async fn test_start_stop(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
54 assert!(state_machine.is_running());
56
57 state_machine.start()?;
59 assert!(state_machine.is_running());
60 state_machine.stop()?;
61 assert!(!state_machine.is_running());
62 state_machine.start()?;
63 assert!(state_machine.is_running());
64
65 Ok(())
66 }
67
68 async fn test_basic_kv_operations(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
70 let test_key = b"test_key";
71 let test_value = b"test_value";
72
73 let entries = vec![create_insert_entry(
75 1,
76 test_key.to_vec(),
77 test_value.to_vec(),
78 )];
79 state_machine.apply_chunk(entries).await?;
80
81 match state_machine.get(test_key)? {
83 Some(value) => assert_eq!(value, test_value),
84 None => panic!("Value not found after insert"),
85 }
86
87 let entries = vec![create_delete_entry(2, test_key.to_vec())];
89 state_machine.apply_chunk(entries).await?;
90
91 assert!(state_machine.get(test_key)?.is_none());
93
94 Ok(())
95 }
96
97 async fn test_apply_chunk_functionality(
99 state_machine: Arc<dyn StateMachine>
100 ) -> Result<(), Error> {
101 let entries = vec![
103 create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
104 create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
105 create_delete_entry(3, b"key1".to_vec()),
106 create_insert_entry(4, b"key3".to_vec(), b"value3".to_vec()),
107 ];
108
109 state_machine.apply_chunk(entries).await?;
110
111 assert!(state_machine.get(b"key1")?.is_none());
113 assert_eq!(state_machine.get(b"key2")?, Some(b"value2".to_vec()));
114 assert_eq!(state_machine.get(b"key3")?, Some(b"value3".to_vec()));
115 assert_eq!(state_machine.last_applied(), LogId { index: 4, term: 1 });
116
117 Ok(())
118 }
119
120 async fn test_last_applied_detection(
122 state_machine: Arc<dyn StateMachine>
123 ) -> Result<(), Error> {
124 assert!(state_machine.reset().await.is_ok());
125 assert_eq!(state_machine.last_applied(), LogId { index: 0, term: 0 });
127
128 let entries = vec![
130 create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
131 create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
132 create_insert_entry(3, b"key3".to_vec(), b"value3".to_vec()),
133 ];
134
135 state_machine.apply_chunk(entries).await?;
136 assert_eq!(state_machine.last_applied(), LogId { index: 3, term: 1 });
137
138 Ok(())
139 }
140
141 async fn test_snapshot_operations(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
143 let entries = vec![
145 create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
146 create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
147 create_insert_entry(3, b"key3".to_vec(), b"value3".to_vec()),
148 ];
149 state_machine.apply_chunk(entries).await?;
150
151 let temp_dir = TempDir::new()?;
153 let snapshot_dir = temp_dir.path().to_path_buf();
154
155 let last_included = LogId { index: 3, term: 1 };
157 let checksum = state_machine
158 .generate_snapshot_data(snapshot_dir.clone(), last_included)
159 .await?;
160
161 let metadata = state_machine.snapshot_metadata();
163 assert!(metadata.is_some());
164 assert_eq!(metadata.unwrap().last_included, Some(last_included));
165
166 let metadata = SnapshotMetadata {
169 last_included: Some(last_included),
170 checksum: checksum.to_vec(),
171 };
172
173 state_machine.reset().await?;
175 assert_eq!(state_machine.get(b"key1")?, None);
176 assert_eq!(state_machine.get(b"key2")?, None);
177 assert_eq!(state_machine.get(b"key3")?, None);
178 assert_eq!(state_machine.last_applied(), LogId::default());
179
180 let entries = vec![
183 create_insert_entry(1, b"key1".to_vec(), b"old_value1".to_vec()),
184 create_insert_entry(2, b"old_key2".to_vec(), b"vold_alue2".to_vec()),
185 create_insert_entry(3, b"old_key3".to_vec(), b"old_value3".to_vec()),
186 ];
187 state_machine.apply_chunk(entries).await?;
188
189 state_machine.apply_snapshot_from_file(&metadata, snapshot_dir).await?;
190
191 assert_eq!(state_machine.get(b"key1")?, Some(b"value1".to_vec()));
193 assert_eq!(state_machine.get(b"key2")?, Some(b"value2".to_vec()));
194 assert_eq!(state_machine.get(b"key3")?, Some(b"value3".to_vec()));
195 assert_eq!(state_machine.get(b"old_key2")?, None);
196 assert_eq!(state_machine.get(b"old_key3")?, None);
197 assert_eq!(state_machine.last_applied(), last_included);
198
199 Ok(())
200 }
201
202 async fn test_persistence(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
204 let entries = vec![
206 create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
207 create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
208 ];
209 state_machine.apply_chunk(entries).await?;
210
211 let last_applied = LogId { index: 2, term: 1 };
213 state_machine.persist_last_applied(last_applied)?;
214
215 let snapshot_metadata = SnapshotMetadata {
217 last_included: Some(last_applied),
218 checksum: vec![0; 32],
219 };
220 state_machine.persist_last_snapshot_metadata(&snapshot_metadata)?;
221
222 state_machine.flush()?;
224
225 Ok(())
226 }
227
228 async fn test_reset_operation(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
236 let entries = vec![
238 create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
239 create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
240 create_insert_entry(3, b"key3".to_vec(), b"value3".to_vec()),
241 ];
242 state_machine.apply_chunk(entries).await?;
243
244 assert_eq!(state_machine.get(b"key1")?, Some(b"value1".to_vec()));
246 assert_eq!(state_machine.get(b"key2")?, Some(b"value2".to_vec()));
247 assert_eq!(state_machine.get(b"key3")?, Some(b"value3".to_vec()));
248 assert_eq!(state_machine.last_applied(), LogId { index: 3, term: 1 });
249 assert!(state_machine.snapshot_metadata().is_none());
250
251 let was_running = state_machine.is_running();
253
254 state_machine.reset().await?;
256
257 assert!(state_machine.get(b"key1")?.is_none());
259 assert!(state_machine.get(b"key2")?.is_none());
260 assert!(state_machine.get(b"key3")?.is_none());
261
262 assert_eq!(state_machine.last_applied(), LogId { index: 0, term: 0 });
264 assert!(state_machine.snapshot_metadata().is_none());
265
266 assert_eq!(state_machine.is_running(), was_running);
268
269 let new_entries = vec![
271 create_insert_entry(1, b"new_key1".to_vec(), b"new_value1".to_vec()),
272 create_insert_entry(2, b"new_key2".to_vec(), b"new_value2".to_vec()),
273 ];
274 state_machine.apply_chunk(new_entries).await?;
275
276 assert_eq!(
278 state_machine.get(b"new_key1")?,
279 Some(b"new_value1".to_vec())
280 );
281 assert_eq!(
282 state_machine.get(b"new_key2")?,
283 Some(b"new_value2".to_vec())
284 );
285 assert_eq!(state_machine.last_applied(), LogId { index: 2, term: 1 });
286
287 Ok(())
288 }
289}
290
291fn create_insert_entry(
293 index: u64,
294 key: Vec<u8>,
295 value: Vec<u8>,
296) -> Entry {
297 let insert = Insert { key, value };
298 let operation = Operation::Insert(insert);
299 let write_cmd = WriteCommand {
300 operation: Some(operation),
301 };
302
303 Entry {
304 index,
305 term: 1,
306 payload: Some(EntryPayload {
307 payload: Some(Payload::Command(write_cmd.encode_to_vec())),
308 }),
309 }
310}
311
312fn create_delete_entry(
314 index: u64,
315 key: Vec<u8>,
316) -> Entry {
317 let delete = Delete { key };
318 let operation = Operation::Delete(delete);
319 let write_cmd = WriteCommand {
320 operation: Some(operation),
321 };
322
323 Entry {
324 index,
325 term: 1,
326 payload: Some(EntryPayload {
327 payload: Some(Payload::Command(write_cmd.encode_to_vec())),
328 }),
329 }
330}