use bytes::{Bytes, BytesMut};
use quinn::{RecvStream, SendStream, VarInt};
use tracing::debug;
use crate::{
error::NetError,
proto::{
ProtoError,
op::{DnsResponse, Message},
},
};
pub(crate) const DOQ_ALPN: &[u8] = b"doq";
#[derive(Clone, Copy)]
pub enum DoqErrorCode {
NoError,
InternalError,
ProtocolError,
RequestCancelled,
ExcessiveLoad,
ErrorReserved,
Unknown(u32),
}
const NO_ERROR: u32 = 0x0;
const INTERNAL_ERROR: u32 = 0x1;
const PROTOCOL_ERROR: u32 = 0x2;
const REQUEST_CANCELLED: u32 = 0x3;
const EXCESSIVE_LOAD: u32 = 0x4;
const ERROR_RESERVED: u32 = 0xd098ea5e;
impl From<DoqErrorCode> for VarInt {
fn from(doq_error: DoqErrorCode) -> Self {
use DoqErrorCode::*;
match doq_error {
NoError => Self::from_u32(NO_ERROR),
InternalError => Self::from_u32(INTERNAL_ERROR),
ProtocolError => Self::from_u32(PROTOCOL_ERROR),
RequestCancelled => Self::from_u32(REQUEST_CANCELLED),
ExcessiveLoad => Self::from_u32(EXCESSIVE_LOAD),
ErrorReserved => Self::from_u32(ERROR_RESERVED),
Unknown(code) => Self::from_u32(code),
}
}
}
impl From<VarInt> for DoqErrorCode {
fn from(doq_error: VarInt) -> Self {
let code: u32 = if let Ok(code) = doq_error.into_inner().try_into() {
code
} else {
return Self::ProtocolError;
};
match code {
NO_ERROR => Self::NoError,
INTERNAL_ERROR => Self::InternalError,
PROTOCOL_ERROR => Self::ProtocolError,
REQUEST_CANCELLED => Self::RequestCancelled,
EXCESSIVE_LOAD => Self::ExcessiveLoad,
ERROR_RESERVED => Self::ErrorReserved,
_ => Self::Unknown(code),
}
}
}
pub struct QuicStream {
send_stream: SendStream,
receive_stream: RecvStream,
}
impl QuicStream {
pub(crate) fn new(send_stream: SendStream, receive_stream: RecvStream) -> Self {
Self {
send_stream,
receive_stream,
}
}
pub async fn send(&mut self, mut message: Message) -> Result<(), NetError> {
message.metadata.id = 0;
let bytes = Bytes::from(message.to_vec()?);
self.send_bytes(bytes).await
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> Result<(), NetError> {
let bytes_len = u16::try_from(bytes.len())
.map_err(|_e| NetError::from(ProtoError::MaxBufferSizeExceeded(bytes.len())))?;
let len = bytes_len.to_be_bytes().to_vec();
let len = Bytes::from(len);
debug!("received packet len: {} bytes: {:x?}", bytes_len, bytes);
self.send_stream.write_all_chunks(&mut [len, bytes]).await?;
Ok(())
}
pub async fn finish(&mut self) -> Result<(), NetError> {
self.send_stream.finish()?;
Ok(())
}
pub async fn receive(&mut self) -> Result<DnsResponse, NetError> {
let bytes = self.receive_bytes().await?;
let message = Message::from_vec(&bytes)?;
if message.id != 0 {
if let Err(error) = self.reset(DoqErrorCode::ProtocolError) {
debug!(%error, "stream already closed");
}
return Err(NetError::QuicMessageIdNot0(message.id));
}
Ok(DnsResponse::from_buffer(bytes.to_vec())?)
}
pub async fn receive_bytes(&mut self) -> Result<BytesMut, NetError> {
let mut len = [0u8; 2];
self.receive_stream.read_exact(&mut len).await?;
let len = u16::from_be_bytes(len) as usize;
let mut bytes = BytesMut::with_capacity(len);
bytes.resize(len, 0);
if let Err(e) = self.receive_stream.read_exact(&mut bytes[..len]).await {
debug!("received bad packet len: {} bytes: {:?}", len, bytes);
if let Err(error) = self.reset(DoqErrorCode::ProtocolError) {
debug!(%error, "stream already closed");
}
return Err(NetError::from(e));
}
debug!("received packet len: {} bytes: {:x?}", len, bytes);
Ok(bytes)
}
pub fn reset(&mut self, code: DoqErrorCode) -> Result<(), NetError> {
self.send_stream
.reset(code.into())
.map_err(|_| NetError::QuinnUnknownStreamError)
}
pub fn stop(&mut self, code: DoqErrorCode) -> Result<(), NetError> {
self.receive_stream
.stop(code.into())
.map_err(|_| NetError::QuinnUnknownStreamError)
}
}