netidx 0.31.5

Secure, fast, pub/sub messaging
Documentation
use super::common::{
    krb5_authentication, DesiredAuth, Response, ResponseChan, FROMREADPOOL, HELLO_TO,
    PUBLISHERPOOL, RAWFROMREADPOOL,
};
use crate::{
    channel::{self, Channel, K5CtxWrap},
    os::local_auth::AuthClient,
    protocol::resolver::{
        Auth, AuthRead, ClientHello, FromRead, Publisher, Referral, ToRead,
    },
    tls,
    utils::Either,
};
use anyhow::{Context, Error, Result};
use cross_krb5::ClientCtx;
use futures::{
    channel::{mpsc, oneshot},
    prelude::*,
};
use fxhash::FxHashSet;
use log::{info, warn};
use poolshark::global::GPooled;
use rand::{rng, seq::SliceRandom, RngExt};
use std::{
    cmp::max, collections::HashSet, fmt::Debug, net::SocketAddr, sync::Arc,
    time::Duration,
};
use tokio::{net::TcpStream, task, time};

// continue with timeout
macro_rules! cwt {
    ($msg:expr, $e:expr) => {
        try_cf!(
            $msg,
            continue,
            try_cf!($msg, continue, time::timeout(HELLO_TO, $e).await)
        )
    };
}

async fn connect(
    bad_addrs: &mut FxHashSet<SocketAddr>,
    resolver: &Referral,
    desired_auth: &DesiredAuth,
    tls: &Option<tls::CachedConnector>,
) -> Result<Channel> {
    let mut addrs = resolver.addrs.clone();
    addrs.as_mut_slice().shuffle(&mut rng());
    let mut n = 0;
    loop {
        let (addr, auth) = &addrs[n % addrs.len()];
        let tries = n / addrs.len();
        if tries >= 3 {
            bail!("can't connect to any resolver servers");
        }
        if tries == 0 && bad_addrs.contains(&addr) {
            n += 1;
            continue;
        } else {
            bad_addrs.clear()
        }
        if n % addrs.len() == 0 && tries > 0 {
            let wait = rng().random_range(1..12);
            time::sleep(Duration::from_secs(wait)).await;
        }
        n += 1;
        let mut con = match time::timeout(HELLO_TO, TcpStream::connect(&addr)).await {
            Ok(Ok(con)) => con,
            Err(_) => {
                warn!(
                    "failed to connect to resolver server {} connection timed out",
                    addr
                );
                bad_addrs.insert(*addr);
                continue;
            }
            Ok(Err(e)) => {
                warn!("failed to connect to resolver server {} error: {}", addr, e);
                bad_addrs.insert(*addr);
                continue;
            }
        };
        try_cf!("no delay", con.set_nodelay(true));
        cwt!("send version", channel::write_raw(&mut con, &3u64));
        if cwt!("recv version", channel::read_raw::<u64, _, 1024>(&mut con)) != 3 {
            continue;
        }
        let con = match (desired_auth, auth) {
            (DesiredAuth::Anonymous, _) => {
                let mut con = Channel::new::<ClientCtx, TcpStream>(None, con);
                cwt!("hello", con.send_one(&ClientHello::ReadOnly(AuthRead::Anonymous)));
                match cwt!("reply", con.receive::<AuthRead>()) {
                    AuthRead::Anonymous => (),
                    AuthRead::Local | AuthRead::Krb5 | AuthRead::Tls => {
                        bail!("protocol error")
                    }
                }
                con
            }
            (
                DesiredAuth::Krb5 { .. } | DesiredAuth::Local | DesiredAuth::Tls { .. },
                Auth::Anonymous,
            ) => {
                bail!("requested authentication mechanism not supported")
            }
            (
                DesiredAuth::Local | DesiredAuth::Krb5 { .. } | DesiredAuth::Tls { .. },
                Auth::Local { path },
            ) => {
                let mut con = Channel::new::<ClientCtx, TcpStream>(None, con);
                let tok = cwt!("local token", AuthClient::token(&*path));
                cwt!("hello", con.send_one(&ClientHello::ReadOnly(AuthRead::Local)));
                cwt!("token", con.send_one(&tok));
                match cwt!("reply", con.receive::<AuthRead>()) {
                    AuthRead::Local => (),
                    AuthRead::Krb5 | AuthRead::Anonymous | AuthRead::Tls => {
                        bail!("protocol error")
                    }
                }
                con
            }
            (DesiredAuth::Local, Auth::Krb5 { .. } | Auth::Tls { .. }) => {
                bail!("local auth not supported")
            }
            (DesiredAuth::Krb5 { .. }, Auth::Tls { .. }) => {
                bail!("krb5 authentication is not supported")
            }
            (DesiredAuth::Krb5 { upn, .. }, Auth::Krb5 { spn }) => {
                let upn = upn.as_ref().map(|s| s.as_str());
                let hello = ClientHello::ReadOnly(AuthRead::Krb5);
                cwt!("hello", channel::write_raw(&mut con, &hello));
                let ctx = cwt!("k5auth", krb5_authentication(upn, &*spn, &mut con));
                match cwt!("reply", channel::read_raw::<AuthRead, _, 1024>(&mut con)) {
                    AuthRead::Krb5 => Channel::new(Some(K5CtxWrap::new(ctx)), con),
                    AuthRead::Local | AuthRead::Anonymous | AuthRead::Tls => {
                        bail!("protocol error")
                    }
                }
            }
            (DesiredAuth::Tls { .. }, Auth::Krb5 { .. }) => {
                bail!("tls authentication is not supported")
            }
            (DesiredAuth::Tls { .. }, Auth::Tls { name }) => {
                let tls = tls.as_ref().ok_or_else(|| anyhow!("no tls cache"))?;
                let ctx = task::spawn_blocking({
                    let tls = tls.clone();
                    let name = name.clone();
                    move || tls.load(&name)
                })
                .await
                .context("loading tls connector")??;
                let hello = ClientHello::ReadOnly(AuthRead::Tls);
                cwt!("hello", channel::write_raw(&mut con, &hello));
                let name = rustls_pki_types::ServerName::try_from(&**name)
                    .context("creating rustls servername")?
                    .to_owned();
                let tls = ctx.connect(name, con).await?;
                let mut con = Channel::new::<
                    ClientCtx,
                    tokio_rustls::client::TlsStream<TcpStream>,
                >(None, tls);
                match cwt!("reply", con.receive::<AuthRead>()) {
                    AuthRead::Tls => con,
                    AuthRead::Local | AuthRead::Anonymous | AuthRead::Krb5 { .. } => {
                        bail!("protocol error")
                    }
                }
            }
        };
        break Ok(con);
    }
}

type Batch = (GPooled<Vec<(usize, ToRead)>>, oneshot::Sender<Response<FromRead>>);

fn partition_publishers(m: FromRead) -> Either<FromRead, Publisher> {
    match m {
        FromRead::Publisher(p) => Either::Right(p),
        FromRead::Denied
        | FromRead::Error(_)
        | FromRead::GetChangeNr(_)
        | FromRead::List(_)
        | FromRead::ListMatching(_)
        | FromRead::Referral(_)
        | FromRead::Resolved(_)
        | FromRead::Table(_) => Either::Left(m),
    }
}

async fn connection(
    mut receiver: mpsc::UnboundedReceiver<Batch>,
    resolver: Arc<Referral>,
    desired_auth: DesiredAuth,
    tls: Option<tls::CachedConnector>,
) {
    let mut con: Option<Channel> = None;
    let mut bad_addrs: FxHashSet<SocketAddr> = HashSet::default();
    'main: loop {
        match receiver.next().await {
            None => break,
            Some((tx_batch, reply)) => {
                let mut tries: usize = 0;
                'batch: loop {
                    if tries > 3 {
                        break;
                    }
                    if tries > 1 {
                        let wait = rng().random_range(1..12);
                        time::sleep(Duration::from_secs(wait)).await
                    }
                    tries += 1;
                    let c = match con {
                        Some(ref mut c) => c,
                        None => {
                            match connect(&mut bad_addrs, &resolver, &desired_auth, &tls)
                                .await
                            {
                                Ok(c) => {
                                    con = Some(c);
                                    con.as_mut().unwrap()
                                }
                                Err(e) => {
                                    con = None;
                                    warn!(
                                        "connect_read failed: {}, {}",
                                        e,
                                        e.root_cause()
                                    );
                                    continue;
                                }
                            }
                        }
                    };
                    let mut timeout =
                        max(HELLO_TO, Duration::from_micros(tx_batch.len() as u64 * 50));
                    for (_, m) in &*tx_batch {
                        match m {
                            ToRead::List(_) | ToRead::ListMatching(_) => {
                                timeout += HELLO_TO;
                            }
                            _ => (),
                        }
                        match c.queue_send(m) {
                            Ok(()) => (),
                            Err(e) => {
                                warn!("failed to encode {:?}", e);
                                c.clear();
                                continue 'main;
                            }
                        }
                    }
                    match c.flush_timeout(timeout).await {
                        Err(e) => {
                            warn!("read connection send error: {}", e);
                            con = None;
                        }
                        Ok(()) => {
                            let mut rx_batch = RAWFROMREADPOOL.take();
                            let mut publishers = PUBLISHERPOOL.take();
                            while rx_batch.len() < tx_batch.len() {
                                let f =
                                    c.receive_batch_fn(|m| {
                                        match partition_publishers(m) {
                                            Either::Left(m) => rx_batch.push(m),
                                            Either::Right(p) => {
                                                publishers.insert(p.id, p);
                                            }
                                        }
                                    });
                                match time::timeout(timeout, f).await {
                                    Ok(Ok(())) => (),
                                    Ok(Err(e)) => {
                                        warn!("read connection failed {}", e);
                                        con = None;
                                        continue 'batch;
                                    }
                                    Err(e) => {
                                        warn!("read connection timeout: {}", e);
                                        con = None;
                                        continue 'batch;
                                    }
                                }
                            }
                            let mut result = FROMREADPOOL.take();
                            result.extend(
                                rx_batch
                                    .drain(..)
                                    .enumerate()
                                    .map(|(i, m)| (tx_batch[i].0, m)),
                            );
                            let _ = reply.send((publishers, result));
                            break;
                        }
                    }
                }
            }
        }
    }
}

#[derive(Debug, Clone)]
pub(super) struct ReadClient(mpsc::UnboundedSender<Batch>);

impl ReadClient {
    pub(super) fn new(
        resolver: Arc<Referral>,
        desired_auth: DesiredAuth,
        tls: Option<tls::CachedConnector>,
    ) -> Self {
        let (to_tx, to_rx) = mpsc::unbounded();
        task::spawn(async move {
            connection(to_rx, resolver, desired_auth, tls).await;
            info!("read task shutting down")
        });
        Self(to_tx)
    }

    pub(crate) fn send(
        &mut self,
        batch: GPooled<Vec<(usize, ToRead)>>,
    ) -> ResponseChan<FromRead> {
        let (tx, rx) = oneshot::channel();
        let _ = self.0.unbounded_send((batch, tx));
        rx
    }
}