use std::pin::Pin;
use std::task::{self, Context, Poll};
use futures::{FutureExt as _, SinkExt as _, Stream, StreamExt as _};
use oneshot_fused_workaround as oneshot;
use tor_basic_utils::assert_val_impl_trait;
use tor_cell::chancell::msg::{AnyChanMsg, Destroy};
use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec};
#[derive(Debug)]
pub(crate) struct CircuitRxSender(Option<CircuitRxSenderInner>);
#[derive(Debug)]
struct CircuitRxSenderInner {
destroy_tx: oneshot::Sender<Destroy>,
cell_tx: mq_queue::Sender<AnyChanMsg, MpscSpec>,
}
#[derive(Debug)]
pub(crate) struct CircuitRxReceiver(Option<CircuitRxReceiverInner>);
#[derive(Debug)]
struct CircuitRxReceiverInner {
destroy_rx: oneshot::Receiver<Destroy>,
cell_rx: mq_queue::Receiver<AnyChanMsg, MpscSpec>,
}
pub(crate) fn channel(
cell_tx: mq_queue::Sender<AnyChanMsg, MpscSpec>,
cell_rx: mq_queue::Receiver<AnyChanMsg, MpscSpec>,
) -> (CircuitRxSender, CircuitRxReceiver) {
let (destroy_tx, destroy_rx) = oneshot::channel();
let sender = CircuitRxSender(Some(CircuitRxSenderInner {
destroy_tx,
cell_tx,
}));
let receiver = CircuitRxReceiver(Some(CircuitRxReceiverInner {
destroy_rx,
cell_rx,
}));
(sender, receiver)
}
impl Stream for CircuitRxReceiver {
type Item = AnyChanMsg;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Some(inner) = self.0.as_mut() else {
return Poll::Ready(None);
};
assert_val_impl_trait!(inner.destroy_rx, futures_util::future::FusedFuture);
let destroy_cell = match inner.destroy_rx.poll_unpin(cx) {
Poll::Ready(destroy) => {
destroy.ok()
}
Poll::Pending => {
None
}
};
if let Some(destroy) = destroy_cell {
self.0 = None;
return Poll::Ready(Some(AnyChanMsg::Destroy(destroy)));
}
let res = task::ready!(inner.cell_rx.poll_next_unpin(cx));
debug_assert!(!matches!(res, Some(AnyChanMsg::Destroy(_))));
Poll::Ready(res)
}
}
#[derive(thiserror::Error, Clone, Debug)]
pub(crate) enum SendError {
#[error("{0}")]
Channel(#[from] mq_queue::SendError<<MpscSpec as ChannelSpec>::SendError>),
#[error("the receiver has dropped")]
Disconnected,
#[error("sender has closed")]
Closed,
}
impl CircuitRxSender {
pub(crate) async fn send(&mut self, msg: AnyChanMsg) -> Result<(), SendError> {
if let AnyChanMsg::Destroy(d) = msg {
let inner = self.take_inner()?;
if inner.destroy_tx.send(d).is_err() {
return Err(SendError::Disconnected);
}
Ok(())
} else {
self.borrow_for_sending()?.cell_tx.send(msg).await?;
Ok(())
}
}
fn borrow_for_sending(&mut self) -> Result<&mut CircuitRxSenderInner, SendError> {
self.0.as_mut().ok_or_else(|| SendError::Closed)
}
fn take_inner(&mut self) -> Result<CircuitRxSenderInner, SendError> {
self.0.take().ok_or_else(|| SendError::Closed)
}
}
#[cfg(test)]
pub(crate) mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::string_slice)]
use super::*;
use tor_cell::chancell::msg::{self, DestroyReason};
use tor_rtmock::MockRuntime;
use std::task::Waker;
#[cfg(test)]
pub(crate) fn fake_mpsc(buffer: usize) -> (CircuitRxSender, CircuitRxReceiver) {
let (tx, rx) = crate::fake_mpsc(buffer);
crate::circuit::circ_sender::channel(tx, rx)
}
fn destroy_msg(reason: DestroyReason) -> AnyChanMsg {
AnyChanMsg::Destroy(msg::Destroy::new(reason))
}
fn relay_msg() -> AnyChanMsg {
AnyChanMsg::Relay(msg::Relay::new(b"hello"))
}
macro_rules! assert_eos {
($tx:expr, $rx:expr) => {{
assert!($rx.next().await.is_none());
let err = $tx.send(relay_msg()).await.unwrap_err();
assert!(matches!(err, SendError::Closed));
}};
}
const BUFFER_SIZE: usize = 16;
#[test]
fn destroy_skips_queue() {
MockRuntime::test_with_various(|_rt| async move {
let (mut tx, mut rx) = fake_mpsc(BUFFER_SIZE);
tx.send(relay_msg()).await.unwrap();
tx.send(destroy_msg(DestroyReason::HIBERNATING))
.await
.unwrap();
let destroy = rx.next().await.unwrap();
assert!(matches!(destroy, AnyChanMsg::Destroy(_)));
assert_eos!(tx, rx);
});
}
#[test]
fn destroy_on_empty_queue() {
MockRuntime::test_with_various(|_rt| async move {
let (mut tx, mut rx) = fake_mpsc(BUFFER_SIZE);
tx.send(destroy_msg(DestroyReason::HIBERNATING))
.await
.unwrap();
let destroy = rx.next().await.unwrap();
assert!(matches!(destroy, AnyChanMsg::Destroy(_)));
assert_eos!(tx, rx);
});
}
#[test]
fn destroy_after_data() {
MockRuntime::test_with_various(|_rt| async move {
let (mut tx, mut rx) = fake_mpsc(BUFFER_SIZE);
for _ in 0..3 {
tx.send(relay_msg()).await.unwrap();
}
for _ in 0..3 {
let data = rx.next().await.unwrap();
assert!(matches!(data, AnyChanMsg::Relay(_)));
}
let mut noop_cx = Context::from_waker(Waker::noop());
assert!(rx.poll_next_unpin(&mut noop_cx).is_pending());
tx.send(destroy_msg(DestroyReason::PROTOCOL)).await.unwrap();
let destroy = rx.next().await.unwrap();
assert!(matches!(destroy, AnyChanMsg::Destroy(_)));
assert_eos!(tx, rx);
});
}
#[test]
fn destroy_full_queue() {
MockRuntime::test_with_various(|_rt| async move {
let (mut tx, mut rx) = fake_mpsc(BUFFER_SIZE);
loop {
let fut = Box::pin(tx.send(relay_msg()));
match futures::poll!(fut) {
Poll::Pending => {
break;
}
Poll::Ready(res) => {
let () = res.unwrap();
}
}
}
tx.send(destroy_msg(DestroyReason::INTERNAL)).await.unwrap();
let destroy = rx.next().await.unwrap();
assert!(matches!(destroy, AnyChanMsg::Destroy(_)));
assert_eos!(tx, rx);
});
}
}