use std::{collections::HashMap, sync::OnceLock};
use crate::types::MessageRef;
use async_channel::Receiver;
use async_lock::RwLock;
use futures::StreamExt;
use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport};
use smol_str::{SmolStr, format_smolstr};
use crate::event::{
INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY,
InternalQueryEvent,
};
use super::{
Serf, SerfWeakRef,
delegate::Delegate,
error::Error,
serf::{NodeResponse, QueryResponse},
types::KeyRequestMessage,
};
#[viewit::viewit(
vis_all = "pub(crate)",
getters(style = "move", vis_all = "pub"),
setters(skip)
)]
#[derive(Default, Debug)]
pub struct KeyResponse<I> {
#[viewit(getter(
const,
style = "ref",
attrs(doc = "Returns a map of node id to response message.")
))]
messages: HashMap<I, SmolStr>,
#[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
num_nodes: usize,
#[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
num_resp: usize,
#[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
num_err: usize,
#[viewit(getter(
const,
style = "ref",
attrs(
doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
)
))]
keys: HashMap<SecretKey, usize>,
#[viewit(getter(
const,
style = "ref",
attrs(
doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
)
))]
primary_keys: HashMap<SecretKey, usize>,
}
pub struct KeyRequestOptions {
pub relay_factor: u8,
}
pub struct KeyManager<T, D>
where
D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
T: Transport,
{
serf: OnceLock<SerfWeakRef<T, D>>,
l: RwLock<()>,
}
impl<T, D> KeyManager<T, D>
where
D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
T: Transport,
{
pub(crate) fn new() -> Self {
Self {
serf: OnceLock::new(),
l: RwLock::new(()),
}
}
pub(crate) fn store(&self, serf: SerfWeakRef<T, D>) {
let _ = self.serf.set(serf);
}
fn this(&self) -> Option<Serf<T, D>> {
self.serf.get().and_then(|weak_ref| weak_ref.upgrade())
}
pub async fn install_key(
&self,
key: SecretKey,
opts: Option<KeyRequestOptions>,
) -> Result<KeyResponse<T::Id>, Error<T, D>> {
let _mu = self.l.write().await;
self
.handle_key_request(
Some(key),
INTERNAL_INSTALL_KEY,
opts,
InternalQueryEvent::InstallKey,
)
.await
}
pub async fn use_key(
&self,
key: SecretKey,
opts: Option<KeyRequestOptions>,
) -> Result<KeyResponse<T::Id>, Error<T, D>> {
let _mu = self.l.write().await;
self
.handle_key_request(
Some(key),
INTERNAL_USE_KEY,
opts,
InternalQueryEvent::UseKey,
)
.await
}
pub async fn remove_key(
&self,
key: SecretKey,
opts: Option<KeyRequestOptions>,
) -> Result<KeyResponse<T::Id>, Error<T, D>> {
let _mu = self.l.write().await;
self
.handle_key_request(
Some(key),
INTERNAL_REMOVE_KEY,
opts,
InternalQueryEvent::RemoveKey,
)
.await
}
pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
let _mu = self.l.read().await;
self
.handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
.await
}
pub(crate) async fn handle_key_request(
&self,
key: Option<SecretKey>,
ty: &str,
opts: Option<KeyRequestOptions>,
event: InternalQueryEvent<T::Id>,
) -> Result<KeyResponse<T::Id>, Error<T, D>> {
let kr = KeyRequestMessage { key };
let buf = crate::types::encode_message_to_bytes(&kr)?;
let Some(this) = self.this() else {
return Ok(KeyResponse {
num_nodes: 0,
messages: HashMap::new(),
num_resp: 0,
num_err: 0,
keys: HashMap::new(),
primary_keys: HashMap::new(),
});
};
let mut q_param = this.default_query_param().await;
if let Some(opts) = opts {
q_param.relay_factor = opts.relay_factor;
}
let qresp: QueryResponse<T::Id, T::ResolvedAddress> = this
.internal_query(SmolStr::new(ty), buf, Some(q_param), event)
.await?;
let resp = Self::stream_key_response(&this, qresp.response_rx()).await;
if resp.num_err > 0 {
tracing::error!(
"serf: {}/{} nodes reported failure",
resp.num_err,
resp.num_nodes
);
}
if resp.num_resp != resp.num_nodes {
tracing::error!(
"serf: {}/{} nodes responded success",
resp.num_resp,
resp.num_nodes
);
}
Ok(resp)
}
async fn stream_key_response(
this: &Serf<T, D>,
ch: Receiver<NodeResponse<T::Id, T::ResolvedAddress>>,
) -> KeyResponse<T::Id> {
let mut resp = KeyResponse {
num_nodes: this.num_members().await,
messages: HashMap::new(),
num_resp: 0,
num_err: 0,
keys: HashMap::new(),
primary_keys: HashMap::new(),
};
futures::pin_mut!(ch);
while let Some(r) = ch.next().await {
resp.num_resp += 1;
if r.payload.is_empty() {
resp
.messages
.insert(r.from.id().cheap_clone(), SmolStr::new("empty payload"));
resp.num_err += 1;
if resp.num_resp == resp.num_nodes {
return resp;
}
continue;
}
let node_response =
match crate::types::decode_message::<T::Id, T::ResolvedAddress>(&r.payload) {
Ok(msg) => match msg {
MessageRef::KeyResponse(kr) => kr,
msg => {
tracing::error!(type=%msg.ty(), "serf: invalid key query response type");
resp.messages.insert(
r.from.id().cheap_clone(),
format_smolstr!("invalid key query response: {}", msg.ty()),
);
resp.num_err += 1;
if resp.num_resp == resp.num_nodes {
return resp;
}
continue;
}
},
Err(e) => {
tracing::error!(err=%e, "serf: failed to decode key query response");
resp
.messages
.insert(r.from.id().cheap_clone(), format_smolstr!("{e}"));
resp.num_err += 1;
if resp.num_resp == resp.num_nodes {
return resp;
}
continue;
}
};
if !node_response.result() {
resp.messages.insert(
r.from.id().cheap_clone(),
SmolStr::new(node_response.message()),
);
resp.num_err += 1;
}
if node_response.result() && !node_response.message().is_empty() {
tracing::warn!("serf: {}", node_response.message());
resp.messages.insert(
r.from.id().cheap_clone(),
SmolStr::new(node_response.message()),
);
}
let res = node_response
.keys()
.iter::<SecretKey>()
.try_for_each(|res| {
res.map(|k| {
let count = resp.keys.entry(k).or_insert(0);
*count += 1;
})
});
if let Err(e) = res {
resp.messages.insert(
r.from.id().cheap_clone(),
SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
);
resp.num_err += 1;
if resp.num_resp == resp.num_nodes {
return resp;
}
continue;
}
if let Some(pk) = node_response.primary_key() {
let ctr = resp.primary_keys.entry(*pk).or_insert(0);
*ctr += 1;
}
if resp.num_resp == resp.num_nodes {
return resp;
}
}
resp
}
}