bab 0.0.4

build a bus - low-level components for your next message bus
Documentation
use core::{
    pin::Pin,
    task::{Context, Poll, Waker},
};

#[cfg(feature = "alloc")]
use alloc::sync::Arc;
#[cfg(feature = "std")]
use std::sync::Arc;

use crossbeam_utils::CachePadded;
use spin::Mutex;
use thid::ThreadLocal;

use crate::{buffer::BufferPtr, buffer_chain::BufferChain};

pub fn buffer_queue() -> (BufferQueueSender, BufferQueueReceiver) {
    let shared = Arc::new(Mutex::new(BufferQueueShared {
        head_tail: None,
        waker: None,
    }));
    let sender = BufferQueueSender {
        shared: shared.clone(),
        local_chain: Arc::new(ThreadLocal::new()),
    };
    let receiver = BufferQueueReceiver::new(shared);

    (sender, receiver)
}

struct BufferQueueShared {
    head_tail: Option<(BufferPtr, BufferPtr)>,
    waker: Option<Waker>,
}

#[derive(Clone)]
pub struct BufferQueueSender {
    shared: Arc<Mutex<BufferQueueShared>>,
    local_chain: Arc<ThreadLocal<CachePadded<BufferChain>>>,
}

impl BufferQueueSender {
    pub fn push(&self, buffer: BufferPtr) {
        let local_chain = self.local_chain.get_or_default();
        local_chain.push(buffer);
    }

    pub fn flush(&self) {
        let local_chain = self.local_chain.get_or_default();

        let Some((head, tail)) = local_chain.take_all() else {
            return;
        };

        let mut shared = self.shared.lock();
        if let Some((_, prev_shared_tail)) = &mut shared.head_tail {
            unsafe {
                prev_shared_tail.set_next(Some(head));
            }
            *prev_shared_tail = tail;
        } else {
            shared.head_tail = Some((head, tail));
            if let Some(waker) = &shared.waker {
                waker.wake_by_ref();
            }
        }
    }
}

impl Drop for BufferQueueSender {
    fn drop(&mut self) {
        self.flush();
    }
}

struct BufferQueueReceive<'a> {
    shared: &'a Mutex<BufferQueueShared>,
}

impl core::future::Future for BufferQueueReceive<'_> {
    type Output = BufferPtr;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut shared = self.shared.lock();
        if let Some((head, _tail)) = shared.head_tail.take() {
            shared.waker = None;
            Poll::Ready(head)
        } else {
            let new_waker = cx.waker();
            if let Some(existing_waker) = &shared.waker {
                if !existing_waker.will_wake(new_waker) {
                    shared.waker = Some(new_waker.clone());
                }
            } else {
                shared.waker = Some(new_waker.clone());
            }

            Poll::Pending
        }
    }
}

impl Drop for BufferQueueReceive<'_> {
    fn drop(&mut self) {
        let mut shared = self.shared.lock();
        shared.waker = None;
    }
}

#[derive(Clone)]
pub struct BufferQueueReceiver {
    shared: Arc<Mutex<BufferQueueShared>>,
}

impl BufferQueueReceiver {
    fn new(shared: Arc<Mutex<BufferQueueShared>>) -> Self {
        Self { shared }
    }

    pub async fn recv(&self) -> BufferQueueReceiveIterator {
        let recv_head = BufferQueueReceive {
            shared: &self.shared,
        }
        .await;
        BufferQueueReceiveIterator {
            head: Some(recv_head),
        }
    }
}

pub struct BufferQueueReceiveIterator {
    head: Option<BufferPtr>,
}

impl core::iter::Iterator for BufferQueueReceiveIterator {
    type Item = BufferPtr;

    fn next(&mut self) -> Option<Self::Item> {
        if let Some(buffer) = self.head {
            self.head = unsafe { buffer.swap_next(None) };
            Some(buffer)
        } else {
            None
        }
    }
}

impl Drop for BufferQueueReceiveIterator {
    fn drop(&mut self) {
        while self.next().is_some() {}
    }
}