use std::{
collections::VecDeque,
future::Future,
io,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker, ready},
};
use bytes::{BufMut, Bytes};
use qbase::{
error::{Error, ErrorKind, QuicError},
frame::{DatagramFrame, EncodeSize, GetFrameType},
};
#[derive(Debug)]
struct RawDatagarmReader {
local_max_size: usize,
rcvd_datagrams: VecDeque<Bytes>,
read_waker: Option<Waker>,
}
impl RawDatagarmReader {
fn new(local_max_size: usize) -> Self {
Self {
local_max_size,
rcvd_datagrams: VecDeque::new(),
read_waker: None,
}
}
}
#[derive(Debug, Clone)]
pub struct DatagramIncoming(Arc<Mutex<Result<RawDatagarmReader, Error>>>);
impl DatagramIncoming {
pub fn new(local_max_size: usize) -> Self {
Self(Arc::new(Mutex::new(Ok(RawDatagarmReader::new(
local_max_size,
)))))
}
pub fn new_reader(&self) -> io::Result<DatagramReader> {
let mut guard = self.0.lock().unwrap();
let reader = guard.as_mut().map_err(|e| e.clone())?;
if reader.local_max_size == 0 {
tracing::error!(" Cause by: DatagramIncoming::new_reader local_max_size is 0");
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"Unreliable Datagram Extension was disenabled by local parameters",
));
}
Ok(DatagramReader(self.0.clone()))
}
pub fn recv_datagram(&self, frame: &DatagramFrame, data: bytes::Bytes) -> Result<(), Error> {
let mut guard = self.0.lock().unwrap();
let reader = guard.as_mut().map_err(|e| e.clone())?;
if (frame.encoding_size() + data.len()) > reader.local_max_size {
tracing::error!(" Cause by: DatagramIncoming::recv_datagram");
return Err(QuicError::new(
ErrorKind::ProtocolViolation,
frame.frame_type().into(),
format!(
"datagram size {} exceeds the maximum size {}",
frame.encoding_size() + data.len(),
reader.local_max_size
),
)
.into());
}
reader.rcvd_datagrams.push_back(data);
if let Some(waker) = reader.read_waker.take() {
waker.wake();
}
Ok(())
}
pub fn on_conn_error(&self, error: &Error) {
let guard = &mut self.0.lock().unwrap();
if let Ok(reader) = guard.as_mut() {
if let Some(waker) = reader.read_waker.take() {
waker.wake();
}
**guard = Err(error.clone());
}
}
}
#[derive(Debug, Clone)]
pub struct DatagramReader(Arc<Mutex<Result<RawDatagarmReader, Error>>>);
impl DatagramReader {
pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<io::Result<Bytes>> {
let mut reader = self.0.lock().unwrap();
match reader.as_mut() {
Ok(reader) => match reader.rcvd_datagrams.pop_front() {
Some(bytes) => Poll::Ready(Ok(bytes)),
None => {
reader.read_waker = Some(cx.waker().clone());
Poll::Pending
}
},
Err(e) => Poll::Ready(Err(io::Error::from(e.clone()))),
}
}
pub fn recv(&mut self) -> RecvDatagram<'_> {
RecvDatagram { reader: self }
}
pub fn read<'b>(&'b mut self, buf: &'b mut [u8]) -> ReadIntoSlice<'b> {
ReadIntoSlice { reader: self, buf }
}
pub fn read_buf<'b, B: BufMut>(&'b mut self, buf: &'b mut B) -> ReadIntoBuf<'b, B> {
ReadIntoBuf { reader: self, buf }
}
}
pub struct RecvDatagram<'a> {
reader: &'a mut DatagramReader,
}
impl Future for RecvDatagram<'_> {
type Output = io::Result<Bytes>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.reader.poll_recv(cx)
}
}
pub struct ReadIntoSlice<'a> {
reader: &'a mut DatagramReader,
buf: &'a mut [u8],
}
impl Future for ReadIntoSlice<'_> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let s = self.get_mut();
let bytes = ready!(s.reader.poll_recv(cx)?);
let len = bytes.len().min(s.buf.len());
s.buf[..len].copy_from_slice(&bytes[..len]);
Poll::Ready(Ok(len))
}
}
pub struct ReadIntoBuf<'a, B> {
reader: &'a mut DatagramReader,
buf: &'a mut B,
}
impl<B> Future for ReadIntoBuf<'_, B>
where
B: BufMut,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let s = self.get_mut();
let bytes = ready!(s.reader.poll_recv(cx)?);
let len = bytes.len();
s.buf.put(bytes);
Poll::Ready(Ok(len))
}
}
#[cfg(test)]
mod tests {
use qbase::{frame::FrameType, varint::VarInt};
use super::*;
#[tokio::test]
async fn test_datagram_reader_recv_buf() {
let incoming = DatagramIncoming::new(1024);
let recv = tokio::spawn({
let mut reader = incoming.new_reader().unwrap();
async move {
let n = reader.read(&mut [0u8; 1024]).await.unwrap();
assert_eq!(n, 11);
}
});
incoming
.recv_datagram(
&DatagramFrame::new(false, VarInt::from_u32(11)),
Bytes::from_static(b"hello world"),
)
.unwrap();
recv.await.unwrap();
}
#[tokio::test]
async fn test_datagram_reader_on_conn_error() {
let incoming = DatagramIncoming::new(1024);
let error = QuicError::new(
ErrorKind::ProtocolViolation,
FrameType::Datagram(0).into(),
"protocol violation",
)
.into();
incoming.on_conn_error(&error);
let new_reader = incoming.new_reader();
assert!(new_reader.is_err());
assert_eq!(new_reader.unwrap_err().kind(), io::ErrorKind::BrokenPipe);
}
}