#![cfg(any(feature = "tokio_components", doc))]
use crate::cbus::{LCPipe, RecvError, SendError};
use crate::fiber::Cond;
use std::cell::RefCell;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
use std::thread;
use std::time::Duration;
use tokio::sync::Notify;
type CordWaker = crate::cbus::unbounded::Waker;
struct TaskWaker {
notify: Notify,
}
impl TaskWaker {
fn new() -> Self {
Self {
notify: Notify::default(),
}
}
async fn wait(&self, disconnected: &AtomicBool) {
if disconnected.load(Ordering::Acquire) {
return;
}
while (tokio::time::timeout(Duration::from_millis(10), self.notify.notified()).await)
.is_err()
{}
}
fn wakeup_one(&self) {
self.notify.notify_one();
}
fn wakeup_all(&self) {
self.notify.notify_waiters();
}
}
struct Channel<T> {
list: crossbeam_queue::ArrayQueue<T>,
disconnected: AtomicBool,
cbus_endpoint: String,
}
impl<T> Channel<T> {
fn new(cbus_endpoint: &str, cap: NonZeroUsize) -> Self {
Self {
list: crossbeam_queue::ArrayQueue::new(cap.into()),
disconnected: AtomicBool::new(false),
cbus_endpoint: cbus_endpoint.to_string(),
}
}
}
pub fn channel<T>(cbus_endpoint: &str, cap: NonZeroUsize) -> (Sender<T>, EndpointReceiver<T>) {
let chan = Arc::new(Channel::new(cbus_endpoint, cap));
let waker = Arc::new(CordWaker::new(Cond::new()));
let arc_guard = Arc::new(tokio::sync::Mutex::default());
let task_waker = Arc::new(TaskWaker::new());
let s = Sender {
inner: Arc::new(SenderInner {
chan: Arc::clone(&chan),
}),
cord_waker: Arc::downgrade(&waker),
task_waker: Arc::clone(&task_waker),
lcpipe: RefCell::new(LCPipe::new(&chan.cbus_endpoint)),
arc_guard: Arc::clone(&arc_guard),
};
let r = EndpointReceiver {
chan: Arc::clone(&chan),
cord_waker: Some(Arc::clone(&waker)),
task_waker: Arc::clone(&task_waker),
arc_guard,
};
(s, r)
}
struct SenderInner<T> {
chan: Arc<Channel<T>>,
}
unsafe impl<T> Send for SenderInner<T> {}
impl<T> Drop for SenderInner<T> {
fn drop(&mut self) {
self.chan.disconnected.store(true, Ordering::Release);
}
}
pub struct Sender<T> {
inner: Arc<SenderInner<T>>,
cord_waker: Weak<CordWaker>,
task_waker: Arc<TaskWaker>,
lcpipe: RefCell<LCPipe>,
arc_guard: Arc<tokio::sync::Mutex<()>>,
}
unsafe impl<T> Send for Sender<T> {}
unsafe impl<T> Sync for Sender<T> {}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let crit = self.arc_guard.clone();
let cord_waker = self.cord_waker.clone();
let lcpipe: &mut LCPipe = &mut self.lcpipe.borrow_mut();
thread::scope(move |s| {
s.spawn(move || {
let _crit_section = crit.blocking_lock();
if let Some(waker) = cord_waker.upgrade() {
waker.wakeup(lcpipe);
}
});
});
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
cord_waker: self.cord_waker.clone(),
task_waker: self.task_waker.clone(),
lcpipe: RefCell::new(LCPipe::new(&self.inner.chan.cbus_endpoint)),
arc_guard: self.arc_guard.clone(),
}
}
}
impl<T> Sender<T> {
pub async fn send(&self, msg: T) -> Result<(), SendError<T>> {
let mut msg = msg;
loop {
let crit_section = self.arc_guard.lock().await;
if let Some(waker) = self.cord_waker.upgrade() {
let push_result = self.inner.chan.list.push(msg);
if let Err(not_accepted_msg) = push_result {
drop(waker);
drop(crit_section);
self.task_waker.wait(&self.inner.chan.disconnected).await;
if self.inner.chan.disconnected.load(Ordering::Acquire) {
return Err(SendError(not_accepted_msg));
}
msg = not_accepted_msg;
} else {
waker.wakeup(&mut self.lcpipe.borrow_mut());
return Ok(());
}
} else {
return Err(SendError(msg));
}
}
}
}
pub struct EndpointReceiver<T> {
chan: Arc<Channel<T>>,
cord_waker: Option<Arc<CordWaker>>,
task_waker: Arc<TaskWaker>,
arc_guard: Arc<tokio::sync::Mutex<()>>,
}
unsafe impl<T> Send for EndpointReceiver<T> {}
impl<T> Drop for EndpointReceiver<T> {
fn drop(&mut self) {
self.chan.disconnected.store(true, Ordering::Release);
self.task_waker.wakeup_all();
let _crit_section = self.arc_guard.blocking_lock();
drop(self.cord_waker.take());
}
}
impl<T> EndpointReceiver<T> {
pub fn receive(&self) -> Result<T, RecvError> {
loop {
if let Some(msg) = self.chan.list.pop() {
self.task_waker.wakeup_one();
return Ok(msg);
}
if self.chan.disconnected.load(Ordering::Acquire) {
return Err(RecvError::Disconnected);
}
self.cord_waker
.as_ref()
.expect("unreachable: waker must exists")
.wait();
}
}
pub fn len(&self) -> usize {
self.chan.list.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(feature = "internal_test")]
mod tests {
use crate::cbus::sync;
use crate::cbus::tests::run_cbus_endpoint;
use crate::cbus::RecvError;
use crate::fiber;
use crate::fiber::{check_yield, YieldResult};
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread;
use std::time::Duration;
#[crate::test(tarantool = "crate")]
pub fn single_producer() {
let cbus_fiber_id = run_cbus_endpoint("tokio_single_producer");
let cap = NonZeroUsize::new(10).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_single_producer", cap);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
for i in 0..1000 {
_ = tx.send(i).await;
if i % 100 == 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
});
});
assert_eq!(
check_yield(|| {
let mut recv_results = vec![];
for _ in 0..1000 {
recv_results.push(rx.receive().unwrap());
}
recv_results
}),
YieldResult::Yielded((0..1000).collect::<Vec<_>>())
);
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn single_producer_lock() {
let cbus_fiber_id = run_cbus_endpoint("tokio_single_producer_lock");
static SEND_COUNTER: AtomicU64 = AtomicU64::new(0);
let cap = NonZeroUsize::new(10).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_single_producer_lock", cap);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
for i in 0..100 {
_ = tx.send(i).await;
SEND_COUNTER.fetch_add(1, Ordering::SeqCst);
}
});
});
fiber::sleep(Duration::from_millis(100));
let mut recv_results = vec![];
for i in 0..10 {
assert_eq!(SEND_COUNTER.load(Ordering::SeqCst), (i + 1) * 10);
for _ in 0..10 {
recv_results.push(rx.receive().unwrap());
}
fiber::sleep(Duration::from_millis(100));
}
assert_eq!((0..100).collect::<Vec<_>>(), recv_results);
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn drop_rx_before_tx() {
let cbus_fiber_id = run_cbus_endpoint("tokio_drop_rx_before_tx");
let cap = NonZeroUsize::new(1000).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_drop_rx_before_tx", cap);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
for i in 1..300 {
_ = tx.send(i).await;
if i % 100 == 0 {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
});
});
fiber::sleep(Duration::from_secs(1));
drop(rx);
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn tx_disconnect() {
let cbus_fiber_id = run_cbus_endpoint("tokio_tx_disconnect");
let cap = NonZeroUsize::new(1).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_tx_disconnect", cap);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
_ = tx.send(1).await;
_ = tx.send(2).await;
});
});
assert!(matches!(rx.receive(), Ok(1)));
assert!(matches!(rx.receive(), Ok(2)));
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn rx_disconnect() {
let cbus_fiber_id = run_cbus_endpoint("tokio_rx_disconnect");
let cap = NonZeroUsize::new(1).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_rx_disconnect", cap);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
assert!(tx.send(1).await.is_ok());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(tx.send(2).await.is_err());
});
});
assert!(matches!(rx.receive(), Ok(1)));
drop(rx);
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn multiple_producer() {
const MESSAGES_PER_PRODUCER: i32 = 10_000;
let cbus_fiber_id = run_cbus_endpoint("tokio_multiple_producer");
let cap = NonZeroUsize::new(10).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_multiple_producer", cap);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
let mut handles = vec![];
for _ in 0..3 {
let sender = tx.clone();
let jh = tokio::spawn(async move {
for i in 0..MESSAGES_PER_PRODUCER {
_ = sender.send(i).await;
}
});
handles.push(jh);
}
for h in handles {
h.await.unwrap();
}
});
});
for _ in 0..MESSAGES_PER_PRODUCER * 3 {
assert!(rx.receive().is_ok());
}
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
#[crate::test(tarantool = "crate")]
pub fn multiple_producer_lock() {
const MESSAGES_PER_PRODUCER: i32 = 100;
let cbus_fiber_id = run_cbus_endpoint("tokio_multiple_producer_lock");
let cap = NonZeroUsize::new(10).unwrap();
let (tx, rx) = sync::tokio::channel("tokio_multiple_producer_lock", cap);
static SEND_COUNTER: AtomicU64 = AtomicU64::new(0);
let tokio_rt = thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
let mut handles = vec![];
for _ in 0..3 {
let sender = tx.clone();
let jh = tokio::spawn(async move {
for i in 0..MESSAGES_PER_PRODUCER {
_ = sender.send(i).await;
SEND_COUNTER.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(jh);
}
for h in handles {
h.await.unwrap();
}
});
});
fiber::sleep(Duration::from_millis(100));
for i in 0..10 * 3 {
assert_eq!(SEND_COUNTER.load(Ordering::SeqCst), (i + 1) * 10);
for _ in 0..10 {
assert!(rx.receive().is_ok());
}
fiber::sleep(Duration::from_millis(100));
}
assert!(matches!(rx.receive(), Err(RecvError::Disconnected)));
tokio_rt.join().unwrap();
assert!(fiber::cancel(cbus_fiber_id));
}
}