use std::fmt::Debug;
use std::mem;
use std::num::{NonZeroUsize, TryFromIntError};
use derive_deftly::define_derive_deftly;
use educe::Educe;
use tor_bytes::Reader;
use tor_error::{Bug, internal};
use crate::SOCKS_BUF_LEN;
use crate::{Action, Error, Truncated};
pub trait ReadPrecision: ReadPrecisionSealed + Default + Copy + Debug {}
impl ReadPrecision for PreciseReads {}
impl ReadPrecision for () {}
pub trait ReadPrecisionSealed {
fn recv_step_buf(buf: &mut [u8], deficit: NonZeroUsize) -> &mut [u8];
}
impl ReadPrecisionSealed for () {
fn recv_step_buf(buf: &mut [u8], _deficit: NonZeroUsize) -> &mut [u8] {
buf
}
}
impl ReadPrecisionSealed for PreciseReads {
fn recv_step_buf<'b>(buf: &mut [u8], deficit: NonZeroUsize) -> &mut [u8] {
&mut buf[0..deficit.into()]
}
}
#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
#[allow(clippy::exhaustive_structs)]
pub struct PreciseReads;
#[derive(Educe)]
#[educe(Debug)]
pub struct Buffer<P: ReadPrecision = ()> {
#[educe(Debug(ignore))]
buf: Box<[u8]>,
filled: usize,
#[allow(dead_code)]
precision: P,
}
#[derive(Debug)]
#[allow(clippy::exhaustive_enums)] pub enum NextStep<'b, O, P: ReadPrecision> {
Send(Vec<u8>),
Recv(RecvStep<'b, P>),
Finished(Finished<'b, O, P>),
}
#[derive(Debug)]
#[must_use]
pub struct Finished<'b, O, P: ReadPrecision> {
buffer: &'b mut Buffer<P>,
output: O,
}
impl<'b, O> Finished<'b, O, PreciseReads> {
pub fn into_output(self) -> Result<O, Bug> {
if let Ok(nonzero) = NonZeroUsize::try_from(self.buffer.filled_slice().len()) {
Err(internal!(
"handshake complete, but we read too much earlier, and are now misframed by {nonzero} bytes!"
))
} else {
Ok(self.output)
}
}
}
impl<'b, O, P: ReadPrecision> Finished<'b, O, P> {
pub fn into_output_and_slice(self) -> (O, &'b [u8]) {
let filled = mem::take(&mut self.buffer.filled);
let data = &self.buffer.buf[0..filled];
(self.output, data)
}
pub fn into_output_and_vec(self) -> (O, Vec<u8>) {
let mut data = mem::take(&mut self.buffer.buf).into_vec();
data.truncate(self.buffer.filled);
(self.output, data)
}
pub fn into_output_forbid_pipelining(self) -> Result<O, Error> {
if !self.buffer.filled_slice().is_empty() {
Err(Error::ForbiddenPipelining)
} else {
Ok(self.output)
}
}
}
#[derive(Debug)]
pub struct RecvStep<'b, P: ReadPrecision> {
buffer: &'b mut Buffer<P>,
deficit: NonZeroUsize,
}
impl<'b, P: ReadPrecision> RecvStep<'b, P> {
pub fn buf(&mut self) -> &mut [u8] {
P::recv_step_buf(self.buffer.unfilled_slice(), self.deficit)
}
pub fn note_received(self, len: usize) -> Result<(), Error> {
let len = len
.try_into()
.map_err(|_: TryFromIntError| Error::UnexpectedEof)?;
self.buffer.note_received(len);
Ok(())
}
}
impl<P: ReadPrecision> Default for Buffer<P> {
fn default() -> Self {
Buffer::with_size(SOCKS_BUF_LEN)
}
}
impl Buffer<()> {
pub fn new() -> Self {
Self::default()
}
}
impl Buffer<PreciseReads> {
pub fn new_precise() -> Self {
Self::default()
}
}
impl<P: ReadPrecision> Buffer<P> {
pub fn with_size(size: usize) -> Self {
Buffer {
buf: vec![0xaa; size].into(),
filled: 0,
precision: P::default(),
}
}
pub fn from_parts(buf: Box<[u8]>, filled: usize) -> Self {
Buffer {
buf,
filled,
precision: P::default(),
}
}
pub fn into_parts(self) -> (Box<[u8]>, usize) {
let Buffer {
buf,
filled,
precision: _,
} = self;
(buf, filled)
}
pub fn unfilled_slice(&mut self) -> &mut [u8] {
&mut self.buf[self.filled..]
}
pub fn filled_slice(&mut self) -> &[u8] {
&self.buf[..self.filled]
}
pub fn note_received(&mut self, len: NonZeroUsize) {
let len = usize::from(len);
assert!(len <= self.unfilled_slice().len());
self.filled += len;
}
}
define_derive_deftly! {
Handshake for struct, expect items:
impl $crate::handshake::framework::HasHandshakeState for $ttype {
fn set_failed(&mut self) {
self.state = State::Failed {};
}
}
$(
${when fmeta(handshake(output))}
${define OUTPUT { <$ftype as IntoIterator>::Item }}
impl $crate::handshake::framework::Handshake for $ttype {
type Output = $OUTPUT;
}
impl $crate::handshake::framework::HasHandshakeOutput<$OUTPUT> for $ttype {
fn take_output(&mut self) -> Option<$OUTPUT> {
Option::take(&mut self.$fname)
}
}
)
}
#[allow(unused_imports)] #[allow(clippy::single_component_path_imports)] use derive_deftly_template_Handshake;
pub(crate) enum ImplNextStep {
Reply {
reply: Vec<u8>,
},
Finished,
}
pub(super) trait HasHandshakeState {
fn set_failed(&mut self);
}
pub(super) trait HasHandshakeOutput<O> {
fn take_output(&mut self) -> Option<O>;
}
pub(super) trait HandshakeImpl: HasHandshakeState {
fn handshake_impl(&mut self, r: &mut tor_bytes::Reader<'_>) -> crate::Result<ImplNextStep>;
fn call_handshake_impl(&mut self, input: &[u8]) -> (usize, crate::Result<ImplNextStep>) {
let mut b = Reader::from_possibly_incomplete_slice(input);
let rv = self.handshake_impl(&mut b);
let drain = b.consumed();
match &rv {
Ok(ImplNextStep::Reply { reply }) if reply.is_empty() && drain == 0 => {
return (
0,
Err(
internal!("protocol implementation drained nothing, replied nothing")
.into(),
),
);
}
_ => {}
};
(drain, rv)
}
}
#[allow(private_bounds)] pub trait Handshake: HandshakeImpl + HasHandshakeOutput<Self::Output> {
type Output: Debug;
fn step<'b, P: ReadPrecision>(
&mut self,
buffer: &'b mut Buffer<P>,
) -> Result<NextStep<'b, <Self as Handshake>::Output, P>, Error> {
let (drain, rv) = self.call_handshake_impl(buffer.filled_slice());
if let Err(Error::Decode(tor_bytes::Error::Incomplete { deficit, .. })) = rv {
let deficit = deficit.into_inner();
return if usize::from(deficit) > buffer.unfilled_slice().len() {
Err(Error::MessageTooLong {
limit: buffer.buf.len(),
})
} else {
Ok(NextStep::Recv(RecvStep { buffer, deficit }))
};
};
let rv = rv?;
buffer.buf.copy_within(drain..buffer.filled, 0);
buffer.filled -= drain;
Ok(match rv {
ImplNextStep::Reply { reply } => NextStep::Send(reply),
ImplNextStep::Finished => {
let output = self.take_output().ok_or_else(|| internal!("no output!"))?;
NextStep::Finished(Finished { buffer, output })
}
})
}
#[deprecated = "use the new Handshake::step API instead"]
fn handshake(&mut self, input: &[u8]) -> crate::TResult<Action> {
let (drain, rv) = self.call_handshake_impl(input);
match rv {
#[allow(deprecated)]
Err(Error::Decode(
tor_bytes::Error::Incomplete { .. } | tor_bytes::Error::Truncated,
)) => Err(Truncated::new()),
Err(e) => {
self.set_failed();
Ok(Err(e))
}
Ok(ImplNextStep::Reply { reply }) => Ok(Ok(Action {
drain,
reply,
finished: false,
})),
Ok(ImplNextStep::Finished) => Ok(Ok(Action {
drain,
reply: vec![],
finished: true,
})),
}
}
#[cfg(test)]
#[allow(deprecated)]
fn handshake_for_tests(&mut self, input: &[u8]) -> crate::TResult<Action> {
self.handshake(input)
}
}