use crate::ClientCmd;
use crate::Error;
use crate::MaybeCloneOneshot;
use crate::MockBuilder;
use crate::MockMembership;
use crate::MockStateMachineHandler;
use crate::NewCommitData;
use crate::RaftEvent;
use crate::RaftOneshot;
use crate::RoleEvent;
use crate::raft_role::learner_state::LearnerState;
use crate::raft_role::role_state::RaftRoleState;
use crate::test_utils::mock::MockTypeConfig;
use crate::test_utils::mock::mock_raft_context;
use crate::test_utils::mock::mock_raft_context_with_temp;
use crate::test_utils::node_config;
use d_engine_proto::client::WriteCommand;
use d_engine_proto::common::LogId;
use d_engine_proto::common::NodeRole;
use d_engine_proto::common::NodeStatus;
use d_engine_proto::server::cluster::ClusterConfChangeRequest;
use d_engine_proto::server::cluster::ClusterConfUpdateResponse;
use d_engine_proto::server::cluster::MetadataRequest;
use d_engine_proto::server::cluster::NodeMeta;
use d_engine_proto::server::cluster::cluster_conf_update_response;
use d_engine_proto::server::election::VoteRequest;
use mockall::predicate::eq;
use std::sync::Arc;
use tokio::sync::{mpsc, watch};
use tonic::Code;
#[tokio::test]
async fn test_learner_drain_read_buffer_returns_error() {
let mut state =
LearnerState::<MockTypeConfig>::new(1, Arc::new(node_config("/tmp/test_learner_drain")));
let result = state.drain_read_buffer();
assert!(
result.is_err(),
"Learner drain_read_buffer should return error"
);
if let Err(e) = result {
let error_str = format!("{e:?}");
assert!(
error_str.contains("NotLeader"),
"Error should be NotLeader, got: {error_str}"
);
}
}
#[tokio::test]
async fn test_learner_tick_succeeds() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let (event_tx, _event_rx) = mpsc::channel(1);
assert!(
state.tick(&role_tx, &event_tx, &context).await.is_ok(),
"Learner tick should succeed"
);
}
#[tokio::test]
async fn test_learner_rejects_vote_request_updates_term() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let term_before = state.current_term();
let request_term = term_before + 10;
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::ReceiveVoteRequest(
VoteRequest {
term: request_term,
candidate_id: 2,
last_log_index: 11,
last_log_term: 0,
},
resp_tx,
);
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_ok(),
"handle_raft_event should succeed"
);
assert_eq!(
state.current_term(),
request_term,
"Should update to request term"
);
let response = resp_rx.recv().await.unwrap().unwrap();
assert!(!response.vote_granted, "Learner should never grant votes");
}
#[tokio::test]
async fn test_learner_rejects_cluster_conf_request() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::ClusterConf(MetadataRequest {}, resp_tx);
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_ok(),
"handle_raft_event should succeed"
);
let status = resp_rx.recv().await.unwrap().unwrap_err();
assert_eq!(
status.code(),
Code::PermissionDenied,
"Should return PermissionDenied"
);
}
#[tokio::test]
async fn test_learner_handles_cluster_conf_update_success() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut membership = MockMembership::new();
membership
.expect_update_cluster_conf_from_leader()
.times(1)
.returning(|_, _, _, _, _| {
Ok(ClusterConfUpdateResponse {
id: 1,
term: 1,
version: 1,
success: true,
error_code: cluster_conf_update_response::ErrorCode::Unspecified.into(),
})
});
membership.expect_get_cluster_conf_version().returning(|| 1);
context.membership = Arc::new(membership);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::ClusterConfUpdate(
ClusterConfChangeRequest {
id: 2, term: 1,
version: 1,
change: None,
},
resp_tx,
);
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_ok(),
"handle_raft_event should succeed"
);
let response = resp_rx.recv().await.unwrap().unwrap();
assert!(response.success, "Update should succeed");
assert_eq!(
response.error_code,
cluster_conf_update_response::ErrorCode::Unspecified as i32
);
}
#[tokio::test]
async fn test_learner_handles_append_entries_success() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let learner_term = 1;
let leader_term = learner_term + 1;
let expected_commit = 2;
let mut replication_handler = crate::MockReplicationCore::new();
replication_handler.expect_handle_append_entries().returning(move |_, _, _| {
Ok(crate::AppendResponseWithUpdates {
response: d_engine_proto::server::replication::AppendEntriesResponse::success(
1,
leader_term,
Some(LogId {
term: leader_term,
index: 1,
}),
),
commit_index_update: Some(expected_commit),
})
});
let membership = MockMembership::new();
context.membership = Arc::new(membership);
context.handlers.replication_handler = replication_handler;
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
state.update_current_term(learner_term);
let append_request = d_engine_proto::server::replication::AppendEntriesRequest {
term: leader_term,
leader_id: 5,
prev_log_index: 0,
prev_log_term: 1,
entries: vec![],
leader_commit_index: 0,
};
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let raft_event = RaftEvent::AppendEntries(append_request, resp_tx);
let (role_tx, mut role_rx) = mpsc::unbounded_channel();
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_ok(),
"handle_raft_event should succeed"
);
assert!(matches!(
role_rx.try_recv().unwrap(),
crate::RoleEvent::LeaderDiscovered(5, _)
));
assert!(matches!(
role_rx.try_recv().unwrap(),
crate::RoleEvent::NotifyNewCommitIndex(_)
));
assert_eq!(state.current_term(), leader_term);
assert_eq!(state.commit_index(), expected_commit);
let response = resp_rx.recv().await.unwrap().unwrap();
assert!(response.is_success());
}
#[tokio::test]
async fn test_learner_rejects_append_entries_stale_term() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let learner_term = 2;
let stale_term = learner_term - 1;
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
state.update_current_term(learner_term);
let append_request = d_engine_proto::server::replication::AppendEntriesRequest {
term: stale_term,
leader_id: 5,
prev_log_index: 0,
prev_log_term: 1,
entries: vec![],
leader_commit_index: 0,
};
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let raft_event = RaftEvent::AppendEntries(append_request, resp_tx);
let (role_tx, mut role_rx) = mpsc::unbounded_channel();
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_ok(),
"handle_raft_event should succeed"
);
assert!(role_rx.try_recv().is_err(), "No events should be sent");
assert_eq!(state.current_term(), learner_term);
let response = resp_rx.recv().await.unwrap().unwrap();
assert!(response.is_higher_term());
}
#[tokio::test]
async fn test_learner_handles_append_entries_handler_error() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let learner_term = 1;
let leader_term = learner_term + 1;
let mut replication_handler = crate::MockReplicationCore::new();
replication_handler
.expect_handle_append_entries()
.returning(|_, _, _| Err(crate::Error::Fatal("test".to_string())));
let membership = MockMembership::new();
context.membership = Arc::new(membership);
context.handlers.replication_handler = replication_handler;
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
state.update_current_term(learner_term);
let append_request = d_engine_proto::server::replication::AppendEntriesRequest {
term: leader_term,
leader_id: 5,
prev_log_index: 0,
prev_log_term: 1,
entries: vec![],
leader_commit_index: 0,
};
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let raft_event = RaftEvent::AppendEntries(append_request, resp_tx);
let (role_tx, mut role_rx) = mpsc::unbounded_channel();
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_err(),
"handle_raft_event should return error"
);
assert!(matches!(
role_rx.try_recv().unwrap(),
crate::RoleEvent::LeaderDiscovered(5, _)
));
assert!(role_rx.try_recv().is_err());
assert_eq!(state.current_term(), leader_term);
let response = resp_rx.recv().await.unwrap().unwrap();
assert!(!response.is_success());
}
#[tokio::test]
async fn test_learner_rejects_client_write_request() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let cmd = ClientCmd::Propose(
d_engine_proto::client::ClientWriteRequest {
client_id: 1,
command: Some(WriteCommand::default()),
},
resp_tx,
);
state.push_client_cmd(cmd, &context);
let result = resp_rx.recv().await.expect("channel should not be closed");
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.code(), tonic::Code::FailedPrecondition);
assert!(err.message().contains("Not leader"));
}
#[tokio::test]
async fn test_learner_rejects_client_read_request() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let client_read_request = d_engine_proto::client::ClientReadRequest {
client_id: 1,
consistency_policy: None,
keys: vec![],
};
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let cmd = ClientCmd::Read(client_read_request, resp_tx);
state.push_client_cmd(cmd, &context);
let result = resp_rx.recv().await.expect("channel should not be closed");
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.code(), tonic::Code::FailedPrecondition);
assert!(err.message().contains("Not leader"));
}
#[tokio::test]
async fn test_learner_rejects_join_cluster() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let request = d_engine_proto::server::cluster::JoinRequest {
status: d_engine_proto::common::NodeStatus::Promotable as i32,
node_id: 2,
node_role: d_engine_proto::common::NodeRole::Learner.into(),
address: "127.0.0.1:9090".to_string(),
};
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let raft_event = RaftEvent::JoinCluster(request, resp_tx);
let (role_tx, _role_rx) = mpsc::unbounded_channel();
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_err(),
"handle_raft_event should return error"
);
let response = resp_rx.recv().await.expect("should receive response");
assert!(response.is_err());
let status = response.unwrap_err();
assert_eq!(status.code(), Code::PermissionDenied);
}
#[tokio::test]
async fn test_learner_rejects_leader_discovery() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let request = d_engine_proto::server::cluster::LeaderDiscoveryRequest {
node_id: 2,
requester_address: "127.0.0.1:9090".to_string(),
};
let (resp_tx, mut resp_rx) = MaybeCloneOneshot::new();
let raft_event = RaftEvent::DiscoverLeader(request, resp_tx);
let (role_tx, _role_rx) = mpsc::unbounded_channel();
assert!(
state.handle_raft_event(raft_event, &context, role_tx).await.is_ok(),
"handle_raft_event should succeed"
);
let response = resp_rx.recv().await.expect("should receive response");
assert!(response.is_err());
let status = response.unwrap_err();
assert_eq!(status.code(), Code::PermissionDenied);
}
#[tokio::test]
async fn test_broadcast_discovery_succeeds_first_attempt() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut transport = crate::MockTransport::new();
transport.expect_discover_leader().returning(|_, _, _| {
Ok(vec![
d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 5,
leader_address: "127.0.0.1:5005".to_string(),
term: 3,
},
])
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let result = state.broadcast_discovery(context.membership.clone(), &context).await;
assert!(result.is_ok(), "Should return leader info");
}
#[tokio::test]
async fn test_broadcast_discovery_fails_after_retries() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut transport = crate::MockTransport::new();
transport.expect_discover_leader().returning(|_, _, _| Ok(vec![]));
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let result = state.broadcast_discovery(context.membership.clone(), &context).await;
assert!(result.is_err(), "Should error after retries");
assert!(matches!(
result.unwrap_err(),
crate::Error::System(crate::SystemError::Network(
crate::NetworkError::RetryTimeoutError(_)
))
));
}
#[tokio::test]
async fn test_select_valid_leader_prioritizes_highest_term() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let responses = vec![
d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 3,
term: 5,
leader_address: "127.0.0.1:5003".to_string(),
},
d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 5,
term: 7,
leader_address: "127.0.0.1:5005".to_string(),
}, d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 4,
term: 7,
leader_address: "127.0.0.1:5004".to_string(),
}, ];
let result = state.select_valid_leader(responses).await;
assert!(result.is_some());
assert_eq!(result.unwrap(), 5, "Should select highest term");
}
#[tokio::test]
async fn test_select_valid_leader_filters_invalid_responses() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let responses = vec![
d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 0,
term: 5,
leader_address: "127.0.0.1:5003".to_string(),
}, d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 3,
term: 0,
leader_address: "127.0.0.1:5003".to_string(),
}, ];
let result = state.select_valid_leader(responses).await;
assert!(result.is_none(), "Should filter invalid responses");
}
#[tokio::test]
async fn test_join_cluster_succeeds_with_known_leader() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_join_cluster().returning(|_, _, _, _| {
Ok(d_engine_proto::server::cluster::JoinResponse {
success: true,
error: "".to_string(),
config: None,
config_version: 1,
snapshot_metadata: None,
leader_id: 3,
})
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
state.shared_state.set_current_leader(5);
let result = state.join_cluster(&context).await;
assert!(result.is_ok(), "Join should succeed with known leader");
}
#[tokio::test]
async fn test_join_cluster_succeeds_after_discovery() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_discover_leader().returning(|_, _, _| {
Ok(vec![
d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 5,
leader_address: "127.0.0.1:5005".to_string(),
term: 3,
},
])
});
transport.expect_join_cluster().returning(|_, _, _, _| {
Ok(d_engine_proto::server::cluster::JoinResponse {
success: true,
error: "".to_string(),
config: None,
config_version: 0,
snapshot_metadata: None,
leader_id: 2,
})
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
let result = state.join_cluster(&context).await;
assert!(result.is_ok(), "Join should succeed after discovery");
}
#[tokio::test]
async fn test_join_cluster_fails_discovery_timeout() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_discover_leader().returning(|_, _, _| Ok(vec![]));
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
let result = state.join_cluster(&context).await;
assert!(result.is_err(), "Should timeout during discovery");
assert!(matches!(
result.unwrap_err(),
crate::Error::System(crate::SystemError::Network(
crate::NetworkError::RetryTimeoutError(_)
))
));
}
#[tokio::test]
async fn test_join_cluster_fails_on_rpc_failure() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_join_cluster().returning(|_, _, _, _| {
Err(crate::NetworkError::ServiceUnavailable("Service unavailable".to_string()).into())
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
state.shared_state.set_current_leader(5);
let result = state.join_cluster(&context).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::Error::System(crate::SystemError::Network(
crate::NetworkError::ServiceUnavailable(_)
))
));
}
#[tokio::test]
async fn test_join_cluster_fails_invalid_response() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_join_cluster().returning(|_, _, _, _| {
Ok(d_engine_proto::server::cluster::JoinResponse {
success: false,
error: "Node rejected".to_string(),
config: None,
config_version: 0,
snapshot_metadata: None,
leader_id: 0,
})
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
state.shared_state.set_current_leader(5);
let result = state.join_cluster(&context).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::Error::Consensus(crate::ConsensusError::Membership(
crate::MembershipError::JoinClusterFailed(_)
))
));
}
#[tokio::test]
async fn test_join_cluster_handles_redirect() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_join_cluster().returning(|_, _, _, _| {
Err(crate::NetworkError::ServiceUnavailable("Not leader".to_string()).into())
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
state.shared_state.set_current_leader(5);
let result = state.join_cluster(&context).await;
assert!(result.is_err(), "Should handle redirect scenario");
}
#[tokio::test]
async fn test_join_cluster_succeeds_large_cluster() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let membership = MockMembership::new();
context.membership = Arc::new(membership);
let mut transport = crate::MockTransport::new();
transport.expect_discover_leader().returning(|_, _, _| {
Ok(vec![
d_engine_proto::server::cluster::LeaderDiscoveryResponse {
leader_id: 5,
leader_address: "127.0.0.1:5005".to_string(),
term: 3,
},
])
});
transport.expect_join_cluster().returning(|_, _, _, _| {
Ok(d_engine_proto::server::cluster::JoinResponse {
success: true,
error: "".to_string(),
config: None,
config_version: 1,
snapshot_metadata: None,
leader_id: 3,
})
});
context.transport = Arc::new(transport);
let state = LearnerState::<MockTypeConfig>::new(100, context.node_config.clone());
let result = state.join_cluster(&context).await;
assert!(result.is_ok(), "Should handle large cluster");
}
#[tokio::test]
async fn test_learner_promotion_on_membership_applied() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut mock_membership = MockMembership::<MockTypeConfig>::new();
let promoted_node = NodeMeta {
id: 3,
address: "127.0.0.1:8003".to_string(),
role: NodeRole::Follower as i32,
status: NodeStatus::Active as i32,
};
mock_membership
.expect_retrieve_node_meta()
.with(eq(3))
.times(1)
.return_once(move |_| Some(promoted_node));
context.membership = Arc::new(mock_membership);
let mut state = LearnerState::<MockTypeConfig>::new(3, context.node_config.clone());
let (role_tx, mut role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::MembershipApplied;
let result = state.handle_raft_event(raft_event, &context, role_tx).await;
assert!(result.is_ok(), "MembershipApplied should succeed");
let role_event = tokio::time::timeout(std::time::Duration::from_millis(100), role_rx.recv())
.await
.expect("Should receive event within timeout")
.expect("Channel should not be closed");
match role_event {
crate::RoleEvent::BecomeFollower(leader_id) => {
assert_eq!(leader_id, None, "Should not specify leader on promotion");
}
other => panic!("Expected BecomeFollower event, got: {other:?}"),
}
}
#[tokio::test]
async fn test_learner_stays_learner_on_membership_applied() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut mock_membership = MockMembership::<MockTypeConfig>::new();
let learner_node = NodeMeta {
id: 3,
address: "127.0.0.1:8003".to_string(),
role: NodeRole::Learner as i32,
status: NodeStatus::Promotable as i32,
};
mock_membership
.expect_retrieve_node_meta()
.with(eq(3))
.times(1)
.return_once(move |_| Some(learner_node));
context.membership = Arc::new(mock_membership);
let mut state = LearnerState::<MockTypeConfig>::new(3, context.node_config.clone());
let (role_tx, mut role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::MembershipApplied;
let result = state.handle_raft_event(raft_event, &context, role_tx).await;
assert!(result.is_ok(), "MembershipApplied should succeed");
let timeout_result =
tokio::time::timeout(std::time::Duration::from_millis(100), role_rx.recv()).await;
if let Ok(Some(event)) = timeout_result {
panic!("Should not send role transition when still Learner, got: {event:?}");
}
}
#[tokio::test]
async fn test_learner_node_not_found_on_membership_applied() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut mock_membership = MockMembership::<MockTypeConfig>::new();
mock_membership
.expect_retrieve_node_meta()
.with(eq(3))
.times(1)
.return_once(move |_| None);
context.membership = Arc::new(mock_membership);
let mut state = LearnerState::<MockTypeConfig>::new(3, context.node_config.clone());
let (role_tx, mut role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::MembershipApplied;
let result = state.handle_raft_event(raft_event, &context, role_tx).await;
assert!(
result.is_ok(),
"MembershipApplied should succeed even when node not found"
);
let timeout_result =
tokio::time::timeout(std::time::Duration::from_millis(100), role_rx.recv()).await;
if let Ok(Some(event)) = timeout_result {
panic!("Should not send role transition when node not found, got: {event:?}");
}
}
mod role_violation_tests {
use super::*;
#[tokio::test]
async fn test_learner_rejects_leader_only_events() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let raft_event = RaftEvent::LogPurgeCompleted(LogId { term: 1, index: 1 });
let e = state.handle_raft_event(raft_event, &context, role_tx).await.unwrap_err();
assert!(
matches!(
e,
crate::Error::Consensus(crate::ConsensusError::RoleViolation { .. })
),
"LogPurgeCompleted should return RoleViolation"
);
}
#[tokio::test]
async fn test_learner_ignores_duplicate_create_snapshot_event() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let result1 = state
.handle_raft_event(RaftEvent::CreateSnapshotEvent, &context, role_tx.clone())
.await;
assert!(result1.is_ok(), "First CreateSnapshotEvent should succeed");
assert!(
state.snapshot_in_progress.load(std::sync::atomic::Ordering::SeqCst),
"snapshot_in_progress should be true after first event"
);
let result2 =
state.handle_raft_event(RaftEvent::CreateSnapshotEvent, &context, role_tx).await;
assert!(
result2.is_ok(),
"Second CreateSnapshotEvent should return Ok (ignored)"
);
assert!(
state.snapshot_in_progress.load(std::sync::atomic::Ordering::SeqCst),
"snapshot_in_progress should remain true"
);
}
#[tokio::test]
async fn test_learner_resets_snapshot_flag_on_success() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
state.snapshot_in_progress.store(true, std::sync::atomic::Ordering::SeqCst);
let metadata = d_engine_proto::server::storage::SnapshotMetadata {
last_included: Some(LogId { term: 1, index: 50 }),
checksum: bytes::Bytes::new(),
};
let snapshot_result = Ok((metadata, std::path::PathBuf::from("/tmp/test_snapshot.bin")));
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let result = state
.handle_raft_event(
RaftEvent::SnapshotCreated(snapshot_result),
&context,
role_tx,
)
.await;
assert!(result.is_ok(), "SnapshotCreated should succeed");
assert!(
!state.snapshot_in_progress.load(std::sync::atomic::Ordering::SeqCst),
"snapshot_in_progress should be false after SnapshotCreated"
);
}
#[tokio::test]
async fn test_learner_resets_snapshot_flag_on_failure() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
state.snapshot_in_progress.store(true, std::sync::atomic::Ordering::SeqCst);
let snapshot_result = Err(crate::Error::Fatal("Snapshot creation failed".to_string()));
let (role_tx, _role_rx) = mpsc::unbounded_channel();
let result = state
.handle_raft_event(
RaftEvent::SnapshotCreated(snapshot_result),
&context,
role_tx,
)
.await;
assert!(
result.is_ok(),
"SnapshotCreated with error should return Ok"
);
assert!(
!state.snapshot_in_progress.load(std::sync::atomic::Ordering::SeqCst),
"snapshot_in_progress should be false after failed SnapshotCreated"
);
}
}
#[tokio::test]
async fn test_learner_handles_fatal_error_returns_error() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let context = mock_raft_context(
"/tmp/test_learner_handles_fatal_error_returns_error",
graceful_rx,
None,
);
let mut learner = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let fatal_error = RaftEvent::FatalError {
source: "StateMachine".to_string(),
error: "Disk failure".to_string(),
};
let (role_tx, mut role_rx) = mpsc::unbounded_channel::<RoleEvent>();
let result = learner.handle_raft_event(fatal_error, &context, role_tx).await;
assert!(
result.is_err(),
"Expected handle_raft_event to return Err, got: {result:?}"
);
match result.unwrap_err() {
Error::Fatal(msg) => {
assert!(
msg.contains("StateMachine"),
"Error message should mention source, got: {msg}"
);
}
other => panic!("Expected Error::Fatal, got: {other:?}"),
}
assert!(
role_rx.try_recv().is_err(),
"No role transition events should be sent during FatalError handling"
);
}
#[tokio::test]
async fn test_learner_serves_eventual_read_locally() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let (mut context, _temp_dir) = mock_raft_context_with_temp(graceful_rx, None);
let mut state_machine_handler = MockStateMachineHandler::new();
state_machine_handler
.expect_read_from_state_machine()
.times(1)
.withf(|keys| keys.len() == 1 && keys[0] == "eventual_key")
.returning(|_| {
Some(vec![d_engine_proto::client::ClientResult {
key: bytes::Bytes::from("eventual_key"),
value: bytes::Bytes::from("eventual_value"),
}])
});
context.handlers.state_machine_handler = Arc::new(state_machine_handler);
let mut state = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (response_tx, mut response_rx) = MaybeCloneOneshot::new();
let read_req = d_engine_proto::client::ClientReadRequest {
client_id: 1,
keys: vec![bytes::Bytes::from("eventual_key")],
consistency_policy: Some(
d_engine_proto::client::ReadConsistencyPolicy::EventualConsistency as i32,
),
};
let start = tokio::time::Instant::now();
state.push_client_cmd(ClientCmd::Read(read_req, response_tx), &context);
let result = response_rx.recv().await;
let elapsed = start.elapsed();
assert!(result.is_ok(), "Eventual read should return response");
assert!(
elapsed.as_millis() < 10,
"Eventual read latency should be <10ms, got {:?}ms",
elapsed.as_millis()
);
if let Ok(response) = result {
match response {
Ok(read_response) => {
match read_response.success_result {
Some(d_engine_proto::client::client_response::SuccessResult::ReadData(
read_data,
)) => {
assert!(!read_data.results.is_empty(), "Should have read results");
assert_eq!(
read_data.results[0].value,
bytes::Bytes::from("eventual_value")
);
}
other => panic!("Expected ReadData variant, got: {other:?}"),
}
}
Err(e) => {
panic!("Eventual read should succeed on Learner, got error: {e:?}");
}
}
}
}
#[tokio::test]
async fn test_apply_completed_triggers_snapshot_when_condition_met() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let mut mock_sm_handler = MockStateMachineHandler::new();
mock_sm_handler
.expect_should_snapshot()
.with(eq(NewCommitData {
new_commit_index: 100,
role: NodeRole::Learner as i32,
current_term: 1,
}))
.times(1)
.returning(|_| true);
let context = MockBuilder::new(graceful_rx)
.with_state_machine_handler(mock_sm_handler)
.build_context();
let mut learner = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (role_tx, mut role_rx) = mpsc::unbounded_channel::<RoleEvent>();
let apply_completed_event = RaftEvent::ApplyCompleted {
last_index: 100,
results: vec![],
};
let result = learner.handle_raft_event(apply_completed_event, &context, role_tx).await;
assert!(
result.is_ok(),
"ApplyCompleted should be handled successfully, got: {result:?}"
);
let event = role_rx.try_recv().expect("Should receive snapshot event");
match event {
RoleEvent::ReprocessEvent(boxed_event) => {
match *boxed_event {
RaftEvent::CreateSnapshotEvent => {
}
other => panic!("Expected CreateSnapshotEvent, got: {other:?}"),
}
}
other => panic!("Expected RoleEvent::ReprocessEvent, got: {other:?}"),
}
assert!(
role_rx.try_recv().is_err(),
"Should only send one snapshot event"
);
}
#[tokio::test]
async fn test_apply_completed_does_not_trigger_snapshot_when_condition_not_met() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let mut mock_sm_handler = MockStateMachineHandler::new();
mock_sm_handler
.expect_should_snapshot()
.with(eq(NewCommitData {
new_commit_index: 50,
role: NodeRole::Learner as i32,
current_term: 1,
}))
.times(1)
.returning(|_| false);
let context = MockBuilder::new(graceful_rx)
.with_state_machine_handler(mock_sm_handler)
.build_context();
let mut learner = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (role_tx, mut role_rx) = mpsc::unbounded_channel::<RoleEvent>();
let apply_completed_event = RaftEvent::ApplyCompleted {
last_index: 50,
results: vec![],
};
let result = learner.handle_raft_event(apply_completed_event, &context, role_tx).await;
assert!(
result.is_ok(),
"ApplyCompleted should be handled successfully"
);
assert!(
role_rx.try_recv().is_err(),
"Should not send snapshot event when condition is not met"
);
}
#[tokio::test]
async fn test_apply_completed_respects_snapshot_disabled_config() {
let (_graceful_tx, graceful_rx) = watch::channel(());
let mock_sm_handler = MockStateMachineHandler::new();
let mut node_config = node_config("/tmp/test_learner_snapshot_disabled");
node_config.raft.snapshot.enable = false;
let context = MockBuilder::new(graceful_rx)
.with_state_machine_handler(mock_sm_handler)
.with_node_config(node_config)
.build_context();
let mut learner = LearnerState::<MockTypeConfig>::new(1, context.node_config.clone());
let (role_tx, mut role_rx) = mpsc::unbounded_channel::<RoleEvent>();
let apply_completed_event = RaftEvent::ApplyCompleted {
last_index: 100,
results: vec![],
};
let result = learner.handle_raft_event(apply_completed_event, &context, role_tx).await;
assert!(
result.is_ok(),
"ApplyCompleted should be handled successfully"
);
assert!(
role_rx.try_recv().is_err(),
"Should not send snapshot event when snapshot is disabled in config"
);
}