#![forbid(unsafe_code)]
use crate::{
crypto::tls,
event::metrics::aggregate,
frame::ConnectionClose,
varint::{VarInt, VarIntError},
};
use core::fmt;
use s2n_codec::DecoderError;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Error {
pub code: Code,
pub frame_type: VarInt,
pub reason: &'static str,
}
impl core::error::Error for Error {}
impl Error {
pub const fn new(code: VarInt) -> Self {
Self {
code: Code::new(code),
reason: "",
frame_type: VarInt::from_u8(0),
}
}
#[must_use]
pub const fn with_frame_type(mut self, frame_type: VarInt) -> Self {
self.frame_type = frame_type;
self
}
#[must_use]
pub const fn with_reason(mut self, reason: &'static str) -> Self {
self.reason = reason;
self
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.reason.is_empty() {
self.reason.fmt(f)
} else {
self.code.fmt(f)
}
}
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_struct("transport::Error");
d.field("code", &self.code.as_u64());
if let Some(description) = self.description() {
d.field("description", &description);
}
if !self.reason.is_empty() {
d.field("reason", &self.reason);
}
d.field("frame_type", &self.frame_type);
d.finish()
}
}
impl From<Error> for ConnectionClose<'_> {
fn from(error: Error) -> Self {
ConnectionClose {
error_code: error.code.0,
frame_type: Some(error.frame_type),
reason: Some(error.reason.as_bytes()).filter(|reason| !reason.is_empty()),
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Code(VarInt);
impl Code {
#[doc(hidden)]
pub const fn new(code: VarInt) -> Self {
Self(code)
}
#[inline]
pub fn as_u64(self) -> u64 {
self.0.as_u64()
}
#[inline]
pub fn as_varint(self) -> VarInt {
self.0
}
}
impl fmt::Debug for Code {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_tuple("transport::error::Code");
d.field(&self.0);
if let Some(desc) = self.description() {
d.field(&desc);
}
d.finish()
}
}
impl fmt::Display for Code {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(description) = self.description() {
description.fmt(f)
} else {
write!(f, "error({:x?})", self.as_u64())
}
}
}
const UNKNOWN_FRAME_TYPE: u32 = 0;
const CRYPTO_ERROR_RANGE: core::ops::RangeInclusive<u64> = 0x100..=0x1ff;
macro_rules! impl_errors {
($($(#[doc = $doc:expr])* $name:ident = $code:literal $(.with_frame_type($frame:expr))?),* $(,)?) => {
impl Code {
$(
$(#[doc = $doc])*
pub const $name: Self = Self::new(VarInt::from_u32($code));
)*
pub fn description(&self) -> Option<&'static str> {
match self.0.as_u64() {
$(
$code => Some(stringify!($name)),
)*
code if CRYPTO_ERROR_RANGE.contains(&code) => tls::Error::new(code as u8).description(),
_ => None
}
}
}
impl aggregate::AsVariant for Code {
const VARIANTS: &'static [aggregate::info::Variant] = &{
use aggregate::info::{Variant, Str};
const fn count(_v: u64) -> usize {
1
}
const QUIC_VARIANTS: usize = 0 $( + count($code))*;
const TLS: &'static [Variant] = tls::Error::VARIANTS;
let mut array = [
Variant { name: Str::new("\0"), id: 0 };
QUIC_VARIANTS + TLS.len() + 1
];
let mut id = 0;
$(
array[id] = Variant {
name: Str::new(concat!("QUIC_", stringify!($name), "\0")),
id,
};
id += 1;
)*
let mut tls_idx = 0;
while tls_idx < TLS.len() {
let variant = TLS[tls_idx];
array[id] = Variant {
name: variant.name,
id,
};
id += 1;
tls_idx += 1;
}
array[id] = Variant {
name: Str::new("QUIC_UNKNOWN_ERROR\0"),
id,
};
array
};
#[inline]
fn variant_idx(&self) -> usize {
let mut idx = 0;
let code = self.0.as_u64();
$(
if code == $code {
return idx;
}
idx += 1;
)*
if CRYPTO_ERROR_RANGE.contains(&code) {
return tls::Error::new(code as _).variant_idx() + idx;
}
idx + tls::Error::VARIANTS.len()
}
}
impl Error {
$(
$(#[doc = $doc])*
pub const $name: Self = Self::new(VarInt::from_u32($code))
$( .with_frame_type(VarInt::from_u32($frame)) )?;
)*
pub fn description(&self) -> Option<&'static str> {
self.code.description()
}
}
#[test]
fn description_test() {
$(
assert_eq!(&Error::$name.to_string(), stringify!($name));
)*
assert_eq!(&Error::from(tls::Error::DECODE_ERROR).to_string(), "DECODE_ERROR");
}
#[test]
#[cfg_attr(miri, ignore)]
fn variants_test() {
use aggregate::AsVariant;
insta::assert_debug_snapshot!(Code::VARIANTS);
let mut seen = std::collections::HashSet::new();
for variant in Code::VARIANTS {
assert!(seen.insert(variant.id));
}
}
};
}
impl_errors! {
NO_ERROR = 0x0.with_frame_type(UNKNOWN_FRAME_TYPE),
INTERNAL_ERROR = 0x1.with_frame_type(UNKNOWN_FRAME_TYPE),
CONNECTION_REFUSED = 0x2.with_frame_type(UNKNOWN_FRAME_TYPE),
FLOW_CONTROL_ERROR = 0x3.with_frame_type(UNKNOWN_FRAME_TYPE),
STREAM_LIMIT_ERROR = 0x4.with_frame_type(UNKNOWN_FRAME_TYPE),
STREAM_STATE_ERROR = 0x5.with_frame_type(UNKNOWN_FRAME_TYPE),
FINAL_SIZE_ERROR = 0x6.with_frame_type(UNKNOWN_FRAME_TYPE),
FRAME_ENCODING_ERROR = 0x7.with_frame_type(UNKNOWN_FRAME_TYPE),
TRANSPORT_PARAMETER_ERROR = 0x8.with_frame_type(UNKNOWN_FRAME_TYPE),
CONNECTION_ID_LIMIT_ERROR = 0x9.with_frame_type(UNKNOWN_FRAME_TYPE),
PROTOCOL_VIOLATION = 0xA.with_frame_type(UNKNOWN_FRAME_TYPE),
INVALID_TOKEN = 0xB.with_frame_type(UNKNOWN_FRAME_TYPE),
APPLICATION_ERROR = 0xC,
CRYPTO_BUFFER_EXCEEDED = 0xD.with_frame_type(UNKNOWN_FRAME_TYPE),
KEY_UPDATE_ERROR = 0xe.with_frame_type(UNKNOWN_FRAME_TYPE),
AEAD_LIMIT_REACHED = 0xf.with_frame_type(UNKNOWN_FRAME_TYPE),
}
impl Error {
#[inline]
pub const fn crypto_error(code: u8) -> Self {
Self::new(VarInt::from_u16(0x100 | (code as u16)))
.with_frame_type(VarInt::from_u32(UNKNOWN_FRAME_TYPE))
}
#[inline]
pub fn try_into_tls_error(self) -> Option<tls::Error> {
let code = self.code.as_u64();
if (0x100..=0x1ff).contains(&code) {
Some(tls::Error::new(code as u8).with_reason(self.reason))
} else {
None
}
}
}
impl Error {
#[inline]
pub const fn application_error(code: VarInt) -> Self {
Self::new(code)
}
}
impl From<DecoderError> for Error {
fn from(decoder_error: DecoderError) -> Self {
match decoder_error {
DecoderError::InvariantViolation(reason) => {
Self::PROTOCOL_VIOLATION.with_reason(reason)
}
_ => Self::PROTOCOL_VIOLATION.with_reason("malformed packet"),
}
}
}
impl From<tls::Error> for Error {
fn from(tls_error: tls::Error) -> Self {
Self::crypto_error(tls_error.code).with_reason(tls_error.reason)
}
}
impl From<VarIntError> for Error {
fn from(_: VarIntError) -> Self {
Self::INTERNAL_ERROR.with_reason("varint encoding limit exceeded")
}
}