use std::{
fmt::{Debug, Display},
io,
net::{SocketAddr, ToSocketAddrs},
ops,
sync::Arc,
};
use quiche::{ConnectionId, RecvInfo};
use rand::{seq::IteratorRandom, thread_rng};
use rasi::{
executor::spawn,
future::FutureExt,
io::{AsyncRead, AsyncWrite},
stream::TryStreamExt,
syscall::{global_network, Network},
time::TimeoutExt,
};
use crate::{
net::udp_group::{self, PathInfo, UdpGroup},
utils::ReadBuf,
};
use super::{
state::{QuicConnState, QuicListenerState, QuicServerStateIncoming},
Config,
};
struct QuicConnFinalizer(QuicConnState);
impl Drop for QuicConnFinalizer {
fn drop(&mut self) {
let state = self.0.clone();
spawn(async move {
match state.close(false, 0, b"").await {
Ok(_) => {}
Err(err) => {
log::error!("drop with error: {}", err);
}
}
});
}
}
pub struct QuicConnector {
udp_group: UdpGroup,
conn_state: QuicConnState,
max_send_udp_payload_size: usize,
}
impl QuicConnector {
pub async fn new<L: ToSocketAddrs, R: ToSocketAddrs>(
server_name: Option<&str>,
laddrs: L,
raddrs: R,
config: &mut Config,
) -> io::Result<Self> {
Self::new_with(server_name, laddrs, raddrs, config, global_network()).await
}
pub async fn new_with<L: ToSocketAddrs, R: ToSocketAddrs>(
server_name: Option<&str>,
laddrs: L,
raddrs: R,
config: &mut Config,
syscall: &'static dyn rasi::syscall::Network,
) -> io::Result<Self> {
let udp_group = UdpGroup::bind_with(laddrs, syscall).await?;
let raddr = raddrs.to_socket_addrs()?.choose(&mut thread_rng()).unwrap();
let laddr = udp_group
.local_addrs()
.filter(|addr| raddr.is_ipv4() == addr.is_ipv4())
.choose(&mut thread_rng())
.unwrap();
let conn_state = QuicConnState::new_client(server_name, *laddr, raddr, config)?;
Ok(Self {
conn_state,
udp_group,
max_send_udp_payload_size: config.max_send_udp_payload_size,
})
}
pub async fn connect(mut self) -> io::Result<QuicConn> {
let (sender, mut receiver) = self.udp_group.split();
loop {
let mut read_buf = ReadBuf::with_capacity(self.max_send_udp_payload_size);
let (read_size, send_info) = self.conn_state.send(read_buf.chunk_mut()).await?;
let send_size = sender
.send_to_on_path(
&read_buf.into_bytes(Some(read_size)),
PathInfo {
from: send_info.from,
to: send_info.to,
},
)
.await?;
log::trace!("Quic connection, {:?}, send data {}", send_info, send_size);
let (mut buf, path_info) =
if let Some(timeout_at) = self.conn_state.to_inner_conn().await.timeout_instant() {
match receiver.try_next().timeout_at(timeout_at).await {
Some(Ok(r)) => r.ok_or(io::Error::new(
io::ErrorKind::BrokenPipe,
"Underlying udp socket closed",
))?,
Some(Err(err)) => {
return Err(err);
}
None => {
continue;
}
}
} else {
receiver.try_next().await?.ok_or(io::Error::new(
io::ErrorKind::BrokenPipe,
"Underlying udp socket closed",
))?
};
log::trace!("Quic connection, {:?}, recv data {}", path_info, buf.len());
self.conn_state
.recv(
&mut buf,
RecvInfo {
from: path_info.from,
to: path_info.to,
},
)
.await?;
if self.conn_state.is_established().await {
self.conn_state.update_dcid().await;
break;
}
}
spawn(QuicConn::recv_loop(self.conn_state.clone(), receiver));
spawn(QuicConn::send_loop(
self.conn_state.clone(),
sender,
self.max_send_udp_payload_size,
));
Ok(QuicConn::new(self.conn_state))
}
}
pub struct QuicConn {
inner: Arc<QuicConnFinalizer>,
}
impl Display for QuicConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner.0)
}
}
impl QuicConn {
pub fn source_id(&self) -> &ConnectionId<'static> {
&self.inner.0.scid
}
pub(super) fn new(state: QuicConnState) -> QuicConn {
Self {
inner: Arc::new(QuicConnFinalizer(state)),
}
}
pub async fn stream_accept(&self) -> Option<QuicStream> {
self.inner
.0
.stream_accept()
.await
.map(|stream_id| QuicStream::new(stream_id, self.inner.clone()))
}
pub async fn stream_open(&self, stream_limits_error: bool) -> io::Result<QuicStream> {
self.inner
.0
.stream_open(stream_limits_error)
.await
.map(|stream_id| QuicStream::new(stream_id, self.inner.clone()))
}
pub async fn peer_streams_left_bidi(&self) -> u64 {
self.inner.0.peer_streams_left_bidi().await
}
pub async fn to_inner_conn(&self) -> impl ops::Deref<Target = quiche::Connection> + '_ {
self.inner.0.to_inner_conn().await
}
}
impl QuicConn {
async fn recv_loop(state: QuicConnState, receiver: udp_group::Receiver) {
match Self::recv_loop_inner(state, receiver).await {
Ok(_) => {
log::info!("QuicListener recv_loop stopped.");
}
Err(err) => {
log::error!("QuicListener recv_loop stopped with err: {}", err);
}
}
}
async fn recv_loop_inner(
state: QuicConnState,
mut receiver: udp_group::Receiver,
) -> io::Result<()> {
while let Some((mut buf, path_info)) = receiver.try_next().await? {
log::trace!("QuicConn {:?}, recv data {}", path_info, buf.len());
let _ = match state
.recv(
&mut buf,
RecvInfo {
from: path_info.from,
to: path_info.to,
},
)
.await
{
Ok(r) => r,
Err(err) => {
log::error!(
"Quic client: handle packet from {:?} error: {}",
path_info,
err
);
continue;
}
};
}
Ok(())
}
async fn send_loop(
state: QuicConnState,
sender: udp_group::Sender,
max_send_udp_payload_size: usize,
) {
match Self::send_loop_inner(state, sender, max_send_udp_payload_size).await {
Ok(_) => {
log::info!("Quic client send_loop stopped.");
}
Err(err) => {
log::error!("Quic client send_loop stopped with err: {}", err);
}
}
}
pub(crate) async fn send_loop_inner(
state: QuicConnState,
sender: udp_group::Sender,
max_send_udp_payload_size: usize,
) -> io::Result<()> {
loop {
let mut read_buf = ReadBuf::with_capacity(max_send_udp_payload_size);
let (send_size, send_info) = state.send(read_buf.chunk_mut()).await?;
let send_size = sender
.send_to_on_path(
&read_buf.into_bytes(Some(send_size)),
PathInfo {
from: send_info.from,
to: send_info.to,
},
)
.await?;
log::trace!("QuicConn {:?}, send data {}", send_info, send_size);
}
}
}
struct QuicStreamFinalizer(u64, Arc<QuicConnFinalizer>);
impl Drop for QuicStreamFinalizer {
fn drop(&mut self) {
let state = self.1 .0.clone();
let stream_id = self.0;
spawn(async move {
match state.stream_close(stream_id).await {
Ok(_) => {}
Err(err) => {
log::error!("drop with error: {}", err);
}
}
});
}
}
#[derive(Clone)]
pub struct QuicStream {
inner: Arc<QuicStreamFinalizer>,
}
impl Display for QuicStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}, stream_id={}", self.inner.1 .0, self.inner.0)
}
}
impl Debug for QuicStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QuicStream")
.field("stream_id", &self.inner.0)
.field("conn", &self.inner.1 .0.to_string())
.finish()
}
}
impl QuicStream {
fn new(stream_id: u64, inner: Arc<QuicConnFinalizer>) -> Self {
Self {
inner: Arc::new(QuicStreamFinalizer(stream_id, inner)),
}
}
pub async fn stream_send(&self, buf: &[u8], fin: bool) -> io::Result<usize> {
self.inner.1 .0.stream_send(self.inner.0, buf, fin).await
}
pub async fn stream_recv(&self, buf: &mut [u8]) -> io::Result<(usize, bool)> {
self.inner.1 .0.stream_recv(self.inner.0, buf).await
}
}
impl AsyncWrite for QuicStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
Box::pin(self.inner.1 .0.stream_send(self.inner.0, buf, false)).poll_unpin(cx)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
Box::pin(self.inner.1 .0.stream_close(self.inner.0))
.poll_unpin(cx)
.map(|_| Ok(()))
}
}
impl AsyncRead for QuicStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<io::Result<usize>> {
Box::pin(self.inner.1 .0.stream_recv(self.inner.0, buf))
.poll_unpin(cx)
.map(|r| match r {
Ok((readsize, _)) => Ok(readsize),
Err(err) => {
if err.kind() == io::ErrorKind::BrokenPipe {
Ok(0)
} else {
Err(err)
}
}
})
}
}
pub struct QuicListener {
incoming: QuicServerStateIncoming,
laddrs: Vec<SocketAddr>,
}
impl QuicListener {
pub async fn bind_with<A: ToSocketAddrs>(
laddrs: A,
config: Config,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let max_send_udp_payload_size = config.max_send_udp_payload_size;
let socket = UdpGroup::bind_with(laddrs, syscall).await?;
let laddrs = socket.local_addrs().cloned().collect::<Vec<_>>();
let (sender, receiver) = socket.split();
let (state, incoming) = QuicListenerState::new(config)?;
rasi::executor::spawn(Self::recv_loop(
state,
receiver,
sender,
max_send_udp_payload_size,
));
Ok(Self { incoming, laddrs })
}
pub async fn bind<A: ToSocketAddrs>(laddrs: A, config: Config) -> io::Result<Self> {
Self::bind_with(laddrs, config, global_network()).await
}
pub async fn accept(&self) -> Option<QuicConn> {
self.incoming
.accept()
.await
.map(|state| QuicConn::new(state))
}
pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
self.laddrs.iter()
}
}
impl QuicListener {
async fn recv_loop(
state: QuicListenerState,
receiver: udp_group::Receiver,
sender: udp_group::Sender,
max_send_udp_payload_size: usize,
) {
match Self::recv_loop_inner(state, receiver, sender, max_send_udp_payload_size).await {
Ok(_) => {
log::info!("QuicListener recv_loop stopped.");
}
Err(err) => {
log::error!("QuicListener recv_loop stopped with err: {}", err);
}
}
}
async fn recv_loop_inner(
state: QuicListenerState,
mut receiver: udp_group::Receiver,
sender: udp_group::Sender,
max_send_udp_payload_size: usize,
) -> io::Result<()> {
while let Some((mut buf, path_info)) = receiver.try_next().await? {
let (_, resp, conn_state) = match state
.recv(
&mut buf,
RecvInfo {
from: path_info.from,
to: path_info.to,
},
)
.await
{
Ok(r) => r,
Err(err) => {
log::error!("handle packet from {:?} error: {}", path_info, err);
continue;
}
};
if let Some(resp) = resp {
sender.send_to_on_path(&resp, path_info.reverse()).await?;
}
if let Some(conn_state) = conn_state {
spawn(Self::send_loop(
state.clone(),
conn_state,
sender.clone(),
max_send_udp_payload_size,
));
}
}
Ok(())
}
async fn send_loop(
state: QuicListenerState,
conn_state: QuicConnState,
sender: udp_group::Sender,
max_send_udp_payload_size: usize,
) {
match QuicConn::send_loop_inner(conn_state.clone(), sender, max_send_udp_payload_size).await
{
Ok(_) => {
log::trace!("{}, stop send loop", conn_state);
}
Err(err) => {
log::error!("{}, stop send loop with error: {}", conn_state, err);
}
}
state.remove_conn(&conn_state.scid).await;
}
}