use super::CoordEvent;
use crate::consts::{MAX_LOOPS, STRICT_CTRL};
use crate::types::{
outstream_init, tagstream::TaggedInStream, ConecConn, ControlMsg, CtrlStream, InStream, OutStream, StreamTo,
};
use crate::util;
use err_derive::Error;
use futures::{channel::mpsc, prelude::*};
use quinn::{ConnectionError, IncomingBiStreams, RecvStream, SendStream};
use std::collections::{HashMap, HashSet, VecDeque};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio_serde::{formats::SymmetricalBincode, SymmetricallyFramed};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
#[derive(Debug, Error)]
pub enum CoordChanError {
#[error(display = "Unknown streamid")]
UnknownSid,
#[error(display = "Peer closed connection")]
PeerClosed,
#[error(display = "Stream poll: {:?}", _0)]
StreamPoll(#[error(source, no_from)] io::Error),
#[error(display = "Control sink: {:?}", _0)]
Sink(#[error(source, no_from)] util::SinkError),
#[error(display = "Unexpected message from coordinator")]
WrongMessage(ControlMsg),
#[error(display = "Sending CoordEvent: {:?}", _0)]
SendCoordEvent(#[source] mpsc::SendError),
#[error(display = "Unexpected end of Bi stream")]
EndOfBiStream,
#[error(display = "Accepting Bi stream: {:?}", _0)]
AcceptBiStream(#[source] ConnectionError),
#[error(display = "Opening Bi stream: {:?}", _0)]
OpenBiStream(#[error(no_from, source)] ConnectionError),
#[error(display = "Events channel closed")]
EventsClosed,
}
def_into_error!(CoordChanError);
pub(super) struct CoordChanInner {
conn: ConecConn,
ctrl: CtrlStream,
ibi: IncomingBiStreams,
peer: String,
coord: mpsc::UnboundedSender<CoordEvent>,
sender: mpsc::UnboundedSender<CoordChanEvent>,
events: mpsc::UnboundedReceiver<CoordChanEvent>,
ref_count: usize,
driver: Option<Waker>,
to_send: VecDeque<ControlMsg>,
flushing: bool,
new_streams: HashMap<u64, (SendStream, RecvStream)>,
new_broadcasts: HashMap<u64, String>,
sids: HashSet<u64>,
}
#[allow(clippy::large_enum_variant)]
pub(super) enum CoordChanEvent {
NSErr(u64),
NSReq(String, u64),
NSRes(u64, Result<(SendStream, RecvStream), ConnectionError>),
NCErr(u64),
NCReq(String, u64, Vec<u8>, SocketAddr),
NCRes(u64, SocketAddr, Vec<u8>),
BiIn(StreamTo, OutStream, InStream),
BCErr(u64),
BCRes(u64, (usize, usize)),
}
impl CoordChanInner {
fn drive_ctrl_recv(&mut self, cx: &mut Context) -> Result<bool, CoordChanError> {
let mut recvd = 0;
loop {
let msg = match self.ctrl.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => Err(CoordChanError::PeerClosed),
Poll::Ready(Some(Err(e))) => Err(CoordChanError::StreamPoll(e)),
Poll::Ready(Some(Ok(msg))) => Ok(msg),
}?;
match msg {
ControlMsg::NewStreamReq(to, sid) => self
.coord
.unbounded_send(CoordEvent::NewStreamReq(self.peer.clone(), to, sid))
.map_err(|e| CoordChanError::SendCoordEvent(e.into_send_error())),
ControlMsg::NewChannelReq(to, sid) => self
.coord
.unbounded_send(CoordEvent::NewChannelReq(
self.peer.clone(),
to,
sid,
self.conn.get_cert_bytes().to_vec(),
self.conn.remote_addr(),
))
.map_err(|e| CoordChanError::SendCoordEvent(e.into_send_error())),
ControlMsg::CertOk(to, sid) => {
let addr = self.conn.remote_addr();
let cert = self.conn.get_cert_bytes().to_vec();
self.coord
.unbounded_send(CoordEvent::NewChannelRes(to, sid, addr, cert))
.map_err(|e| CoordChanError::SendCoordEvent(e.into_send_error()))
}
ControlMsg::CertNok(to, sid) => self
.coord
.unbounded_send(CoordEvent::NewChannelErr(to, sid))
.map_err(|e| CoordChanError::SendCoordEvent(e.into_send_error())),
ControlMsg::NewBroadcastReq(chan, sid) => {
self.to_send.push_back(ControlMsg::NewBroadcastOk(sid));
self.new_broadcasts.insert(sid, chan);
Ok(())
}
ControlMsg::BroadcastCountReq(chan, sid) => self
.coord
.unbounded_send(CoordEvent::BroadcastCountReq(chan, self.peer.clone(), sid))
.map_err(|e| CoordChanError::SendCoordEvent(e.into_send_error())),
ControlMsg::KeepAlive => {
self.to_send.push_back(ControlMsg::KeepAlive);
Ok(())
}
_ => {
let err = CoordChanError::WrongMessage(msg);
if STRICT_CTRL {
Err(err)
} else {
tracing::warn!("CoordChanInner::drive_ctrl_recv: {:?}", err);
Ok(())
}
}
}?;
recvd += 1;
if recvd >= MAX_LOOPS {
return Ok(true);
}
}
Ok(false)
}
fn drive_ctrl_send(&mut self, cx: &mut Context) -> Result<bool, CoordChanError> {
util::drive_ctrl_send(cx, &mut self.flushing, &mut self.ctrl, &mut self.to_send)
.map_err(CoordChanError::Sink)
}
fn handle_events(&mut self, cx: &mut Context) -> Result<bool, CoordChanError> {
use CoordChanEvent::*;
let mut accepted = 0;
loop {
let event = match self.events.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => Err(CoordChanError::EventsClosed),
Poll::Ready(Some(event)) => Ok(event),
}?;
match event {
NSErr(sid) => self.to_send.push_back(ControlMsg::NewStreamErr(sid)),
NSRes(sid, res) => match res {
Err(_) => {
self.to_send.push_back(ControlMsg::NewStreamErr(sid));
}
Ok(send_recv) => {
self.to_send.push_back(ControlMsg::NewStreamOk(sid));
self.new_streams.insert(sid, send_recv);
}
},
NSReq(to, sid) => {
let coord = self.coord.clone();
let bi = self.conn.open_bi();
tokio::spawn(async move {
coord.unbounded_send(CoordEvent::NewStreamRes(to, sid, bi.await)).ok();
});
}
NCErr(sid) => self.to_send.push_back(ControlMsg::NewChannelErr(sid)),
NCReq(to, sid, cert, addr) => self.to_send.push_back(ControlMsg::CertReq(to, sid, cert, addr)),
NCRes(sid, addr, cert) => self.to_send.push_back(ControlMsg::NewChannelOk(sid, addr, cert)),
BiIn(sid, n_send, n_recv) => match sid {
StreamTo::Broadcast(sid) => {
let tagged_recv = TaggedInStream::new(n_recv, self.peer.clone());
let res = self
.new_broadcasts
.remove(&sid)
.ok_or(CoordChanError::UnknownSid)
.and_then(|chan| {
self.coord
.unbounded_send(CoordEvent::NewBroadcastReq(chan, n_send, tagged_recv))
.map_err(|e| CoordChanError::SendCoordEvent(e.into_send_error()))
});
if res.is_err() {
self.to_send.push_back(ControlMsg::NewBroadcastErr(sid));
}
}
StreamTo::Client(sid) => {
if let Some((send, recv)) = self.new_streams.remove(&sid) {
let from = self.peer.clone();
tokio::spawn(async move {
let send = match outstream_init(send, from, sid).await {
Ok(send) => send,
Err(e) => {
tracing::warn!("CoordChan::handle_events: BiIn: Client: {:?}", e);
return;
}
};
let recv = FramedRead::new(recv, LengthDelimitedCodec::new());
let fw1 = n_recv.map(|b| b.map(|bb| bb.freeze())).forward(send);
let fw2 = recv.map(|b| b.map(|bb| bb.freeze())).forward(n_send);
let (sf, rf) = futures::future::join(fw1, fw2).await;
sf.ok();
rf.ok();
});
} else {
self.to_send.push_back(ControlMsg::NewStreamErr(sid));
}
}
},
BCErr(sid) => self.to_send.push_back(ControlMsg::BroadcastCountErr(sid)),
BCRes(sid, counts) => self.to_send.push_back(ControlMsg::BroadcastCountRes(sid, counts)),
};
accepted += 1;
if accepted >= MAX_LOOPS {
return Ok(true);
}
}
Ok(false)
}
fn drive_ibi_recv(&mut self, cx: &mut Context) -> Result<bool, CoordChanError> {
let mut recvd = 0;
loop {
let (send, recv) = match self.ibi.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => Err(CoordChanError::EndOfBiStream),
Poll::Ready(Some(r)) => r.map_err(CoordChanError::AcceptBiStream),
}?;
let sender = self.sender.clone();
tokio::spawn(async move {
let mut read_stream = SymmetricallyFramed::new(
FramedRead::new(recv, LengthDelimitedCodec::new()),
SymmetricalBincode::<StreamTo>::default(),
);
let sid = match read_stream.try_next().await {
Err(e) => {
tracing::warn!("drive_ibi_recv: {:?}", e);
return;
}
Ok(None) => {
tracing::warn!("drive_ibi_recv: unexpected end of stream");
return;
}
Ok(Some(sid)) => sid,
};
let recv = read_stream.into_inner();
let send = FramedWrite::new(send, LengthDelimitedCodec::new());
sender.unbounded_send(CoordChanEvent::BiIn(sid, send, recv)).ok();
});
recvd += 1;
if recvd >= MAX_LOOPS {
return Ok(true);
}
}
Ok(false)
}
fn run_driver(&mut self, cx: &mut Context) -> Result<(), CoordChanError> {
let mut iters = 0;
loop {
let mut keep_going = false;
keep_going |= self.drive_ctrl_recv(cx)?;
keep_going |= self.handle_events(cx)?;
if !self.to_send.is_empty() || self.flushing {
keep_going |= self.drive_ctrl_send(cx)?;
}
keep_going |= self.drive_ibi_recv(cx)?;
if !keep_going {
break;
}
iters += 1;
if iters >= MAX_LOOPS {
cx.waker().wake_by_ref();
break;
}
}
Ok(())
}
}
def_ref!(CoordChanInner, CoordChanRef);
impl CoordChanRef {
pub(super) fn new(
conn: ConecConn,
ctrl: CtrlStream,
ibi: IncomingBiStreams,
peer: String,
coord: mpsc::UnboundedSender<CoordEvent>,
) -> (Self, mpsc::UnboundedSender<CoordChanEvent>) {
let mut to_send = VecDeque::new();
to_send.push_back(ControlMsg::CoHello);
let (sender, events) = mpsc::unbounded();
(
Self(Arc::new(Mutex::new(CoordChanInner {
conn,
ctrl,
ibi,
peer,
coord,
sender: sender.clone(),
events,
ref_count: 0,
driver: None,
to_send,
flushing: false,
new_streams: HashMap::new(),
new_broadcasts: HashMap::new(),
sids: HashSet::new(),
}))),
sender,
)
}
}
def_driver!(CoordChanRef, CoordChanDriver, CoordChanError);
impl Drop for CoordChanDriver {
fn drop(&mut self) {
let mut inner = self.0.lock().unwrap();
inner
.coord
.unbounded_send(CoordEvent::ChanClose(inner.peer.clone()))
.ok();
inner.conn.close(b"coord driver died");
inner.coord.disconnect();
inner.sender.close_channel();
inner.events.close();
inner.to_send.clear();
inner.new_streams.clear();
inner.sids.clear();
}
}
pub(super) struct CoordChan {
#[allow(dead_code)]
pub(super) inner: CoordChanRef,
pub(super) sender: mpsc::UnboundedSender<CoordChanEvent>,
}
impl CoordChan {
pub(super) fn send(&self, event: CoordChanEvent) {
self.sender.unbounded_send(event).ok();
}
}