use std::sync::mpsc::{Receiver as MpscReceiver, RecvError, SendError, SyncSender, sync_channel};
use std::sync::{Arc, Condvar, Mutex};
pub const BYTE_CHANNEL_DEFAULT_CAPACITY: usize = 64 * 1024 * 1024;
const INNER_SLOT_CAPACITY: usize = 200_000;
pub trait HasByteSize {
fn byte_size(&self) -> usize;
}
impl HasByteSize for crate::pes::PesFrame {
fn byte_size(&self) -> usize {
self.data.len() + 14
}
}
struct Accounting {
used: Mutex<usize>,
cv: Condvar,
capacity: usize,
}
pub struct Sender<T: HasByteSize> {
tx: SyncSender<T>,
acct: Arc<Accounting>,
}
impl<T: HasByteSize> Clone for Sender<T> {
fn clone(&self) -> Self {
Sender {
tx: self.tx.clone(),
acct: self.acct.clone(),
}
}
}
impl<T: HasByteSize> Sender<T> {
pub fn send(&self, item: T) -> Result<(), SendError<T>> {
let sz = item.byte_size();
{
let mut used = self.acct.used.lock().expect("byte_channel poisoned");
while *used + sz > self.acct.capacity && *used > 0 {
used = self.acct.cv.wait(used).expect("byte_channel cv poisoned");
}
*used += sz;
}
match self.tx.send(item) {
Ok(()) => Ok(()),
Err(SendError(returned)) => {
let mut used = self.acct.used.lock().expect("byte_channel poisoned");
*used = used.saturating_sub(sz);
self.acct.cv.notify_all();
Err(SendError(returned))
}
}
}
}
pub struct Receiver<T: HasByteSize> {
rx: MpscReceiver<T>,
acct: Arc<Accounting>,
}
impl<T: HasByteSize> Receiver<T> {
pub fn recv(&self) -> Result<T, RecvError> {
let item = self.rx.recv()?;
let sz = item.byte_size();
let mut used = self.acct.used.lock().expect("byte_channel poisoned");
*used = used.saturating_sub(sz);
self.acct.cv.notify_all();
Ok(item)
}
}
pub fn channel<T: HasByteSize>(capacity_bytes: usize) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = sync_channel::<T>(INNER_SLOT_CAPACITY);
let acct = Arc::new(Accounting {
used: Mutex::new(0),
cv: Condvar::new(),
capacity: capacity_bytes,
});
(
Sender {
tx,
acct: acct.clone(),
},
Receiver { rx, acct },
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::{Duration, Instant};
#[derive(Clone, Debug, PartialEq, Eq)]
struct Item {
sz: usize,
tag: u32,
}
impl HasByteSize for Item {
fn byte_size(&self) -> usize {
self.sz
}
}
#[test]
fn send_recv_round_trip() {
let (tx, rx) = channel::<Item>(1024);
for i in 0..5 {
tx.send(Item { sz: 100, tag: i }).unwrap();
}
for i in 0..5 {
let got = rx.recv().unwrap();
assert_eq!(got, Item { sz: 100, tag: i });
}
}
#[test]
fn byte_accounting_decrements_on_recv() {
let (tx, rx) = channel::<Item>(1024);
for _ in 0..4 {
tx.send(Item { sz: 256, tag: 0 }).unwrap();
}
for _ in 0..4 {
rx.recv().unwrap();
}
let start = Instant::now();
tx.send(Item { sz: 1024, tag: 99 }).unwrap();
assert!(start.elapsed() < Duration::from_millis(100));
let got = rx.recv().unwrap();
assert_eq!(got.tag, 99);
}
#[test]
fn send_blocks_at_capacity_unblocks_on_recv() {
let (tx, rx) = channel::<Item>(200);
tx.send(Item { sz: 100, tag: 0 }).unwrap();
tx.send(Item { sz: 100, tag: 1 }).unwrap();
let tx2 = tx.clone();
let sent_at = Arc::new(Mutex::new(None::<Instant>));
let sent_at2 = sent_at.clone();
let h = thread::spawn(move || {
tx2.send(Item { sz: 100, tag: 2 }).unwrap();
*sent_at2.lock().unwrap() = Some(Instant::now());
});
thread::sleep(Duration::from_millis(100));
assert!(
sent_at.lock().unwrap().is_none(),
"third send should be blocked at capacity"
);
let recv_at = Instant::now();
let got = rx.recv().unwrap();
assert_eq!(got.tag, 0);
h.join().unwrap();
let sent_when = sent_at.lock().unwrap().unwrap();
assert!(
sent_when >= recv_at,
"sender must complete AFTER receiver freed capacity"
);
assert_eq!(rx.recv().unwrap().tag, 1);
assert_eq!(rx.recv().unwrap().tag, 2);
}
#[test]
fn item_larger_than_capacity_still_goes_through() {
let (tx, rx) = channel::<Item>(100);
tx.send(Item { sz: 1000, tag: 7 }).unwrap();
let got = rx.recv().unwrap();
assert_eq!(got, Item { sz: 1000, tag: 7 });
}
#[test]
fn concurrent_send_recv_stress() {
const SENDERS: u32 = 4;
const PER_SENDER: u32 = 1000;
const TOTAL: u32 = SENDERS * PER_SENDER;
let (tx, rx) = channel::<Item>(8 * 1024);
let sent = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for s in 0..SENDERS {
let tx = tx.clone();
let sent = sent.clone();
handles.push(thread::spawn(move || {
for i in 0..PER_SENDER {
let sz = 1 + ((i as usize) % 256);
tx.send(Item {
sz,
tag: s * PER_SENDER + i,
})
.unwrap();
sent.fetch_add(1, Ordering::SeqCst);
}
}));
}
drop(tx);
let mut received = 0u32;
while let Ok(_item) = rx.recv() {
received += 1;
}
for h in handles {
h.join().unwrap();
}
assert_eq!(received, TOTAL);
assert_eq!(sent.load(Ordering::SeqCst) as u32, TOTAL);
}
#[test]
fn send_after_recv_dropped_returns_err() {
let (tx, rx) = channel::<Item>(1024);
drop(rx);
let r = tx.send(Item { sz: 10, tag: 0 });
assert!(r.is_err());
}
}