use futures_channel::mpsc;
use futures_channel::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
use futures_channel::oneshot;
use futures_util::sink::SinkExt;
use crate::channels::{
register_channel, send_channel_event, ChannelEvent, ChannelType, Instant, RegisteredChannel, RT,
};
fn wrap_channel_impl<T, F>(
inner: (Sender<T>, Receiver<T>),
source: &'static str,
label: Option<String>,
capacity: usize,
mut get_msg_log: F,
) -> (Sender<T>, Receiver<T>)
where
T: Send + 'static,
F: FnMut(&T) -> Option<String> + Send + 'static + Clone,
{
#[allow(unused_mut)]
let (inner_tx, mut inner_rx) = inner;
#[cfg(feature = "hotpath-meta")]
let mut inner_rx = hotpath_meta::stream!(inner_rx, label = "hp-ftc-bounded-rx");
let (mut proxy_tx, proxy_rx) = mpsc::channel::<T>(1);
let RegisteredChannel { id, stats_tx } =
register_channel::<T>(source, label, ChannelType::Bounded(capacity));
RT.spawn(async move {
use futures_util::stream::StreamExt;
while let Some(msg) = inner_rx.next().await {
let log = get_msg_log(&msg);
send_channel_event(
&stats_tx,
ChannelEvent::MessageSent {
id,
log,
timestamp: Instant::now(),
},
);
if proxy_tx.send(msg).await.is_ok() {
send_channel_event(
&stats_tx,
ChannelEvent::MessageReceived {
id,
timestamp: Instant::now(),
},
);
} else {
break;
}
}
send_channel_event(&stats_tx, ChannelEvent::Closed { id });
});
(inner_tx, proxy_rx)
}
pub(crate) fn wrap_channel<T: Send + 'static>(
inner: (Sender<T>, Receiver<T>),
source: &'static str,
label: Option<String>,
capacity: usize,
) -> (Sender<T>, Receiver<T>) {
wrap_channel_impl(inner, source, label, capacity, |_| None)
}
pub(crate) fn wrap_channel_log<T: Send + std::fmt::Debug + 'static>(
inner: (Sender<T>, Receiver<T>),
source: &'static str,
label: Option<String>,
capacity: usize,
) -> (Sender<T>, Receiver<T>) {
wrap_channel_impl(inner, source, label, capacity, |msg| {
Some(crate::output::format_debug_truncated(msg))
})
}
fn wrap_unbounded_impl<T, F>(
inner: (UnboundedSender<T>, UnboundedReceiver<T>),
source: &'static str,
label: Option<String>,
mut get_msg_log: F,
) -> (UnboundedSender<T>, UnboundedReceiver<T>)
where
T: Send + 'static,
F: FnMut(&T) -> Option<String> + Send + 'static + Clone,
{
#[allow(unused_mut)]
let (inner_tx, mut inner_rx) = inner;
#[cfg(feature = "hotpath-meta")]
let mut inner_rx = hotpath_meta::stream!(inner_rx, label = "hp-ftc-unbounded-rx");
let (proxy_tx, proxy_rx) = mpsc::unbounded::<T>();
let RegisteredChannel { id, stats_tx } =
register_channel::<T>(source, label, ChannelType::Unbounded);
RT.spawn(async move {
use futures_util::stream::StreamExt;
while let Some(msg) = inner_rx.next().await {
let log = get_msg_log(&msg);
send_channel_event(
&stats_tx,
ChannelEvent::MessageSent {
id,
log,
timestamp: Instant::now(),
},
);
if proxy_tx.unbounded_send(msg).is_ok() {
send_channel_event(
&stats_tx,
ChannelEvent::MessageReceived {
id,
timestamp: Instant::now(),
},
);
} else {
break;
}
}
send_channel_event(&stats_tx, ChannelEvent::Closed { id });
});
(inner_tx, proxy_rx)
}
pub(crate) fn wrap_unbounded<T: Send + 'static>(
inner: (UnboundedSender<T>, UnboundedReceiver<T>),
source: &'static str,
label: Option<String>,
) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
wrap_unbounded_impl(inner, source, label, |_| None)
}
pub(crate) fn wrap_unbounded_log<T: Send + std::fmt::Debug + 'static>(
inner: (UnboundedSender<T>, UnboundedReceiver<T>),
source: &'static str,
label: Option<String>,
) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
wrap_unbounded_impl(inner, source, label, |msg| {
Some(crate::output::format_debug_truncated(msg))
})
}
fn wrap_oneshot_impl<T, F>(
inner: (oneshot::Sender<T>, oneshot::Receiver<T>),
source: &'static str,
label: Option<String>,
mut get_msg_log: F,
) -> (oneshot::Sender<T>, oneshot::Receiver<T>)
where
T: Send + 'static,
F: FnMut(&T) -> Option<String> + Send + 'static + Clone,
{
let (inner_tx, inner_rx) = inner;
let (proxy_tx, proxy_rx) = oneshot::channel::<T>();
let RegisteredChannel { id, stats_tx } =
register_channel::<T>(source, label, ChannelType::Oneshot);
RT.spawn(async move {
let mut inner_rx = Some(inner_rx);
let mut proxy_tx = Some(proxy_tx);
let mut message_completed = false;
tokio::select! {
msg = async { inner_rx.take().unwrap().await }, if inner_rx.is_some() => {
match msg {
Ok(msg) => {
let log = get_msg_log(&msg);
send_channel_event(&stats_tx, ChannelEvent::MessageSent {
id,
log,
timestamp: Instant::now(),
});
send_channel_event(&stats_tx, ChannelEvent::Notified { id });
if proxy_tx.take().unwrap().send(msg).is_ok() {
send_channel_event(&stats_tx, ChannelEvent::MessageReceived {
id,
timestamp: Instant::now(),
});
message_completed = true;
}
}
Err(_) => {
}
}
}
_ = async { proxy_tx.as_mut().unwrap().cancellation().await }, if proxy_tx.is_some() => {
drop(inner_rx);
}
}
if !message_completed {
send_channel_event(&stats_tx, ChannelEvent::Closed { id });
}
});
(inner_tx, proxy_rx)
}
pub(crate) fn wrap_oneshot<T: Send + 'static>(
inner: (oneshot::Sender<T>, oneshot::Receiver<T>),
source: &'static str,
label: Option<String>,
) -> (oneshot::Sender<T>, oneshot::Receiver<T>) {
wrap_oneshot_impl(inner, source, label, |_| None)
}
pub(crate) fn wrap_oneshot_log<T: Send + std::fmt::Debug + 'static>(
inner: (oneshot::Sender<T>, oneshot::Receiver<T>),
source: &'static str,
label: Option<String>,
) -> (oneshot::Sender<T>, oneshot::Receiver<T>) {
wrap_oneshot_impl(inner, source, label, |msg| {
Some(crate::output::format_debug_truncated(msg))
})
}
use crate::channels::InstrumentChannel;
impl<T: Send + 'static> InstrumentChannel
for (
futures_channel::mpsc::Sender<T>,
futures_channel::mpsc::Receiver<T>,
)
{
type Output = (
futures_channel::mpsc::Sender<T>,
futures_channel::mpsc::Receiver<T>,
);
fn instrument(
self,
source: &'static str,
label: Option<String>,
capacity: Option<usize>,
) -> Self::Output {
if capacity.is_none() {
panic!("Capacity is required for bounded futures channels, because they don't expose their capacity in a public API");
}
wrap_channel(self, source, label, capacity.unwrap())
}
}
impl<T: Send + 'static> InstrumentChannel
for (
futures_channel::mpsc::UnboundedSender<T>,
futures_channel::mpsc::UnboundedReceiver<T>,
)
{
type Output = (
futures_channel::mpsc::UnboundedSender<T>,
futures_channel::mpsc::UnboundedReceiver<T>,
);
fn instrument(
self,
source: &'static str,
label: Option<String>,
_capacity: Option<usize>,
) -> Self::Output {
wrap_unbounded(self, source, label)
}
}
impl<T: Send + 'static> InstrumentChannel
for (
futures_channel::oneshot::Sender<T>,
futures_channel::oneshot::Receiver<T>,
)
{
type Output = (
futures_channel::oneshot::Sender<T>,
futures_channel::oneshot::Receiver<T>,
);
fn instrument(
self,
source: &'static str,
label: Option<String>,
_capacity: Option<usize>,
) -> Self::Output {
wrap_oneshot(self, source, label)
}
}
use crate::channels::InstrumentChannelLog;
impl<T: Send + std::fmt::Debug + 'static> InstrumentChannelLog
for (
futures_channel::mpsc::Sender<T>,
futures_channel::mpsc::Receiver<T>,
)
{
type Output = (
futures_channel::mpsc::Sender<T>,
futures_channel::mpsc::Receiver<T>,
);
fn instrument_log(
self,
source: &'static str,
label: Option<String>,
capacity: Option<usize>,
) -> Self::Output {
if capacity.is_none() {
panic!("Capacity is required for bounded futures channels, because they don't expose their capacity in a public API");
}
wrap_channel_log(self, source, label, capacity.unwrap())
}
}
impl<T: Send + std::fmt::Debug + 'static> InstrumentChannelLog
for (
futures_channel::mpsc::UnboundedSender<T>,
futures_channel::mpsc::UnboundedReceiver<T>,
)
{
type Output = (
futures_channel::mpsc::UnboundedSender<T>,
futures_channel::mpsc::UnboundedReceiver<T>,
);
fn instrument_log(
self,
source: &'static str,
label: Option<String>,
_capacity: Option<usize>,
) -> Self::Output {
wrap_unbounded_log(self, source, label)
}
}
impl<T: Send + std::fmt::Debug + 'static> InstrumentChannelLog
for (
futures_channel::oneshot::Sender<T>,
futures_channel::oneshot::Receiver<T>,
)
{
type Output = (
futures_channel::oneshot::Sender<T>,
futures_channel::oneshot::Receiver<T>,
);
fn instrument_log(
self,
source: &'static str,
label: Option<String>,
_capacity: Option<usize>,
) -> Self::Output {
wrap_oneshot_log(self, source, label)
}
}