use super::error::QuicError;
use super::stream::{RecvStream, SendStream, StreamTracker};
use crate::cx::Cx;
use std::future::{Future, poll_fn};
use std::net::SocketAddr;
use std::sync::Arc;
use std::task::Poll;
#[derive(Debug)]
pub struct QuicConnection {
inner: quinn::Connection,
tracker: Arc<StreamTracker>,
}
impl QuicConnection {
pub(crate) fn new(inner: quinn::Connection) -> Self {
Self {
inner,
tracker: StreamTracker::new(),
}
}
#[must_use]
pub fn remote_address(&self) -> SocketAddr {
self.inner.remote_address()
}
#[must_use]
pub fn stable_id(&self) -> usize {
self.inner.stable_id()
}
#[must_use]
pub fn alpn_protocol(&self) -> Option<Vec<u8>> {
self.inner.handshake_data().and_then(|data| {
data.downcast::<quinn::crypto::rustls::HandshakeData>()
.ok()
.and_then(|hs| hs.protocol.clone())
})
}
pub async fn open_bi(&self, cx: &Cx) -> Result<(SendStream, RecvStream), QuicError> {
let (send, recv) = wait_with_cx(cx, self.inner.open_bi()).await??;
Ok((
SendStream::new(send, &self.tracker),
RecvStream::new(recv, &self.tracker),
))
}
pub async fn open_uni(&self, cx: &Cx) -> Result<SendStream, QuicError> {
let send = wait_with_cx(cx, self.inner.open_uni()).await??;
Ok(SendStream::new(send, &self.tracker))
}
pub async fn accept_bi(&self, cx: &Cx) -> Result<(SendStream, RecvStream), QuicError> {
let (send, recv) = wait_with_cx(cx, self.inner.accept_bi()).await??;
Ok((
SendStream::new(send, &self.tracker),
RecvStream::new(recv, &self.tracker),
))
}
pub async fn accept_uni(&self, cx: &Cx) -> Result<RecvStream, QuicError> {
let recv = wait_with_cx(cx, self.inner.accept_uni()).await??;
Ok(RecvStream::new(recv, &self.tracker))
}
pub async fn close(&self, cx: &Cx, code: u32, reason: &[u8]) -> Result<(), QuicError> {
self.tracker.mark_closing();
self.inner.close(code.into(), reason);
let _ = wait_with_cx(cx, self.inner.closed()).await?;
Ok(())
}
pub fn close_immediately(&self, code: u32, reason: &[u8]) {
self.tracker.mark_closing();
self.inner.close(code.into(), reason);
}
#[must_use]
pub fn is_open(&self) -> bool {
!self.tracker.is_closing() && self.inner.close_reason().is_none()
}
pub async fn closed(&self, cx: &Cx) -> Result<(), QuicError> {
let _ = wait_with_cx(cx, self.inner.closed()).await?;
Ok(())
}
#[must_use]
pub fn max_datagram_size(&self) -> Option<usize> {
self.inner.max_datagram_size()
}
pub fn send_datagram(&self, data: &[u8]) -> Result<(), QuicError> {
self.inner.send_datagram(data.to_vec().into())?;
Ok(())
}
pub async fn read_datagram(&self, cx: &Cx) -> Result<Vec<u8>, QuicError> {
let data = wait_with_cx(cx, self.inner.read_datagram()).await??;
Ok(data.to_vec())
}
#[must_use]
pub fn rtt(&self) -> std::time::Duration {
self.inner.rtt()
}
#[must_use]
pub fn inner(&self) -> &quinn::Connection {
&self.inner
}
}
impl Drop for QuicConnection {
fn drop(&mut self) {
self.tracker.mark_closing();
}
}
async fn wait_with_cx<T, F>(cx: &Cx, future: F) -> Result<T, QuicError>
where
F: Future<Output = T>,
{
let mut future = std::pin::pin!(future);
poll_fn(|poll_cx| {
if let Err(err) = cx.checkpoint() {
return Poll::Ready(Err(QuicError::from(err)));
}
future.as_mut().poll(poll_cx).map(Ok)
})
.await
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
use std::pin::Pin;
use std::task::Context;
fn noop_waker() -> std::task::Waker {
std::task::Waker::noop().clone()
}
struct PendingOnce {
polled: bool,
}
impl Future for PendingOnce {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.polled {
Poll::Ready(())
} else {
self.polled = true;
Poll::Pending
}
}
}
#[test]
fn wait_with_cx_returns_cancelled_when_context_is_cancelled_between_polls() {
let cx = Cx::for_testing();
let mut future = std::pin::pin!(wait_with_cx(&cx, PendingOnce { polled: false }));
let waker = noop_waker();
let mut poll_cx = Context::from_waker(&waker);
assert!(matches!(future.as_mut().poll(&mut poll_cx), Poll::Pending));
cx.set_cancel_requested(true);
let cancelled = matches!(
future.as_mut().poll(&mut poll_cx),
Poll::Ready(Err(QuicError::Cancelled))
);
assert!(
cancelled,
"future should return cancelled after Cx cancellation"
);
}
}