use digest::Digest;
use tor_bytes::Reader;
use tor_cell::chancell::{
AnyChanCell, ChanCell, ChanCmd, ChanMsg, codec,
msg::{self, AnyChanMsg},
};
use tor_error::internal;
use tor_llcrypto as ll;
use asynchronous_codec as futures_codec;
use bytes::BytesMut;
use crate::{channel::msg::LinkVersion, util::err::Error as ChanError};
use super::{ChannelType, msg::MessageFilter};
pub(crate) enum ChannelCellHandler {
New(NewChannelHandler),
Handshake(HandshakeChannelHandler),
Open(OpenChannelHandler),
}
impl From<super::ChannelType> for ChannelCellHandler {
fn from(ty: ChannelType) -> Self {
Self::New(ty.into())
}
}
impl ChannelCellHandler {
pub(crate) fn channel_type(&self) -> ChannelType {
match self {
Self::New(h) => h.channel_type,
Self::Handshake(h) => h.channel_type(),
Self::Open(h) => h.channel_type(),
}
}
pub(crate) fn set_link_version(&mut self, link_version: u16) -> Result<(), ChanError> {
let Self::New(new_handler) = self else {
return Err(ChanError::Bug(internal!(
"Setting link protocol without a new handler",
)));
};
*self = Self::Handshake(new_handler.next_handler(link_version.try_into()?));
Ok(())
}
pub(crate) fn set_open(&mut self) -> Result<(), ChanError> {
let Self::Handshake(handler) = self else {
return Err(ChanError::Bug(internal!(
"Setting open without a handshake handler"
)));
};
*self = Self::Open(handler.next_handler());
Ok(())
}
pub(crate) fn set_authenticated(&mut self) -> Result<(), ChanError> {
let Self::Handshake(handler) = self else {
return Err(ChanError::Bug(internal!(
"Setting authenticated without a handshake handler"
)));
};
handler.set_authenticated();
Ok(())
}
pub(crate) fn take_send_log_digest(&mut self) -> Result<[u8; 32], ChanError> {
if let Self::Handshake(handler) = self {
handler
.take_send_log_digest()
.ok_or(ChanError::Bug(internal!(
"No send log digest on channel, or already taken"
)))
} else {
Err(ChanError::Bug(internal!(
"Getting send log digest without a handshake handler"
)))
}
}
pub(crate) fn take_recv_log_digest(&mut self) -> Result<[u8; 32], ChanError> {
if let Self::Handshake(handler) = self {
handler
.take_recv_log_digest()
.ok_or(ChanError::Bug(internal!(
"No recv log digest on channel, or already taken"
)))
} else {
Err(ChanError::Bug(internal!(
"Getting recv log digest without a handshake handler"
)))
}
}
}
impl futures_codec::Decoder for ChannelCellHandler {
type Item = AnyChanCell;
type Error = ChanError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self {
Self::New(c) => c
.decode(src)
.map(|opt| opt.map(|msg| ChanCell::new(None, msg.into()))),
Self::Handshake(c) => c.decode(src),
Self::Open(c) => c.decode(src),
}
}
}
impl futures_codec::Encoder for ChannelCellHandler {
type Item<'a> = AnyChanCell;
type Error = ChanError;
fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
match self {
Self::New(c) => {
let AnyChanMsg::Versions(versions) = item.into_circid_and_msg().1 else {
return Err(Self::Error::HandshakeProto(
"Non VERSIONS cell for new handler".into(),
));
};
c.encode(versions, dst)
}
Self::Handshake(c) => c.encode(item, dst),
Self::Open(c) => c.encode(item, dst),
}
}
}
pub(crate) struct NewChannelHandler {
channel_type: ChannelType,
send_log: Option<ll::d::Sha256>,
recv_log: Option<ll::d::Sha256>,
}
impl NewChannelHandler {
fn next_handler(&mut self, link_version: LinkVersion) -> HandshakeChannelHandler {
HandshakeChannelHandler::new(self, link_version)
}
}
impl From<ChannelType> for NewChannelHandler {
fn from(channel_type: ChannelType) -> Self {
match channel_type {
ChannelType::ClientInitiator => Self {
channel_type,
send_log: None,
recv_log: None,
},
ChannelType::RelayInitiator | ChannelType::RelayResponder { .. } => Self {
channel_type,
send_log: Some(ll::d::Sha256::new()),
recv_log: Some(ll::d::Sha256::new()),
},
}
}
}
impl futures_codec::Decoder for NewChannelHandler {
type Item = msg::Versions;
type Error = ChanError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
const HEADER_SIZE: usize = 5;
if src.len() < HEADER_SIZE {
return Ok(None);
}
let circ_id = u16::from_be_bytes([src[0], src[1]]);
if circ_id != 0 {
return Err(Self::Error::HandshakeProto(
"Invalid CircID in variable cell".into(),
));
}
let cmd = ChanCmd::from(src[2]);
if cmd != ChanCmd::VERSIONS {
return Err(Self::Error::HandshakeProto(format!(
"Invalid command {cmd} variable cell, expected a VERSIONS."
)));
}
let body_len = u16::from_be_bytes([src[3], src[4]]) as usize;
if body_len % 2 == 1 {
return Err(Self::Error::HandshakeProto(
"VERSIONS cell body length is odd. Rejecting.".into(),
));
}
let wanted_bytes = HEADER_SIZE + body_len;
if src.len() < wanted_bytes {
return Ok(None);
}
let mut data = src.split_to(wanted_bytes);
if let Some(recv_log) = self.recv_log.as_mut() {
recv_log.update(&data);
}
let body = data.split_off(HEADER_SIZE).freeze();
let mut reader = Reader::from_bytes(&body);
let cell = msg::Versions::decode_from_reader(cmd, &mut reader)
.map_err(|e| Self::Error::from_bytes_err(e, "new cell handler"))?;
Ok(Some(cell))
}
}
impl futures_codec::Encoder for NewChannelHandler {
type Item<'a> = msg::Versions;
type Error = ChanError;
fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
let encoded_bytes = item
.encode_for_handshake()
.map_err(|e| Self::Error::from_bytes_enc(e, "new cell handler"))?;
if let Some(send_log) = self.send_log.as_mut() {
send_log.update(&encoded_bytes);
}
dst.extend_from_slice(&encoded_bytes);
Ok(())
}
}
pub(crate) struct HandshakeChannelHandler {
filter: MessageFilter,
inner: codec::ChannelCodec,
send_log: Option<ll::d::Sha256>,
recv_log: Option<ll::d::Sha256>,
}
impl HandshakeChannelHandler {
fn new(new_handler: &mut NewChannelHandler, link_version: LinkVersion) -> Self {
Self {
filter: MessageFilter::new(
link_version,
new_handler.channel_type,
super::msg::MessageStage::Handshake,
),
send_log: new_handler.send_log.take(),
recv_log: new_handler.recv_log.take(),
inner: codec::ChannelCodec::new(link_version.value()),
}
}
fn finalize_log(log: Option<ll::d::Sha256>) -> Option<[u8; 32]> {
log.map(|sha256| sha256.finalize().into())
}
fn next_handler(&mut self) -> OpenChannelHandler {
OpenChannelHandler::new(
self.inner
.link_version()
.try_into()
.expect("Channel Codec with unknown link version"),
self.channel_type(),
)
}
pub(crate) fn take_send_log_digest(&mut self) -> Option<[u8; 32]> {
Self::finalize_log(self.send_log.take())
}
pub(crate) fn take_recv_log_digest(&mut self) -> Option<[u8; 32]> {
Self::finalize_log(self.recv_log.take())
}
pub(crate) fn channel_type(&self) -> ChannelType {
self.filter.channel_type()
}
pub(crate) fn set_authenticated(&mut self) {
self.filter.channel_type_mut().set_authenticated();
}
}
impl futures_codec::Encoder for HandshakeChannelHandler {
type Item<'a> = AnyChanCell;
type Error = ChanError;
fn encode(
&mut self,
item: Self::Item<'_>,
dst: &mut BytesMut,
) -> std::result::Result<(), Self::Error> {
let before_dst_len = dst.len();
self.filter.encode_cell(item, &mut self.inner, dst)?;
let after_dst_len = dst.len();
if let Some(send_log) = self.send_log.as_mut() {
send_log.update(&dst[before_dst_len..after_dst_len]);
}
Ok(())
}
}
impl futures_codec::Decoder for HandshakeChannelHandler {
type Item = AnyChanCell;
type Error = ChanError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
let orig = src.clone(); let cell = self.filter.decode_cell(&mut self.inner, src)?;
if let Some(recv_log) = self.recv_log.as_mut() {
let n_used = orig.len() - src.len();
recv_log.update(&orig[..n_used]);
}
Ok(cell)
}
}
pub(crate) struct OpenChannelHandler {
filter: MessageFilter,
inner: codec::ChannelCodec,
}
impl OpenChannelHandler {
fn new(link_version: LinkVersion, channel_type: ChannelType) -> Self {
Self {
inner: codec::ChannelCodec::new(link_version.value()),
filter: MessageFilter::new(link_version, channel_type, super::msg::MessageStage::Open),
}
}
fn channel_type(&self) -> ChannelType {
self.filter.channel_type()
}
}
impl futures_codec::Encoder for OpenChannelHandler {
type Item<'a> = AnyChanCell;
type Error = ChanError;
fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.filter.encode_cell(item, &mut self.inner, dst)
}
}
impl futures_codec::Decoder for OpenChannelHandler {
type Item = AnyChanCell;
type Error = ChanError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.filter.decode_cell(&mut self.inner, src)
}
}
#[cfg(test)]
pub(crate) mod test {
#![allow(clippy::unwrap_used)]
use bytes::BytesMut;
use digest::Digest;
use futures::io::{AsyncRead, AsyncWrite, Cursor, Result};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use futures::task::{Context, Poll};
use hex_literal::hex;
use std::pin::Pin;
use tor_bytes::Writer;
use tor_llcrypto as ll;
use tor_rtcompat::StreamOps;
use crate::channel::msg::LinkVersion;
use crate::channel::{ChannelType, new_frame};
use super::{ChannelCellHandler, OpenChannelHandler, futures_codec};
use tor_cell::chancell::{AnyChanCell, ChanCmd, ChanMsg, CircId, msg};
pub(crate) struct MsgBuf {
inbuf: futures::io::Cursor<Vec<u8>>,
outbuf: futures::io::Cursor<Vec<u8>>,
}
impl AsyncRead for MsgBuf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inbuf).poll_read(cx, buf)
}
}
impl AsyncWrite for MsgBuf {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.outbuf).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.outbuf).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.outbuf).poll_close(cx)
}
}
impl StreamOps for MsgBuf {}
impl MsgBuf {
pub(crate) fn new<T: Into<Vec<u8>>>(output: T) -> Self {
let inbuf = Cursor::new(output.into());
let outbuf = Cursor::new(Vec::new());
MsgBuf { inbuf, outbuf }
}
pub(crate) fn consumed(&self) -> usize {
self.inbuf.position() as usize
}
pub(crate) fn all_consumed(&self) -> bool {
self.inbuf.get_ref().len() == self.consumed()
}
pub(crate) fn into_response(self) -> Vec<u8> {
self.outbuf.into_inner()
}
}
fn new_client_open_frame(mbuf: MsgBuf) -> futures_codec::Framed<MsgBuf, ChannelCellHandler> {
let open_handler = ChannelCellHandler::Open(OpenChannelHandler::new(
LinkVersion::V5,
ChannelType::ClientInitiator,
));
futures_codec::Framed::new(mbuf, open_handler)
}
#[test]
fn check_client_encoding() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
let mb = MsgBuf::new(&b""[..]);
let mut framed = new_client_open_frame(mb);
let destroycell = msg::Destroy::new(2.into());
framed
.send(AnyChanCell::new(CircId::new(7), destroycell.into()))
.await
.unwrap();
framed.flush().await.unwrap();
let data = framed.into_inner().into_response();
assert_eq!(&data[0..10], &hex!("00000007 04 0200000000")[..]);
});
}
#[test]
fn check_client_decoding() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
let mut dat = Vec::new();
dat.extend_from_slice(&hex!("00000007 04 0200000000")[..]);
dat.resize(514, 0);
let mb = MsgBuf::new(&dat[..]);
let mut framed = new_client_open_frame(mb);
let destroy = framed.next().await.unwrap().unwrap();
let circ_id = CircId::new(7);
assert_eq!(destroy.circid(), circ_id);
assert_eq!(destroy.msg().cmd(), ChanCmd::DESTROY);
assert!(framed.into_inner().all_consumed());
});
}
#[test]
fn handler_transition() {
let mut handler: ChannelCellHandler = ChannelType::ClientInitiator.into();
assert!(matches!(handler, ChannelCellHandler::New(_)));
let r = handler.set_link_version(5);
assert!(r.is_ok());
assert!(matches!(handler, ChannelCellHandler::Handshake(_)));
let r = handler.set_open();
assert!(r.is_ok());
assert!(matches!(handler, ChannelCellHandler::Open(_)));
}
#[test]
fn clog_digest() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
let mut our_clog = ll::d::Sha256::new();
let mbuf = MsgBuf::new(*b"");
let mut frame = new_frame(mbuf, ChannelType::RelayInitiator);
our_clog.update(hex!("0000 07 0002 0005"));
let version_cell = AnyChanCell::new(
None,
msg::Versions::new(vec![5]).expect("Fail VERSIONS").into(),
);
let _ = frame.send(version_cell).await.unwrap();
frame
.codec_mut()
.set_link_version(5)
.expect("Fail link version set");
our_clog.update(hex!("0000 0000 81 0001 00"));
let certs_cell = msg::Certs::new_empty();
frame
.send(AnyChanCell::new(None, certs_cell.into()))
.await
.unwrap();
let clog_hash: [u8; 32] = our_clog.finalize().into();
assert_eq!(frame.codec_mut().take_send_log_digest().unwrap(), clog_hash);
});
}
#[test]
fn slog_digest() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
let mut our_slog = ll::d::Sha256::new();
let mut data = BytesMut::new();
data.extend_from_slice(
msg::Versions::new(vec![5])
.unwrap()
.encode_for_handshake()
.expect("Fail VERSIONS encoding")
.as_slice(),
);
our_slog.update(&data);
let mbuf = MsgBuf::new(data);
let mut frame = new_frame(mbuf, ChannelType::RelayInitiator);
let _ = frame.next().await.transpose().expect("Fail to get cell");
frame
.codec_mut()
.set_link_version(5)
.expect("Fail link version set");
let mut data = BytesMut::new();
data.write_u32(0);
data.write_u8(ChanCmd::AUTH_CHALLENGE.into());
data.write_u16(36); msg::AuthChallenge::new([42_u8; 32], vec![3])
.encode_onto(&mut data)
.expect("Fail AUTH_CHALLENGE encoding");
our_slog.update(&data);
*frame = MsgBuf::new(data);
let _ = frame.next().await.transpose().expect("Fail to get cell");
let slog_hash: [u8; 32] = our_slog.finalize().into();
assert_eq!(frame.codec_mut().take_recv_log_digest().unwrap(), slog_hash);
});
}
}