use std::{
collections::HashMap,
fmt,
future::Future,
mem,
net::SocketAddr,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
time::Instant,
};
use bytes::Bytes;
use err_derive::Error;
use futures::{
channel::{mpsc, oneshot},
FutureExt, StreamExt,
};
use proto::{ConnectionError, ConnectionHandle, Dir, StreamId};
use tokio::time::{delay_until, Delay, Instant as TokioInstant};
use tracing::info_span;
use crate::{
broadcast::{self, Broadcast},
streams::{RecvStream, SendStream, WriteError},
ConnectionEvent, EndpointEvent, VarInt,
};
#[derive(Debug)]
pub struct Connecting<S>
where
S: proto::crypto::Session,
{
conn: Option<ConnectionRef<S>>,
connected: oneshot::Receiver<bool>,
}
impl<S> Connecting<S>
where
S: proto::crypto::Session,
{
pub(crate) fn new(conn: ConnectionRef<S>, connected: oneshot::Receiver<bool>) -> Self {
Self {
conn: Some(conn),
connected,
}
}
pub fn into_0rtt(mut self) -> Result<(NewConnection<S>, ZeroRttAccepted), Self> {
let conn = (self.conn.as_mut().unwrap().0).lock().unwrap();
if conn.inner.has_0rtt() || conn.inner.side().is_server() {
drop(conn);
let conn = self.conn.take().unwrap();
Ok((NewConnection::new(conn), ZeroRttAccepted(self.connected)))
} else {
drop(conn);
Err(self)
}
}
pub fn authentication_data(&self) -> S::AuthenticationData {
(self.conn.as_ref().unwrap().0)
.lock()
.unwrap()
.inner
.crypto_session()
.authentication_data()
}
}
impl<S> Future for Connecting<S>
where
S: proto::crypto::Session,
{
type Output = Result<NewConnection<S>, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.connected.poll_unpin(cx).map(|_| {
let conn = self.conn.take().unwrap();
let inner = conn.lock().unwrap();
if inner.connected {
drop(inner);
Ok(NewConnection::new(conn))
} else {
Err(inner
.error
.clone()
.expect("connected signaled without connection success or error"))
}
})
}
}
impl<S> Connecting<S>
where
S: proto::crypto::Session,
{
pub fn remote_address(&self) -> SocketAddr {
let conn_ref: &ConnectionRef<S> = &self.conn.as_ref().expect("used after yielding Ready");
conn_ref.lock().unwrap().inner.remote_address()
}
}
pub struct ZeroRttAccepted(oneshot::Receiver<bool>);
impl Future for ZeroRttAccepted {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll_unpin(cx).map(|x| x.unwrap_or(false))
}
}
#[cfg_attr(
feature = "rustls",
doc = "```rust
# use quinn::NewConnection;
# fn dummy(new_connection: NewConnection) {
let NewConnection { connection, .. } = { new_connection };
# }
```"
)]
#[derive(Debug)]
#[non_exhaustive]
pub struct NewConnection<S>
where
S: proto::crypto::Session,
{
pub connection: Connection<S>,
pub uni_streams: IncomingUniStreams<S>,
pub bi_streams: IncomingBiStreams<S>,
pub datagrams: Datagrams<S>,
}
impl<S> NewConnection<S>
where
S: proto::crypto::Session,
{
fn new(conn: ConnectionRef<S>) -> Self {
Self {
connection: Connection(conn.clone()),
uni_streams: IncomingUniStreams(conn.clone()),
bi_streams: IncomingBiStreams(conn.clone()),
datagrams: Datagrams(conn),
}
}
}
#[must_use = "connection drivers must be spawned for their connections to function"]
#[derive(Debug)]
pub(crate) struct ConnectionDriver<S: proto::crypto::Session>(pub(crate) ConnectionRef<S>);
impl<S> Future for ConnectionDriver<S>
where
S: proto::crypto::Session,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let conn = &mut *self.0.lock().unwrap();
let span = info_span!("drive", id = conn.handle.0);
let _guard = span.enter();
loop {
let mut keep_going = false;
if let Err(e) = conn.process_conn_events(cx) {
conn.terminate(e);
return Poll::Ready(());
}
conn.drive_transmit();
keep_going |= conn.drive_timer(cx);
conn.forward_endpoint_events();
conn.forward_app_events();
if !keep_going || conn.inner.is_drained() {
break;
}
}
if !conn.inner.is_drained() {
conn.driver = Some(cx.waker().clone());
return Poll::Pending;
}
match conn.error {
Some(ConnectionError::LocallyClosed) => Poll::Ready(()),
Some(_) => Poll::Ready(()),
None => unreachable!("drained connections always have an error"),
}
}
}
#[derive(Debug)]
pub struct Connection<S: proto::crypto::Session>(ConnectionRef<S>);
impl<S> Connection<S>
where
S: proto::crypto::Session,
{
pub fn open_uni(&self) -> OpenUni<S> {
OpenUni {
conn: self.0.clone(),
state: broadcast::State::default(),
}
}
pub fn open_bi(&self) -> OpenBi<S> {
OpenBi {
conn: self.0.clone(),
state: broadcast::State::default(),
}
}
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
let conn = &mut *self.0.lock().unwrap();
conn.close(error_code, Bytes::copy_from_slice(reason));
}
pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
let conn = &mut *self.0.lock().unwrap();
if let Some(ref x) = conn.error {
return Err(SendDatagramError::ConnectionClosed(x.clone()));
}
use proto::SendDatagramError::*;
match conn.inner.send_datagram(data) {
Ok(()) => {
conn.wake();
Ok(())
}
Err(e) => Err(match e {
UnsupportedByPeer => SendDatagramError::UnsupportedByPeer,
Disabled => SendDatagramError::Disabled,
TooLarge => SendDatagramError::TooLarge,
}),
}
}
pub fn max_datagram_size(&self) -> Option<usize> {
self.0.lock().unwrap().inner.max_datagram_size()
}
pub fn remote_address(&self) -> SocketAddr {
self.0.lock().unwrap().inner.remote_address()
}
pub fn authentication_data(&self) -> S::AuthenticationData {
self.0
.lock()
.unwrap()
.inner
.crypto_session()
.authentication_data()
}
#[doc(hidden)]
pub fn force_key_update(&self) {
self.0.lock().unwrap().inner.initiate_key_update()
}
}
impl<S> Clone for Connection<S>
where
S: proto::crypto::Session,
{
fn clone(&self) -> Self {
Connection(self.0.clone())
}
}
#[derive(Debug)]
pub struct IncomingUniStreams<S: proto::crypto::Session>(ConnectionRef<S>);
impl<S> futures::Stream for IncomingUniStreams<S>
where
S: proto::crypto::Session,
{
type Item = Result<RecvStream<S>, ConnectionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut conn = self.0.lock().unwrap();
if let Some(x) = conn.inner.accept(Dir::Uni) {
conn.wake();
mem::drop(conn);
Poll::Ready(Some(Ok(RecvStream::new(self.0.clone(), x, false))))
} else if let Some(ConnectionError::LocallyClosed) = conn.error {
Poll::Ready(None)
} else if let Some(ref e) = conn.error {
Poll::Ready(Some(Err(e.clone())))
} else {
conn.incoming_uni_streams_reader = Some(cx.waker().clone());
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct IncomingBiStreams<S: proto::crypto::Session>(ConnectionRef<S>);
impl<S> futures::Stream for IncomingBiStreams<S>
where
S: proto::crypto::Session,
{
type Item = Result<(SendStream<S>, RecvStream<S>), ConnectionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut conn = self.0.lock().unwrap();
if let Some(x) = conn.inner.accept(Dir::Bi) {
let is_0rtt = conn.inner.is_handshaking();
conn.wake();
mem::drop(conn);
Poll::Ready(Some(Ok((
SendStream::new(self.0.clone(), x, is_0rtt),
RecvStream::new(self.0.clone(), x, is_0rtt),
))))
} else if let Some(ConnectionError::LocallyClosed) = conn.error {
Poll::Ready(None)
} else if let Some(ref e) = conn.error {
Poll::Ready(Some(Err(e.clone())))
} else {
conn.incoming_bi_streams_reader = Some(cx.waker().clone());
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct Datagrams<S: proto::crypto::Session>(ConnectionRef<S>);
impl<S> futures::Stream for Datagrams<S>
where
S: proto::crypto::Session,
{
type Item = Result<Bytes, ConnectionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut conn = self.0.lock().unwrap();
if let Some(x) = conn.inner.recv_datagram() {
Poll::Ready(Some(Ok(x)))
} else if let Some(ConnectionError::LocallyClosed) = conn.error {
Poll::Ready(None)
} else if let Some(ref e) = conn.error {
Poll::Ready(Some(Err(e.clone())))
} else {
conn.datagram_reader = Some(cx.waker().clone());
Poll::Pending
}
}
}
pub struct OpenUni<S>
where
S: proto::crypto::Session,
{
conn: ConnectionRef<S>,
state: broadcast::State,
}
impl<S> Future for OpenUni<S>
where
S: proto::crypto::Session,
{
type Output = Result<SendStream<S>, ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut conn = this.conn.lock().unwrap();
if let Some(ref e) = conn.error {
return Poll::Ready(Err(e.clone()));
}
if let Some(id) = conn.inner.open(Dir::Uni) {
let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking();
drop(conn);
return Poll::Ready(Ok(SendStream::new(this.conn.clone(), id, is_0rtt)));
}
conn.uni_opening.register(cx, &mut this.state);
Poll::Pending
}
}
pub struct OpenBi<S>
where
S: proto::crypto::Session,
{
conn: ConnectionRef<S>,
state: broadcast::State,
}
impl<S> Future for OpenBi<S>
where
S: proto::crypto::Session,
{
type Output = Result<(SendStream<S>, RecvStream<S>), ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut conn = this.conn.lock().unwrap();
if let Some(ref e) = conn.error {
return Poll::Ready(Err(e.clone()));
}
if let Some(id) = conn.inner.open(Dir::Bi) {
let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking();
drop(conn);
return Poll::Ready(Ok((
SendStream::new(this.conn.clone(), id, is_0rtt),
RecvStream::new(this.conn.clone(), id, is_0rtt),
)));
}
conn.bi_opening.register(cx, &mut this.state);
Poll::Pending
}
}
#[derive(Debug)]
pub struct ConnectionRef<S: proto::crypto::Session>(Arc<Mutex<ConnectionInner<S>>>);
impl<S> ConnectionRef<S>
where
S: proto::crypto::Session,
{
pub(crate) fn new(
handle: ConnectionHandle,
conn: proto::generic::Connection<S>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
on_connected: oneshot::Sender<bool>,
) -> Self {
Self(Arc::new(Mutex::new(ConnectionInner {
inner: conn,
driver: None,
handle,
on_connected: Some(on_connected),
connected: false,
timer: None,
conn_events,
endpoint_events,
blocked_writers: HashMap::new(),
blocked_readers: HashMap::new(),
uni_opening: Broadcast::new(),
bi_opening: Broadcast::new(),
incoming_uni_streams_reader: None,
incoming_bi_streams_reader: None,
datagram_reader: None,
finishing: HashMap::new(),
error: None,
ref_count: 0,
})))
}
}
impl<S> Clone for ConnectionRef<S>
where
S: proto::crypto::Session,
{
fn clone(&self) -> Self {
self.0.lock().unwrap().ref_count += 1;
Self(self.0.clone())
}
}
impl<S> Drop for ConnectionRef<S>
where
S: proto::crypto::Session,
{
fn drop(&mut self) {
let conn = &mut *self.0.lock().unwrap();
if let Some(x) = conn.ref_count.checked_sub(1) {
conn.ref_count = x;
if x == 0 && !conn.inner.is_closed() {
conn.implicit_close();
}
}
}
}
impl<S> std::ops::Deref for ConnectionRef<S>
where
S: proto::crypto::Session,
{
type Target = Mutex<ConnectionInner<S>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct ConnectionInner<S>
where
S: proto::crypto::Session,
{
pub(crate) inner: proto::generic::Connection<S>,
driver: Option<Waker>,
handle: ConnectionHandle,
on_connected: Option<oneshot::Sender<bool>>,
connected: bool,
timer: Option<Delay>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
pub(crate) blocked_writers: HashMap<StreamId, Waker>,
pub(crate) blocked_readers: HashMap<StreamId, Waker>,
uni_opening: Broadcast,
bi_opening: Broadcast,
incoming_uni_streams_reader: Option<Waker>,
incoming_bi_streams_reader: Option<Waker>,
datagram_reader: Option<Waker>,
pub(crate) finishing: HashMap<StreamId, oneshot::Sender<Option<WriteError>>>,
pub(crate) error: Option<ConnectionError>,
ref_count: usize,
}
impl<S> ConnectionInner<S>
where
S: proto::crypto::Session,
{
fn drive_transmit(&mut self) {
let now = Instant::now();
while let Some(t) = self.inner.poll_transmit(now) {
let _ = self
.endpoint_events
.unbounded_send((self.handle, EndpointEvent::Transmit(t)));
}
}
fn forward_endpoint_events(&mut self) {
while let Some(event) = self.inner.poll_endpoint_events() {
let _ = self
.endpoint_events
.unbounded_send((self.handle, EndpointEvent::Proto(event)));
}
}
fn process_conn_events(&mut self, cx: &mut Context) -> Result<(), ConnectionError> {
loop {
match self.conn_events.poll_next_unpin(cx) {
Poll::Ready(Some(ConnectionEvent::Proto(event))) => {
self.inner.handle_event(event);
}
Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => {
self.close(error_code, reason);
}
Poll::Ready(None) => {
return Err(ConnectionError::TransportError(proto::TransportError {
code: proto::TransportErrorCode::INTERNAL_ERROR,
frame: None,
reason: "endpoint driver future was dropped".to_string(),
}));
}
Poll::Pending => {
return Ok(());
}
}
}
}
fn forward_app_events(&mut self) {
while let Some(event) = self.inner.poll() {
use proto::Event::*;
match event {
Connected => {
self.connected = true;
if let Some(x) = self.on_connected.take() {
let _ = x.send(self.inner.accepted_0rtt());
}
}
ConnectionLost { reason } => {
self.terminate(reason);
}
StreamWritable { stream } => {
if let Some(writer) = self.blocked_writers.remove(&stream) {
writer.wake();
}
}
StreamOpened { dir: Dir::Uni } => {
if let Some(x) = self.incoming_uni_streams_reader.take() {
x.wake();
}
}
StreamOpened { dir: Dir::Bi } => {
if let Some(x) = self.incoming_bi_streams_reader.take() {
x.wake();
}
}
DatagramReceived => {
if let Some(x) = self.datagram_reader.take() {
x.wake();
}
}
StreamReadable { stream } => {
if let Some(reader) = self.blocked_readers.remove(&stream) {
reader.wake();
}
}
StreamAvailable { dir } => {
let tasks = match dir {
Dir::Uni => &mut self.uni_opening,
Dir::Bi => &mut self.bi_opening,
};
tasks.wake();
}
StreamFinished {
stream,
stop_reason,
} => {
if let Some(finishing) = self.finishing.remove(&stream) {
let _ = finishing.send(stop_reason.map(WriteError::Stopped));
}
}
}
}
}
fn drive_timer(&mut self, cx: &mut Context) -> bool {
let mut keep_going = false;
loop {
if let Some(ref mut delay) = self.timer {
if delay.poll_unpin(cx) == Poll::Ready(()) {
self.inner.handle_timeout(Instant::now());
self.timer = None;
keep_going = true;
}
}
match (
self.inner.poll_timeout().map(TokioInstant::from_std),
&mut self.timer,
) {
(Some(timeout), &mut None) => self.timer = Some(delay_until(timeout)),
(Some(timeout), &mut Some(ref mut delay)) if delay.deadline() != timeout => {
delay.reset(timeout);
}
(None, _) => {
self.timer = None;
break;
}
_ => break,
}
}
keep_going
}
pub(crate) fn wake(&mut self) {
if let Some(x) = self.driver.take() {
x.wake();
}
}
fn terminate(&mut self, reason: ConnectionError) {
self.error = Some(reason.clone());
for (_, writer) in self.blocked_writers.drain() {
writer.wake()
}
for (_, reader) in self.blocked_readers.drain() {
reader.wake()
}
self.uni_opening.wake();
self.bi_opening.wake();
if let Some(x) = self.incoming_uni_streams_reader.take() {
x.wake();
}
if let Some(x) = self.incoming_bi_streams_reader.take() {
x.wake();
}
if let Some(x) = self.datagram_reader.take() {
x.wake();
}
for (_, x) in self.finishing.drain() {
let _ = x.send(Some(WriteError::ConnectionClosed(reason.clone())));
}
if let Some(x) = self.on_connected.take() {
let _ = x.send(false);
}
}
fn close(&mut self, error_code: VarInt, reason: Bytes) {
self.inner.close(Instant::now(), error_code, reason);
self.terminate(ConnectionError::LocallyClosed);
self.wake();
}
pub fn implicit_close(&mut self) {
self.close(0u32.into(), Bytes::new());
}
pub(crate) fn check_0rtt(&self) -> Result<(), ()> {
if self.inner.is_handshaking()
|| self.inner.accepted_0rtt()
|| self.inner.side().is_server()
{
Ok(())
} else {
Err(())
}
}
}
impl<S> Drop for ConnectionInner<S>
where
S: proto::crypto::Session,
{
fn drop(&mut self) {
if !self.inner.is_drained() {
let _ = self.endpoint_events.unbounded_send((
self.handle,
EndpointEvent::Proto(proto::EndpointEvent::drained()),
));
}
}
}
impl<S> fmt::Debug for ConnectionInner<S>
where
S: proto::crypto::Session,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ConnectionInner")
.field("inner", &self.inner)
.finish()
}
}
#[derive(Debug, Error, Clone, Eq, PartialEq)]
pub enum SendDatagramError {
#[error(display = "datagrams not supported by peer")]
UnsupportedByPeer,
#[error(display = "datagram support disabled")]
Disabled,
#[error(display = "datagram too large")]
TooLarge,
#[error(display = "connection closed: {}", _0)]
ConnectionClosed(ConnectionError),
}