use super::{LCPipe, Message, SendError, UnsafeCond};
use crate::cbus::RecvError;
use crate::fiber::Cond;
use std::cell::RefCell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Weak};
use std::time::Duration;
pub(super) struct Waker {
condition: Option<Arc<UnsafeCond>>,
woken: AtomicBool,
}
impl Waker {
pub(super) fn new(cond: Cond) -> Self {
Self {
condition: Some(Arc::new(UnsafeCond(cond))),
woken: AtomicBool::new(false),
}
}
pub(super) fn force_wakeup(&self, cond: Arc<UnsafeCond>, pipe: &mut LCPipe) {
let msg = Message::new(move || {
unsafe { (*cond).as_ref().signal() };
});
pipe.push_message(msg);
}
pub(super) fn wakeup(&self, pipe: &mut LCPipe) {
let do_wake = self
.woken
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok();
if do_wake {
let cond = Arc::clone(
self.condition
.as_ref()
.expect("unreachable: condition never empty"),
);
self.force_wakeup(cond, pipe);
}
}
pub(super) fn wait(&self) {
if self
.woken
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
let cond = self
.condition
.as_ref()
.expect("unreachable: condition never empty");
unsafe { (**cond).as_ref().wait_timeout(Duration::from_millis(1)) };
}
}
}
struct Channel<T> {
list: crossbeam_queue::SegQueue<T>,
disconnected: AtomicBool,
cbus_endpoint: String,
}
impl<T> Channel<T> {
fn new(cbus_endpoint: &str) -> Self {
Self {
list: crossbeam_queue::SegQueue::new(),
disconnected: AtomicBool::new(false),
cbus_endpoint: cbus_endpoint.to_string(),
}
}
}
pub fn channel<T>(cbus_endpoint: &str) -> (Sender<T>, EndpointReceiver<T>) {
let chan = Arc::new(Channel::new(cbus_endpoint));
let waker = Arc::new(Waker::new(Cond::new()));
let arc_guard = Arc::new(Mutex::default());
let s = Sender {
inner: Arc::new(SenderInner {
chan: Arc::clone(&chan),
}),
waker: Arc::downgrade(&waker),
lcpipe: RefCell::new(LCPipe::new(&chan.cbus_endpoint)),
arc_guard: Arc::clone(&arc_guard),
};
let r = EndpointReceiver {
chan: Arc::clone(&chan),
waker: Some(Arc::clone(&waker)),
arc_guard,
};
(s, r)
}
struct SenderInner<T> {
chan: Arc<Channel<T>>,
}
unsafe impl<T> Send for SenderInner<T> {}
impl<T> Drop for SenderInner<T> {
fn drop(&mut self) {
self.chan.disconnected.store(true, Ordering::Release);
}
}
pub struct Sender<T> {
inner: Arc<SenderInner<T>>,
waker: Weak<Waker>,
lcpipe: RefCell<LCPipe>,
arc_guard: Arc<Mutex<()>>,
}
unsafe impl<T> Send for Sender<T> {}
unsafe impl<T> Sync for Sender<T> {}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let _crit_section = self.arc_guard.lock().unwrap();
if let Some(waker) = self.waker.upgrade() {
waker.wakeup(&mut self.lcpipe.borrow_mut());
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
waker: self.waker.clone(),
lcpipe: RefCell::new(LCPipe::new(&self.inner.chan.cbus_endpoint)),
arc_guard: self.arc_guard.clone(),
}
}
}
impl<T> Sender<T> {
pub fn send(&self, msg: T) -> Result<(), SendError<T>> {
let _crit_section = self.arc_guard.lock().unwrap();
if let Some(waker) = self.waker.upgrade() {
self.inner.chan.list.push(msg);
waker.wakeup(&mut self.lcpipe.borrow_mut());
Ok(())
} else {
Err(SendError(msg))
}
}
}
pub struct EndpointReceiver<T> {
chan: Arc<Channel<T>>,
waker: Option<Arc<Waker>>,
arc_guard: Arc<Mutex<()>>,
}
unsafe impl<T> Send for EndpointReceiver<T> {}
impl<T> Drop for EndpointReceiver<T> {
fn drop(&mut self) {
let _crit_section = self.arc_guard.lock().unwrap();
drop(self.waker.take());
}
}
impl<T> EndpointReceiver<T> {
pub fn receive(&self) -> Result<T, RecvError> {
loop {
if let Some(msg) = self.chan.list.pop() {
return Ok(msg);
}
if self.chan.disconnected.load(Ordering::Acquire) {
return Err(RecvError::Disconnected);
}
self.waker
.as_ref()
.expect("unreachable: waker must exists")
.wait();
}
}
pub fn len(&self) -> usize {
self.chan.list.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(feature = "internal_test")]
#[allow(clippy::redundant_pattern_matching)]
mod tests {
use super::super::tests::run_cbus_endpoint;
use crate::cbus::{unbounded, RecvError};
use crate::fiber;
use crate::fiber::{check_yield, YieldResult};
use std::thread;
use std::thread::JoinHandle;
use std::time::Duration;
#[crate::test(tarantool = "crate")]
pub fn unbounded_test() {
let cbus_fiber_id = run_cbus_endpoint("unbounded_test");
let (tx, rx) = unbounded::channel("unbounded_test");
let thread = thread::spawn(move || {
for i in 0..1000 {
_ = tx.send(i);
if i % 100 == 0 {
thread::sleep(Duration::from_millis(1000));
}
}
});
assert_eq!(
check_yield(|| {
let mut recv_results = vec![];
for _ in 0..1000 {
recv_results.push(rx.receive().unwrap());
}
recv_results
}),
YieldResult::Yielded((0..1000).collect::<Vec<_>>())
);
thread.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn unbounded_test_drop_rx_before_tx() {
let cbus_fiber_id = run_cbus_endpoint("unbounded_test_drop_rx_before_tx");
let (tx, rx) = unbounded::channel("unbounded_test_drop_rx_before_tx");
let thread = thread::spawn(move || {
for i in 1..300 {
_ = tx.send(i);
if i % 100 == 0 {
thread::sleep(Duration::from_secs(1));
}
}
});
fiber::sleep(Duration::from_secs(1));
drop(rx);
thread.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn unbounded_disconnect_test() {
let cbus_fiber_id = run_cbus_endpoint("unbounded_disconnect_test");
let (tx, rx) = unbounded::channel("unbounded_disconnect_test");
let thread = thread::spawn(move || {
_ = tx.send(1);
_ = tx.send(2);
});
assert!(matches!(rx.receive(), Ok(1)));
assert!(matches!(rx.receive(), Ok(2)));
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
thread.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn unbounded_mpsc_test() {
const MESSAGES_PER_PRODUCER: i32 = 10_000;
let cbus_fiber_id = run_cbus_endpoint("unbounded_mpsc_test");
let (tx, rx) = unbounded::channel("unbounded_mpsc_test");
fn create_producer(sender: unbounded::Sender<i32>) -> JoinHandle<()> {
thread::spawn(move || {
for i in 0..MESSAGES_PER_PRODUCER {
_ = sender.send(i);
}
})
}
let jh1 = create_producer(tx.clone());
let jh2 = create_producer(tx.clone());
let jh3 = create_producer(tx);
for _ in 0..MESSAGES_PER_PRODUCER * 3 {
assert!(matches!(rx.receive(), Ok(_)));
}
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
jh1.join().unwrap();
jh2.join().unwrap();
jh3.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
}