use std::{
pin::{Pin, pin},
task::{Context, Poll},
};
use cfg_if::cfg_if;
use futures::Sink;
use pin_project::pin_project;
use tor_rtcompat::DynTimeProvider;
use tracing::instrument;
use crate::{
HopNum,
channel::{ChanCellQueueEntry, ChannelSender},
congestion::CongestionSignals,
util::{SinkExt, sometimes_unbounded_sink::SometimesUnboundedSink},
};
cfg_if! {
if #[cfg(feature="circ-padding")] {
use crate::util::sink_blocker::{BooleanPolicy, SinkBlocker};
type InnerSink = SinkBlocker<
SometimesUnbounded, BooleanPolicy,
>;
type SometimesUnbounded = SometimesUnboundedSink<
ChanCellQueueEntry,
ChannelSender
>;
} else {
type InnerSink = SometimesUnboundedSink<ChanCellQueueEntry, ChannelSender>;
type SometimesUnbounded = InnerSink;
}
}
#[pin_project]
pub(crate) struct CircuitCellSender {
#[pin]
sink: InnerSink,
}
impl CircuitCellSender {
pub(crate) fn from_channel_sender(inner: ChannelSender) -> Self {
cfg_if! {
if #[cfg(feature="circ-padding")] {
let sink = SinkBlocker::new(
SometimesUnboundedSink::new(
inner
),
BooleanPolicy::Unblocked
);
} else {
let sink = SometimesUnboundedSink::new(inner);
}
}
Self { sink }
}
pub(crate) fn n_queued(&self) -> usize {
self.sometimes_unbounded().n_queued()
}
#[cfg(feature = "circ-padding")]
pub(crate) fn have_queued_cell_for_hop_or_later(&self, hop: HopNum) -> bool {
if hop.is_first_hop() && self.chan_sender().approx_count() > 0 {
return true;
}
self.sometimes_unbounded()
.iter_queue()
.any(|(_, info)| info.is_some_and(|inf| inf.target_hop >= hop))
}
#[instrument(level = "trace", skip_all)]
pub(crate) async fn send_unbounded(&mut self, entry: ChanCellQueueEntry) -> crate::Result<()> {
Pin::new(self.sometimes_unbounded_mut())
.send_unbounded(entry)
.await?;
self.chan_sender().note_cell_queued();
Ok(())
}
pub(crate) fn time_provider(&self) -> &DynTimeProvider {
self.chan_sender().time_provider()
}
#[cfg(feature = "circ-padding")]
pub(crate) fn start_blocking(&mut self) {
self.pre_queue_blocker_mut().set_blocked();
}
#[cfg(feature = "circ-padding")]
pub(crate) fn stop_blocking(&mut self) {
self.pre_queue_blocker_mut().set_unblocked();
}
#[instrument(level = "trace", skip_all)]
pub(crate) async fn congestion_signals(&mut self) -> CongestionSignals {
futures::future::poll_fn(|cx| -> Poll<CongestionSignals> {
let channel_ready = self
.chan_sender_mut()
.poll_ready_unpin_bool(cx)
.unwrap_or(false);
Poll::Ready(CongestionSignals::new(
!channel_ready,
self.n_queued(),
))
})
.await
}
fn sometimes_unbounded(&self) -> &SometimesUnbounded {
cfg_if! {
if #[cfg(feature="circ-padding")] {
self.sink.as_inner()
} else {
&self.sink
}
}
}
fn sometimes_unbounded_mut(&mut self) -> &mut SometimesUnbounded {
cfg_if! {
if #[cfg(feature="circ-padding")] {
self.sink.as_inner_mut()
} else {
&mut self.sink
}
}
}
fn chan_sender(&self) -> &ChannelSender {
cfg_if! {
if #[cfg(feature="circ-padding")] {
self.sink.as_inner().as_inner()
} else {
self.sink.as_inner()
}
}
}
fn chan_sender_mut(&mut self) -> &mut ChannelSender {
cfg_if! {
if #[cfg(feature="circ-padding")] {
self.sink.as_inner_mut().as_inner_mut()
} else {
self.sink.as_inner_mut()
}
}
}
#[cfg(feature = "circ-padding")]
fn pre_queue_blocker_mut(&mut self) -> &mut InnerSink {
&mut self.sink
}
}
impl Sink<ChanCellQueueEntry> for CircuitCellSender {
type Error = <ChannelSender as Sink<ChanCellQueueEntry>>::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
cfg_if! {
if #[cfg(feature = "circ-padding")] {
let _ignore = pin!(self.sometimes_unbounded_mut()).poll_ready(cx);
}
}
self.project().sink.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: ChanCellQueueEntry) -> Result<(), Self::Error> {
self.as_mut().project().sink.start_send(item)?;
self.chan_sender().note_cell_queued();
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_close(cx)
}
}