use std::sync::Arc;
use nodecraft::CheapClone;
use smol_str::SmolStr;
use crate::delegate::DelegateError;
use super::*;
impl<D, T> Memberlist<T, D>
where
D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
T: Transport,
<<T::Runtime as Runtime>::Interval as Stream>::Item: Send,
<<T::Runtime as Runtime>::Sleep as Future>::Output: Send,
{
pub(crate) fn stream_listener(
&self,
shutdown_rx: async_channel::Receiver<()>,
) -> <T::Runtime as Runtime>::JoinHandle<()> {
let this = self.clone();
let transport_rx = this.inner.transport.stream();
<T::Runtime as Runtime>::spawn(async move {
tracing::debug!("memberlist stream listener start");
loop {
futures::select! {
_ = shutdown_rx.recv().fuse() => {
tracing::debug!("memberlist stream listener shutting down");
return;
}
conn = transport_rx.recv().fuse() => {
match conn {
Ok((remote_addr, conn)) => {
let this = this.clone();
<T::Runtime as Runtime>::spawn_detach(async move {
this.handle_conn(remote_addr, conn).await;
});
},
Err(e) => {
if !this.inner.shutdown_tx.is_closed() {
tracing::error!(local = %this.inner.id, "memberlist stream listener failed to accept connection: {}", e);
}
return;
},
}
}
}
}
})
}
pub(crate) async fn merge_remote_state(
&self,
node_state: PushPull<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>,
) -> Result<(), Error<T, D>> {
self.verify_protocol(node_state.states().as_slice()).await?;
if node_state.join() {
if let Some(merge) = self.delegate.as_ref() {
let peers = node_state
.states()
.iter()
.map(|n| {
Arc::new(
NodeState::new(n.id().cheap_clone(), n.address().cheap_clone(), n.state())
.with_meta(n.meta().cheap_clone())
.with_protocol_version(n.protocol_version())
.with_delegate_version(n.delegate_version()),
)
})
.collect::<SmallVec<_>>();
merge
.notify_merge(peers)
.await
.map_err(|e| Error::delegate(DelegateError::merge(e)))?;
}
}
let (join, user_data, states) = node_state.into_components();
self.merge_state(states.as_slice()).await;
if let Some(d) = &self.delegate {
if !user_data.is_empty() {
d.merge_remote_state(user_data, join).await;
}
}
Ok(())
}
pub(crate) async fn send_user_msg(
&self,
addr: &<T::Resolver as AddressResolver>::ResolvedAddress,
msg: Bytes,
) -> Result<(), Error<T, D>> {
let mut conn = self
.inner
.transport
.dial_with_deadline(addr, Instant::now() + self.inner.opts.timeout)
.await
.map_err(Error::transport)?;
self.send_message(&mut conn, Message::UserData(msg)).await?;
self
.inner
.transport
.cache_stream(addr, conn)
.await
.map_err(Error::transport)
}
}
impl<D, T> Memberlist<T, D>
where
D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
T: Transport,
<<T::Runtime as Runtime>::Interval as Stream>::Item: Send,
<<T::Runtime as Runtime>::Sleep as Future>::Output: Send,
{
pub(super) async fn send_local_state(
&self,
conn: &mut T::Stream,
join: bool,
) -> Result<(), Error<T, D>> {
conn.set_deadline(Some(Instant::now() + self.inner.opts.timeout));
#[cfg(feature = "metrics")]
let mut node_state_counts = State::metrics_array();
let local_nodes = {
self
.inner
.nodes
.read()
.await
.nodes
.iter()
.map(|m| {
let n = &m.state;
let this = PushNodeState::new(
n.incarnation.load(Ordering::Acquire),
n.id().cheap_clone(),
n.address().cheap_clone(),
n.state,
)
.with_meta(n.meta().cheap_clone())
.with_protocol_version(n.protocol_version())
.with_delegate_version(n.delegate_version());
#[cfg(feature = "metrics")]
{
node_state_counts[this.state() as u8 as usize].1 += 1;
}
this
})
.collect::<TinyVec<_>>()
};
let msg: Message<_, _> = if let Some(delegate) = &self.delegate {
PushPull::new(join, local_nodes)
.with_user_data(delegate.local_state(join).await)
.into()
} else {
PushPull::new(join, local_nodes).into()
};
#[cfg(feature = "metrics")]
{
std::thread_local! {
#[cfg(not(target_family = "wasm"))]
static NODE_INSTANCES_GAUGE: std::cell::OnceCell<std::cell::RefCell<crate::types::MetricLabels>> = const { std::cell::OnceCell::new() };
#[cfg(target_family = "wasm")]
static NODE_INSTANCES_GAUGE: once_cell::sync::OnceCell<std::cell::RefCell<crate::types::MetricLabels>> = once_cell::sync::OnceCell::new();
}
NODE_INSTANCES_GAUGE.with(|g| {
let mut labels = g
.get_or_init(|| {
let mut labels = (*self.inner.opts.metric_labels).clone();
labels.reserve_exact(1);
std::cell::RefCell::new(labels)
})
.borrow_mut();
for (idx, (node_state, cnt)) in node_state_counts.into_iter().enumerate() {
let label = metrics::Label::new("node_state", node_state);
if idx == 0 {
labels.push(label);
} else {
*labels.last_mut().unwrap() = label;
}
let iter = labels.iter();
metrics::gauge!("memberlist.node.instances", iter).set(cnt as f64);
}
labels.pop();
});
}
#[cfg(feature = "metrics")]
{
use crate::transport::Wire;
metrics::gauge!(
"memberlist.size.local",
self.inner.opts.metric_labels.iter()
)
.set(<T::Wire as Wire>::encoded_len(&msg) as f64);
}
self.send_message(conn, msg).await
}
}
impl<D, T> Memberlist<T, D>
where
D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
T: Transport,
<<T::Runtime as Runtime>::Interval as Stream>::Item: Send,
<<T::Runtime as Runtime>::Sleep as Future>::Output: Send,
{
async fn handle_conn(
self,
addr: <T::Resolver as AddressResolver>::ResolvedAddress,
mut conn: T::Stream,
) {
tracing::debug!(target = "memberlist.stream", local = %self.inner.id, peer = %addr, "handle stream connection");
#[cfg(feature = "metrics")]
{
metrics::counter!(
"memberlist.promised.accept",
self.inner.opts.metric_labels.iter()
)
.increment(1);
}
conn.set_deadline(Some(Instant::now() + self.inner.opts.timeout));
let msg = match self.read_message(&addr, &mut conn).await {
Ok((_read, msg)) => {
#[cfg(feature = "metrics")]
{
metrics::histogram!(
"memberlist.size.remote",
self.inner.opts.metric_labels.iter()
)
.record(_read as f64);
}
msg
}
Err(e) => {
tracing::error!(target = "memberlist.stream", err=%e, local = %self.inner.id, remote_node = %addr, "failed to receive");
let err_resp = ErrorResponse::new(SmolStr::new(e.to_string()));
if let Err(e) = self.send_message(&mut conn, err_resp.into()).await {
tracing::error!(target = "memberlist.stream", err=%e, local = %self.inner.id, remote_node = %addr, "failed to send error response");
return;
}
return;
}
};
match msg {
Message::Ping(ping) => {
if ping.target().id().ne(self.local_id()) {
tracing::error!(target = "memberlist.stream", local=%self.inner.id, remote = %addr, "got ping for unexpected node {}", ping.target());
return;
}
let ack = Ack::new(ping.sequence_number());
if let Err(e) = self.send_message(&mut conn, ack.into()).await {
tracing::error!(target = "memberlist.stream", err=%e, remote_node = %addr, "failed to send ack response");
}
if let Err(e) = self.inner.transport.cache_stream(&addr, conn).await {
tracing::warn!(target = "memberlist.stream", err=%e, remote_node = %addr, "failed to cache stream");
}
}
Message::PushPull(pp) => {
let num_concurrent = self.inner.hot.push_pull_req.fetch_add(1, Ordering::SeqCst);
scopeguard::defer! {
self.inner.hot.push_pull_req.fetch_sub(1, Ordering::SeqCst);
}
if num_concurrent >= MAX_PUSH_PULL_REQUESTS {
tracing::error!(
target: "memberlist.stream",
"too many pending push/pull requests"
);
return;
}
if let Err(e) = self.send_local_state(&mut conn, pp.join()).await {
tracing::error!(target = "memberlist.stream", err=%e, remote_node = %addr, "failed to push local state");
return;
}
if let Err(e) = self.merge_remote_state(pp).await {
tracing::error!(target = "memberlist.stream", err=%e, remote_node = %addr, "failed to push/pull merge");
}
if let Err(e) = self.inner.transport.cache_stream(&addr, conn).await {
tracing::warn!(target = "memberlist.stream", err=%e, remote_node = %addr, "failed to cache stream");
}
}
Message::UserData(data) => {
if let Some(d) = &self.delegate {
tracing::trace!(target = "memberlist.stream", remote_node = %addr, data=?data.as_ref(), "notify user message");
d.notify_message(data).await
}
}
msg => {
tracing::error!(target = "memberlist.stream", remote_node = %addr, "received invalid msg type {}", msg.kind());
}
}
}
}