use std::mem;
use std::sync::mpsc::{self, Receiver, Sender, SyncSender};
use crate::{init_stats_state, ChannelType, StatsEvent};
fn wrap_sync_channel_impl<T, F>(
inner: (SyncSender<T>, Receiver<T>),
channel_id: &'static str,
label: Option<&'static str>,
capacity: usize,
mut log_on_send: F,
) -> (SyncSender<T>, Receiver<T>)
where
T: Send + 'static,
F: FnMut(&T) -> Option<String> + Send + 'static,
{
let (inner_tx, inner_rx) = inner;
let type_name = std::any::type_name::<T>();
let (outer_tx, to_inner_rx) = mpsc::sync_channel::<T>(capacity);
let (from_inner_tx, outer_rx) = mpsc::sync_channel::<T>(capacity);
let (stats_tx, _) = init_stats_state();
let _ = stats_tx.send(StatsEvent::Created {
id: channel_id,
display_label: label,
channel_type: ChannelType::Bounded(capacity),
type_name,
type_size: mem::size_of::<T>(),
});
let stats_tx_send = stats_tx.clone();
let stats_tx_recv = stats_tx.clone();
let (close_signal_tx, close_signal_rx) = mpsc::channel::<()>();
std::thread::spawn(move || {
loop {
match close_signal_rx.try_recv() {
Ok(_) => {
break;
}
Err(mpsc::TryRecvError::Disconnected) => {
break;
}
Err(mpsc::TryRecvError::Empty) => {
}
}
match to_inner_rx.recv_timeout(std::time::Duration::from_millis(10)) {
Ok(msg) => {
let log = log_on_send(&msg);
if inner_tx.send(msg).is_err() {
break;
}
let _ = stats_tx_send.send(StatsEvent::MessageSent {
id: channel_id,
log,
timestamp: std::time::SystemTime::now(),
});
}
Err(mpsc::RecvTimeoutError::Timeout) => {
continue;
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
break;
}
}
}
let _ = stats_tx_send.send(StatsEvent::Closed { id: channel_id });
});
std::thread::spawn(move || {
while let Ok(msg) = inner_rx.recv() {
if from_inner_tx.send(msg).is_err() {
let _ = close_signal_tx.send(());
break;
}
let _ = stats_tx_recv.send(StatsEvent::MessageReceived {
id: channel_id,
timestamp: std::time::SystemTime::now(),
});
}
let _ = stats_tx_recv.send(StatsEvent::Closed { id: channel_id });
});
(outer_tx, outer_rx)
}
pub(crate) fn wrap_sync_channel<T: Send + 'static>(
inner: (SyncSender<T>, Receiver<T>),
channel_id: &'static str,
label: Option<&'static str>,
capacity: usize,
) -> (SyncSender<T>, Receiver<T>) {
wrap_sync_channel_impl(inner, channel_id, label, capacity, |_| None)
}
pub(crate) fn wrap_sync_channel_log<T: Send + std::fmt::Debug + 'static>(
inner: (SyncSender<T>, Receiver<T>),
channel_id: &'static str,
label: Option<&'static str>,
capacity: usize,
) -> (SyncSender<T>, Receiver<T>) {
wrap_sync_channel_impl(inner, channel_id, label, capacity, |msg| {
Some(format!("{:?}", msg))
})
}
fn wrap_channel_impl<T, F>(
inner: (Sender<T>, Receiver<T>),
channel_id: &'static str,
label: Option<&'static str>,
mut log_on_send: F,
) -> (Sender<T>, Receiver<T>)
where
T: Send + 'static,
F: FnMut(&T) -> Option<String> + Send + 'static,
{
let (inner_tx, inner_rx) = inner;
let type_name = std::any::type_name::<T>();
let (outer_tx, to_inner_rx) = mpsc::channel::<T>();
let (from_inner_tx, outer_rx) = mpsc::channel::<T>();
let (stats_tx, _) = init_stats_state();
let _ = stats_tx.send(StatsEvent::Created {
id: channel_id,
display_label: label,
channel_type: ChannelType::Unbounded,
type_name,
type_size: mem::size_of::<T>(),
});
let stats_tx_send = stats_tx.clone();
let stats_tx_recv = stats_tx.clone();
let (close_signal_tx, close_signal_rx) = mpsc::channel::<()>();
std::thread::spawn(move || {
loop {
match close_signal_rx.try_recv() {
Ok(_) => {
break;
}
Err(mpsc::TryRecvError::Disconnected) => {
break;
}
Err(mpsc::TryRecvError::Empty) => {
}
}
match to_inner_rx.recv_timeout(std::time::Duration::from_millis(10)) {
Ok(msg) => {
let log = log_on_send(&msg);
if inner_tx.send(msg).is_err() {
break;
}
let _ = stats_tx_send.send(StatsEvent::MessageSent {
id: channel_id,
log,
timestamp: std::time::SystemTime::now(),
});
}
Err(mpsc::RecvTimeoutError::Timeout) => {
continue;
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
break;
}
}
}
let _ = stats_tx_send.send(StatsEvent::Closed { id: channel_id });
});
std::thread::spawn(move || {
while let Ok(msg) = inner_rx.recv() {
if from_inner_tx.send(msg).is_err() {
let _ = close_signal_tx.send(());
break;
}
let _ = stats_tx_recv.send(StatsEvent::MessageReceived {
id: channel_id,
timestamp: std::time::SystemTime::now(),
});
}
let _ = stats_tx_recv.send(StatsEvent::Closed { id: channel_id });
});
(outer_tx, outer_rx)
}
pub(crate) fn wrap_channel<T: Send + 'static>(
inner: (Sender<T>, Receiver<T>),
channel_id: &'static str,
label: Option<&'static str>,
) -> (Sender<T>, Receiver<T>) {
wrap_channel_impl(inner, channel_id, label, |_| None)
}
pub(crate) fn wrap_channel_log<T: Send + std::fmt::Debug + 'static>(
inner: (Sender<T>, Receiver<T>),
channel_id: &'static str,
label: Option<&'static str>,
) -> (Sender<T>, Receiver<T>) {
wrap_channel_impl(inner, channel_id, label, |msg| Some(format!("{:?}", msg)))
}
use crate::Instrument;
impl<T: Send + 'static> Instrument for (std::sync::mpsc::Sender<T>, std::sync::mpsc::Receiver<T>) {
type Output = (std::sync::mpsc::Sender<T>, std::sync::mpsc::Receiver<T>);
fn instrument(
self,
channel_id: &'static str,
label: Option<&'static str>,
_capacity: Option<usize>,
) -> Self::Output {
wrap_channel(self, channel_id, label)
}
}
impl<T: Send + 'static> Instrument
for (std::sync::mpsc::SyncSender<T>, std::sync::mpsc::Receiver<T>)
{
type Output = (std::sync::mpsc::SyncSender<T>, std::sync::mpsc::Receiver<T>);
fn instrument(
self,
channel_id: &'static str,
label: Option<&'static str>,
capacity: Option<usize>,
) -> Self::Output {
if capacity.is_none() {
panic!("Capacity is required for bounded std channels, because they don't expose their capacity in a public API");
}
wrap_sync_channel(self, channel_id, label, capacity.unwrap())
}
}
use crate::InstrumentLog;
impl<T: Send + std::fmt::Debug + 'static> InstrumentLog
for (std::sync::mpsc::Sender<T>, std::sync::mpsc::Receiver<T>)
{
type Output = (std::sync::mpsc::Sender<T>, std::sync::mpsc::Receiver<T>);
fn instrument_log(
self,
channel_id: &'static str,
label: Option<&'static str>,
_capacity: Option<usize>,
) -> Self::Output {
wrap_channel_log(self, channel_id, label)
}
}
impl<T: Send + std::fmt::Debug + 'static> InstrumentLog
for (std::sync::mpsc::SyncSender<T>, std::sync::mpsc::Receiver<T>)
{
type Output = (std::sync::mpsc::SyncSender<T>, std::sync::mpsc::Receiver<T>);
fn instrument_log(
self,
channel_id: &'static str,
label: Option<&'static str>,
capacity: Option<usize>,
) -> Self::Output {
if capacity.is_none() {
panic!("Capacity is required for bounded std channels, because they don't expose their capacity in a public API");
}
wrap_sync_channel_log(self, channel_id, label, capacity.unwrap())
}
}