mod bchan;
mod chan;
pub(crate) mod config;
mod tls;
use crate::consts::{ALPN_CONEC, MAX_LOOPS};
use crate::types::{tagstream::TaggedInStream, ConecConn, ConecConnError, ControlMsg, CtrlStream, OutStream};
use bchan::{BroadcastChan, BroadcastChanDriver, BroadcastChanEvent, BroadcastChanRef};
pub use chan::CoordChanError;
use chan::{CoordChan, CoordChanDriver, CoordChanEvent, CoordChanRef};
use config::CoordConfig;
use err_derive::Error;
use futures::{channel::mpsc, prelude::*};
use quinn::{
crypto::rustls::TLSError, Certificate, CertificateChain, ConnectionError, Endpoint, EndpointError, Incoming,
IncomingBiStreams, PrivateKey, RecvStream, SendStream, ServerConfig, ServerConfigBuilder,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::task::{JoinError, JoinHandle};
#[derive(Debug, Error)]
pub enum CoordError {
#[error(display = "Unexpected end of Incoming stream")]
EndOfIncomingStream,
#[error(display = "Connection error: {:?}", _0)]
Connect(#[source] ConnectionError),
#[error(display = "Error connecting control channel: {:?}", _0)]
Control(#[source] ConecConnError),
#[error(display = "Certificate: {:?}", _0)]
CertificateChain(#[source] TLSError),
#[error(display = "Binding port: {:?}", _0)]
Bind(#[source] EndpointError),
#[error(display = "Join eror: {:?}", _0)]
Join(#[source] JoinError),
#[error(display = "Error sending to driver: {:?}", _0)]
Driver(#[source] mpsc::SendError),
#[error(display = "Events channel closed")]
EventsClosed,
}
def_into_error!(CoordError);
enum CoordEvent {
Accepted(ConecConn, CtrlStream, IncomingBiStreams, String),
ChanClose(String),
BroadcastClose(String),
NewStreamReq(String, String, u64),
NewStreamRes(String, u64, Result<(SendStream, RecvStream), ConnectionError>),
NewChannelReq(String, String, u64, Vec<u8>, SocketAddr),
NewChannelRes(String, u64, SocketAddr, Vec<u8>),
NewChannelErr(String, u64),
NewBroadcastReq(String, OutStream, TaggedInStream),
BroadcastCountReq(String, String, u64),
BroadcastCountRes(String, u64, (usize, usize)),
}
struct CoordInner {
incoming: Incoming,
clients: HashMap<String, CoordChan>,
broadcasts: HashMap<String, BroadcastChan>,
ref_count: usize,
driver: Option<Waker>,
sender: mpsc::UnboundedSender<CoordEvent>,
events: mpsc::UnboundedReceiver<CoordEvent>,
}
impl CoordInner {
fn drive_accept(&mut self, cx: &mut Context) -> Result<bool, CoordError> {
let mut accepted = 0;
loop {
let conn = match self.incoming.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => Err(CoordError::EndOfIncomingStream),
Poll::Ready(Some(conn)) => Ok(conn),
}?;
let sender = self.sender.clone();
tokio::spawn(async move {
use CoordError::*;
if let Err(e) = async {
let (mut conn, mut ibi) = conn.await.map_err(Connect).map(ConecConn::new)?;
let (ctrl, peer) = conn.accept_ctrl(&mut ibi).await.map_err(Control)?;
sender
.unbounded_send(CoordEvent::Accepted(conn, ctrl, ibi, peer))
.map_err(|e| Driver(e.into_send_error()))
}
.await
{
tracing::warn!("coord drive_accept: {:?}", e);
}
});
accepted += 1;
if accepted >= MAX_LOOPS {
return Ok(true);
}
}
Ok(false)
}
fn handle_events(&mut self, cx: &mut Context) -> Result<bool, CoordError> {
use CoordEvent::*;
let mut accepted = 0;
loop {
let event = match self.events.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => Err(CoordError::EventsClosed),
Poll::Ready(Some(event)) => Ok(event),
}?;
match event {
Accepted(conn, ctrl, ibi, peer) => {
if self.clients.get(&peer).is_some() {
tokio::spawn(async move {
let mut ctrl = ctrl;
ctrl.send(ControlMsg::HelloError("name in use".to_string())).await.ok();
ctrl.finish().await.ok();
drop(ctrl);
drop(conn);
});
} else {
let (inner, sender) =
CoordChanRef::new(conn, ctrl, ibi, peer.clone(), self.sender.clone());
let driver = CoordChanDriver(inner.clone());
tokio::spawn(async move { driver.await });
self.clients.insert(peer, CoordChan { inner, sender });
}
}
ChanClose(client) => {
self.clients.remove(&client);
}
BroadcastClose(chan) => {
self.broadcasts.remove(&chan);
}
NewStreamReq(from, to, sid) => {
if let Some(c_to) = self.clients.get(&to) {
c_to.send(CoordChanEvent::NSReq(from, sid));
} else if let Some(c_from) = self.clients.get(&from) {
c_from.send(CoordChanEvent::NSErr(sid));
} else {
tracing::warn!("NSReq clients disappeared: {}:{} -> {}", from, sid, to);
}
}
NewStreamRes(to, sid, result) => {
if let Some(c_to) = self.clients.get(&to) {
c_to.send(CoordChanEvent::NSRes(sid, result));
} else {
tracing::warn!("NSRes client disappeared: {}:{}", to, sid);
}
}
NewChannelReq(from, to, sid, cert, addr) => {
if let Some(c_to) = self.clients.get(&to) {
c_to.send(CoordChanEvent::NCReq(from, sid, cert, addr));
} else if let Some(c_from) = self.clients.get(&from) {
c_from.send(CoordChanEvent::NCErr(sid));
} else {
tracing::warn!("NCReq clients disappeared: {}:{} -> {}", from, sid, to);
}
}
NewChannelRes(to, sid, addr, cert) => {
if let Some(c_to) = self.clients.get(&to) {
c_to.send(CoordChanEvent::NCRes(sid, addr, cert));
} else {
tracing::warn!("NCRes client disappeared: {}:{}", to, sid);
}
}
NewChannelErr(to, sid) => {
if let Some(c_to) = self.clients.get(&to) {
c_to.send(CoordChanEvent::NCErr(sid));
} else {
tracing::warn!("NCErr client disappeared: {}:{}", to, sid);
}
}
NewBroadcastReq(chan, send, recv) => {
if let Some(c_chan) = self.broadcasts.get(&chan) {
c_chan.send(BroadcastChanEvent::New(send, recv));
} else {
let (inner, sender) = BroadcastChanRef::new(chan.clone(), self.sender.clone(), send, recv);
let driver = BroadcastChanDriver::new(inner.clone());
tokio::spawn(async move { driver.await });
let bchan = BroadcastChan { inner, sender };
self.broadcasts.insert(chan, bchan);
};
}
BroadcastCountReq(chan, from, sid) => {
if let Some(c_chan) = self.broadcasts.get(&chan) {
c_chan.send(BroadcastChanEvent::Count(from, sid));
} else if let Some(c_from) = self.clients.get(&from) {
c_from.send(CoordChanEvent::BCErr(sid));
} else {
tracing::warn!("BCReq client disappeared: {}:{} -> {}", from, sid, chan);
}
}
BroadcastCountRes(to, sid, counts) => {
if let Some(c_to) = self.clients.get(&to) {
c_to.send(CoordChanEvent::BCRes(sid, counts));
} else {
tracing::warn!("BCRes client disappeared: {}:{}", to, sid);
}
}
};
accepted += 1;
if accepted >= MAX_LOOPS {
return Ok(true);
}
}
Ok(false)
}
fn run_driver(&mut self, cx: &mut Context) -> Result<(), CoordError> {
let mut iters = 0;
loop {
let mut keep_going = false;
keep_going |= self.drive_accept(cx)?;
keep_going |= self.handle_events(cx)?;
if !keep_going {
break;
}
iters += 1;
if iters >= MAX_LOOPS {
cx.waker().wake_by_ref();
break;
}
}
Ok(())
}
}
def_ref!(CoordInner, CoordRef, pub(self));
impl CoordRef {
fn new(incoming: Incoming) -> Self {
let (sender, events) = mpsc::unbounded();
Self(Arc::new(Mutex::new(CoordInner {
incoming,
clients: HashMap::new(),
broadcasts: HashMap::new(),
ref_count: 0,
driver: None,
sender,
events,
})))
}
}
def_driver!(pub(self), CoordRef; pub(self), CoordDriver; CoordError);
impl Drop for CoordDriver {
fn drop(&mut self) {
let mut inner = self.0.lock().unwrap();
inner.clients.clear();
inner.sender.close_channel();
inner.events.close();
}
}
pub struct Coord {
endpoint: Endpoint,
inner: CoordRef,
driver_handle: JoinHandle<Result<(), CoordError>>,
}
impl Coord {
pub(crate) fn build_config(
stateless_retry: bool,
keylog: bool,
certs: CertificateChain,
key: PrivateKey,
client_ca: Option<Certificate>,
) -> Result<ServerConfig, TLSError> {
let mut qscb = ServerConfigBuilder::new({
let mut qsc = ServerConfig::default();
qsc.crypto = tls::build_rustls_server_config(client_ca.map(|c| c.as_der().to_vec()));
qsc
});
qscb.protocols(ALPN_CONEC);
qscb.use_stateless_retry(stateless_retry);
if keylog {
qscb.enable_keylog();
}
qscb.certificate(certs, key)?;
Ok(qscb.build())
}
pub async fn new(config: CoordConfig) -> Result<Self, CoordError> {
let (cert, key) = config.cert_and_key;
let qsc = Self::build_config(config.stateless_retry, config.keylog, cert, key, config.client_ca)?;
let mut endpoint = Endpoint::builder();
endpoint.listen(qsc);
let (endpoint, incoming) = endpoint.bind(&config.laddr)?;
let inner = CoordRef::new(incoming);
let driver = CoordDriver(inner.clone());
let driver_handle = tokio::spawn(async move { driver.await });
Ok(Self {
endpoint,
inner,
driver_handle,
})
}
pub fn num_clients(&self) -> usize {
let inner = self.inner.lock().unwrap();
inner.clients.len()
}
pub fn num_broadcasts(&self) -> usize {
let inner = self.inner.lock().unwrap();
inner.broadcasts.len()
}
pub fn local_addr(&self) -> std::net::SocketAddr {
self.endpoint.local_addr().unwrap()
}
}
def_flat_future!(Coord, (), CoordError, Join, driver_handle);