use std::{hash::Hash, net::SocketAddr, num::ParseIntError, time::Duration};
use serde::{Deserialize, Serialize};
use unix_time::Instant as UnixInstant;
pub type PluginInputType = u32;
pub type PluginOutputType = i64;
pub type WASMPtr = u32;
pub type WASMLen = u32;
pub type APIResult = i64;
#[derive(Clone, Debug)]
pub enum ConversionError {
InvalidBool,
InvalidI32,
InvalidI64,
InvalidU32,
InvalidU64,
InvalidF32,
InvalidF64,
InvalidUsize,
InvalidBytes,
InvalidDuration,
InvalidInstant,
InvalidFrame,
InvalidFrameParam,
InvalidHeader,
InvalidSentPacket,
InvalidSocketAddr,
InvalidQVal,
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize, Deserialize, PartialOrd)]
pub enum PluginOp {
Init,
Test,
PluginControl(u64),
OnPluginTimeout(u64),
DecodeTransportParameter(u64),
WriteTransportParameter(u64),
LogFrame(u64),
NotifyFrame(u64),
OnFrameReserved(u64),
ParseFrame(u64),
PrepareFrame(u64),
ProcessFrame(u64),
ShouldSendFrame(u64),
WireLen(u64),
WriteFrame(u64),
#[doc(hidden)]
UpdateRtt,
Other([u8; 32]),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Anchor {
Before,
Define,
After,
}
impl Anchor {
pub fn index(&self) -> usize {
match self {
Anchor::Before => 0,
Anchor::Define => 1,
Anchor::After => 2,
}
}
}
fn extract_po_param(name: &str) -> Result<u64, ParseIntError> {
let end_num = name.rfind('_').map(|i| &name[i + 1..]).unwrap_or("");
u64::from_str_radix(end_num, 16)
}
impl PluginOp {
pub fn from_name(name: &str) -> (PluginOp, Anchor) {
let (name, anchor) = if let Some(po_name) = name.strip_prefix("pre_") {
(po_name, Anchor::Before)
} else if let Some(po_name) = name.strip_prefix("before_") {
(po_name, Anchor::Before)
} else if let Some(po_name) = name.strip_prefix("post_") {
(po_name, Anchor::After)
} else if let Some(po_name) = name.strip_prefix("after_") {
(po_name, Anchor::After)
} else {
(name, Anchor::Define)
};
if name == "init" {
(PluginOp::Init, anchor)
} else if name.starts_with("decode_transport_parameter_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::DecodeTransportParameter(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("write_transport_parameter_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::WriteTransportParameter(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("log_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::LogFrame(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("notify_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::NotifyFrame(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("on_frame_reserved_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::OnFrameReserved(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("parse_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::ParseFrame(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("prepare_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::PrepareFrame(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("process_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::ProcessFrame(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("should_send_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::ShouldSendFrame(frame_type), anchor),
Err(_) => panic!("Invalid protocol operation name"),
}
} else if name.starts_with("wire_len_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::WireLen(frame_type), anchor),
Err(e) => panic!("Invalid protocol operation name: {e}"),
}
} else if name.starts_with("write_frame_") {
match extract_po_param(name) {
Ok(frame_type) => (PluginOp::WriteFrame(frame_type), anchor),
Err(e) => panic!("Invalid protocol operation name: {e}"),
}
} else if name.starts_with("plugin_control_") {
match extract_po_param(name) {
Ok(val) => (PluginOp::PluginControl(val), anchor),
Err(e) => panic!("Invalid protocol operation name: {e}"),
}
} else if name.starts_with("on_plugin_timeout_") {
match extract_po_param(name) {
Ok(val) => (PluginOp::OnPluginTimeout(val), anchor),
Err(e) => panic!("Invalid protocol operation name: {e}"),
}
} else if name == "update_rtt" {
(PluginOp::UpdateRtt, anchor)
} else {
let mut name_array = [0; 32];
name_array[..name.len()].copy_from_slice(name.as_bytes());
(PluginOp::Other(name_array), anchor)
}
}
pub fn always_enabled(&self) -> bool {
matches!(
self,
PluginOp::Init
| PluginOp::DecodeTransportParameter(_)
| PluginOp::WriteTransportParameter(_)
)
}
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq, PartialOrd, Eq, Ord)]
pub struct Bytes {
pub tag: u64,
pub max_read_len: u64,
pub max_write_len: u64,
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
pub enum PluginVal {
Bool(bool),
I32(i32),
I64(i64),
U32(u32),
U64(u64),
F32(f32),
F64(f64),
Usize(u64),
Bytes(Bytes),
Duration(Duration),
UNIXInstant(UnixInstant),
SocketAddr(SocketAddr),
QUIC(quic::QVal),
}
macro_rules! impl_from_try_from {
($e:ident, $v:ident, $t:ty, $err:ident, $verr:ident) => {
impl From<$t> for $e {
fn from(v: $t) -> Self {
$e::$v(v)
}
}
impl TryFrom<$e> for $t {
type Error = $err;
fn try_from(v: $e) -> Result<Self, Self::Error> {
match v {
$e::$v(v) => Ok(v),
_ => Err($err::$verr),
}
}
}
};
}
impl From<usize> for PluginVal {
fn from(value: usize) -> Self {
PluginVal::Usize(value as u64)
}
}
impl TryFrom<PluginVal> for usize {
type Error = ConversionError;
fn try_from(value: PluginVal) -> Result<Self, Self::Error> {
match value {
PluginVal::Usize(v) => Ok(v as Self),
_ => Err(ConversionError::InvalidUsize),
}
}
}
impl TryFrom<PluginVal> for () {
type Error = ConversionError;
fn try_from(_: PluginVal) -> Result<Self, Self::Error> {
Ok(())
}
}
impl_from_try_from!(PluginVal, Bool, bool, ConversionError, InvalidBool);
impl_from_try_from!(PluginVal, I32, i32, ConversionError, InvalidI32);
impl_from_try_from!(PluginVal, I64, i64, ConversionError, InvalidI64);
impl_from_try_from!(PluginVal, U32, u32, ConversionError, InvalidU32);
impl_from_try_from!(PluginVal, U64, u64, ConversionError, InvalidU64);
impl_from_try_from!(PluginVal, F32, f32, ConversionError, InvalidF32);
impl_from_try_from!(PluginVal, F64, f64, ConversionError, InvalidF64);
impl_from_try_from!(PluginVal, Bytes, Bytes, ConversionError, InvalidBytes);
impl_from_try_from!(
PluginVal,
Duration,
Duration,
ConversionError,
InvalidDuration
);
impl_from_try_from!(
PluginVal,
UNIXInstant,
UnixInstant,
ConversionError,
InvalidInstant
);
impl_from_try_from!(
PluginVal,
SocketAddr,
SocketAddr,
ConversionError,
InvalidSocketAddr
);
impl_from_try_from!(PluginVal, QUIC, quic::QVal, ConversionError, InvalidQVal);
pub mod quic;