use core::{
cmp,
future::Future,
mem,
net::SocketAddr,
pin::Pin,
task::{Context, Poll, ready},
};
use std::{io, sync::Arc};
use quinn::{ClientConfig, Connection, Endpoint, RecvStream, SendStream, crypto::rustls::QuicClientConfig};
use xitca_io::{
bytes::{Buf, Bytes},
io::{AsyncIo, Interest, Ready},
};
use crate::error::Error;
pub(crate) const QUIC_ALPN: &[u8] = b"quic";
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub struct QuicStream {
writer: Writer,
reader: Reader,
}
enum Writer {
Tx(SendStream),
Error(io::Error),
InFlight(BoxFuture<io::Result<SendStream>>),
Closed,
}
impl Writer {
fn poll_ready(&mut self, interest: Interest, ready: &mut Ready, cx: &mut Context<'_>) {
match self {
Self::InFlight(fut) => {
if let Poll::Ready(res) = fut.as_mut().poll(cx) {
if interest.is_writable() {
*ready |= Ready::WRITABLE;
}
match res {
Ok(tx) => *self = Self::Tx(tx),
Err(e) => *self = Self::Error(e),
}
}
}
_ => {
if interest.is_writable() {
*ready |= Ready::WRITABLE;
}
}
}
}
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Error(_) => {
let Self::Error(e) = mem::replace(self, Self::Closed) else {
unreachable!()
};
Err(e)
}
Self::Tx(_) => {
let Self::Tx(mut tx) = mem::replace(self, Self::Closed) else {
unreachable!()
};
let bytes = Bytes::copy_from_slice(buf);
*self = Self::InFlight(Box::pin(async move {
tx.write_chunk(bytes).await?;
Ok(tx)
}));
Ok(buf.len())
}
Self::InFlight(_) => Err(io::ErrorKind::WouldBlock.into()),
Self::Closed => unreachable!(),
}
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self {
Self::Tx(tx) => Poll::Ready(tx.finish().map_err(io::Error::other)),
Self::InFlight(fut) => {
let tx = ready!(fut.as_mut().poll(cx))?;
*self = Self::Tx(tx);
self.poll_shutdown(cx)
}
_ => Poll::Ready(Ok(())),
}
}
}
enum Reader {
Buffered((Bytes, RecvStream)),
InFlight(BoxFuture<io::Result<Option<(Bytes, RecvStream)>>>),
Error(io::Error),
Closed,
}
impl Reader {
fn in_flight(mut rx: RecvStream) -> Self {
Self::InFlight(Box::pin(async move {
let chunk = rx.read_chunk(4096, true).await?;
Ok(chunk.map(|c| (c.bytes, rx)))
}))
}
fn poll_ready_once(&mut self, cx: &mut Context<'_>, ready: &mut Ready) {
match self {
Self::Buffered((bytes, _)) => {
if !bytes.is_empty() {
*ready |= Ready::READABLE;
return;
}
let Self::Buffered((_, rx)) = mem::replace(self, Self::Closed) else {
unreachable!()
};
*self = Self::in_flight(rx);
self.poll_ready_once(cx, ready);
}
Self::InFlight(fut) => {
if let Poll::Ready(res) = fut.as_mut().poll(cx) {
*ready |= Ready::READABLE;
match res {
Ok(Some(res)) => *self = Self::Buffered(res),
Ok(None) => *self = Self::Closed,
Err(e) => *self = Self::Error(e),
}
}
}
Self::Error(_) => *ready |= Ready::READABLE,
Self::Closed => {}
}
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Buffered((bytes, _)) => {
let len = cmp::min(buf.len(), bytes.len());
buf[..len].copy_from_slice(&bytes[..len]);
bytes.advance(len);
Ok(len)
}
Self::Error(_) => {
let Self::Error(e) = mem::replace(self, Self::Closed) else {
unreachable!()
};
Err(e)
}
Self::Closed => Ok(0),
_ => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
impl From<(SendStream, RecvStream)> for QuicStream {
fn from((tx, rx): (SendStream, RecvStream)) -> Self {
Self {
writer: Writer::Tx(tx),
reader: Reader::in_flight(rx),
}
}
}
impl AsyncIo for QuicStream {
async fn ready(&mut self, interest: Interest) -> io::Result<Ready> {
core::future::poll_fn(|cx| self.poll_ready(interest, cx)).await
}
fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
let mut ready = Ready::EMPTY;
if interest.is_readable() {
self.reader.poll_ready_once(cx, &mut ready);
}
self.writer.poll_ready(interest, &mut ready, cx);
if ready.is_empty() {
Poll::Pending
} else {
Poll::Ready(Ok(ready))
}
}
fn is_vectored_write(&self) -> bool {
false
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().writer.poll_shutdown(cx)
}
}
impl io::Read for QuicStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.reader.read(buf)
}
}
impl io::Write for QuicStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.writer.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cold]
#[inline(never)]
pub(crate) async fn connect_quic(host: &str, ports: &[u16]) -> Result<(QuicStream, SocketAddr), Error> {
let (conn, addr) = _connect_quic(host, ports).await?;
let stream = conn.open_bi().await.map_err(|_| Error::todo())?;
Ok((stream.into(), addr))
}
#[cold]
#[inline(never)]
pub(crate) async fn connect_quic_addr(host: &str, addr: SocketAddr) -> Result<QuicStream, Error> {
let endpoint = prepare_endpoint()?;
let conn = endpoint.connect(addr, host).map_err(|_| Error::todo())?;
let conn = conn.await.map_err(|_| Error::todo())?;
let stream = conn.open_bi().await.map_err(|_| Error::todo())?;
Ok(stream.into())
}
async fn _connect_quic(host: &str, ports: &[u16]) -> Result<(Connection, SocketAddr), Error> {
let addrs = super::dns_resolve(host, ports).await?;
let endpoint = prepare_endpoint()?;
let mut err = None;
for addr in addrs {
match endpoint.connect(addr, host) {
Ok(conn) => match conn.await {
Ok(conn) => return Ok((conn, addr)),
Err(_) => err = Some(Error::todo()),
},
Err(_) => err = Some(Error::todo()),
}
}
Err(err.unwrap())
}
fn prepare_endpoint() -> Result<Endpoint, Error> {
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
let cfg = super::tls::dangerous_config(vec![QUIC_ALPN.to_vec()]);
let cfg = QuicClientConfig::try_from(cfg).unwrap();
endpoint.set_default_client_config(ClientConfig::new(Arc::new(cfg)));
Ok(endpoint)
}