use super::{proto, Metadata, State};
use futures::stream::{unfold, Stream};
use rand::{thread_rng, Rng};
use std::{
collections::HashMap,
convert::TryInto,
net::SocketAddr,
ops::Index,
result,
sync::{Arc, Weak},
};
use thiserror::Error;
use tokio::sync::{
broadcast::{error::RecvError, Receiver},
RwLock,
};
use tonic::transport::{self, Channel, ClientTlsConfig};
pub(crate) type Result = result::Result<(), Closed>;
#[derive(Copy, Clone, Debug, Error)]
#[error("closed")]
pub struct Closed;
pub struct Subscription {
state: Weak<RwLock<State>>,
rx: Receiver<MultiNodeCut>,
}
impl Subscription {
pub(crate) fn new(state: Weak<RwLock<State>>, rx: Receiver<MultiNodeCut>) -> Self {
Self { state, rx }
}
pub async fn recv(&mut self) -> result::Result<MultiNodeCut, Closed> {
let n = match self.rx.recv().await {
Ok(view_change) => {
return Ok(view_change);
}
Err(RecvError::Closed) => {
return Err(Closed);
}
Err(RecvError::Lagged(n)) => n,
};
let state = self.state.upgrade().ok_or(Closed)?;
let state = state.read().await;
let mut cut = state.last_cut.clone().ok_or(Closed)?;
cut.skipped = n;
Ok(cut)
}
pub fn into_stream(self) -> impl Stream<Item = MultiNodeCut> {
unfold(self, |mut s| async { Some((s.recv().await.ok()?, s)) })
}
pub fn as_stream(&mut self) -> impl Stream<Item = MultiNodeCut> + '_ {
unfold(self, |s| async { Some((s.recv().await.ok()?, s)) })
}
}
#[derive(Clone, Debug)]
pub struct MultiNodeCut {
pub(crate) skipped: u64,
pub(crate) local_addr: SocketAddr,
pub(crate) conf_id: u64,
pub(crate) degraded: bool,
pub(crate) members: Arc<[Member]>,
pub(crate) joined: Arc<[Member]>,
pub(crate) kicked: Arc<[Member]>,
}
impl Index<SocketAddr> for MultiNodeCut {
type Output = Member;
#[inline]
fn index(&self, addr: SocketAddr) -> &Self::Output {
self.lookup(addr).unwrap()
}
}
impl MultiNodeCut {
pub fn skipped(&self) -> u64 {
self.skipped
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn conf_id(&self) -> u64 {
self.conf_id
}
pub fn is_degraded(&self) -> bool {
self.degraded
}
pub(crate) fn random_member(&self) -> &Member {
&self.members[thread_rng().gen_range(0..self.members.len())]
}
pub fn members(&self) -> &Arc<[Member]> {
&self.members
}
pub fn joined(&self) -> &Arc<[Member]> {
&self.joined
}
pub fn kicked(&self) -> &Arc<[Member]> {
&self.kicked
}
pub fn with_meta<K: AsRef<str>>(&self, key: K) -> impl Iterator<Item = (&Member, &[u8])> {
self.members.iter().filter_map(move |m| {
let val = m.meta.get(key.as_ref())?;
Some((m, val.as_ref()))
})
}
pub fn lookup(&self, addr: SocketAddr) -> Option<&Member> {
self.members
.binary_search_by_key(&addr, |m| m.addr())
.ok()
.map(|i| &self.members[i])
}
}
#[derive(Clone, Debug)]
pub struct Member {
addr: SocketAddr,
tls: Option<Arc<ClientTlsConfig>>,
meta: Metadata,
chan: Channel,
}
impl From<&Member> for proto::Endpoint {
#[inline]
fn from(Member { addr, tls, .. }: &Member) -> Self {
Self::from(*addr).tls(tls.is_some())
}
}
impl From<&Member> for transport::Endpoint {
#[inline]
fn from(Member { addr, tls, .. }: &Member) -> Self {
endpoint(*addr, tls.as_deref())
}
}
#[inline]
fn endpoint(addr: SocketAddr, tls: Option<&ClientTlsConfig>) -> transport::Endpoint {
match tls.cloned() {
Some(tls) => format!("https://{}", addr)
.try_into()
.map(|e: transport::Endpoint| e.tls_config(tls).unwrap()),
None => format!("http://{}", addr).try_into(),
}
.unwrap()
}
impl Member {
#[inline]
pub(crate) fn new(addr: SocketAddr, tls: Option<Arc<ClientTlsConfig>>, meta: Metadata) -> Self {
let chan = endpoint(addr, tls.as_deref()).connect_lazy();
#[rustfmt::skip]
let m = Self { addr, tls, meta, chan };
m
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn tls_config(&self) -> Option<&ClientTlsConfig> {
self.tls.as_deref()
}
pub fn metadata(&self) -> &HashMap<String, Vec<u8>> {
&self.meta.keys
}
pub fn channel(&self) -> Channel {
self.chan.clone()
}
}