use std::collections::BTreeSet;
use openraft::error::{ClientWriteError, RaftError};
use openraft::storage::RaftStateMachine;
use openraft::{Raft, RaftTypeConfig};
use thiserror::Error;
use tracing::info;
#[derive(Debug, Error)]
pub enum MembershipError {
#[error("not leader; redirect to {leader:?}")]
NotLeader { leader: Option<String> },
#[error("membership change conflict: {0}")]
Conflict(String),
#[error("raft error: {0}")]
Other(String),
}
pub async fn change_membership<C, SM>(
raft: &Raft<C, SM>,
voters: BTreeSet<C::NodeId>,
retain: bool,
) -> Result<(), MembershipError>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
match raft.change_membership(voters, retain).await {
Ok(_) => {
info!("tsoracle-openraft-toolkit: membership change applied");
Ok(())
}
Err(RaftError::APIError(ClientWriteError::ChangeMembershipError(e))) => {
Err(MembershipError::Conflict(e.to_string()))
}
Err(RaftError::APIError(ClientWriteError::ForwardToLeader(f))) => {
Err(MembershipError::NotLeader {
leader: f.leader_node.as_ref().map(|n| format!("{n:?}")),
})
}
Err(e) => Err(MembershipError::Other(e.to_string())),
}
}
pub async fn add_learner<C, SM>(
raft: &Raft<C, SM>,
id: C::NodeId,
node: C::Node,
blocking: bool,
) -> Result<(), MembershipError>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
match raft.add_learner(id, node, blocking).await {
Ok(_) => {
info!("tsoracle-openraft-toolkit: learner added");
Ok(())
}
Err(RaftError::APIError(ClientWriteError::ForwardToLeader(f))) => {
Err(MembershipError::NotLeader {
leader: f.leader_node.as_ref().map(|n| format!("{n:?}")),
})
}
Err(RaftError::APIError(ClientWriteError::ChangeMembershipError(e))) => {
Err(MembershipError::Conflict(e.to_string()))
}
Err(e) => Err(MembershipError::Other(e.to_string())),
}
}