use super::extlist::{Ext, ExtList, ExtListRef, decl_extension_group};
#[cfg(feature = "hs")]
use super::hs::pow::ProofOfWork;
use caret::caret_int;
use itertools::Itertools as _;
use tor_bytes::{EncodeResult, Reader, Writeable as _, Writer};
use tor_protover::NumberedSubver;
caret_int! {
#[derive(PartialOrd,Ord)]
pub struct CircRequestExtType(u8) {
CC_REQUEST = 1,
PROOF_OF_WORK = 2,
SUBPROTOCOL_REQUEST = 3,
}
}
caret_int! {
#[derive(PartialOrd,Ord)]
pub struct CircResponseExtType(u8) {
CC_RESPONSE = 2
}
}
#[derive(Clone, Debug, PartialEq, Eq, Default)]
#[non_exhaustive]
pub struct CcRequest {}
impl Ext for CcRequest {
type Id = CircRequestExtType;
fn type_id(&self) -> Self::Id {
CircRequestExtType::CC_REQUEST
}
fn take_body_from(_b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
Ok(Self {})
}
fn write_body_onto<B: Writer + ?Sized>(&self, _b: &mut B) -> EncodeResult<()> {
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CcResponse {
sendme_inc: u8,
}
impl CcResponse {
pub fn new(sendme_inc: u8) -> Self {
CcResponse { sendme_inc }
}
pub fn sendme_inc(&self) -> u8 {
self.sendme_inc
}
}
impl Ext for CcResponse {
type Id = CircResponseExtType;
fn type_id(&self) -> Self::Id {
CircResponseExtType::CC_RESPONSE
}
fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
let sendme_inc = b.take_u8()?;
Ok(Self { sendme_inc })
}
fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
b.write_u8(self.sendme_inc);
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SubprotocolRequest {
protocols: Vec<tor_protover::NumberedSubver>,
}
impl<A> FromIterator<A> for SubprotocolRequest
where
A: Into<tor_protover::NumberedSubver>,
{
fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
let mut protocols: Vec<_> = iter.into_iter().map(Into::into).collect();
protocols.sort();
protocols.dedup();
Self { protocols }
}
}
impl Ext for SubprotocolRequest {
type Id = CircRequestExtType;
fn type_id(&self) -> Self::Id {
CircRequestExtType::SUBPROTOCOL_REQUEST
}
fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
let mut protocols = Vec::new();
while b.remaining() != 0 {
protocols.push(b.extract()?);
}
if !is_strictly_ascending(&protocols) {
return Err(tor_bytes::Error::InvalidMessage(
"SubprotocolRequest not sorted and deduplicated.".into(),
));
}
Ok(Self { protocols })
}
fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
for p in self.protocols.iter() {
b.write(p)?;
}
Ok(())
}
}
impl SubprotocolRequest {
pub fn contains(&self, cap: tor_protover::NamedSubver) -> bool {
self.protocols.binary_search(&cap.into()).is_ok()
}
pub fn contains_only(&self, list: &tor_protover::Protocols) -> bool {
self.protocols
.iter()
.all(|p| list.supports_numbered_subver(*p))
}
}
decl_extension_group! {
#[derive(Debug,Clone,PartialEq)]
#[non_exhaustive]
pub enum CircRequestExt [ CircRequestExtType ] {
CcRequest,
[ feature: #[cfg(feature = "hs")] ]
ProofOfWork,
SubprotocolRequest,
}
}
decl_extension_group! {
#[derive(Debug,Clone,PartialEq)]
#[non_exhaustive]
pub enum CircResponseExt [ CircResponseExtType ] {
CcResponse,
}
}
macro_rules! impl_encode_decode {
($extgroup:ty, $name:expr) => {
impl $extgroup {
pub fn write_many_onto<W: Writer>(exts: &[Self], out: &mut W) -> EncodeResult<()> {
ExtListRef::from(exts).write_onto(out)?;
Ok(())
}
pub fn decode(message: &[u8]) -> crate::Result<Vec<Self>> {
let err_cvt = |err| crate::Error::BytesErr { err, parsed: $name };
let mut r = tor_bytes::Reader::from_slice(message);
let list: ExtList<_> = r.extract().map_err(err_cvt)?;
r.should_be_exhausted().map_err(err_cvt)?;
Ok(list.into_vec())
}
}
};
}
impl_encode_decode!(CircRequestExt, "CREATE2 extension list");
impl_encode_decode!(CircResponseExt, "CREATED2 extension list");
fn is_strictly_ascending(vers: &[NumberedSubver]) -> bool {
vers.iter().tuple_windows().all(|(a, b)| a < b)
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
#[test]
fn subproto_ext_valid() {
use tor_protover::named::*;
let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
let mut v = Vec::new();
sp.write_body_onto(&mut v).unwrap();
assert_eq!(&v[..], [0, 4, 2, 4]);
let mut r = Reader::from_slice(&v[..]);
let sp2: SubprotocolRequest = SubprotocolRequest::take_body_from(&mut r).unwrap();
assert_eq!(sp, sp2);
}
#[test]
fn subproto_invalid() {
let mut r = Reader::from_slice(&[0, 4, 2]);
let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
dbg!(e.to_string());
assert!(e.to_string().contains("too short"));
let mut r = Reader::from_slice(&[0, 4, 0, 4]);
let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
dbg!(e.to_string());
assert!(e.to_string().contains("deduplicated"));
let mut r = Reader::from_slice(&[2, 4, 0, 4]);
let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
dbg!(e.to_string());
assert!(e.to_string().contains("sorted"));
}
#[test]
fn subproto_supported() {
use tor_protover::named::*;
let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
assert!(sp.contains(LINK_V4));
assert!(!sp.contains(LINK_V2));
assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4, CONFLUX_BASE].into_iter().collect()));
assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4].into_iter().collect()));
assert!(!sp.contains_only(&[LINK_V4].into_iter().collect()));
assert!(!sp.contains_only(&[LINK_V4, CONFLUX_BASE].into_iter().collect()));
assert!(!sp.contains_only(&[CONFLUX_BASE].into_iter().collect()));
}
}