use derive_deftly::Deftly;
use tor_bytes::{EncodeError, EncodeResult, Readable, Reader, Result, Writeable, Writer};
use tor_memquota::{HasMemoryCostStructural, derive_deftly_template_HasMemoryCost};
#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, Deftly)]
#[derive_deftly(HasMemoryCost)]
#[deftly(has_memory_cost(bounds = "T: HasMemoryCostStructural"))]
pub(super) struct ExtList<T> {
pub(super) extensions: Vec<T>,
}
impl<T> Default for ExtList<T> {
fn default() -> Self {
Self {
extensions: Vec::new(),
}
}
}
#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, derive_more::From)]
pub(super) struct ExtListRef<'a, T> {
extensions: &'a [T],
}
pub(super) trait ExtGroup: Readable + Writeable {
type Id: From<u8> + Into<u8> + Eq + PartialEq + Ord + Copy;
fn type_id(&self) -> Self::Id;
}
pub(super) trait Ext: Sized {
type Id: From<u8> + Into<u8>;
fn type_id(&self) -> Self::Id;
fn take_body_from(b: &mut Reader<'_>) -> Result<Self>;
fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()>;
}
impl<T: ExtGroup> Readable for ExtList<T> {
fn take_from(b: &mut Reader<'_>) -> Result<Self> {
let n_extensions = b.take_u8()?;
let extensions: Result<Vec<T>> = (0..n_extensions).map(|_| b.extract::<T>()).collect();
Ok(Self {
extensions: extensions?,
})
}
}
impl<'a, T: ExtGroup> Writeable for ExtListRef<'a, T> {
fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
let n_extensions = self
.extensions
.len()
.try_into()
.map_err(|_| EncodeError::BadLengthValue)?;
b.write_u8(n_extensions);
let mut exts_sorted: Vec<&T> = self.extensions.iter().collect();
exts_sorted.sort_by_key(|ext| ext.type_id());
exts_sorted.iter().try_for_each(|ext| ext.write_onto(b))?;
Ok(())
}
}
impl<T: ExtGroup> Writeable for ExtList<T> {
fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
ExtListRef::from(&self.extensions[..]).write_onto(b)
}
}
impl<T: ExtGroup> ExtList<T> {
#[cfg(feature = "hs")] pub(super) fn replace_by_type(&mut self, ext: T) {
self.retain(|e| e.type_id() != ext.type_id());
self.push(ext);
}
pub(super) fn into_vec(self) -> Vec<T> {
self.extensions
}
}
#[derive(Clone, Debug, Deftly, Eq, PartialEq)]
#[derive_deftly(HasMemoryCost)]
#[deftly(has_memory_cost(bounds = "ID: Copy + 'static"))]
pub struct UnrecognizedExt<ID> {
#[deftly(has_memory_cost(copy))]
pub(super) type_id: ID,
pub(super) body: Vec<u8>,
}
impl<ID> UnrecognizedExt<ID> {
pub fn new(type_id: ID, body: impl Into<Vec<u8>>) -> Self {
Self {
type_id,
body: body.into(),
}
}
}
macro_rules! decl_extension_group {
{
$( #[$meta:meta] )*
$v:vis enum $id:ident [ $type_id:ty ] {
$(
$(#[$cmeta:meta])*
$([feature: #[$fmeta:meta]])?
$case:ident),*
$(,)?
}
} => {paste::paste!{
$( #[$meta] )*
$v enum $id {
$( $(#[$cmeta])*
$( #[$fmeta] )?
$case($case),
)*
Unrecognized(crate::relaycell::extlist::UnrecognizedExt<$type_id>)
}
impl tor_bytes::Readable for $id {
fn take_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
#[allow(unused)]
use crate::relaycell::extlist::Ext as _;
let type_id = b.take_u8()?.into();
Ok(match type_id {
$(
$( #[$fmeta] )?
$type_id::[< $case:snake:upper >] => {
Self::$case( b.read_nested_u8len(|r| $case::take_body_from(r))? )
}
)*
_ => {
Self::Unrecognized(crate::relaycell::extlist::UnrecognizedExt {
type_id,
body: b.read_nested_u8len(|r| Ok(r.take_rest().into()))?,
})
}
})
}
}
impl tor_bytes::Writeable for $id {
fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> tor_bytes::EncodeResult<
()> {
#![allow(unused_imports)]
use crate::relaycell::extlist::Ext as _;
use tor_bytes::Writeable as _;
use std::ops::DerefMut;
match self {
$(
$( #[$fmeta] )?
Self::$case(val) => {
b.write_u8(val.type_id().into());
let mut nested = b.write_nested_u8len();
val.write_body_onto(nested.deref_mut())?;
nested.finish()?;
}
)*
Self::Unrecognized(unrecognized) => {
b.write_u8(unrecognized.type_id.into());
let mut nested = b.write_nested_u8len();
nested.write_all(&unrecognized.body[..]);
nested.finish()?;
}
}
Ok(())
}
}
impl crate::relaycell::extlist::ExtGroup for $id {
type Id = $type_id;
fn type_id(&self) -> Self::Id {
#![allow(unused_imports)]
use crate::relaycell::extlist::Ext as _;
match self {
$(
$( #[$fmeta] )?
Self::$case(val) => val.type_id(),
)*
Self::Unrecognized(unrecognized) => unrecognized.type_id,
}
}
}
$(
$( #[$fmeta] )?
impl From<$case> for $id {
fn from(val: $case) -> $id {
$id :: $case ( val )
}
}
)*
impl From<crate::relaycell::extlist::UnrecognizedExt<$type_id>> for $id {
fn from(val: crate::relaycell::extlist::UnrecognizedExt<$type_id>) -> $id {
$id :: Unrecognized(val)
}
}
}}
}
pub(super) use decl_extension_group;