use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use std::task::{Context, Poll};
use atomiclock_async::AtomicLockAsync;
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum SendError {
#[error("The consumer has hung up")]
ConsumerHangup,
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum RecvError {
#[error("The producer has hung up")]
ProducerHangup,
}
#[derive(Debug)]
struct Hungup<T> {
data: Option<T>
}
impl<T> Hungup<T> {
fn expect_mut(&mut self, reason: &str) -> &mut T {
self.data.as_mut().expect(reason)
}
fn new(data: T) -> Self {
Hungup {
data: Some(data),
}
}
fn as_mut(&mut self) -> &mut Option<T> {
&mut self.data
}
fn replace_if_needed(&mut self, replacement: T) -> Option<T> {
if let Some(data) = &mut self.data {
let mut replacement = replacement;
std::mem::swap(data, &mut replacement);
Some(replacement)
}
else {
None
}
}
fn hangup(&mut self) -> T {
self.data.take().expect("Already hungup")
}
}
#[derive(Debug)]
struct Locked<T> {
pending_consumer: Hungup<Option<r#continue::Sender<Result<T,RecvError>>>>,
pending_producers: Hungup<Vec<PendingProducer<T>>>,
}
#[derive(Debug)]
struct Shared<T> {
lock: atomiclock_async::AtomicLockAsync<Locked<T>>,
}
#[derive(Debug)]
struct PendingProducer<T> {
data: T,
continuation: r#continue::Sender<Result<(),SendError>>,
}
impl<T> PendingProducer<T> {
fn into_inner(self) -> (T, r#continue::Sender<Result<(),SendError>>) {
(self.data, self.continuation)
}
}
pub fn channel<T>() -> (ChannelProducer<T>, ChannelConsumer<T>) {
let shared = Arc::new(Shared {
lock: AtomicLockAsync::new(Locked {
pending_consumer: Hungup::new(None),
pending_producers: Hungup::new(Vec::new()),
}),
});
(ChannelProducer { shared: shared.clone(), active_producers: Arc::new(AtomicU8::new(1)), async_dropped: false}, ChannelConsumer { shared, async_dropped: false })
}
#[derive(Debug)]
pub struct ChannelConsumer<T> {
shared: Arc<Shared<T>>,
async_dropped: bool,
}
pub struct ChannelConsumerRecvFuture<'a, T> {
future: Box<dyn Future<Output=Result<T,RecvError>> + Send>,
_inner: &'a mut ChannelConsumer<T>,
}
impl <'a, T> Future for ChannelConsumerRecvFuture<'a, T> {
type Output = Result<T,RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe{self.map_unchecked_mut(|s| Box::as_mut(&mut s.future))}.poll(cx)
}
}
impl<T: 'static + Send> ChannelConsumer<T> {
pub fn receive(&mut self) -> ChannelConsumerRecvFuture<'_, T> {
let shared = self.shared.clone();
ChannelConsumerRecvFuture {
_inner: self,
future:Box::new(async move {
let future = {
let mut lock = shared.lock.lock().await;
if let Some(p) = lock.pending_producers.expect_mut("Pending producers hungup").pop() {
let (data, continuation) = p.into_inner();
continuation.send(Ok(()));
return Ok(data);
}
else {
match lock.pending_consumer.as_mut() {
Some(p) => {
let (sender, future) = r#continue::continuation();
*p = Some(sender);
future
}
None => {
return Err(RecvError::ProducerHangup);
}
}
}
};
future.await
}),
}
}
pub async fn async_drop(&mut self) {
assert!(!self.async_dropped);
let mut lock = self.shared.lock.lock().await;
lock.pending_consumer.hangup();
if let Some(producers) = lock.pending_producers.replace_if_needed(Vec::new()) {
drop(lock);
for producer in producers {
producer.continuation.send(Err(SendError::ConsumerHangup));
}
}
self.async_dropped = true;
}
}
impl<T> Drop for ChannelConsumer<T> {
fn drop(&mut self) {
if !std::thread::panicking() {
assert!(self.async_dropped, "You must call async_drop on the consumer before dropping it");
}
}
}
#[derive(Debug)]
pub struct ChannelProducer<T> {
shared: Arc<Shared<T>>,
active_producers: Arc<AtomicU8>,
async_dropped: bool,
}
pub struct ChannelProducerSendFuture<'a, T> {
inner: &'a mut ChannelProducer<T>,
future: Box<dyn Future<Output=Result<(),SendError>> + Send>,
}
impl Debug for ChannelProducerSendFuture<'_, ()> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelProducerSendFuture")
.field("_inner", self.inner)
.finish()
}
}
impl<'a, T> Future for ChannelProducerSendFuture<'a, T>
{
type Output = Result<(), SendError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe{self.map_unchecked_mut(|s| Box::as_mut(&mut s.future))}.poll(cx)
}
}
impl<T: Send + 'static> ChannelProducer<T> {
pub fn send(&mut self, data: T) -> ChannelProducerSendFuture<T> {
let move_shared = self.shared.clone();
ChannelProducerSendFuture {
inner: self,
future: Box::new(async move {
let future = {
let mut lock = move_shared.lock.lock().await;
match lock.pending_consumer.as_mut() {
Some(maybe_consumer) => {
match maybe_consumer.take() {
Some(consumer) => {
consumer.send(Ok(data));
return Ok(());
}
None => {
let (sender, future) = r#continue::continuation();
lock.pending_producers.expect_mut("Pending producers hungup").push(PendingProducer {
data,
continuation: sender,
});
future
}
}
}
None => {
return Err(SendError::ConsumerHangup);
}
}
};
future.await
}),
}
}
pub async fn async_drop(&mut self) {
assert!(!self.async_dropped);
let old = self.active_producers.fetch_sub(1, Ordering::SeqCst);
logwise::warn_sync!("old {old}",old=old);
if old == 1 {
let mut lock = self.shared.lock.lock().await;
let old = lock.pending_producers.hangup();
assert!(old.is_empty());
if let Some(consumer) = lock.pending_consumer.replace_if_needed(None) {
drop(lock);
if let Some(consumer) = consumer {
consumer.send(Err(RecvError::ProducerHangup));
}
}
}
self.async_dropped = true;
}
}
impl<T> Clone for ChannelProducer<T> {
fn clone(&self) -> Self {
self.active_producers.fetch_add(1, Ordering::SeqCst);
ChannelProducer {
shared: self.shared.clone(),
active_producers: self.active_producers.clone(),
async_dropped: false,
}
}
}
impl <T> Drop for ChannelProducer<T> {
fn drop(&mut self) {
if !std::thread::panicking() {
assert!(self.async_dropped, "You must call async_drop on the producer before dropping it");
}
}
}
#[test]
fn test_push() {
logwise::context::Context::reset("test_push");
let (producer, mut consumer) = channel();
test_executors::spawn_on("test_push_thread",async move {
let mut producer = producer;
producer.send(1).await.unwrap();
producer.async_drop().await;
});
let r = test_executors::spin_on(consumer.receive()).unwrap();
assert_eq!(r, 1);
test_executors::spin_on(consumer.async_drop());
logwise::context::Context::reset("finished");
}