use super::{LCPipe, Message, UnsafeCond};
use crate::cbus::RecvError;
use crate::fiber::Cond;
use std::cell::{RefCell, UnsafeCell};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Weak};
struct Channel<T> {
message: UnsafeCell<Option<T>>,
cond: Arc<UnsafeCond>,
ready: AtomicBool,
}
unsafe impl<T> Sync for Channel<T> where T: Send {}
unsafe impl<T> Send for Channel<T> where T: Send {}
impl<T> Channel<T> {
fn new() -> Self {
Self {
message: UnsafeCell::new(None),
ready: AtomicBool::new(false),
cond: Arc::new(UnsafeCond(Cond::new())),
}
}
}
pub struct Sender<T> {
channel: Weak<Channel<T>>,
pipe: RefCell<LCPipe>,
arc_guard: Arc<Mutex<()>>,
}
unsafe impl<T> Send for Sender<T> {}
unsafe impl<T> Sync for Sender<T> {}
pub struct EndpointReceiver<T> {
channel: Option<Arc<Channel<T>>>,
arc_guard: Arc<Mutex<()>>,
}
pub fn channel<T>(cbus_endpoint: &str) -> (Sender<T>, EndpointReceiver<T>) {
let channel = Arc::new(Channel::new());
let arc_guard = Arc::new(Mutex::default());
(
Sender {
channel: Arc::downgrade(&channel),
pipe: RefCell::new(LCPipe::new(cbus_endpoint)),
arc_guard: Arc::clone(&arc_guard),
},
EndpointReceiver {
channel: Some(channel),
arc_guard,
},
)
}
impl<T> Sender<T> {
pub fn send(self, message: T) {
let _crit_sect = self.arc_guard.lock().unwrap();
if let Some(chan) = self.channel.upgrade() {
unsafe { *chan.message.get() = Some(message) };
chan.ready.store(true, Ordering::Release);
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let _crit_sect = self.arc_guard.lock().unwrap();
let mb_chan = self.channel.upgrade();
let mb_cond = mb_chan.map(|chan| chan.cond.clone());
if let Some(cond) = mb_cond {
let msg = Message::new(move || {
unsafe { (*cond).as_ref().signal() };
});
self.pipe.borrow_mut().push_message(msg);
}
}
}
impl<T> EndpointReceiver<T> {
pub fn receive(self) -> Result<T, RecvError> {
let channel = self
.channel
.as_ref()
.expect("unreachable: channel must exists");
if !channel.ready.swap(false, Ordering::Acquire) {
unsafe {
(*channel.cond).as_ref().wait();
}
}
unsafe {
channel
.message
.get()
.as_mut()
.expect("unexpected null pointer")
.take()
}
.ok_or(RecvError::Disconnected)
}
}
impl<T> Drop for EndpointReceiver<T> {
fn drop(&mut self) {
let _crit_sect = self.arc_guard.lock().unwrap();
drop(self.channel.take());
}
}
impl<T> Default for Channel<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "internal_test")]
mod tests {
use super::super::tests::run_cbus_endpoint;
use crate::cbus::{oneshot, RecvError};
use crate::fiber;
use crate::fiber::{check_yield, YieldResult};
use std::time::Duration;
use std::{mem, thread};
#[crate::test(tarantool = "crate")]
pub fn oneshot_test() {
let cbus_fiber_id = run_cbus_endpoint("oneshot_test");
let (sender, receiver) = oneshot::channel("oneshot_test");
let thread = thread::spawn(move || {
thread::sleep(Duration::from_secs(1));
sender.send(1);
});
assert_eq!(
check_yield(|| { receiver.receive().unwrap() }),
YieldResult::Yielded(1)
);
thread.join().unwrap();
let (sender, receiver) = oneshot::channel("oneshot_test");
let thread = thread::spawn(move || {
sender.send(2);
});
thread.join().unwrap();
assert_eq!(
check_yield(|| { receiver.receive().unwrap() }),
YieldResult::DidntYield(2)
);
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn oneshot_multiple_channels_test() {
let cbus_fiber_id = run_cbus_endpoint("oneshot_multiple_channels_test");
let (sender1, receiver1) = oneshot::channel("oneshot_multiple_channels_test");
let (sender2, receiver2) = oneshot::channel("oneshot_multiple_channels_test");
let thread1 = thread::spawn(move || {
thread::sleep(Duration::from_secs(1));
sender1.send("1");
});
let thread2 = thread::spawn(move || {
thread::sleep(Duration::from_secs(2));
sender2.send("2");
});
let result2 = receiver2.receive();
let result1 = receiver1.receive();
assert!(matches!(result1, Ok("1")));
assert!(matches!(result2, Ok("2")));
thread1.join().unwrap();
thread2.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn oneshot_sender_drop_test() {
let cbus_fiber_id = run_cbus_endpoint("oneshot_sender_drop_test");
let (sender, receiver) = oneshot::channel::<()>("oneshot_sender_drop_test");
let thread = thread::spawn(move || {
thread::sleep(Duration::from_secs(1));
mem::drop(sender)
});
let result = receiver.receive();
assert!(matches!(result, Err(RecvError::Disconnected)));
thread.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
}