d_engine/storage/
state_machine_test.rs

1use std::sync::Arc;
2
3use prost::Message;
4use tempfile::TempDir;
5use tonic::async_trait;
6
7use crate::proto::client::write_command::Delete;
8use crate::proto::client::write_command::Insert;
9use crate::proto::client::write_command::Operation;
10use crate::proto::client::WriteCommand;
11use crate::proto::common::entry_payload::Payload;
12use crate::proto::common::Entry;
13use crate::proto::common::EntryPayload;
14use crate::proto::common::LogId;
15use crate::proto::storage::SnapshotMetadata;
16use crate::storage::StateMachine;
17use crate::Error;
18
19/// Test suite for StateMachine implementations
20///
21/// This suite provides comprehensive tests that can be used to validate
22/// any StateMachine implementation. Developers should implement the
23/// `StateMachineBuilder` trait and then call `run_all_state_machine_tests`
24/// with their builder.
25pub struct StateMachineTestSuite;
26
27/// Builder trait for creating StateMachine instances for testing
28#[async_trait]
29pub trait StateMachineBuilder: Send + Sync {
30    /// Create a new StateMachine instance for testing
31    async fn build(&self) -> Result<Arc<dyn StateMachine>, Error>;
32
33    /// Clean up any resources after testing
34    async fn cleanup(&self) -> Result<(), Error>;
35}
36
37impl StateMachineTestSuite {
38    /// Run all state machine tests
39    pub async fn run_all_tests<B: StateMachineBuilder>(builder: B) -> Result<(), Error> {
40        Self::test_start_stop(builder.build().await?).await?;
41        Self::test_basic_kv_operations(builder.build().await?).await?;
42        Self::test_apply_chunk_functionality(builder.build().await?).await?;
43        Self::test_last_applied_detection(builder.build().await?).await?;
44        Self::test_snapshot_operations(builder.build().await?).await?;
45        Self::test_persistence(builder.build().await?).await?;
46        Self::test_reset_operation(builder.build().await?).await?;
47
48        builder.cleanup().await?;
49        Ok(())
50    }
51
52    /// Test start/stop functionality
53    async fn test_start_stop(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
54        // Test default state
55        assert!(state_machine.is_running());
56
57        // Test explicit start/stop
58        state_machine.start()?;
59        assert!(state_machine.is_running());
60        state_machine.stop()?;
61        assert!(!state_machine.is_running());
62        state_machine.start()?;
63        assert!(state_machine.is_running());
64
65        Ok(())
66    }
67
68    /// Test basic key-value operations
69    async fn test_basic_kv_operations(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
70        let test_key = b"test_key";
71        let test_value = b"test_value";
72
73        // Create an insert entry
74        let entries = vec![create_insert_entry(
75            1,
76            test_key.to_vec(),
77            test_value.to_vec(),
78        )];
79        state_machine.apply_chunk(entries).await?;
80
81        // Verify the value was inserted
82        match state_machine.get(test_key)? {
83            Some(value) => assert_eq!(value, test_value),
84            None => panic!("Value not found after insert"),
85        }
86
87        // Create a delete entry
88        let entries = vec![create_delete_entry(2, test_key.to_vec())];
89        state_machine.apply_chunk(entries).await?;
90
91        // Verify the value was deleted
92        assert!(state_machine.get(test_key)?.is_none());
93
94        Ok(())
95    }
96
97    /// Test chunk application functionality
98    async fn test_apply_chunk_functionality(
99        state_machine: Arc<dyn StateMachine>
100    ) -> Result<(), Error> {
101        // Create a mix of insert and delete operations
102        let entries = vec![
103            create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
104            create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
105            create_delete_entry(3, b"key1".to_vec()),
106            create_insert_entry(4, b"key3".to_vec(), b"value3".to_vec()),
107        ];
108
109        state_machine.apply_chunk(entries).await?;
110
111        // Verify the final state
112        assert!(state_machine.get(b"key1")?.is_none());
113        assert_eq!(state_machine.get(b"key2")?, Some(b"value2".to_vec()));
114        assert_eq!(state_machine.get(b"key3")?, Some(b"value3".to_vec()));
115        assert_eq!(state_machine.last_applied(), LogId { index: 4, term: 1 });
116
117        Ok(())
118    }
119
120    /// Test last applied index detection
121    async fn test_last_applied_detection(
122        state_machine: Arc<dyn StateMachine>
123    ) -> Result<(), Error> {
124        assert!(state_machine.reset().await.is_ok());
125        // Initial state
126        assert_eq!(state_machine.last_applied(), LogId { index: 0, term: 0 });
127
128        // Apply entries with different terms
129        let entries = vec![
130            create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
131            create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
132            create_insert_entry(3, b"key3".to_vec(), b"value3".to_vec()),
133        ];
134
135        state_machine.apply_chunk(entries).await?;
136        assert_eq!(state_machine.last_applied(), LogId { index: 3, term: 1 });
137
138        Ok(())
139    }
140
141    /// Test snapshot operations
142    async fn test_snapshot_operations(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
143        // Add some test data
144        let entries = vec![
145            create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
146            create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
147            create_insert_entry(3, b"key3".to_vec(), b"value3".to_vec()),
148        ];
149        state_machine.apply_chunk(entries).await?;
150
151        // Create a temporary directory for the snapshot
152        let temp_dir = TempDir::new()?;
153        let snapshot_dir = temp_dir.path().to_path_buf();
154
155        // Generate snapshot
156        let last_included = LogId { index: 3, term: 1 };
157        let checksum = state_machine
158            .generate_snapshot_data(snapshot_dir.clone(), last_included)
159            .await?;
160
161        // Verify snapshot metadata was updated
162        let metadata = state_machine.snapshot_metadata();
163        assert!(metadata.is_some());
164        assert_eq!(metadata.unwrap().last_included, Some(last_included));
165
166        // Apply snapshot (simulate receiving from leader)
167        // let snapshot_path = snapshot_dir.join("snapshot.bin");
168        let metadata = SnapshotMetadata {
169            last_included: Some(last_included),
170            checksum: checksum.to_vec(),
171        };
172
173        //Reset State Machine to make sure it is fresh
174        state_machine.reset().await?;
175        assert_eq!(state_machine.get(b"key1")?, None);
176        assert_eq!(state_machine.get(b"key2")?, None);
177        assert_eq!(state_machine.get(b"key3")?, None);
178        assert_eq!(state_machine.last_applied(), LogId::default());
179
180        // Assume there were some entries. After applying snapshot, all the old ones should be
181        // cleared.
182        let entries = vec![
183            create_insert_entry(1, b"key1".to_vec(), b"old_value1".to_vec()),
184            create_insert_entry(2, b"old_key2".to_vec(), b"vold_alue2".to_vec()),
185            create_insert_entry(3, b"old_key3".to_vec(), b"old_value3".to_vec()),
186        ];
187        state_machine.apply_chunk(entries).await?;
188
189        state_machine.apply_snapshot_from_file(&metadata, snapshot_dir).await?;
190
191        // Verify state was preserved after snapshot application
192        assert_eq!(state_machine.get(b"key1")?, Some(b"value1".to_vec()));
193        assert_eq!(state_machine.get(b"key2")?, Some(b"value2".to_vec()));
194        assert_eq!(state_machine.get(b"key3")?, Some(b"value3".to_vec()));
195        assert_eq!(state_machine.get(b"old_key2")?, None);
196        assert_eq!(state_machine.get(b"old_key3")?, None);
197        assert_eq!(state_machine.last_applied(), last_included);
198
199        Ok(())
200    }
201
202    /// Test data persistence
203    async fn test_persistence(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
204        // Add test data
205        let entries = vec![
206            create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
207            create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
208        ];
209        state_machine.apply_chunk(entries).await?;
210
211        // Update last applied
212        let last_applied = LogId { index: 2, term: 1 };
213        state_machine.persist_last_applied(last_applied)?;
214
215        // Update snapshot metadata
216        let snapshot_metadata = SnapshotMetadata {
217            last_included: Some(last_applied),
218            checksum: vec![0; 32],
219        };
220        state_machine.persist_last_snapshot_metadata(&snapshot_metadata)?;
221
222        // Flush to ensure persistence
223        state_machine.flush()?;
224
225        Ok(())
226    }
227
228    /// Test reset operation functionality
229    ///
230    /// This test verifies that the reset operation:
231    /// 1. Clears all data from memory
232    /// 2. Resets Raft state to initial values
233    /// 3. Clears all persisted files
234    /// 4. Maintains operational state (running status)
235    async fn test_reset_operation(state_machine: Arc<dyn StateMachine>) -> Result<(), Error> {
236        // Add test data
237        let entries = vec![
238            create_insert_entry(1, b"key1".to_vec(), b"value1".to_vec()),
239            create_insert_entry(2, b"key2".to_vec(), b"value2".to_vec()),
240            create_insert_entry(3, b"key3".to_vec(), b"value3".to_vec()),
241        ];
242        state_machine.apply_chunk(entries).await?;
243
244        // Verify data exists
245        assert_eq!(state_machine.get(b"key1")?, Some(b"value1".to_vec()));
246        assert_eq!(state_machine.get(b"key2")?, Some(b"value2".to_vec()));
247        assert_eq!(state_machine.get(b"key3")?, Some(b"value3".to_vec()));
248        assert_eq!(state_machine.last_applied(), LogId { index: 3, term: 1 });
249        assert!(state_machine.snapshot_metadata().is_none());
250
251        // Store running state for verification
252        let was_running = state_machine.is_running();
253
254        // Perform reset
255        state_machine.reset().await?;
256
257        // Verify all data is cleared
258        assert!(state_machine.get(b"key1")?.is_none());
259        assert!(state_machine.get(b"key2")?.is_none());
260        assert!(state_machine.get(b"key3")?.is_none());
261
262        // Verify Raft state is reset
263        assert_eq!(state_machine.last_applied(), LogId { index: 0, term: 0 });
264        assert!(state_machine.snapshot_metadata().is_none());
265
266        // Verify operational state is maintained
267        assert_eq!(state_machine.is_running(), was_running);
268
269        // Test that we can add new data after reset
270        let new_entries = vec![
271            create_insert_entry(1, b"new_key1".to_vec(), b"new_value1".to_vec()),
272            create_insert_entry(2, b"new_key2".to_vec(), b"new_value2".to_vec()),
273        ];
274        state_machine.apply_chunk(new_entries).await?;
275
276        // Verify new data exists
277        assert_eq!(
278            state_machine.get(b"new_key1")?,
279            Some(b"new_value1".to_vec())
280        );
281        assert_eq!(
282            state_machine.get(b"new_key2")?,
283            Some(b"new_value2".to_vec())
284        );
285        assert_eq!(state_machine.last_applied(), LogId { index: 2, term: 1 });
286
287        Ok(())
288    }
289}
290
291/// Helper function to create an insert entry
292fn create_insert_entry(
293    index: u64,
294    key: Vec<u8>,
295    value: Vec<u8>,
296) -> Entry {
297    let insert = Insert { key, value };
298    let operation = Operation::Insert(insert);
299    let write_cmd = WriteCommand {
300        operation: Some(operation),
301    };
302
303    Entry {
304        index,
305        term: 1,
306        payload: Some(EntryPayload {
307            payload: Some(Payload::Command(write_cmd.encode_to_vec())),
308        }),
309    }
310}
311
312/// Helper function to create a delete entry
313fn create_delete_entry(
314    index: u64,
315    key: Vec<u8>,
316) -> Entry {
317    let delete = Delete { key };
318    let operation = Operation::Delete(delete);
319    let write_cmd = WriteCommand {
320        operation: Some(operation),
321    };
322
323    Entry {
324        index,
325        term: 1,
326        payload: Some(EntryPayload {
327            payload: Some(Payload::Command(write_cmd.encode_to_vec())),
328        }),
329    }
330}