use super::{LCPipe, Message};
use crate::cbus::RecvError;
use crate::fiber::Cond;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
struct Waker {
condition: Option<Arc<Cond>>,
woken: AtomicBool,
pipe: LCPipe,
}
impl Waker {
fn new(cond: Cond, pipe: LCPipe) -> Self {
Self {
condition: Some(Arc::new(cond)),
woken: AtomicBool::new(false),
pipe,
}
}
fn force_wakeup(&self, cond: Arc<Cond>) {
let msg = Message::new(move || {
cond.signal();
});
self.pipe.push_message(msg);
}
fn wakeup(&self) {
let do_wake = self
.woken
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok();
if do_wake {
let cond = Arc::clone(
self.condition
.as_ref()
.expect("unreachable: condition never empty"),
);
self.force_wakeup(cond);
}
}
fn wait(&self) {
while self
.woken
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
self.condition
.as_ref()
.expect("unreachable: condition never empty")
.wait();
}
}
}
impl Drop for Waker {
fn drop(&mut self) {
if let Some(cond) = self.condition.take() {
self.force_wakeup(cond);
}
}
}
struct Channel<T> {
list: crossbeam_queue::SegQueue<T>,
waker: Waker,
disconnected: AtomicBool,
}
impl<T> Channel<T> {
fn new(pipe: LCPipe) -> Self {
let cond = Cond::new();
Self {
list: crossbeam_queue::SegQueue::new(),
waker: Waker::new(cond, pipe),
disconnected: AtomicBool::new(false),
}
}
}
pub fn channel<T>(cbus_endpoint: &str) -> (Sender<T>, EndpointReceiver<T>) {
let pipe = LCPipe::new(cbus_endpoint);
let chan = Arc::new(Channel::new(pipe));
let s = SenderInner {
chan: Arc::clone(&chan),
};
let r = EndpointReceiver {
chan: Arc::clone(&chan),
};
(Sender { inner: Arc::new(s) }, r)
}
struct SenderInner<T> {
chan: Arc<Channel<T>>,
}
unsafe impl<T> Send for SenderInner<T> {}
impl<T> Drop for SenderInner<T> {
fn drop(&mut self) {
self.chan.disconnected.store(true, Ordering::Release);
self.chan.waker.wakeup();
}
}
pub struct Sender<T> {
inner: Arc<SenderInner<T>>,
}
unsafe impl<T> Send for Sender<T> {}
unsafe impl<T> Sync for Sender<T> {}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Sender<T> {
pub fn send(&self, msg: T) {
self.inner.chan.list.push(msg);
self.inner.chan.waker.wakeup();
}
}
pub struct EndpointReceiver<T> {
chan: Arc<Channel<T>>,
}
unsafe impl<T> Send for EndpointReceiver<T> {}
impl<T> EndpointReceiver<T> {
pub fn receive(&self) -> Result<T, RecvError> {
loop {
if let Some(msg) = self.chan.list.pop() {
return Ok(msg);
}
if self.chan.disconnected.load(Ordering::Acquire) {
return Err(RecvError::Disconnected);
}
self.chan.waker.wait();
}
}
pub fn len(&self) -> usize {
self.chan.list.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(feature = "internal_test")]
mod tests {
use super::super::tests::run_cbus_endpoint;
use crate::cbus::{unbounded, RecvError};
use crate::fiber;
use crate::fiber::{check_yield, YieldResult};
use std::thread;
use std::thread::JoinHandle;
use std::time::Duration;
#[crate::test(tarantool = "crate")]
pub fn unbounded_test() {
let mut cbus_fiber = run_cbus_endpoint("unbounded_test");
let (tx, rx) = unbounded::channel("unbounded_test");
let thread = thread::spawn(move || {
for i in 0..1000 {
tx.send(i);
if i % 100 == 0 {
thread::sleep(Duration::from_millis(1000));
}
}
});
assert_eq!(
check_yield(|| {
let mut recv_results = vec![];
for _ in 0..1000 {
recv_results.push(rx.receive().unwrap());
}
recv_results
}),
YieldResult::Yielded((0..1000).collect::<Vec<_>>())
);
thread.join().unwrap();
cbus_fiber.cancel();
}
#[crate::test(tarantool = "crate")]
pub fn unbounded_test_drop_rx_before_tx() {
let mut cbus_fiber = run_cbus_endpoint("unbounded_test_drop_rx_before_tx");
let (tx, rx) = unbounded::channel("unbounded_test_drop_rx_before_tx");
let thread = thread::spawn(move || {
for i in 1..300 {
tx.send(i);
if i % 100 == 0 {
thread::sleep(Duration::from_secs(1));
}
}
});
fiber::sleep(Duration::from_secs(1));
drop(rx);
thread.join().unwrap();
cbus_fiber.cancel();
}
#[crate::test(tarantool = "crate")]
pub fn unbounded_disconnect_test() {
let mut cbus_fiber = run_cbus_endpoint("unbounded_disconnect_test");
let (tx, rx) = unbounded::channel("unbounded_disconnect_test");
let thread = thread::spawn(move || {
tx.send(1);
tx.send(2);
});
assert!(matches!(rx.receive(), Ok(1)));
assert!(matches!(rx.receive(), Ok(2)));
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
thread.join().unwrap();
cbus_fiber.cancel();
}
#[crate::test(tarantool = "crate")]
pub fn unbounded_mpsc_test() {
const MESSAGES_PER_PRODUCER: i32 = 10_000;
let mut cbus_fiber = run_cbus_endpoint("unbounded_mpsc_test");
let (tx, rx) = unbounded::channel("unbounded_mpsc_test");
fn create_producer(sender: unbounded::Sender<i32>) -> JoinHandle<()> {
thread::spawn(move || {
for i in 0..MESSAGES_PER_PRODUCER {
sender.send(i);
}
})
}
let jh1 = create_producer(tx.clone());
let jh2 = create_producer(tx.clone());
let jh3 = create_producer(tx);
for _ in 0..MESSAGES_PER_PRODUCER * 3 {
assert!(matches!(rx.receive(), Ok(_)));
}
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
jh1.join().unwrap();
jh2.join().unwrap();
jh3.join().unwrap();
cbus_fiber.cancel();
}
}