use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll};
use diatomic_waker::{WakeSink, WakeSource};
use futures_core::Stream;
use futures_util::stream::StreamExt;
use serde::Serialize;
use crate::model::Message;
use crate::path::Path;
use crate::simulation::{DuplicateEventSinkError, SimInit};
use crate::util::shared_cell::SharedCell;
use super::{EventSinkReader, EventSinkWriter, SinkState};
pub fn event_slot<T: Send>(state: SinkState) -> (EventSlotWriter<T>, EventSlotReader<T>) {
let event_slot = Arc::new(EventSlot {
is_enabled: AtomicBool::new(state == SinkState::Enabled),
writer_count: AtomicUsize::new(1),
cell: SharedCell::new(),
});
let wake_sink = WakeSink::new();
let wake_source = wake_sink.source();
let reader = EventSlotReader {
inner: event_slot.clone(),
wake_sink,
};
let writer = EventSlotWriter {
inner: event_slot,
wake_source,
};
(writer, reader)
}
pub fn event_slot_endpoint<T: Message + Serialize + Send + 'static>(
sim_init: &mut SimInit,
state: SinkState,
path: impl Into<Path>,
) -> Result<EventSlotWriter<T>, DuplicateEventSinkError> {
let (writer, reader) = event_slot(state);
sim_init.bind_event_sink(reader, path).map(|()| writer)
}
pub fn event_slot_endpoint_raw<T: Serialize + Send + 'static>(
sim_init: &mut SimInit,
state: SinkState,
path: impl Into<Path>,
) -> Result<EventSlotWriter<T>, DuplicateEventSinkError> {
let (writer, reader) = event_slot(state);
sim_init.bind_event_sink_raw(reader, path).map(|()| writer)
}
#[derive(Default)]
struct EventSlot<T: Send> {
is_enabled: AtomicBool,
writer_count: AtomicUsize,
cell: SharedCell<T>,
}
impl<T: Send> fmt::Debug for EventSlot<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EventSlotReader")
.field("is_enabled", &self.is_enabled.load(Ordering::Relaxed))
.field("writer_count", &self.writer_count.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
pub struct EventSlotReader<T: Send> {
inner: Arc<EventSlot<T>>,
wake_sink: WakeSink,
}
impl<T: Send> Stream for EventSlotReader<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(event) = self.inner.cell.try_read() {
return Poll::Ready(Some(event));
}
self.wake_sink.register(cx.waker());
if let Some(event) = self.inner.cell.try_read() {
self.wake_sink.unregister();
return Poll::Ready(Some(event));
} else if self.inner.writer_count.load(Ordering::Relaxed) == 0 {
self.wake_sink.unregister();
return Poll::Ready(None);
}
Poll::Pending
}
}
impl<T: Send + 'static> EventSinkReader<T> for EventSlotReader<T> {
fn enable(&mut self) {
self.inner.is_enabled.store(true, Ordering::Relaxed);
}
fn disable(&mut self) {
self.inner.is_enabled.store(false, Ordering::Relaxed);
}
fn try_read(&mut self) -> Option<T> {
self.inner.cell.try_read()
}
fn read(&mut self) -> Option<T> {
pollster::block_on(self.next())
}
}
impl<T: Send> fmt::Debug for EventSlotReader<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EventSlotReader")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
pub struct EventSlotWriter<T: Send> {
inner: Arc<EventSlot<T>>,
wake_source: WakeSource,
}
impl<T: Send + 'static> EventSinkWriter<T> for EventSlotWriter<T> {
fn write(&self, event: T) {
if !self.inner.is_enabled.load(Ordering::Relaxed) {
return;
}
if self.inner.cell.try_write(event).is_ok() {
self.wake_source.notify();
}
}
}
impl<T: Send> Drop for EventSlotWriter<T> {
fn drop(&mut self) {
if self.inner.writer_count.fetch_sub(1, Ordering::Relaxed) == 1 {
self.wake_source.notify();
}
}
}
impl<T: Send> Clone for EventSlotWriter<T> {
fn clone(&self) -> Self {
self.inner.writer_count.fetch_add(1, Ordering::Relaxed);
Self {
inner: self.inner.clone(),
wake_source: self.wake_source.clone(),
}
}
}
impl<T: Send> fmt::Debug for EventSlotWriter<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EventSlotWriter")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
#[cfg(all(test, not(nexosim_loom)))]
mod tests {
#[cfg(not(miri))]
use std::thread;
#[cfg(not(miri))]
use std::time::Duration;
use super::*;
#[test]
fn event_slot_try_read_single_threaded() {
let (writer, mut reader) = event_slot(SinkState::Enabled);
assert!(reader.try_read().is_none());
writer.write(123);
assert_eq!(reader.try_read(), Some(123));
writer.write(7);
writer.write(42);
assert_eq!(reader.try_read(), Some(42));
}
#[test]
fn event_slot_read_drop_single_threaded() {
let (writer1, mut reader) = event_slot(SinkState::Enabled);
let writer2 = writer1.clone();
writer1.write(123);
drop(writer1);
drop(writer2);
assert_eq!(reader.read(), Some(123));
assert!(reader.read().is_none());
}
#[cfg(not(miri))]
#[test]
fn event_slot_try_read_multi_threaded() {
let (writer, mut reader) = event_slot(SinkState::Enabled);
assert!(reader.try_read().is_none());
let th = thread::spawn(move || {
writer.write(123);
thread::sleep(Duration::from_millis(20));
});
thread::sleep(Duration::from_millis(10));
assert_eq!(reader.try_read(), Some(123));
th.join().unwrap();
}
#[cfg(not(miri))]
#[test]
fn event_slot_read_multi_threaded() {
let (writer1, mut reader) = event_slot(SinkState::Enabled);
let writer2 = writer1.clone();
let th1 = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
writer1.write(123);
thread::sleep(Duration::from_millis(30));
});
let th2 = thread::spawn(move || {
thread::sleep(Duration::from_millis(20));
writer2.write(42);
thread::sleep(Duration::from_millis(20));
});
assert_eq!(reader.read(), Some(123));
assert_eq!(reader.read(), Some(42));
assert!(reader.try_read().is_none());
th1.join().unwrap();
th2.join().unwrap();
}
#[cfg(not(miri))]
#[test]
fn event_slot_drop_multi_threaded() {
let (writer1, mut reader) = event_slot(SinkState::Enabled);
let writer2 = writer1.clone();
let th1 = thread::spawn(move || drop(writer1));
let th2 = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
writer2.write(123);
thread::sleep(Duration::from_millis(10));
});
assert_eq!(reader.read(), Some(123));
assert!(reader.read().is_none());
th1.join().unwrap();
th2.join().unwrap();
}
}