use std::{
collections::HashMap,
fmt,
future::Future,
mem,
net::SocketAddr,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
time::Instant,
};
use bytes::Bytes;
use err_derive::Error;
use futures::{
channel::{mpsc, oneshot},
FutureExt, StreamExt,
};
use proto::{ConnectionError, ConnectionHandle, ConnectionId, Dir, StreamId, TimerUpdate};
use tokio::time::{delay_until, Delay, Instant as TokioInstant};
use tracing::{info_span, trace};
use crate::{
broadcast::{self, Broadcast},
streams::{RecvStream, SendStream, WriteError},
ConnectionEvent, EndpointEvent, VarInt,
};
pub struct Connecting(Option<ConnectionDriver>);
impl Connecting {
pub(crate) fn new(conn: ConnectionRef) -> Self {
Self(Some(ConnectionDriver(conn)))
}
pub fn into_0rtt(mut self) -> Result<(NewConnection, ZeroRttAccepted), Self> {
let mut conn = (self.0.as_mut().unwrap().0).lock().unwrap();
if conn.inner.has_0rtt() || conn.inner.side().is_server() {
let (send, recv) = oneshot::channel();
if conn.connected {
send.send(true).unwrap();
} else {
conn.on_connected = Some(send);
}
drop(conn);
let ConnectionDriver(conn) = self.0.take().unwrap();
Ok((NewConnection::new(conn), ZeroRttAccepted(recv)))
} else {
drop(conn);
Err(self)
}
}
}
impl Future for Connecting {
type Output = Result<NewConnection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let connected = match self.0 {
Some(ref mut driver) => {
let r = driver.poll_unpin(cx)?;
let driver = self.0.as_mut().unwrap().0.lock().unwrap();
match r {
Poll::Ready(()) => {
return Poll::Ready(Err(driver.error.as_ref().unwrap().clone()));
}
Poll::Pending => driver.connected,
}
}
None => panic!("polled after yielding Ready"),
};
if connected {
let ConnectionDriver(conn) = self.0.take().unwrap();
Poll::Ready(Ok(NewConnection::new(conn)))
} else {
Poll::Pending
}
}
}
impl Connecting {
pub fn remote_address(&self) -> SocketAddr {
let conn_ref: &ConnectionRef = &self.0.as_ref().expect("used after yielding Ready").0;
conn_ref.lock().unwrap().inner.remote()
}
}
pub struct ZeroRttAccepted(oneshot::Receiver<bool>);
impl Future for ZeroRttAccepted {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll_unpin(cx).map(|x| x.unwrap_or(false))
}
}
#[derive(Debug)]
pub struct NewConnection {
pub driver: ConnectionDriver,
pub connection: Connection,
pub uni_streams: IncomingUniStreams,
pub bi_streams: IncomingBiStreams,
pub datagrams: Datagrams,
_non_exhaustive: (),
}
impl NewConnection {
fn new(conn: ConnectionRef) -> Self {
Self {
driver: ConnectionDriver(conn.clone()),
connection: Connection(conn.clone()),
uni_streams: IncomingUniStreams(conn.clone()),
bi_streams: IncomingBiStreams(conn.clone()),
datagrams: Datagrams(conn),
_non_exhaustive: (),
}
}
}
#[must_use = "connection drivers must be spawned for their connections to function"]
#[derive(Debug)]
pub struct ConnectionDriver(pub(crate) ConnectionRef);
impl Future for ConnectionDriver {
type Output = Result<(), ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let conn = &mut *self.0.lock().unwrap();
let span = info_span!("drive", id = conn.handle.0);
let _guard = span.enter();
loop {
let now = Instant::now();
let mut keep_going = false;
if let Err(e) = conn.process_conn_events(cx) {
conn.terminate(e.clone());
return Poll::Ready(Err(e));
}
conn.drive_transmit(now);
keep_going |= conn.drive_timers(cx, now);
keep_going |= conn.handle_timer_updates();
conn.forward_endpoint_events();
conn.forward_app_events();
if !keep_going || conn.inner.is_drained() {
break;
}
}
if !conn.inner.is_drained() {
conn.driver = Some(cx.waker().clone());
return Poll::Pending;
}
match conn.error {
Some(ConnectionError::LocallyClosed) => Poll::Ready(Ok(())),
Some(ref e) => Poll::Ready(Err(e.clone())),
None => unreachable!("drained connections always have an error"),
}
}
}
#[derive(Clone, Debug)]
pub struct Connection(ConnectionRef);
impl Connection {
pub fn open_uni(&self) -> OpenUni {
OpenUni {
conn: self.0.clone(),
state: broadcast::State::default(),
}
}
pub fn open_bi(&self) -> OpenBi {
OpenBi {
conn: self.0.clone(),
state: broadcast::State::default(),
}
}
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
let conn = &mut *self.0.lock().unwrap();
conn.close(error_code, Bytes::copy_from_slice(reason));
}
pub fn send_datagram(&self, data: Bytes) -> SendDatagram<'_> {
SendDatagram {
conn: &self.0,
data,
state: broadcast::State::default(),
}
}
pub fn send_datagram_ready(&self) -> SendDatagramReady<'_> {
SendDatagramReady {
conn: &self.0,
state: broadcast::State::default(),
}
}
pub fn max_datagram_size(&self) -> Option<usize> {
self.0.lock().unwrap().inner.max_datagram_size()
}
pub fn remote_address(&self) -> SocketAddr {
self.0.lock().unwrap().inner.remote()
}
pub fn remote_id(&self) -> ConnectionId {
self.0.lock().unwrap().inner.rem_cid()
}
pub fn protocol(&self) -> Option<Box<[u8]>> {
self.0.lock().unwrap().inner.protocol().map(|x| x.into())
}
#[doc(hidden)]
pub fn force_key_update(&self) {
self.0.lock().unwrap().inner.initiate_key_update()
}
}
#[derive(Debug)]
pub struct IncomingUniStreams(ConnectionRef);
impl futures::Stream for IncomingUniStreams {
type Item = Result<RecvStream, ConnectionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut conn = self.0.lock().unwrap();
if let Some(x) = conn.inner.accept(Dir::Uni) {
conn.wake();
mem::drop(conn);
Poll::Ready(Some(Ok(RecvStream::new(self.0.clone(), x, false))))
} else if let Some(ConnectionError::LocallyClosed) = conn.error {
Poll::Ready(None)
} else if let Some(ref e) = conn.error {
Poll::Ready(Some(Err(e.clone())))
} else {
conn.incoming_uni_streams_reader = Some(cx.waker().clone());
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct IncomingBiStreams(ConnectionRef);
impl futures::Stream for IncomingBiStreams {
type Item = Result<(SendStream, RecvStream), ConnectionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut conn = self.0.lock().unwrap();
if let Some(x) = conn.inner.accept(Dir::Bi) {
conn.wake();
mem::drop(conn);
Poll::Ready(Some(Ok((
SendStream::new(self.0.clone(), x, false),
RecvStream::new(self.0.clone(), x, false),
))))
} else if let Some(ConnectionError::LocallyClosed) = conn.error {
Poll::Ready(None)
} else if let Some(ref e) = conn.error {
Poll::Ready(Some(Err(e.clone())))
} else {
conn.incoming_bi_streams_reader = Some(cx.waker().clone());
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct Datagrams(ConnectionRef);
impl futures::Stream for Datagrams {
type Item = Result<Bytes, ConnectionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut conn = self.0.lock().unwrap();
if let Some(x) = conn.inner.recv_datagram() {
Poll::Ready(Some(Ok(x)))
} else if let Some(ConnectionError::LocallyClosed) = conn.error {
Poll::Ready(None)
} else if let Some(ref e) = conn.error {
Poll::Ready(Some(Err(e.clone())))
} else {
conn.datagram_reader = Some(cx.waker().clone());
Poll::Pending
}
}
}
pub struct OpenUni {
conn: ConnectionRef,
state: broadcast::State,
}
impl Future for OpenUni {
type Output = Result<SendStream, ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut conn = this.conn.lock().unwrap();
if let Some(ref e) = conn.error {
return Poll::Ready(Err(e.clone()));
}
if let Some(id) = conn.inner.open(Dir::Uni) {
let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking();
drop(conn);
return Poll::Ready(Ok(SendStream::new(this.conn.clone(), id, is_0rtt)));
}
conn.uni_opening.register(cx, &mut this.state);
Poll::Pending
}
}
pub struct OpenBi {
conn: ConnectionRef,
state: broadcast::State,
}
impl Future for OpenBi {
type Output = Result<(SendStream, RecvStream), ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut conn = this.conn.lock().unwrap();
if let Some(ref e) = conn.error {
return Poll::Ready(Err(e.clone()));
}
if let Some(id) = conn.inner.open(Dir::Bi) {
let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking();
drop(conn);
return Poll::Ready(Ok((
SendStream::new(this.conn.clone(), id, is_0rtt),
RecvStream::new(this.conn.clone(), id, is_0rtt),
)));
}
conn.bi_opening.register(cx, &mut this.state);
Poll::Pending
}
}
pub struct SendDatagramReady<'a> {
conn: &'a ConnectionRef,
state: broadcast::State,
}
impl<'a> Future for SendDatagramReady<'a> {
type Output = Result<(), SendDatagramError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut conn = this.conn.lock().unwrap();
if let Some(ref e) = conn.error {
return Poll::Ready(Err(SendDatagramError::ConnectionClosed(e.clone())));
}
match conn.inner.send_datagram() {
Ok(_) => Poll::Ready(Ok(())),
Err(e) => conn.handle_datagram_err(cx, &mut this.state, e),
}
}
}
pub struct SendDatagram<'a> {
conn: &'a ConnectionRef,
data: Bytes,
state: broadcast::State,
}
impl<'a> Future for SendDatagram<'a> {
type Output = Result<(), SendDatagramError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut conn = this.conn.lock().unwrap();
if let Some(ref e) = conn.error {
return Poll::Ready(Err(SendDatagramError::ConnectionClosed(e.clone())));
}
match conn.inner.send_datagram() {
Ok(sender) => match sender.send(mem::replace(&mut this.data, Bytes::new())) {
Ok(()) => {
conn.wake();
Poll::Ready(Ok(()))
}
Err(proto::DatagramTooLarge) => Poll::Ready(Err(SendDatagramError::TooLarge)),
},
Err(e) => conn.handle_datagram_err(cx, &mut this.state, e),
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum SendDatagramError {
#[error(display = "connection closed: {}", 0)]
ConnectionClosed(ConnectionError),
#[error(display = "datagram too large")]
TooLarge,
#[error(display = "datagrams not supported by peer")]
UnsupportedByPeer,
#[error(display = "datagram support disabled")]
Disabled,
}
#[derive(Debug)]
pub struct ConnectionRef(Arc<Mutex<ConnectionInner>>);
impl ConnectionRef {
pub(crate) fn new(
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
) -> Self {
Self(Arc::new(Mutex::new(ConnectionInner {
epoch: Instant::now(),
inner: conn,
driver: None,
handle,
on_connected: None,
connected: false,
timers: Default::default(),
conn_events,
endpoint_events,
blocked_writers: HashMap::new(),
blocked_readers: HashMap::new(),
uni_opening: Broadcast::new(),
bi_opening: Broadcast::new(),
incoming_uni_streams_reader: None,
incoming_bi_streams_reader: None,
datagram_reader: None,
finishing: HashMap::new(),
error: None,
ref_count: 0,
send_datagram_blocked: Broadcast::new(),
})))
}
}
impl Clone for ConnectionRef {
fn clone(&self) -> Self {
self.0.lock().unwrap().ref_count += 1;
Self(self.0.clone())
}
}
impl Drop for ConnectionRef {
fn drop(&mut self) {
let conn = &mut *self.0.lock().unwrap();
if let Some(x) = conn.ref_count.checked_sub(1) {
conn.ref_count = x;
if x == 0 && !conn.inner.is_closed() {
conn.implicit_close();
}
}
}
}
impl std::ops::Deref for ConnectionRef {
type Target = Mutex<ConnectionInner>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct ConnectionInner {
epoch: Instant,
pub(crate) inner: proto::Connection,
driver: Option<Waker>,
handle: ConnectionHandle,
on_connected: Option<oneshot::Sender<bool>>,
connected: bool,
timers: proto::TimerTable<Option<Delay>>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
pub(crate) blocked_writers: HashMap<StreamId, Waker>,
pub(crate) blocked_readers: HashMap<StreamId, Waker>,
uni_opening: Broadcast,
bi_opening: Broadcast,
incoming_uni_streams_reader: Option<Waker>,
incoming_bi_streams_reader: Option<Waker>,
datagram_reader: Option<Waker>,
pub(crate) finishing: HashMap<StreamId, oneshot::Sender<Option<WriteError>>>,
pub(crate) error: Option<ConnectionError>,
ref_count: usize,
send_datagram_blocked: Broadcast,
}
impl ConnectionInner {
fn drive_transmit(&mut self, now: Instant) {
while let Some(t) = self.inner.poll_transmit(now) {
let _ = self
.endpoint_events
.unbounded_send((self.handle, EndpointEvent::Transmit(t)));
}
}
fn forward_endpoint_events(&mut self) {
while let Some(event) = self.inner.poll_endpoint_events() {
let _ = self
.endpoint_events
.unbounded_send((self.handle, EndpointEvent::Proto(event)));
}
}
fn process_conn_events(&mut self, cx: &mut Context) -> Result<(), ConnectionError> {
loop {
match self.conn_events.poll_next_unpin(cx) {
Poll::Ready(Some(ConnectionEvent::Proto(event))) => {
self.inner.handle_event(event);
}
Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => {
self.close(error_code, reason);
}
Poll::Ready(None) => {
return Err(ConnectionError::TransportError(proto::TransportError {
code: proto::TransportErrorCode::INTERNAL_ERROR,
frame: None,
reason: "endpoint driver future was dropped".to_string(),
}));
}
Poll::Pending => {
return Ok(());
}
}
}
}
fn forward_app_events(&mut self) {
while let Some(event) = self.inner.poll() {
use proto::Event::*;
match event {
Connected { .. } => {
self.connected = true;
if let Some(x) = self.on_connected.take() {
let _ = x.send(self.inner.accepted_0rtt());
}
}
ConnectionLost { reason } => {
self.terminate(reason);
}
StreamWritable { stream } => {
if let Some(writer) = self.blocked_writers.remove(&stream) {
writer.wake();
}
}
StreamOpened { dir: Dir::Uni } => {
if let Some(x) = self.incoming_uni_streams_reader.take() {
x.wake();
}
}
StreamOpened { dir: Dir::Bi } => {
if let Some(x) = self.incoming_bi_streams_reader.take() {
x.wake();
}
}
DatagramReceived => {
if let Some(x) = self.datagram_reader.take() {
x.wake();
}
}
StreamReadable { stream } => {
if let Some(reader) = self.blocked_readers.remove(&stream) {
reader.wake();
}
}
StreamAvailable { dir } => {
let tasks = match dir {
Dir::Uni => &mut self.uni_opening,
Dir::Bi => &mut self.bi_opening,
};
tasks.wake();
}
StreamFinished {
stream,
stop_reason,
} => {
if let Some(finishing) = self.finishing.remove(&stream) {
let _ = finishing
.send(stop_reason.map(|e| WriteError::Stopped { error_code: e }));
}
}
DatagramSendUnblocked => {
self.send_datagram_blocked.wake();
}
}
}
}
fn drive_timers(&mut self, cx: &mut Context, now: Instant) -> bool {
let mut keep_going = false;
for (timer, slot) in &mut self.timers {
if let Some(ref mut delay) = slot {
match delay.poll_unpin(cx) {
Poll::Ready(()) => {
*slot = None;
trace!("{:?} timeout", timer);
self.inner.handle_timeout(now, timer);
keep_going = true;
}
Poll::Pending => {}
}
}
}
keep_going
}
fn handle_timer_updates(&mut self) -> bool {
let mut keep_going = false;
while let Some(update) = self.inner.poll_timers() {
keep_going = true;
match update {
TimerUpdate {
timer,
update: proto::TimerSetting::Start(time),
} => match self.timers[timer] {
ref mut x @ None => {
trace!(time = ?time.duration_since(self.epoch), "{:?} timer start", timer);
*x = Some(delay_until(TokioInstant::from_std(time)));
}
Some(ref mut x) => {
trace!(time = ?time.duration_since(self.epoch), "{:?} timer reset", timer);
x.reset(TokioInstant::from_std(time));
}
},
TimerUpdate {
timer,
update: proto::TimerSetting::Stop,
} => {
if self.timers[timer].take().is_some() {
trace!("{:?} timer stop", timer);
}
}
}
}
keep_going
}
pub(crate) fn wake(&mut self) {
if let Some(x) = self.driver.take() {
x.wake();
}
}
fn terminate(&mut self, reason: ConnectionError) {
self.error = Some(reason.clone());
for (_, writer) in self.blocked_writers.drain() {
writer.wake()
}
for (_, reader) in self.blocked_readers.drain() {
reader.wake()
}
self.uni_opening.wake();
self.bi_opening.wake();
if let Some(x) = self.incoming_uni_streams_reader.take() {
x.wake();
}
if let Some(x) = self.incoming_bi_streams_reader.take() {
x.wake();
}
if let Some(x) = self.datagram_reader.take() {
x.wake();
}
for (_, x) in self.finishing.drain() {
let _ = x.send(Some(WriteError::ConnectionClosed(reason.clone())));
}
self.send_datagram_blocked.wake();
if let Some(x) = self.on_connected.take() {
let _ = x.send(false);
}
}
fn close(&mut self, error_code: VarInt, reason: Bytes) {
self.inner.close(Instant::now(), error_code, reason);
self.terminate(ConnectionError::LocallyClosed);
self.wake();
}
pub fn implicit_close(&mut self) {
self.close(0u32.into(), Bytes::new());
}
pub(crate) fn check_0rtt(&self) -> Result<(), ()> {
if self.inner.is_handshaking() || self.inner.accepted_0rtt() {
Ok(())
} else {
Err(())
}
}
fn handle_datagram_err(
&mut self,
cx: &mut Context,
state: &mut broadcast::State,
e: proto::SendDatagramError,
) -> Poll<Result<(), SendDatagramError>> {
match e {
proto::SendDatagramError::Blocked => {
self.send_datagram_blocked.register(cx, state);
Poll::Pending
}
proto::SendDatagramError::UnsupportedByPeer => {
Poll::Ready(Err(SendDatagramError::UnsupportedByPeer))
}
proto::SendDatagramError::Disabled => Poll::Ready(Err(SendDatagramError::Disabled)),
}
}
}
impl Drop for ConnectionInner {
fn drop(&mut self) {
if !self.inner.is_drained() {
let _ = self.endpoint_events.unbounded_send((
self.handle,
EndpointEvent::Proto(proto::EndpointEvent::drained()),
));
}
}
}
impl fmt::Debug for ConnectionInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ConnectionInner")
.field("inner", &self.inner)
.finish()
}
}