use std::{
collections::HashSet,
sync::Arc,
time::{Duration, Instant},
};
use crate::types::FilterRef;
use async_channel::{Receiver, Sender};
use async_lock::RwLock;
use either::Either;
use futures::{FutureExt, StreamExt, stream::FuturesUnordered};
use memberlist_core::{
CheapClone,
bytes::Bytes,
proto::{Data, RepeatedDecoder, SmallVec, TinyVec},
tracing,
transport::{Node, Transport},
};
use crate::{
delegate::Delegate,
error::Error,
types::{Filter, LamportTime, Member, MemberStatus, QueryMessage, QueryResponseMessage},
};
use super::Serf;
#[viewit::viewit(
vis_all = "pub(crate)",
getters(vis_all = "pub", style = "ref"),
setters(vis_all = "pub", prefix = "with")
)]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct QueryParam<I> {
#[viewit(
getter(const, attrs(doc = "Returns the filters of the query")),
setter(attrs(doc = "Sets the filters of the query"))
)]
filters: TinyVec<Filter<I>>,
#[viewit(
getter(
const,
style = "move",
attrs(
doc = "Returns if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
)
),
setter(attrs(
doc = "Sets if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
))
)]
request_ack: bool,
#[viewit(
getter(
const,
style = "move",
attrs(
doc = "Returns the number of duplicate responses to relay back to the sender through other nodes for redundancy."
)
),
setter(attrs(
doc = "Sets the number of duplicate responses to relay back to the sender through other nodes for redundancy."
))
)]
relay_factor: u8,
#[viewit(
getter(
const,
style = "move",
attrs(
doc = "Returns timeout limits how long the query is left open. If not provided, then a default timeout is used based on the configuration of [`Serf`]"
)
),
setter(attrs(doc = "Sets timeout limits how long the query is left open."))
)]
#[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
timeout: Duration,
}
struct QueryResponseChannel<I, A> {
ack_ch: Option<(Sender<Node<I, A>>, Receiver<Node<I, A>>)>,
resp_ch: (Sender<NodeResponse<I, A>>, Receiver<NodeResponse<I, A>>),
}
pub(crate) struct QueryResponseCore<I, A> {
closed: bool,
acks: HashSet<Node<I, A>>,
responses: HashSet<Node<I, A>>,
}
pub(crate) struct QueryResponseInner<I, A> {
core: RwLock<QueryResponseCore<I, A>>,
channel: QueryResponseChannel<I, A>,
}
#[viewit::viewit(vis_all = "pub(crate)")]
#[derive(Clone)]
pub struct QueryResponse<I, A> {
#[viewit(
getter(
style = "move",
const,
attrs(doc = "Returns the ending deadline of the query")
),
setter(skip)
)]
deadline: Instant,
#[viewit(
getter(style = "move", const, attrs(doc = "Returns the id of the query")),
setter(skip)
)]
id: u32,
#[viewit(
getter(
style = "move",
const,
attrs(doc = "Returns the Lamport Time of the query")
),
setter(skip)
)]
ltime: LamportTime,
#[viewit(getter(vis = "pub(crate)", const, style = "ref"), setter(skip))]
inner: Arc<QueryResponseInner<I, A>>,
}
impl<I, A> QueryResponse<I, A> {
pub(crate) fn from_query(q: &QueryMessage<I, A>, num_nodes: usize) -> Self {
QueryResponse::new(
q.id(),
q.ltime(),
num_nodes,
Instant::now() + q.timeout(),
q.ack(),
)
}
}
impl<I, A> QueryResponse<I, A> {
#[inline]
pub(crate) fn new(
id: u32,
ltime: LamportTime,
num_nodes: usize,
deadline: Instant,
ack: bool,
) -> Self {
let (ack_ch, acks) = if ack {
(
Some(async_channel::bounded(num_nodes)),
HashSet::with_capacity(num_nodes),
)
} else {
(None, HashSet::new())
};
Self {
deadline,
id,
ltime,
inner: Arc::new(QueryResponseInner {
core: RwLock::new(QueryResponseCore {
closed: false,
acks,
responses: HashSet::with_capacity(num_nodes),
}),
channel: QueryResponseChannel {
ack_ch,
resp_ch: async_channel::bounded(num_nodes),
},
}),
}
}
#[inline]
pub fn ack_rx(&self) -> Option<async_channel::Receiver<Node<I, A>>> {
self.inner.channel.ack_ch.as_ref().map(|(_, r)| r.clone())
}
#[inline]
pub fn response_rx(&self) -> async_channel::Receiver<NodeResponse<I, A>> {
self.inner.channel.resp_ch.1.clone()
}
#[inline]
pub async fn finished(&self) -> bool {
let c = self.inner.core.read().await;
c.closed || (Instant::now() > self.deadline)
}
#[inline]
pub async fn close(&self) {
let mut c = self.inner.core.write().await;
if c.closed {
return;
}
c.closed = true;
if let Some((tx, _)) = &self.inner.channel.ack_ch {
tx.close();
}
self.inner.channel.resp_ch.0.close();
}
#[inline]
pub(crate) async fn handle_query_response<T, D>(
&self,
resp: QueryResponseMessage<I, A>,
_local: &T::Id,
#[cfg(feature = "metrics")] metrics_labels: &memberlist_core::proto::MetricLabels,
) where
I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
T: Transport,
{
let c = self.inner.core.read().await;
if c.closed || (Instant::now() > self.deadline) {
return;
}
if resp.ack() {
if c.acks.contains(&resp.from) {
#[cfg(feature = "metrics")]
{
metrics::counter!("serf.query.duplicate_acks", metrics_labels.iter()).increment(1);
}
return;
}
#[cfg(feature = "metrics")]
{
metrics::counter!("serf.query.acks", metrics_labels.iter()).increment(1);
}
drop(c);
if let Err(e) = self.send_ack::<T, D>(&resp).await {
tracing::warn!("serf: {}", e);
}
} else {
if c.responses.contains(&resp.from) {
#[cfg(feature = "metrics")]
{
metrics::counter!("serf.query.duplicate_responses", metrics_labels.iter()).increment(1);
}
return;
}
#[cfg(feature = "metrics")]
{
metrics::counter!("serf.query.responses", metrics_labels.iter()).increment(1);
}
drop(c);
if let Err(e) = self
.send_response::<T, D>(NodeResponse {
from: resp.from,
payload: resp.payload,
})
.await
{
tracing::warn!("serf: {}", e);
}
}
}
#[inline]
pub(crate) async fn send_response<T, D>(&self, nr: NodeResponse<I, A>) -> Result<(), Error<T, D>>
where
I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
T: Transport,
{
let mut c = self.inner.core.write().await;
if c.responses.contains(&nr.from) {
return Ok(());
}
if c.closed {
Ok(())
} else {
let id = nr.from.cheap_clone();
futures::select! {
_ = self.inner.channel.resp_ch.0.send(nr).fuse() => {
c.responses.insert(id);
Ok(())
},
default => {
Err(Error::query_response_delivery_failed())
}
}
}
}
#[inline]
pub(crate) async fn send_ack<T, D>(
&self,
nr: &QueryResponseMessage<I, A>,
) -> Result<(), Error<T, D>>
where
I: Eq + std::hash::Hash + CheapClone,
A: Eq + std::hash::Hash + CheapClone,
D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
T: Transport,
{
let mut c = self.inner.core.write().await;
if c.acks.contains(&nr.from) {
return Ok(());
}
if c.closed {
Ok(())
} else if let Some((tx, _)) = &self.inner.channel.ack_ch {
futures::select! {
_ = tx.send(nr.from.cheap_clone()).fuse() => {
c.acks.insert(nr.from.clone());
Ok(())
},
default => {
Err(Error::query_response_delivery_failed())
}
}
} else {
Ok(())
}
}
}
#[viewit::viewit(
vis_all = "pub(crate)",
setters(skip),
getters(vis_all = "pub", style = "ref")
)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeResponse<I, A> {
#[viewit(getter(attrs(doc = "Returns the node that sent the response")))]
from: Node<I, A>,
#[viewit(getter(attrs(doc = "Returns the payload of the response")))]
payload: Bytes,
}
#[inline]
fn random_members<I, A>(k: usize, mut members: SmallVec<Member<I, A>>) -> SmallVec<Member<I, A>> {
let n = members.len();
if n == 0 {
return SmallVec::new();
}
let rounds = 3 * n;
let mut i = 0;
while i < rounds && i < n {
let j = (rand::random::<u32>() as usize) % (n - i) + i;
members.swap(i, j);
i += 1;
if i >= k && i >= rounds {
break;
}
}
members.truncate(k);
members
}
impl<T, D> Serf<T, D>
where
D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
T: Transport,
{
pub async fn default_query_timeout(&self) -> Duration {
let n = self.inner.memberlist.num_online_members().await;
let mut timeout = self.inner.opts.memberlist_options.gossip_interval();
timeout *= self.inner.opts.query_timeout_mult as u32;
timeout *= ((n + 1) as f64).log10().ceil() as u32; timeout
}
pub async fn default_query_param(&self) -> QueryParam<T::Id> {
QueryParam {
filters: TinyVec::new(),
request_ack: false,
relay_factor: 0,
timeout: self.default_query_timeout().await,
}
}
pub(crate) fn should_process_query(
&self,
filters: Either<RepeatedDecoder<'_>, &[Filter<T::Id>]>,
) -> Result<bool, memberlist_core::proto::DecodeError> {
match filters {
Either::Left(filters) => {
for filter in filters.iter::<Filter<T::Id>>() {
let filter = filter?;
match filter {
FilterRef::Id(ids) => {
let mut found = false;
for id in ids.iter::<T::Id>() {
let id = id?;
if <T::Id as Data>::from_ref(id)?.eq(self.inner.memberlist.local_id()) {
found = true;
break;
}
}
if !found {
return Ok(false);
}
}
FilterRef::Tag(tag) => {
let tags = self.inner.opts.tags.load();
if !tags.is_empty() {
if let Some(expr) = tags.get(tag.tag()) {
if let Some(re) = tag.expr() {
if !regex::Regex::new(re)
.map_err(|_| memberlist_core::proto::DecodeError::custom("invalid regex"))?
.is_match(expr)
{
return Ok(false);
}
}
} else {
return Ok(false);
}
} else {
return Ok(false);
}
}
}
}
Ok(true)
}
Either::Right(filters) => {
for filter in filters.iter() {
match &filter {
Filter::Id(nodes) => {
let found = nodes
.iter()
.any(|n: &T::Id| n.eq(self.inner.memberlist.local_id()));
if !found {
return Ok(false);
}
}
Filter::Tag(tag) => {
let tags = self.inner.opts.tags.load();
if !tags.is_empty() {
if let Some(expr) = tags.get(tag.tag()) {
if let Some(re) = tag.expr() {
if !re.is_match(expr) {
return Ok(false);
}
}
} else {
return Ok(false);
}
} else {
return Ok(false);
}
}
}
}
Ok(true)
}
}
}
pub(crate) async fn relay_response(
&self,
relay_factor: u8,
node: Node<T::Id, T::ResolvedAddress>,
resp: QueryResponseMessage<T::Id, T::ResolvedAddress>,
) -> Result<(), Error<T, D>> {
if relay_factor == 0 {
return Ok(());
}
let members = {
let members = self.inner.members.read().await;
if members.states.len() < relay_factor as usize + 1 {
return Ok(());
}
members
.states
.iter()
.filter_map(|(id, m)| {
if m.member.status == MemberStatus::Alive && id != self.inner.memberlist.local_id() {
Some(m.member.clone())
} else {
None
}
})
.collect::<SmallVec<_>>()
};
if members.is_empty() {
return Ok(());
}
let encoded_len = crate::types::encoded_relay_message_len(&resp, &node);
if encoded_len > self.inner.opts.query_response_size_limit {
return Err(Error::relayed_response_too_large(
self.inner.opts.query_response_size_limit,
));
}
let raw = crate::types::encode_relay_message_to_bytes(&resp, &node)?;
let relay_members = random_members(relay_factor as usize, members);
let futs: FuturesUnordered<_> = relay_members
.into_iter()
.map(|m| {
let raw = raw.clone();
async move {
self
.inner
.memberlist
.send(m.node.address(), raw)
.await
.map_err(|e| (m, e))
}
})
.collect();
let mut errs = TinyVec::new();
let stream = StreamExt::filter_map(futs, |res| async move {
if let Err((m, e)) = res {
Some((m, e))
} else {
None
}
});
futures::pin_mut!(stream);
while let Some(err) = stream.next().await {
errs.push(err);
}
Ok(())
}
}