use crate::{Error, Result};
use static_assertions::assert_impl_all;
use tor_cell::relaycell::msg::EndReason;
use tor_cell::relaycell::{RelayCellFormat, RelayCmd};
use futures::io::{AsyncRead, AsyncWrite};
use futures::stream::StreamExt;
use futures::task::{Context, Poll};
use futures::{Future, Stream};
use pin_project::pin_project;
use postage::watch;
#[cfg(feature = "tokio")]
use tokio_crate::io::ReadBuf;
#[cfg(feature = "tokio")]
use tokio_crate::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite};
#[cfg(feature = "tokio")]
use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
use tor_cell::restricted_msg;
use std::fmt::Debug;
use std::io::Result as IoResult;
use std::num::NonZero;
use std::pin::Pin;
#[cfg(any(feature = "stream-ctrl", feature = "experimental-api"))]
use std::sync::Arc;
#[cfg(feature = "stream-ctrl")]
use std::sync::{Mutex, Weak};
use educe::Educe;
use crate::client::ClientTunnel;
use crate::client::stream::StreamReceiver;
use crate::memquota::StreamAccount;
use crate::stream::StreamTarget;
use crate::stream::cmdcheck::{AnyCmdChecker, CmdChecker, StreamStatus};
use crate::stream::flow_ctrl::state::StreamRateLimit;
use crate::stream::flow_ctrl::xon_xoff::reader::{BufferIsEmpty, XonXoffReader, XonXoffReaderCtrl};
use crate::util::token_bucket::dynamic_writer::DynamicRateLimitedWriter;
use crate::util::token_bucket::writer::{RateLimitedWriter, RateLimitedWriterConfig};
use tor_basic_utils::skip_fmt;
use tor_cell::relaycell::msg::Data;
use tor_error::internal;
use tor_rtcompat::{CoarseTimeProvider, DynTimeProvider, SleepProvider};
type RateConfigStream = futures::stream::Map<
futures::stream::Fuse<watch::Receiver<StreamRateLimit>>,
fn(StreamRateLimit) -> RateLimitedWriterConfig,
>;
#[derive(Debug)]
pub struct DataStream {
w: DataWriter,
r: DataReader,
#[cfg(feature = "stream-ctrl")]
ctrl: Arc<ClientDataStreamCtrl>,
}
assert_impl_all! { DataStream: Send, Sync }
#[cfg(feature = "stream-ctrl")]
#[derive(Debug)]
pub struct ClientDataStreamCtrl {
tunnel: Weak<ClientTunnel>,
#[cfg(feature = "stream-ctrl")]
status: Arc<Mutex<DataStreamStatus>>,
_memquota: StreamAccount,
}
#[derive(Debug)]
struct DataWriterInner {
state: Option<DataWriterState>,
_memquota: StreamAccount,
#[cfg(feature = "stream-ctrl")]
ctrl: Arc<ClientDataStreamCtrl>,
}
#[derive(Debug)]
pub struct DataWriter {
writer: DynamicRateLimitedWriter<DataWriterInner, RateConfigStream, DynTimeProvider>,
}
impl DataWriter {
fn new(
inner: DataWriterInner,
rate_limit_updates: watch::Receiver<StreamRateLimit>,
time_provider: DynTimeProvider,
) -> Self {
fn rate_to_config(rate: StreamRateLimit) -> RateLimitedWriterConfig {
let rate = rate.bytes_per_sec();
RateLimitedWriterConfig {
rate, burst: rate, wake_when_bytes_available: NonZero::new(200).expect("200 != 0"), }
}
let initial_rate: StreamRateLimit = *rate_limit_updates.borrow();
let rate_limit_updates = rate_limit_updates.fuse().map(rate_to_config as fn(_) -> _);
let writer = RateLimitedWriter::new(inner, &rate_to_config(initial_rate), time_provider);
let writer = DynamicRateLimitedWriter::new(writer, rate_limit_updates);
Self { writer }
}
#[cfg(feature = "stream-ctrl")]
pub fn client_stream_ctrl(&self) -> Option<&Arc<ClientDataStreamCtrl>> {
Some(self.writer.inner().client_stream_ctrl())
}
}
impl AsyncWrite for DataWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
AsyncWrite::poll_close(Pin::new(&mut self.writer), cx)
}
}
#[cfg(feature = "tokio")]
impl TokioAsyncWrite for DataWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
}
}
#[derive(Debug)]
pub struct DataReader {
reader: XonXoffReader<DataReaderInner>,
}
impl DataReader {
fn new(reader: DataReaderInner, xon_xoff_reader_ctrl: XonXoffReaderCtrl) -> Self {
Self {
reader: XonXoffReader::new(xon_xoff_reader_ctrl, reader),
}
}
#[cfg(feature = "stream-ctrl")]
pub fn client_stream_ctrl(&self) -> Option<&Arc<ClientDataStreamCtrl>> {
Some(self.reader.inner().client_stream_ctrl())
}
}
impl AsyncRead for DataReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<IoResult<usize>> {
AsyncRead::poll_read(Pin::new(&mut self.reader), cx, buf)
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<IoResult<usize>> {
AsyncRead::poll_read_vectored(Pin::new(&mut self.reader), cx, bufs)
}
}
#[cfg(feature = "tokio")]
impl TokioAsyncRead for DataReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
TokioAsyncRead::poll_read(Pin::new(&mut self.compat()), cx, buf)
}
}
#[derive(Debug)]
pub(crate) struct DataReaderInner {
state: Option<DataReaderState>,
_memquota: StreamAccount,
#[cfg(feature = "stream-ctrl")]
ctrl: Arc<ClientDataStreamCtrl>,
}
impl BufferIsEmpty for DataReaderInner {
fn is_empty(mut self: Pin<&mut Self>) -> bool {
match self
.state
.as_mut()
.expect("forgot to put `DataReaderState` back")
{
DataReaderState::Open(imp) => {
imp.pending[imp.offset..].is_empty() && imp.s.is_empty()
}
DataReaderState::Closed => true,
}
}
}
#[cfg(feature = "stream-ctrl")]
#[derive(Clone, Debug, Default)]
struct DataStreamStatus {
received_connected: bool,
sent_end: bool,
received_end: bool,
received_err: bool,
}
#[cfg(feature = "stream-ctrl")]
impl DataStreamStatus {
fn record_connected(&mut self) {
self.received_connected = true;
}
fn record_error(&mut self, e: &Error) {
match e {
Error::EndReceived(EndReason::DONE) => self.received_end = true,
Error::EndReceived(_) => {
self.received_end = true;
self.received_err = true;
}
_ => self.received_err = true,
}
}
}
restricted_msg! {
enum ClientDataStreamMsg:RelayMsg {
Data, End, Connected,
}
}
#[cfg(feature = "stream-ctrl")]
impl super::ctrl::ClientStreamCtrl for ClientDataStreamCtrl {
fn tunnel(&self) -> Option<Arc<ClientTunnel>> {
self.tunnel.upgrade()
}
}
#[cfg(feature = "stream-ctrl")]
impl ClientDataStreamCtrl {
pub fn is_connected(&self) -> bool {
let s = self.status.lock().expect("poisoned lock");
s.received_connected && !(s.sent_end || s.received_end || s.received_err)
}
}
impl DataStream {
pub(crate) fn new<P: SleepProvider + CoarseTimeProvider>(
time_provider: P,
receiver: StreamReceiver,
xon_xoff_reader_ctrl: XonXoffReaderCtrl,
target: StreamTarget,
memquota: StreamAccount,
) -> Self {
Self::new_inner(
time_provider,
receiver,
xon_xoff_reader_ctrl,
target,
false,
memquota,
)
}
#[cfg(any(feature = "hs-service", feature = "relay"))]
pub(crate) fn new_connected<P: SleepProvider + CoarseTimeProvider>(
time_provider: P,
receiver: StreamReceiver,
xon_xoff_reader_ctrl: XonXoffReaderCtrl,
target: StreamTarget,
memquota: StreamAccount,
) -> Self {
Self::new_inner(
time_provider,
receiver,
xon_xoff_reader_ctrl,
target,
true,
memquota,
)
}
fn new_inner<P: SleepProvider + CoarseTimeProvider>(
time_provider: P,
receiver: StreamReceiver,
xon_xoff_reader_ctrl: XonXoffReaderCtrl,
target: StreamTarget,
connected: bool,
memquota: StreamAccount,
) -> Self {
let relay_cell_format = target.relay_cell_format();
let out_buf_len = Data::max_body_len(relay_cell_format);
let rate_limit_stream = target.rate_limit_stream().clone();
#[cfg(feature = "stream-ctrl")]
let status = {
let mut data_stream_status = DataStreamStatus::default();
if connected {
data_stream_status.record_connected();
}
Arc::new(Mutex::new(data_stream_status))
};
#[cfg(feature = "stream-ctrl")]
let ctrl = {
let tunnel = match target.tunnel() {
crate::stream::Tunnel::Client(t) => Arc::downgrade(t),
#[cfg(feature = "relay")]
crate::stream::Tunnel::Relay(_) => panic!("created a relay tunnel in the client?!"),
};
Arc::new(ClientDataStreamCtrl {
tunnel,
status: status.clone(),
_memquota: memquota.clone(),
})
};
let r = DataReaderInner {
state: Some(DataReaderState::Open(DataReaderImpl {
s: receiver,
pending: Vec::new(),
offset: 0,
connected,
#[cfg(feature = "stream-ctrl")]
status: status.clone(),
})),
_memquota: memquota.clone(),
#[cfg(feature = "stream-ctrl")]
ctrl: ctrl.clone(),
};
let w = DataWriterInner {
state: Some(DataWriterState::Ready(DataWriterImpl {
s: target,
buf: vec![0; out_buf_len].into_boxed_slice(),
n_pending: 0,
#[cfg(feature = "stream-ctrl")]
status,
relay_cell_format,
})),
_memquota: memquota,
#[cfg(feature = "stream-ctrl")]
ctrl: ctrl.clone(),
};
let time_provider = DynTimeProvider::new(time_provider);
DataStream {
w: DataWriter::new(w, rate_limit_stream, time_provider),
r: DataReader::new(r, xon_xoff_reader_ctrl),
#[cfg(feature = "stream-ctrl")]
ctrl,
}
}
pub fn split(self) -> (DataReader, DataWriter) {
(self.r, self.w)
}
pub async fn wait_for_connection(&mut self) -> Result<()> {
let state = self
.r
.reader
.inner_mut()
.state
.take()
.expect("Missing state in DataReaderInner");
if let DataReaderState::Open(mut imp) = state {
let result = if imp.connected {
Ok(())
} else {
std::future::poll_fn(|cx| Pin::new(&mut imp).read_cell(cx)).await
};
self.r.reader.inner_mut().state = Some(match result {
Err(_) => DataReaderState::Closed,
Ok(_) => DataReaderState::Open(imp),
});
result
} else {
Err(Error::from(internal!(
"Expected ready state, got {:?}",
state
)))
}
}
#[cfg(feature = "stream-ctrl")]
pub fn client_stream_ctrl(&self) -> Option<&Arc<ClientDataStreamCtrl>> {
Some(&self.ctrl)
}
}
impl AsyncRead for DataStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<IoResult<usize>> {
AsyncRead::poll_read(Pin::new(&mut self.r), cx, buf)
}
}
#[cfg(feature = "tokio")]
impl TokioAsyncRead for DataStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
TokioAsyncRead::poll_read(Pin::new(&mut self.compat()), cx, buf)
}
}
impl AsyncWrite for DataStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
AsyncWrite::poll_write(Pin::new(&mut self.w), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
AsyncWrite::poll_flush(Pin::new(&mut self.w), cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
AsyncWrite::poll_close(Pin::new(&mut self.w), cx)
}
}
#[cfg(feature = "tokio")]
impl TokioAsyncWrite for DataStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
TokioAsyncWrite::poll_write(Pin::new(&mut self.compat()), cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat()), cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat()), cx)
}
}
type BoxSyncFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;
#[derive(Educe)]
#[educe(Debug)]
enum DataWriterState {
Closed,
Ready(DataWriterImpl),
Flushing(
#[educe(Debug(method = "skip_fmt"))] BoxSyncFuture<'static, (DataWriterImpl, Result<()>)>,
),
}
#[derive(Educe)]
#[educe(Debug)]
struct DataWriterImpl {
s: StreamTarget,
#[educe(Debug(method = "skip_fmt"))]
buf: Box<[u8]>,
n_pending: usize,
relay_cell_format: RelayCellFormat,
#[cfg(feature = "stream-ctrl")]
status: Arc<Mutex<DataStreamStatus>>,
}
impl DataWriterInner {
#[cfg(feature = "stream-ctrl")]
fn client_stream_ctrl(&self) -> &Arc<ClientDataStreamCtrl> {
&self.ctrl
}
fn poll_flush_impl(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
should_close: bool,
) -> Poll<IoResult<()>> {
let state = self.state.take().expect("Missing state in DataWriter");
let mut future: BoxSyncFuture<_> = match state {
DataWriterState::Ready(imp) => {
if imp.n_pending == 0 {
if should_close {
Box::pin(futures::future::ready((imp, Ok(()))))
} else {
self.state = Some(DataWriterState::Ready(imp));
return Poll::Ready(Ok(()));
}
} else {
Box::pin(imp.flush_buf())
}
}
DataWriterState::Flushing(fut) => fut,
DataWriterState::Closed => {
self.state = Some(DataWriterState::Closed);
return Poll::Ready(Err(Error::NotConnected.into()));
}
};
match future.as_mut().poll(cx) {
Poll::Ready((_imp, Err(e))) => {
self.state = Some(DataWriterState::Closed);
Poll::Ready(Err(e.into()))
}
Poll::Ready((mut imp, Ok(()))) => {
if should_close {
imp.s.close();
#[cfg(feature = "stream-ctrl")]
{
imp.status.lock().expect("lock poisoned").sent_end = true;
}
self.state = Some(DataWriterState::Closed);
} else {
self.state = Some(DataWriterState::Ready(imp));
}
Poll::Ready(Ok(()))
}
Poll::Pending => {
self.state = Some(DataWriterState::Flushing(future));
Poll::Pending
}
}
}
}
impl AsyncWrite for DataWriterInner {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let state = self.state.take().expect("Missing state in DataWriter");
let mut future = match state {
DataWriterState::Ready(mut imp) => {
let n_queued = imp.queue_bytes(buf);
if n_queued != 0 {
self.state = Some(DataWriterState::Ready(imp));
return Poll::Ready(Ok(n_queued));
}
Box::pin(imp.flush_buf())
}
DataWriterState::Flushing(fut) => fut,
DataWriterState::Closed => {
self.state = Some(DataWriterState::Closed);
return Poll::Ready(Err(Error::NotConnected.into()));
}
};
match future.as_mut().poll(cx) {
Poll::Ready((_imp, Err(e))) => {
#[cfg(feature = "stream-ctrl")]
{
_imp.status.lock().expect("lock poisoned").record_error(&e);
}
self.state = Some(DataWriterState::Closed);
Poll::Ready(Err(e.into()))
}
Poll::Ready((mut imp, Ok(()))) => {
let n_queued = imp.queue_bytes(buf);
self.state = Some(DataWriterState::Ready(imp));
Poll::Ready(Ok(n_queued))
}
Poll::Pending => {
self.state = Some(DataWriterState::Flushing(future));
Poll::Pending
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
self.poll_flush_impl(cx, false)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
self.poll_flush_impl(cx, true)
}
}
#[cfg(feature = "tokio")]
impl TokioAsyncWrite for DataWriterInner {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
}
}
impl DataWriterImpl {
async fn flush_buf(mut self) -> (Self, Result<()>) {
let result = if let Some((cell, remainder)) =
Data::try_split_from(self.relay_cell_format, &self.buf[..self.n_pending])
{
assert!(remainder.is_empty());
self.n_pending = 0;
self.s.send(cell.into()).await
} else {
Ok(())
};
(self, result)
}
fn queue_bytes(&mut self, b: &[u8]) -> usize {
let empty_space = &mut self.buf[self.n_pending..];
if empty_space.is_empty() {
return 0;
}
let n_to_copy = std::cmp::min(b.len(), empty_space.len());
empty_space[..n_to_copy].copy_from_slice(&b[..n_to_copy]);
self.n_pending += n_to_copy;
n_to_copy
}
}
impl DataReaderInner {
#[cfg(feature = "stream-ctrl")]
pub(crate) fn client_stream_ctrl(&self) -> &Arc<ClientDataStreamCtrl> {
&self.ctrl
}
}
#[derive(Educe)]
#[educe(Debug)]
#[allow(clippy::large_enum_variant)]
enum DataReaderState {
Closed,
Open(DataReaderImpl),
}
#[derive(Educe)]
#[educe(Debug)]
#[pin_project]
struct DataReaderImpl {
#[educe(Debug(method = "skip_fmt"))]
#[pin]
s: StreamReceiver,
#[educe(Debug(method = "skip_fmt"))]
pending: Vec<u8>,
offset: usize,
connected: bool,
#[cfg(feature = "stream-ctrl")]
status: Arc<Mutex<DataStreamStatus>>,
}
impl AsyncRead for DataReaderInner {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<IoResult<usize>> {
let mut state = self.state.take().expect("Missing state in DataReaderInner");
loop {
let mut imp = match state {
DataReaderState::Open(mut imp) => {
let n_copied = imp.extract_bytes(buf);
if n_copied != 0 || buf.is_empty() {
self.state = Some(DataReaderState::Open(imp));
return Poll::Ready(Ok(n_copied));
}
imp
}
DataReaderState::Closed => {
self.state = Some(DataReaderState::Closed);
return Poll::Ready(Err(Error::NotConnected.into()));
}
};
match Pin::new(&mut imp).read_cell(cx) {
Poll::Ready(Err(e)) => {
self.state = Some(DataReaderState::Closed);
#[cfg(feature = "stream-ctrl")]
{
imp.status.lock().expect("lock poisoned").record_error(&e);
}
let result = if matches!(e, Error::EndReceived(EndReason::DONE)) {
Ok(0)
} else {
Err(e.into())
};
return Poll::Ready(result);
}
Poll::Ready(Ok(())) => {
state = DataReaderState::Open(imp);
}
Poll::Pending => {
self.state = Some(DataReaderState::Open(imp));
return Poll::Pending;
}
}
}
}
}
#[cfg(feature = "tokio")]
impl TokioAsyncRead for DataReaderInner {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
TokioAsyncRead::poll_read(Pin::new(&mut self.compat()), cx, buf)
}
}
impl DataReaderImpl {
fn extract_bytes(&mut self, buf: &mut [u8]) -> usize {
let remainder = &self.pending[self.offset..];
let n_to_copy = std::cmp::min(buf.len(), remainder.len());
buf[..n_to_copy].copy_from_slice(&remainder[..n_to_copy]);
self.offset += n_to_copy;
n_to_copy
}
fn buf_is_empty(&self) -> bool {
self.pending.len() == self.offset
}
fn read_cell(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
use ClientDataStreamMsg::*;
let msg = match self.as_mut().project().s.poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(unparsed))) => match unparsed.decode::<ClientDataStreamMsg>() {
Ok(cell) => cell.into_msg(),
Err(e) => {
self.s.protocol_error();
return Poll::Ready(Err(Error::from_bytes_err(e, "message on a data stream")));
}
},
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Poll::Ready(None) => return Poll::Ready(Err(Error::NotConnected)),
};
let result = match msg {
Connected(_) if !self.connected => {
self.connected = true;
#[cfg(feature = "stream-ctrl")]
{
self.status
.lock()
.expect("poisoned lock")
.record_connected();
}
Ok(())
}
Connected(_) => {
self.s.protocol_error();
Err(Error::StreamProto(
"Received a second connect cell on a data stream".to_string(),
))
}
Data(d) if self.connected => {
self.add_data(d.into());
Ok(())
}
Data(_) => {
self.s.protocol_error();
Err(Error::StreamProto(
"Received a data cell an unconnected stream".to_string(),
))
}
End(e) => Err(Error::EndReceived(e.reason())),
};
Poll::Ready(result)
}
fn add_data(&mut self, mut d: Vec<u8>) {
if self.buf_is_empty() {
self.pending = d;
self.offset = 0;
} else {
self.pending.append(&mut d);
}
}
}
#[derive(Debug)]
pub(crate) struct OutboundDataCmdChecker {
expecting_connected: bool,
}
impl Default for OutboundDataCmdChecker {
fn default() -> Self {
Self {
expecting_connected: true,
}
}
}
impl CmdChecker for OutboundDataCmdChecker {
fn check_msg(&mut self, msg: &tor_cell::relaycell::UnparsedRelayMsg) -> Result<StreamStatus> {
use StreamStatus::*;
match msg.cmd() {
RelayCmd::CONNECTED => {
if !self.expecting_connected {
Err(Error::StreamProto(
"Received CONNECTED twice on a stream.".into(),
))
} else {
self.expecting_connected = false;
Ok(Open)
}
}
RelayCmd::DATA => {
if !self.expecting_connected {
Ok(Open)
} else {
Err(Error::StreamProto(
"Received DATA before CONNECTED on a stream".into(),
))
}
}
RelayCmd::END => Ok(Closed),
_ => Err(Error::StreamProto(format!(
"Unexpected {} on a data stream!",
msg.cmd()
))),
}
}
fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
let _ = msg
.decode::<ClientDataStreamMsg>()
.map_err(|err| Error::from_bytes_err(err, "cell on half-closed stream"))?;
Ok(())
}
}
impl OutboundDataCmdChecker {
pub(crate) fn new_any() -> AnyCmdChecker {
Box::<Self>::default()
}
}