use std::{
collections::HashMap,
future::Future,
sync::{atomic::Ordering, Arc},
time::{Duration, Instant},
};
use agnostic::Runtime;
use bytes::Bytes;
use futures::{FutureExt, Stream, StreamExt};
use super::{
base::Memberlist,
delegate::{Delegate, VoidDelegate},
error::{Error, JoinError},
network::META_MAX_SIZE,
state::AckMessage,
transport::{AddressResolver, CheapClone, MaybeResolvedAddress, Node, Transport},
types::{Alive, Dead, Message, Meta, NodeState, Ping, SmallVec},
Options,
};
impl<T, D> Memberlist<T, D>
where
D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
T: Transport,
{
#[inline]
pub fn local_id(&self) -> &T::Id {
&self.inner.id
}
#[inline]
pub fn local_addr(&self) -> &<T::Resolver as AddressResolver>::Address {
self.inner.transport.local_address()
}
#[inline]
pub fn advertise_node(&self) -> Node<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress> {
Node::new(self.inner.id.clone(), self.inner.advertise.clone())
}
#[inline]
pub fn advertise_address(&self) -> &<T::Resolver as AddressResolver>::ResolvedAddress {
&self.inner.advertise
}
#[inline]
pub fn delegate(&self) -> Option<&D> {
self.delegate.as_deref()
}
#[inline]
pub async fn local_state(
&self,
) -> Arc<NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>> {
let nodes = self.inner.nodes.read().await;
nodes
.node_map
.get(&self.inner.id)
.map(|&idx| nodes.nodes[idx].state.server.clone())
.unwrap()
}
pub async fn by_id(
&self,
id: &T::Id,
) -> Option<Arc<NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>> {
let members = self.inner.nodes.read().await;
members
.node_map
.get(id)
.map(|&idx| members.nodes[idx].state.server.clone())
}
#[inline]
pub async fn members(
&self,
) -> SmallVec<Arc<NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>> {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.map(|n| n.state.server.clone())
.collect()
}
#[inline]
pub async fn num_members(&self) -> usize {
self.inner.nodes.read().await.nodes.len()
}
pub async fn online_members(
&self,
) -> SmallVec<Arc<NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>> {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.filter(|n| !n.dead_or_left())
.map(|n| n.state.server.clone())
.collect()
}
pub async fn num_online_members(&self) -> usize {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.filter(|n| !n.dead_or_left())
.count()
}
pub async fn members_by(
&self,
mut f: impl FnMut(&NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>) -> bool,
) -> SmallVec<Arc<NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>> {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.filter(|n| f(&n.state))
.map(|n| n.state.server.clone())
.collect()
}
pub async fn num_members_by(
&self,
mut f: impl FnMut(&NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>) -> bool,
) -> usize {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.filter(|n| f(&n.state))
.count()
}
pub async fn members_map_by<O>(
&self,
mut f: impl FnMut(&NodeState<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>) -> Option<O>,
) -> SmallVec<O> {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.filter_map(|n| f(&n.state))
.collect()
}
}
impl<T> Memberlist<T>
where
T: Transport,
<<T::Runtime as Runtime>::Sleep as Future>::Output: Send,
<<T::Runtime as Runtime>::Interval as Stream>::Item: Send,
{
#[inline]
pub async fn new(
transport: T,
opts: Options,
) -> Result<Self, Error<T, VoidDelegate<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>>
{
Self::create(transport, None, opts).await
}
}
impl<T, D> Memberlist<T, D>
where
D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
T: Transport,
<<T::Runtime as Runtime>::Sleep as Future>::Output: Send,
<<T::Runtime as Runtime>::Interval as Stream>::Item: Send,
{
#[inline]
pub async fn with_delegate(
transport: T,
delegate: D,
opts: Options,
) -> Result<Self, Error<T, D>> {
Self::create(transport, Some(delegate), opts).await
}
pub(crate) async fn create(
transport: T,
delegate: Option<D>,
opts: Options,
) -> Result<Self, Error<T, D>> {
let (shutdown_rx, advertise, this) = Self::new_in(transport, delegate, opts).await?;
let meta = if let Some(d) = &this.delegate {
d.node_meta(META_MAX_SIZE).await
} else {
Meta::empty()
};
if meta.len() > META_MAX_SIZE {
panic!("NodeState meta data provided is longer than the limit");
}
let alive = Alive::new(
this.next_incarnation(),
Node::new(this.inner.id.clone(), this.inner.advertise.clone()),
)
.with_meta(meta)
.with_protocol_version(this.inner.opts.protocol_version)
.with_delegate_version(this.inner.opts.delegate_version);
this.alive_node(alive, None, true).await;
this.schedule(shutdown_rx).await;
tracing::debug!(target = "memberlist", local = %this.inner.id, advertise_addr = %advertise, "node is living");
Ok(this)
}
pub async fn leave(&self, timeout: Duration) -> Result<(), Error<T, D>> {
let _mu = self.inner.leave_lock.lock().await;
if self.has_shutdown() {
panic!("leave after shutdown");
}
if !self.has_left() {
self.inner.hot.leave.store(true, Ordering::Release);
let mut memberlist = self.inner.nodes.write().await;
if let Some(&idx) = memberlist.node_map.get(&self.inner.id) {
let state = &memberlist.nodes[idx];
let d = Dead::new(
state.state.incarnation.load(Ordering::Acquire),
state.id().cheap_clone(),
state.id().cheap_clone(),
);
self.dead_node(&mut memberlist, d).await?;
let any_alive = memberlist.any_alive();
drop(memberlist);
if any_alive {
if timeout > Duration::ZERO {
futures::select! {
_ = self.inner.leave_broadcast_rx.recv().fuse() => {},
_ = <T::Runtime as Runtime>::sleep(timeout).fuse() => {
return Err(Error::LeaveTimeout);
}
}
} else if let Err(e) = self.inner.leave_broadcast_rx.recv().await {
tracing::error!(
target: "memberlist",
"failed to receive leave broadcast: {}",
e
);
}
}
} else {
tracing::warn!(target = "memberlist", "leave but we're not a member");
}
}
Ok(())
}
pub async fn join(&self, node: Node<T::Id, MaybeResolvedAddress<T>>) -> Result<(), Error<T, D>> {
if self.has_left() || self.has_shutdown() {
return Err(Error::NotRunning);
}
let (id, addr) = node.into_components();
let addr = match addr {
MaybeResolvedAddress::Resolved(addr) => addr,
MaybeResolvedAddress::Unresolved(addr) => self
.inner
.transport
.resolve(&addr)
.await
.map_err(Error::Transport)?,
};
self.push_pull_node(Node::new(id, addr), true).await
}
pub async fn join_many(
&self,
existing: impl Iterator<Item = Node<T::Id, MaybeResolvedAddress<T>>>,
) -> Result<
SmallVec<Node<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>,
JoinError<T, D>,
> {
if self.has_left() || self.has_shutdown() {
return Err(JoinError {
joined: SmallVec::new(),
errors: existing
.into_iter()
.map(|n| (n, Error::NotRunning))
.collect(),
});
}
let estimated_total = existing.size_hint().0;
let futs = existing
.into_iter()
.map(|node| {
async move {
let (id, addr) = node.into_components();
let resolved_addr = match addr {
MaybeResolvedAddress::Resolved(addr) => addr,
MaybeResolvedAddress::Unresolved(addr) => {
match self.inner.transport.resolve(&addr).await {
Ok(addr) => addr,
Err(e) => {
tracing::debug!(
target: "memberlist",
err = %e,
"failed to resolve address {}",
addr,
);
return Err((Node::new(id, MaybeResolvedAddress::unresolved(addr)), Error::<T, D>::transport(e)))
}
}
}
};
let node = Node::new(id, resolved_addr);
tracing::info!(target = "memberlist", local = %self.inner.transport.local_id(), peer = %node, "start join...");
if let Err(e) = self.push_pull_node(node.cheap_clone(), true).await {
tracing::debug!(
target: "memberlist",
local = %self.inner.id,
err = %e,
"failed to join {}",
node,
);
let (id, addr) = node.into_components();
Err((Node::new(id, MaybeResolvedAddress::Resolved(addr)), e))
} else {
Ok(node)
}
}
}).collect::<futures::stream::FuturesUnordered<_>>();
let num_success = std::cell::RefCell::new(SmallVec::with_capacity(estimated_total));
let errors = futs
.filter_map(|rst| async {
match rst {
Ok(node) => {
num_success.borrow_mut().push(node);
None
}
Err((node, e)) => Some((node, e)),
}
})
.collect::<HashMap<_, _>>()
.await;
if errors.is_empty() {
return Ok(num_success.into_inner());
}
Err(JoinError {
joined: num_success.into_inner(),
errors,
})
}
#[inline]
pub fn health_score(&self) -> usize {
self.inner.awareness.get_health_score() as usize
}
pub async fn update_node(&self, timeout: Duration) -> Result<(), Error<T, D>> {
if self.has_left() || self.has_shutdown() {
return Err(Error::NotRunning);
}
let meta = if let Some(delegate) = &self.delegate {
let meta = delegate.node_meta(META_MAX_SIZE).await;
if meta.len() > META_MAX_SIZE {
panic!("node meta data provided is longer than the limit");
}
meta
} else {
Meta::empty()
};
let node = {
let members = self.inner.nodes.read().await;
let idx = *members.node_map.get(&self.inner.id).unwrap();
let state = &members.nodes[idx].state;
Node::new(state.id().cheap_clone(), state.address().cheap_clone())
};
let alive = Alive::new(self.next_incarnation(), node)
.with_meta(meta)
.with_protocol_version(self.inner.opts.protocol_version)
.with_delegate_version(self.inner.opts.delegate_version);
let (notify_tx, notify_rx) = async_channel::bounded(1);
self.alive_node(alive, Some(notify_tx), true).await;
if self.any_alive().await {
if timeout > Duration::ZERO {
let _ = <T::Runtime as Runtime>::timeout(timeout, notify_rx.recv())
.await
.map_err(|_| Error::UpdateTimeout)?;
} else {
let _ = notify_rx.recv().await;
}
}
Ok(())
}
#[inline]
pub async fn send(
&self,
to: &<T::Resolver as AddressResolver>::ResolvedAddress,
msg: Bytes,
) -> Result<(), Error<T, D>> {
if self.has_left() || self.has_shutdown() {
return Err(Error::NotRunning);
}
self.transport_send_packet(to, Message::UserData(msg)).await
}
#[inline]
pub async fn send_reliable(
&self,
to: &<T::Resolver as AddressResolver>::ResolvedAddress,
msg: Bytes,
) -> Result<(), Error<T, D>> {
if self.has_left() || self.has_shutdown() {
return Err(Error::NotRunning);
}
self.send_user_msg(to, msg).await
}
pub async fn ping(
&self,
node: Node<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>,
) -> Result<Duration, Error<T, D>> {
let self_addr = self.get_advertise();
let ping = Ping::new(
self.next_sequence_number(),
Node::new(self.inner.transport.local_id().clone(), self_addr.clone()),
node.clone(),
);
let (ack_tx, ack_rx) = async_channel::bounded(self.inner.opts.indirect_checks + 1);
self.inner.ack_manager.set_probe_channels::<T::Runtime>(
ping.sequence_number(),
ack_tx,
None,
Instant::now(),
self.inner.opts.probe_interval,
);
match <T::Runtime as Runtime>::timeout(
self.inner.opts.probe_timeout,
self.send_msg(node.address(), ping.into()),
)
.await
{
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(_) => {
tracing::debug!(
target = "memberlist",
"failed ping {} by packet (timeout reached)",
node
);
return Err(Error::Lost(node));
}
}
let sent = Instant::now();
futures::select! {
v = ack_rx.recv().fuse() => {
if let Ok(AckMessage { complete, .. }) = v {
if complete {
return Ok(sent.elapsed());
}
}
}
_ = <T::Runtime as Runtime>::sleep(self.inner.opts.probe_timeout).fuse() => {}
}
tracing::debug!(
target = "memberlist",
"failed ping {} by packet (timeout reached)",
node
);
Err(Error::Lost(node))
}
pub async fn shutdown(&self) -> Result<(), Error<T, D>> {
self.inner.shutdown().await.map_err(Error::Transport)
}
}