use bytes::BytesMut;
use tor_cell::chancell::{AnyChanCell, ChanCell, ChanMsg, codec, msg::AnyChanMsg};
use crate::{Error, channel::ChannelType};
pub(super) mod linkv4 {
use bytes::BytesMut;
use tor_cell::{
chancell::{AnyChanCell, codec},
restricted_msg,
};
use super::MessageStage;
use crate::{
Error,
channel::{
ChannelType,
msg::{decode_as_any, encode_as_any},
},
};
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum HandshakeRelayInitiatorMsg: ChanMsg {
Authenticate,
Certs,
Netinfo,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum HandshakeRelayResponderMsg: ChanMsg {
AuthChallenge,
Certs,
Netinfo,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum HandshakeClientInitiatorMsg: ChanMsg {
Netinfo,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum OpenChanMsgC2R: ChanMsg {
Create2,
CreateFast,
Destroy,
Padding,
Vpadding,
Relay,
RelayEarly,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum OpenChanMsgR2C : ChanMsg {
CreatedFast,
Created2,
Relay,
Destroy,
Padding,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum OpenChanMsgR2R : ChanMsg {
CreateFast,
CreatedFast,
Create2,
Created2,
Destroy,
Padding,
Vpadding,
Relay,
RelayEarly,
}
}
pub(super) fn decode_cell(
chan_type: ChannelType,
stage: &MessageStage,
codec: &mut codec::ChannelCodec,
src: &mut BytesMut,
) -> Result<Option<AnyChanCell>, Error> {
use ChannelType::*;
use MessageStage::*;
let decode_fn = match (chan_type, stage) {
(ClientInitiator, Handshake) => decode_as_any::<HandshakeRelayResponderMsg>,
(ClientInitiator, Open) => decode_as_any::<OpenChanMsgR2C>,
(RelayInitiator, Handshake) => decode_as_any::<HandshakeRelayResponderMsg>,
(RelayInitiator, Open) => decode_as_any::<OpenChanMsgR2R>,
(RelayResponder { authenticated: _ }, Handshake) => {
decode_as_any::<HandshakeRelayInitiatorMsg>
}
(RelayResponder { authenticated }, Open) => match authenticated {
false => decode_as_any::<OpenChanMsgC2R>,
true => decode_as_any::<OpenChanMsgR2R>,
},
};
decode_fn(stage, codec, src)
}
pub(super) fn encode_cell(
chan_type: ChannelType,
stage: &MessageStage,
cell: AnyChanCell,
codec: &mut codec::ChannelCodec,
dst: &mut BytesMut,
) -> Result<(), Error> {
use ChannelType::*;
use MessageStage::*;
let encode_fn = match (chan_type, stage) {
(ClientInitiator, Handshake) => encode_as_any::<HandshakeClientInitiatorMsg>,
(ClientInitiator, Open) => encode_as_any::<OpenChanMsgC2R>,
(RelayInitiator, Handshake) => encode_as_any::<HandshakeRelayInitiatorMsg>,
(RelayInitiator, Open) => encode_as_any::<OpenChanMsgR2R>,
(RelayResponder { authenticated: _ }, Handshake) => {
encode_as_any::<HandshakeRelayResponderMsg>
}
(RelayResponder { authenticated }, Open) => match authenticated {
false => encode_as_any::<OpenChanMsgR2C>,
true => encode_as_any::<OpenChanMsgR2R>,
},
};
encode_fn(stage, cell, codec, dst)
}
}
pub(super) mod linkv5 {
use bytes::BytesMut;
use tor_cell::{
chancell::{AnyChanCell, codec},
restricted_msg,
};
use super::MessageStage;
use crate::{
Error,
channel::{
ChannelType,
msg::{decode_as_any, encode_as_any},
},
};
restricted_msg! {
#[derive(Clone,Debug)]
pub(super) enum HandshakeRelayInitiatorMsg: ChanMsg {
Authenticate,
Certs,
Netinfo,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone,Debug)]
pub(super) enum HandshakeRelayResponderMsg: ChanMsg {
AuthChallenge,
Certs,
Netinfo,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone,Debug)]
pub(super) enum HandshakeClientInitiatorMsg: ChanMsg {
Netinfo,
Vpadding,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum OpenChanMsgC2R: ChanMsg {
Create2,
CreateFast,
Destroy,
Padding,
PaddingNegotiate,
Vpadding,
Relay,
RelayEarly,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum OpenChanMsgR2C : ChanMsg {
CreatedFast,
Created2,
Destroy,
Padding,
Vpadding,
Relay,
}
}
restricted_msg! {
#[derive(Clone, Debug)]
pub(super) enum OpenChanMsgR2R : ChanMsg {
CreateFast,
CreatedFast,
Create2,
Created2,
Destroy,
Padding,
Vpadding,
Relay,
RelayEarly,
}
}
pub(super) fn decode_cell(
chan_type: ChannelType,
stage: &MessageStage,
codec: &mut codec::ChannelCodec,
src: &mut BytesMut,
) -> Result<Option<AnyChanCell>, Error> {
use ChannelType::*;
use MessageStage::*;
match (chan_type, stage) {
(ClientInitiator, Handshake) => {
decode_as_any::<HandshakeRelayResponderMsg>(stage, codec, src)
}
(ClientInitiator, Open) => decode_as_any::<OpenChanMsgR2C>(stage, codec, src),
(RelayInitiator, Handshake) => {
decode_as_any::<HandshakeRelayResponderMsg>(stage, codec, src)
}
(RelayInitiator, Open) => decode_as_any::<OpenChanMsgR2R>(stage, codec, src),
(RelayResponder { authenticated: _ }, Handshake) => {
decode_as_any::<HandshakeRelayInitiatorMsg>(stage, codec, src)
}
(
RelayResponder {
authenticated: false,
},
Open,
) => decode_as_any::<OpenChanMsgC2R>(stage, codec, src),
(
RelayResponder {
authenticated: true,
},
Open,
) => decode_as_any::<OpenChanMsgR2R>(stage, codec, src),
}
}
pub(super) fn encode_cell(
chan_type: ChannelType,
stage: &MessageStage,
cell: AnyChanCell,
codec: &mut codec::ChannelCodec,
dst: &mut BytesMut,
) -> Result<(), Error> {
use ChannelType::*;
use MessageStage::*;
match (chan_type, stage) {
(ClientInitiator, Handshake) => {
encode_as_any::<HandshakeClientInitiatorMsg>(stage, cell, codec, dst)
}
(ClientInitiator, Open) => encode_as_any::<OpenChanMsgC2R>(stage, cell, codec, dst),
(RelayInitiator, Handshake) => {
encode_as_any::<HandshakeRelayInitiatorMsg>(stage, cell, codec, dst)
}
(RelayInitiator, Open) => encode_as_any::<OpenChanMsgR2R>(stage, cell, codec, dst),
(RelayResponder { authenticated: _ }, Handshake) => {
encode_as_any::<HandshakeRelayResponderMsg>(stage, cell, codec, dst)
}
(
RelayResponder {
authenticated: false,
},
Open,
) => encode_as_any::<OpenChanMsgR2C>(stage, cell, codec, dst),
(
RelayResponder {
authenticated: true,
},
Open,
) => encode_as_any::<OpenChanMsgR2R>(stage, cell, codec, dst),
}
}
}
fn decode_as_any<R>(
stage: &MessageStage,
codec: &mut codec::ChannelCodec,
src: &mut BytesMut,
) -> Result<Option<AnyChanCell>, Error>
where
R: Into<AnyChanMsg> + ChanMsg,
{
codec
.decode_cell::<R>(src)
.map(|opt| {
opt.map(|cell| {
let (circid, msg) = cell.into_circid_and_msg();
ChanCell::new(circid, msg.into())
})
})
.map_err(|e| stage.to_err(format!("Decoding cell error: {e}")))
}
fn encode_as_any<R>(
stage: &MessageStage,
cell: AnyChanCell,
codec: &mut codec::ChannelCodec,
dst: &mut BytesMut,
) -> Result<(), Error>
where
R: ChanMsg + TryFrom<AnyChanMsg, Error = AnyChanMsg>,
{
let (circ_id, any_msg) = cell.into_circid_and_msg();
match R::try_from(any_msg) {
Ok(rmsg) => {
let rcell: ChanCell<R> = ChanCell::new(circ_id, rmsg);
codec
.write_cell(rcell, dst)
.map_err(|e| stage.to_err(format!("Encoding cell error: {e}")))
}
Err(m) => Err(stage.to_err(format!("Disallowed cell command {}", m.cmd(),))),
}
}
#[derive(Copy, Clone, Debug)]
pub(super) enum LinkVersion {
V4,
V5,
}
impl LinkVersion {
pub(super) fn value(&self) -> u16 {
match self {
Self::V4 => 4,
Self::V5 => 5,
}
}
}
impl TryFrom<u16> for LinkVersion {
type Error = Error;
fn try_from(value: u16) -> Result<Self, Self::Error> {
Ok(match value {
4 => Self::V4,
5 => Self::V5,
_ => {
return Err(Error::HandshakeProto(format!(
"Unknown link version {value}"
)));
}
})
}
}
pub(super) enum MessageStage {
Handshake,
Open,
}
impl MessageStage {
fn to_err(&self, msg: String) -> Error {
match self {
Self::Handshake => Error::HandshakeProto(msg),
Self::Open => Error::ChanProto(msg),
}
}
}
pub(super) struct MessageFilter {
link_version: LinkVersion,
channel_type: ChannelType,
stage: MessageStage,
}
impl MessageFilter {
pub(super) fn new(
link_version: LinkVersion,
channel_type: ChannelType,
stage: MessageStage,
) -> Self {
Self {
link_version,
channel_type,
stage,
}
}
pub(super) fn channel_type(&self) -> ChannelType {
self.channel_type
}
pub(super) fn channel_type_mut(&mut self) -> &mut ChannelType {
&mut self.channel_type
}
pub(super) fn decode_cell(
&self,
codec: &mut codec::ChannelCodec,
src: &mut BytesMut,
) -> Result<Option<AnyChanCell>, Error> {
match self.link_version {
LinkVersion::V4 => linkv4::decode_cell(self.channel_type, &self.stage, codec, src),
LinkVersion::V5 => linkv5::decode_cell(self.channel_type, &self.stage, codec, src),
}
}
pub(super) fn encode_cell(
&self,
cell: AnyChanCell,
codec: &mut codec::ChannelCodec,
dst: &mut BytesMut,
) -> Result<(), Error> {
match self.link_version {
LinkVersion::V4 => {
linkv4::encode_cell(self.channel_type, &self.stage, cell, codec, dst)
}
LinkVersion::V5 => {
linkv5::encode_cell(self.channel_type, &self.stage, cell, codec, dst)
}
}
}
}