use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use d_engine_core::MembershipError;
use d_engine_core::Result;
use d_engine_proto::server::cluster::NodeMeta;
use tokio::sync::Mutex;
use tracing::info;
pub struct MembershipGuard {
state: ArcSwap<InnerState>,
write_mutex: Mutex<()>,
}
#[derive(Clone)]
pub struct InnerState {
pub nodes: HashMap<u32, NodeMeta>,
pub cluster_conf_version: u64,
}
impl MembershipGuard {
pub fn new(
initial_nodes: Vec<NodeMeta>,
initial_version: u64,
) -> Self {
info!(
"Initializing membership: {:?}, version: {}",
initial_nodes, initial_version
);
let state = ArcSwap::new(Arc::new(InnerState {
nodes: initial_nodes.into_iter().map(|node| (node.id, node)).collect(),
cluster_conf_version: initial_version,
}));
Self {
state,
write_mutex: Mutex::new(()),
}
}
pub async fn blocking_read<R>(
&self,
f: impl FnOnce(&InnerState) -> R,
) -> R {
let state = self.state.load();
f(&state)
}
pub async fn blocking_write<R>(
&self,
f: impl FnOnce(&mut InnerState) -> R,
) -> R {
let _guard = self.write_mutex.lock().await;
let mut new_state = (**self.state.load()).clone();
let result = f(&mut new_state);
self.state.store(Arc::new(new_state));
result
}
#[allow(unused)]
pub async fn update_node(
&self,
node_id: u32,
f: impl FnOnce(&mut NodeMeta),
) -> Result<()> {
self.blocking_write(|state| {
if let Some(node) = state.nodes.get_mut(&node_id) {
f(node);
Ok(())
} else {
Err(MembershipError::NoMetadataFoundForNode { node_id }.into())
}
})
.await
}
#[allow(unused)]
pub async fn contains_node(
&self,
node_id: u32,
) -> bool {
self.blocking_read(|state| state.nodes.contains_key(&node_id)).await
}
}