use std::collections::BTreeSet;
use openraft::error::{ClientWriteError, RaftError};
use openraft::storage::RaftStateMachine;
use openraft::{Raft, RaftTypeConfig};
use thiserror::Error;
use tracing::info;
#[derive(Debug, Error)]
pub enum MembershipError {
#[error("not leader; redirect to {leader:?}")]
NotLeader { leader: Option<String> },
#[error("membership change conflict: {0}")]
Conflict(String),
#[error("raft error: {0}")]
Other(String),
}
pub async fn change_membership<C, SM>(
raft: &Raft<C, SM>,
voters: BTreeSet<C::NodeId>,
retain: bool,
) -> Result<(), MembershipError>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
match raft.change_membership(voters, retain).await {
Ok(_) => {
info!("tsoracle-openraft-toolkit: membership change applied");
Ok(())
}
Err(RaftError::APIError(ClientWriteError::ChangeMembershipError(e))) => {
Err(MembershipError::Conflict(e.to_string()))
}
Err(RaftError::APIError(ClientWriteError::ForwardToLeader(f))) => {
Err(MembershipError::NotLeader {
leader: f.leader_node.as_ref().map(|n| format!("{n:?}")),
})
}
Err(e) => Err(MembershipError::Other(e.to_string())),
}
}
pub async fn add_learner<C, SM>(
raft: &Raft<C, SM>,
id: C::NodeId,
node: C::Node,
blocking: bool,
) -> Result<(), MembershipError>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
match raft.add_learner(id, node, blocking).await {
Ok(_) => {
info!("tsoracle-openraft-toolkit: learner added");
Ok(())
}
Err(RaftError::APIError(ClientWriteError::ForwardToLeader(f))) => {
Err(MembershipError::NotLeader {
leader: f.leader_node.as_ref().map(|n| format!("{n:?}")),
})
}
Err(RaftError::APIError(ClientWriteError::ChangeMembershipError(e))) => {
Err(MembershipError::Conflict(e.to_string()))
}
Err(e) => Err(MembershipError::Other(e.to_string())),
}
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum JoinerGateError {
#[error(
"candidate cannot read the cluster active write version: \
candidate max_readable_version={candidate_max_readable_version}, \
active_write_version={active_write_version}"
)]
IncompatibleCandidate {
candidate_max_readable_version: u8,
active_write_version: u8,
},
}
pub fn learner_meets_active_write_version(
candidate_max_readable_version: u8,
active_write_version: u8,
) -> Result<(), JoinerGateError> {
if candidate_max_readable_version < active_write_version {
return Err(JoinerGateError::IncompatibleCandidate {
candidate_max_readable_version,
active_write_version,
});
}
Ok(())
}
#[derive(Debug, Error)]
pub enum GatedAdmissionError {
#[error(transparent)]
Gate(#[from] JoinerGateError),
#[error(transparent)]
Membership(#[from] MembershipError),
}
pub async fn add_learner_gated<C, SM>(
raft: &Raft<C, SM>,
id: C::NodeId,
node: C::Node,
blocking: bool,
candidate_max_readable_version: u8,
active_write_version: u8,
) -> Result<(), GatedAdmissionError>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
learner_meets_active_write_version(candidate_max_readable_version, active_write_version)?;
add_learner(raft, id, node, blocking).await?;
Ok(())
}
#[cfg(test)]
mod joiner_gate_tests {
use super::*;
#[test]
fn admits_when_candidate_can_read_active_version() {
assert!(learner_meets_active_write_version(3, 3).is_ok());
assert!(learner_meets_active_write_version(4, 3).is_ok());
}
#[test]
fn refuses_when_candidate_below_active_version() {
let err = learner_meets_active_write_version(3, 4)
.expect_err("a candidate that cannot read the active version must be refused");
let JoinerGateError::IncompatibleCandidate {
candidate_max_readable_version,
active_write_version,
} = err;
assert_eq!(candidate_max_readable_version, 3);
assert_eq!(active_write_version, 4);
}
#[test]
fn refusal_message_is_actionable() {
let err = learner_meets_active_write_version(2, 5).unwrap_err();
let message = err.to_string();
assert!(message.contains('2'), "names the candidate capability");
assert!(message.contains('5'), "names the active write version");
}
#[test]
fn gated_admission_error_unifies_gate_and_membership_failures() {
let gate: GatedAdmissionError = JoinerGateError::IncompatibleCandidate {
candidate_max_readable_version: 3,
active_write_version: 4,
}
.into();
assert!(matches!(gate, GatedAdmissionError::Gate(_)));
let membership: GatedAdmissionError = MembershipError::NotLeader { leader: None }.into();
assert!(matches!(membership, GatedAdmissionError::Membership(_)));
}
}