use std::sync::Arc;
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::{Duration, Instant};
use crate::error::{RepError, Result};
use crate::node_state::NodeState;
use crate::node_type::NodeType;
use crate::quorum_policy::QuorumPolicy;
use crate::rep_config::RepConfig;
use crate::replicated_environment::ReplicatedEnvironment;
use crate::state_change_listener::{StateChangeEvent, StateChangeListener};
static NEXT_BASE_PORT: AtomicU16 = AtomicU16::new(40_000);
fn alloc_base_port(group_size: usize) -> u16 {
let span = (group_size as u16).saturating_add(16);
let mut current = NEXT_BASE_PORT.load(Ordering::SeqCst);
loop {
let next = current.saturating_add(span);
let next = if next >= 60_000 { 40_000 + span } else { next };
match NEXT_BASE_PORT.compare_exchange(
current,
next,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => return current,
Err(actual) => current = actual,
}
}
}
pub struct RepEnvInfo {
config: RepConfig,
node_id: u32,
env: Option<Arc<ReplicatedEnvironment>>,
}
impl RepEnvInfo {
pub fn new(config: RepConfig, node_id: u32) -> Self {
Self { config, node_id, env: None }
}
pub fn open_env(&mut self) -> Result<Arc<ReplicatedEnvironment>> {
if self.env.is_some() {
return Err(RepError::StateError(
"rep env already exists".to_string(),
));
}
let env = Arc::new(ReplicatedEnvironment::new(self.config.clone())?);
env.init_self_weak();
self.env = Some(Arc::clone(&env));
Ok(env)
}
pub fn close_env(&mut self) -> Result<()> {
if let Some(env) = self.env.take() {
env.close()?;
}
Ok(())
}
pub fn abnormal_close_env(&mut self) {
let _ = self.env.take();
}
pub fn get_env(&self) -> Arc<ReplicatedEnvironment> {
self.env.as_ref().expect("open_env not called yet").clone()
}
pub fn env(&self) -> Option<&Arc<ReplicatedEnvironment>> {
self.env.as_ref()
}
pub fn rep_config(&self) -> &RepConfig {
&self.config
}
pub fn node_name(&self) -> &str {
&self.config.node_name
}
pub fn node_id(&self) -> u32 {
self.node_id
}
pub fn is_master(&self) -> bool {
self.env.as_ref().is_some_and(|e| e.get_state() == NodeState::Master)
}
pub fn is_replica(&self) -> bool {
self.env.as_ref().is_some_and(|e| e.get_state() == NodeState::Replica)
}
pub fn is_unknown(&self) -> bool {
self.env.as_ref().is_some_and(|e| e.get_state() == NodeState::Unknown)
}
pub fn state(&self) -> Option<NodeState> {
self.env.as_ref().map(|e| e.get_state())
}
pub fn current_vlsn(&self) -> u64 {
self.env.as_ref().map(|e| e.get_current_vlsn()).unwrap_or(0)
}
}
pub struct RepTestBase {
group_name: String,
nodes: Vec<RepEnvInfo>,
next_term: std::cell::Cell<u64>,
}
impl RepTestBase {
pub fn builder(group_name: impl Into<String>) -> RepTestBaseBuilder {
RepTestBaseBuilder::new(group_name)
}
pub fn group_size(&self) -> usize {
self.nodes.len()
}
pub fn node(&self, idx: usize) -> &RepEnvInfo {
&self.nodes[idx]
}
pub fn node_mut(&mut self, idx: usize) -> &mut RepEnvInfo {
&mut self.nodes[idx]
}
pub fn nodes(&self) -> &[RepEnvInfo] {
&self.nodes
}
pub fn nodes_mut(&mut self) -> &mut [RepEnvInfo] {
&mut self.nodes
}
pub fn group_name(&self) -> &str {
&self.group_name
}
pub fn create_group(&mut self, term: u64) -> Result<()> {
self.create_group_of_size(self.nodes.len(), term)
}
pub fn create_group_of_size(
&mut self,
first_n: usize,
term: u64,
) -> Result<()> {
if first_n == 0 || first_n > self.nodes.len() {
return Err(RepError::ConfigError(format!(
"first_n ({first_n}) must be in 1..={}",
self.nodes.len()
)));
}
for node in &mut self.nodes[..first_n] {
if node.env.is_none() {
node.open_env()?;
}
}
let peer_specs: Vec<crate::rep_node::RepNode> = self.nodes[..first_n]
.iter()
.map(|n| {
crate::rep_node::RepNode::new(
n.config.node_name.clone(),
n.config.node_type,
n.config.node_host.clone(),
n.config.node_port,
n.node_id,
)
})
.collect();
for node in &self.nodes[..first_n] {
let env = node.get_env();
for peer in &peer_specs {
if peer.name == node.config.node_name {
continue;
}
let _ = env.add_peer(peer.clone());
}
}
self.nodes[0].get_env().become_master(term)?;
let master_name = self.nodes[0].config.node_name.clone();
for node in &self.nodes[1..first_n] {
node.get_env().become_replica(&master_name)?;
}
self.next_term.set(term + 1);
Ok(())
}
pub fn shutdown_all(&mut self) {
let mut master_idx: Option<usize> = None;
for (idx, node) in self.nodes.iter_mut().enumerate() {
if node.is_master() {
master_idx = Some(idx);
continue;
}
let _ = node.close_env();
}
if let Some(idx) = master_idx {
let _ = self.nodes[idx].close_env();
}
}
pub fn find_master(&self) -> Option<&RepEnvInfo> {
self.nodes.iter().find(|n| n.is_master())
}
pub fn find_master_mut(&mut self) -> Option<&mut RepEnvInfo> {
self.nodes.iter_mut().find(|n| n.is_master())
}
pub fn find_master_idx(&self) -> Option<usize> {
self.nodes.iter().position(|n| n.is_master())
}
pub fn replicas(&self) -> Vec<&RepEnvInfo> {
self.nodes.iter().filter(|n| n.is_replica()).collect()
}
pub fn await_master(&self, timeout: Duration) -> Result<usize> {
let deadline = Instant::now() + timeout;
loop {
if let Some(idx) = self.find_master_idx() {
return Ok(idx);
}
if Instant::now() >= deadline {
return Err(RepError::StateError(format!(
"timeout: no master after {:?}",
timeout
)));
}
std::thread::sleep(Duration::from_millis(20));
}
}
pub fn await_state(
&self,
idx: usize,
target: NodeState,
timeout: Duration,
) -> Result<()> {
let deadline = Instant::now() + timeout;
loop {
if self.nodes[idx].state() == Some(target) {
return Ok(());
}
if Instant::now() >= deadline {
return Err(RepError::StateError(format!(
"timeout: node {} did not reach {:?} after {:?} (current: {:?})",
idx,
target,
timeout,
self.nodes[idx].state(),
)));
}
std::thread::sleep(Duration::from_millis(20));
}
}
pub fn await_vlsn_at_least(
&self,
idx: usize,
vlsn: u64,
timeout: Duration,
) -> Result<()> {
let deadline = Instant::now() + timeout;
loop {
if self.nodes[idx].current_vlsn() >= vlsn {
return Ok(());
}
if Instant::now() >= deadline {
return Err(RepError::StateError(format!(
"timeout: node {} did not reach VLSN {} after {:?} (current: {})",
idx,
vlsn,
timeout,
self.nodes[idx].current_vlsn(),
)));
}
std::thread::sleep(Duration::from_millis(20));
}
}
pub fn replicate_one(
&self,
vlsn: u64,
file: u32,
offset: u32,
entry_type: u8,
) -> Result<()> {
let master_idx = self.find_master_idx().ok_or_else(|| {
RepError::StateError("no master to replicate from".to_string())
})?;
let master = self.nodes[master_idx].get_env();
master.register_vlsn(vlsn, file, offset);
for (i, node) in self.nodes.iter().enumerate() {
if i == master_idx || !node.is_replica() {
continue;
}
node.get_env().apply_entry(vlsn, entry_type, vec![0u8; 8])?;
}
Ok(())
}
pub fn populate_db(&self, start_vlsn: u64, count: u64) -> Result<()> {
for offset in 0..count {
let vlsn = start_vlsn + offset;
self.replicate_one(vlsn, 0, (vlsn as u32).wrapping_mul(16), 0)?;
}
Ok(())
}
pub fn populate_master_only(
&self,
start_vlsn: u64,
count: u64,
) -> Result<()> {
let master = self.find_master().ok_or_else(|| {
RepError::StateError("no master to populate".to_string())
})?;
for offset in 0..count {
let vlsn = start_vlsn + offset;
master.get_env().register_vlsn(
vlsn,
0,
(vlsn as u32).wrapping_mul(16),
);
}
Ok(())
}
pub fn catch_up_replica(
&self,
replica_idx: usize,
start_vlsn: u64,
count: u64,
) -> Result<()> {
let env = self.nodes[replica_idx].get_env();
for offset in 0..count {
let vlsn = start_vlsn + offset;
env.apply_entry(vlsn, 0, vec![0u8; 8])?;
}
Ok(())
}
pub fn close_master(&mut self) -> Result<usize> {
let idx = self.find_master_idx().ok_or_else(|| {
RepError::StateError("no master to close".to_string())
})?;
self.nodes[idx].close_env()?;
Ok(idx)
}
pub fn failover_to(&mut self, replica_idx: usize) -> Result<()> {
let term = self.next_term.get();
self.next_term.set(term + 1);
let target_env = self.nodes[replica_idx].get_env();
target_env.ensure_unknown_state()?;
target_env.become_master(term)?;
let new_master_name = self.nodes[replica_idx].config.node_name.clone();
for (i, node) in self.nodes.iter().enumerate() {
if i == replica_idx {
continue;
}
if node.env.is_none() {
continue;
}
let env = node.get_env();
let s = env.get_state();
if matches!(s, NodeState::Detached | NodeState::Shutdown) {
continue;
}
env.ensure_unknown_state()?;
env.become_replica(&new_master_name)?;
}
Ok(())
}
pub fn assert_all_at_vlsn(&self, vlsn: u64) {
for node in &self.nodes {
if !(node.is_master() || node.is_replica()) {
continue;
}
assert_eq!(
node.current_vlsn(),
vlsn,
"node {} ({:?}) at unexpected VLSN",
node.node_name(),
node.state(),
);
}
}
pub fn assert_state(&self, idx: usize, state: NodeState) {
assert_eq!(
self.nodes[idx].state(),
Some(state),
"node {} ({}) wrong state",
idx,
self.nodes[idx].node_name(),
);
}
}
impl Drop for RepTestBase {
fn drop(&mut self) {
self.shutdown_all();
}
}
pub struct RepTestBaseBuilder {
group_name: String,
group_size: usize,
base_port: Option<u16>,
node_type: NodeType,
election_timeout: Option<Duration>,
quorum_policy: Option<QuorumPolicy>,
name_prefix: Option<String>,
node_type_overrides: Vec<(usize, NodeType)>,
}
impl RepTestBaseBuilder {
fn new(group_name: impl Into<String>) -> Self {
Self {
group_name: group_name.into(),
group_size: 3,
base_port: None,
node_type: NodeType::Electable,
election_timeout: None,
quorum_policy: None,
name_prefix: None,
node_type_overrides: Vec::new(),
}
}
pub fn group_size(mut self, n: usize) -> Self {
self.group_size = n;
self
}
pub fn base_port(mut self, p: u16) -> Self {
self.base_port = Some(p);
self
}
pub fn node_type(mut self, t: NodeType) -> Self {
self.node_type = t;
self
}
pub fn override_node_type(mut self, idx: usize, t: NodeType) -> Self {
self.node_type_overrides.push((idx, t));
self
}
pub fn election_timeout(mut self, t: Duration) -> Self {
self.election_timeout = Some(t);
self
}
pub fn quorum_policy(mut self, q: QuorumPolicy) -> Self {
self.quorum_policy = Some(q);
self
}
pub fn name_prefix(mut self, p: impl Into<String>) -> Self {
self.name_prefix = Some(p.into());
self
}
pub fn build(self) -> RepTestBase {
let base_port =
self.base_port.unwrap_or_else(|| alloc_base_port(self.group_size));
let prefix = self
.name_prefix
.unwrap_or_else(|| format!("{}_n", self.group_name));
let mut overrides = std::collections::HashMap::new();
for (idx, t) in self.node_type_overrides {
overrides.insert(idx, t);
}
let mut nodes = Vec::with_capacity(self.group_size);
for i in 0..self.group_size {
let node_name = format!("{}{}", prefix, i + 1);
let node_type = *overrides.get(&i).unwrap_or(&self.node_type);
let port = base_port + i as u16;
let mut b =
RepConfig::builder(&self.group_name, &node_name, "127.0.0.1")
.node_port(port)
.node_type(node_type);
if let Some(t) = self.election_timeout {
b = b.election_timeout(t);
}
if let Some(q) = self.quorum_policy.clone() {
b = b.quorum_policy(q);
}
let config = b.build();
nodes.push(RepEnvInfo::new(config, (i + 1) as u32));
}
RepTestBase {
group_name: self.group_name,
nodes,
next_term: std::cell::Cell::new(1),
}
}
}
#[derive(Default)]
pub struct CountingListener {
pub master: std::sync::atomic::AtomicUsize,
pub replica: std::sync::atomic::AtomicUsize,
pub unknown: std::sync::atomic::AtomicUsize,
pub detached: std::sync::atomic::AtomicUsize,
pub shutdown: std::sync::atomic::AtomicUsize,
}
impl CountingListener {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn master_count(&self) -> usize {
self.master.load(Ordering::SeqCst)
}
pub fn replica_count(&self) -> usize {
self.replica.load(Ordering::SeqCst)
}
pub fn unknown_count(&self) -> usize {
self.unknown.load(Ordering::SeqCst)
}
pub fn detached_count(&self) -> usize {
self.detached.load(Ordering::SeqCst)
}
pub fn shutdown_count(&self) -> usize {
self.shutdown.load(Ordering::SeqCst)
}
}
impl StateChangeListener for CountingListener {
fn on_state_change(&self, ev: StateChangeEvent) {
let counter = match ev.new_state {
NodeState::Master => &self.master,
NodeState::Replica => &self.replica,
NodeState::Unknown => &self.unknown,
NodeState::Detached => &self.detached,
NodeState::Shutdown => &self.shutdown,
};
counter.fetch_add(1, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_produces_n_nodes_with_disjoint_names() {
let group = RepTestBase::builder("hs1").group_size(4).build();
assert_eq!(group.group_size(), 4);
let names: Vec<&str> =
group.nodes().iter().map(|n| n.node_name()).collect();
assert_eq!(names, vec!["hs1_n1", "hs1_n2", "hs1_n3", "hs1_n4"]);
let ports: Vec<u16> =
group.nodes().iter().map(|n| n.rep_config().node_port).collect();
for w in ports.windows(2) {
assert!(w[1] == w[0] + 1, "ports must be consecutive: {:?}", ports);
}
}
#[test]
fn create_group_elects_master_and_replicas() {
let mut group = RepTestBase::builder("hs2").group_size(3).build();
group.create_group(1).unwrap();
assert!(group.nodes()[0].is_master(), "node 0 must be master");
assert!(group.nodes()[1].is_replica(), "node 1 must be replica");
assert!(group.nodes()[2].is_replica(), "node 2 must be replica");
let m = group.find_master().unwrap();
assert_eq!(m.node_name(), "hs2_n1");
}
#[test]
fn populate_db_advances_all_replicas() {
let mut group = RepTestBase::builder("hs3").group_size(3).build();
group.create_group(1).unwrap();
group.populate_db(1, 50).unwrap();
group.assert_all_at_vlsn(50);
}
#[test]
fn failover_drives_replica_to_master() {
let mut group = RepTestBase::builder("hs4").group_size(3).build();
group.create_group(1).unwrap();
group.populate_db(1, 10).unwrap();
group.assert_all_at_vlsn(10);
let old_master = group.close_master().unwrap();
assert_eq!(old_master, 0);
group.failover_to(1).unwrap();
assert!(group.nodes()[1].is_master());
assert!(group.nodes()[2].is_replica());
assert!(group.nodes()[1].current_vlsn() >= 10);
}
#[test]
fn await_master_finds_already_elected_master() {
let mut group = RepTestBase::builder("hs5").group_size(3).build();
group.create_group(1).unwrap();
let idx = group.await_master(Duration::from_millis(200)).unwrap();
assert_eq!(idx, 0);
}
#[test]
fn await_master_times_out_when_no_master() {
let group = RepTestBase::builder("hs6").group_size(3).build();
let r = group.await_master(Duration::from_millis(50));
assert!(r.is_err(), "must time out");
}
#[test]
fn counting_listener_counts_transitions() {
let mut group = RepTestBase::builder("hs7").group_size(2).build();
group.create_group(1).unwrap();
let listener = CountingListener::new();
group.nodes()[0]
.get_env()
.set_state_change_listener(
Arc::clone(&listener) as Arc<dyn StateChangeListener>
);
assert_eq!(listener.master_count(), 1);
}
#[test]
fn catch_up_replica_after_partition() {
let mut group = RepTestBase::builder("hs8").group_size(2).build();
group.create_group(1).unwrap();
group.populate_db(1, 5).unwrap();
group.assert_all_at_vlsn(5);
group.populate_master_only(6, 10).unwrap();
assert_eq!(group.nodes()[0].current_vlsn(), 15);
assert_eq!(group.nodes()[1].current_vlsn(), 5);
group.catch_up_replica(1, 6, 10).unwrap();
group.assert_all_at_vlsn(15);
}
}