use std::sync::Arc;
use d_engine_proto::common::Entry;
use tokio::sync::{mpsc, watch};
use super::StateMachineWorker;
use crate::{Error, MockStateMachineHandler, MockTypeConfig, RaftEvent};
fn create_test_entry(
index: u64,
term: u64,
) -> Entry {
Entry {
index,
term,
payload: Some(d_engine_proto::common::EntryPayload::command(
bytes::Bytes::from("test_data"),
)),
}
}
#[tokio::test]
async fn test_apply_success_sends_apply_completed() {
let mut mock_smh = MockStateMachineHandler::new();
mock_smh.expect_apply_chunk().times(1).returning(|entries| {
let results: Vec<crate::ApplyResult> =
entries.iter().map(|e| crate::ApplyResult::success(e.index)).collect();
Ok(results)
});
let (sm_apply_tx, sm_apply_rx) = mpsc::unbounded_channel();
let (event_tx, mut event_rx) = mpsc::channel(10);
let (_shutdown_tx, shutdown_rx) = watch::channel(());
let worker = StateMachineWorker::<MockTypeConfig>::new(
1,
Arc::new(mock_smh),
sm_apply_rx,
event_tx,
shutdown_rx,
);
tokio::spawn(async move {
let _ = worker.run().await;
});
let entries = vec![create_test_entry(1, 1), create_test_entry(2, 1)];
sm_apply_tx.send(entries).unwrap();
match tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv()).await {
Ok(Some(RaftEvent::ApplyCompleted {
last_index,
results,
})) => {
assert_eq!(last_index, 2, "last_index should match last entry");
assert_eq!(results.len(), 2, "results should contain 2 entries");
assert_eq!(results[0].index, 1);
assert_eq!(results[1].index, 2);
}
Ok(Some(other)) => panic!("Expected ApplyCompleted, got {other:?}"),
Ok(None) => panic!("Event channel closed unexpectedly"),
Err(_) => panic!("Timeout waiting for ApplyCompleted event"),
}
}
#[tokio::test]
async fn test_apply_failure_sends_fatal_error() {
let mut mock_smh = MockStateMachineHandler::new();
mock_smh.expect_apply_chunk().times(1).returning(|_| {
Err(Error::Fatal(
"Disk failure - cannot write to storage".to_string(),
))
});
let (sm_apply_tx, sm_apply_rx) = mpsc::unbounded_channel();
let (event_tx, mut event_rx) = mpsc::channel(10);
let (_shutdown_tx, shutdown_rx) = watch::channel(());
let worker = StateMachineWorker::<MockTypeConfig>::new(
1,
Arc::new(mock_smh),
sm_apply_rx,
event_tx,
shutdown_rx,
);
let worker_handle = tokio::spawn(async move { worker.run().await });
let entries = vec![create_test_entry(1, 1)];
sm_apply_tx.send(entries).unwrap();
match tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv()).await {
Ok(Some(RaftEvent::FatalError { source, error })) => {
assert_eq!(source, "StateMachine", "source should be StateMachine");
assert!(
error.contains("Disk failure") || error.contains("storage"),
"error should contain failure details, got: {error}"
);
}
Ok(Some(other)) => panic!("Expected FatalError, got {other:?}"),
Ok(None) => panic!("Event channel closed unexpectedly"),
Err(_) => panic!("Timeout waiting for FatalError event"),
}
let result = worker_handle.await.unwrap();
assert!(
result.is_err(),
"Worker should return error after fatal failure"
);
}
#[tokio::test]
async fn test_shutdown_drains_remaining_entries() {
let mut mock_smh = MockStateMachineHandler::new();
mock_smh.expect_apply_chunk().times(3).returning(|entries| {
let results: Vec<crate::ApplyResult> =
entries.iter().map(|e| crate::ApplyResult::success(e.index)).collect();
Ok(results)
});
let (sm_apply_tx, sm_apply_rx) = mpsc::unbounded_channel();
let (event_tx, mut event_rx) = mpsc::channel(10);
let (shutdown_tx, shutdown_rx) = watch::channel(());
let worker = StateMachineWorker::<MockTypeConfig>::new(
1,
Arc::new(mock_smh),
sm_apply_rx,
event_tx,
shutdown_rx,
);
let worker_handle = tokio::spawn(async move { worker.run().await });
for i in 1..=3 {
sm_apply_tx.send(vec![create_test_entry(i, 1)]).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
shutdown_tx.send(()).unwrap();
let mut apply_count = 0;
while apply_count < 3 {
match tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv()).await {
Ok(Some(RaftEvent::ApplyCompleted { .. })) => {
apply_count += 1;
}
Ok(Some(other)) => panic!("Expected ApplyCompleted, got {other:?}"),
Ok(None) => break,
Err(_) => panic!("Timeout waiting for ApplyCompleted events"),
}
}
assert_eq!(
apply_count, 3,
"All 3 entries should be applied during shutdown drain"
);
let result = worker_handle.await.unwrap();
assert!(result.is_ok(), "Worker should exit cleanly after shutdown");
}
#[tokio::test]
async fn test_channel_closed_worker_exits() {
let mock_smh = MockStateMachineHandler::new();
let (sm_apply_tx, sm_apply_rx) = mpsc::unbounded_channel();
let (event_tx, _event_rx) = mpsc::channel(10);
let (_shutdown_tx, shutdown_rx) = watch::channel(());
let worker = StateMachineWorker::<MockTypeConfig>::new(
1,
Arc::new(mock_smh),
sm_apply_rx,
event_tx,
shutdown_rx,
);
let worker_handle = tokio::spawn(async move { worker.run().await });
drop(sm_apply_tx);
let result = tokio::time::timeout(std::time::Duration::from_millis(100), worker_handle).await;
match result {
Ok(Ok(Ok(()))) => {
}
Ok(Ok(Err(e))) => panic!("Worker returned error: {e:?}"),
Ok(Err(e)) => panic!("Worker task panicked: {e:?}"),
Err(_) => panic!("Worker did not exit within timeout"),
}
}
#[tokio::test]
async fn test_multiple_batches_sequential_processing() {
let mut mock_smh = MockStateMachineHandler::new();
mock_smh.expect_apply_chunk().times(3).returning(|entries| {
let results: Vec<crate::ApplyResult> =
entries.iter().map(|e| crate::ApplyResult::success(e.index)).collect();
Ok(results)
});
let (sm_apply_tx, sm_apply_rx) = mpsc::unbounded_channel();
let (event_tx, mut event_rx) = mpsc::channel(10);
let (_shutdown_tx, shutdown_rx) = watch::channel(());
let worker = StateMachineWorker::<MockTypeConfig>::new(
1,
Arc::new(mock_smh),
sm_apply_rx,
event_tx,
shutdown_rx,
);
tokio::spawn(async move {
let _ = worker.run().await;
});
sm_apply_tx
.send(vec![create_test_entry(1, 1), create_test_entry(2, 1)])
.unwrap();
sm_apply_tx
.send(vec![create_test_entry(3, 1), create_test_entry(4, 1)])
.unwrap();
sm_apply_tx
.send(vec![create_test_entry(5, 1), create_test_entry(6, 1)])
.unwrap();
let expected_last_indices = vec![2, 4, 6];
for expected_index in expected_last_indices {
match tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv()).await {
Ok(Some(RaftEvent::ApplyCompleted { last_index, .. })) => {
assert_eq!(
last_index, expected_index,
"last_index should match expected value"
);
}
Ok(Some(other)) => panic!("Expected ApplyCompleted, got {other:?}"),
Ok(None) => panic!("Event channel closed unexpectedly"),
Err(_) => panic!("Timeout waiting for ApplyCompleted event"),
}
}
}