use core::{
cell::Cell,
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll, Waker},
};
#[cfg(feature = "alloc")]
use alloc::sync::Arc;
#[cfg(feature = "std")]
use std::sync::Arc;
use crossbeam_utils::{Backoff, CachePadded};
use spin::Mutex;
use thid::ThreadLocal;
use crate::{BufferChain, Packet, buffer::BufferPtr};
pub const WRITE_CURSOR_FLUSHED_FLAG: u32 = 0x8000_0000;
pub const WRITE_CURSOR_DONE: u32 = 0x4000_0000;
pub const WRITE_CURSOR_MASK: u32 = 0x3FFF_FFFF;
pub fn new_writer_flusher() -> (WriterFlushSender, WriterFlushReceiver) {
let shared = Arc::new(Mutex::new(WriterFlushShared {
head_tail: None,
waker: None,
}));
let writer_flush_sender = WriterFlushSender {
shared: shared.clone(),
local: Arc::new(ThreadLocal::new()),
};
let writer_flush_receiver = WriterFlushReceiver::new(shared);
(writer_flush_sender, writer_flush_receiver)
}
struct WriterFlushShared {
head_tail: Option<(BufferPtr, BufferPtr)>,
waker: Option<Waker>,
}
impl Drop for WriterFlushShared {
fn drop(&mut self) {
let mut release_head = self.head_tail.take().map(|(head, _tail)| head);
while let Some(buffer) = release_head {
release_head = unsafe { buffer.swap_next(None) };
unsafe {
*buffer.flush_cursor_mut() = 0;
buffer.receive(1);
buffer.release_ref(1);
}
}
}
}
#[derive(Default)]
struct SenderLocal {
unflushed_bytes: Cell<usize>,
chain: BufferChain,
}
#[derive(Clone)]
pub struct WriterFlushSender {
shared: Arc<Mutex<WriterFlushShared>>,
local: Arc<ThreadLocal<CachePadded<SenderLocal>>>,
}
impl WriterFlushSender {
pub fn id(&self) -> u64 {
Arc::as_ptr(&self.shared) as _
}
pub fn unflushed_bytes(&self) -> usize {
let local = self.local.get_or_default();
local.unflushed_bytes.get()
}
pub fn flush(&self) {
let local = self.local.get_or_default();
let Some((head, tail)) = local.chain.take_all() else {
return;
};
local.unflushed_bytes.set(0);
let mut shared = self.shared.lock();
if let Some((_, prev_shared_tail)) = &mut shared.head_tail {
unsafe {
prev_shared_tail.set_next(Some(head));
}
*prev_shared_tail = tail;
} else {
shared.head_tail = Some((head, tail));
if let Some(waker) = &shared.waker {
waker.wake_by_ref();
}
}
}
pub(crate) fn advance_write_cursor(
&self,
buffer: BufferPtr,
write_start: u32,
new_write_cursor: u32,
) {
let backoff = Backoff::new();
let mut write_cursor = buffer.write_cursor().load(Ordering::Acquire);
while write_cursor & WRITE_CURSOR_MASK != write_start {
backoff.snooze();
write_cursor = buffer.write_cursor().load(Ordering::Acquire);
}
loop {
match buffer.write_cursor().compare_exchange(
write_cursor,
new_write_cursor,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
break;
}
Err(modified_write_cursor) => {
write_cursor = modified_write_cursor;
assert_eq!(write_cursor & WRITE_CURSOR_MASK, write_start);
}
}
}
let local = self.local.get_or_default();
local.unflushed_bytes.set(
local.unflushed_bytes.get()
+ ((new_write_cursor & WRITE_CURSOR_MASK) - write_start) as usize,
);
if write_start == 0 || write_cursor & WRITE_CURSOR_FLUSHED_FLAG != 0 {
local.chain.push(buffer);
}
}
pub fn send_complete_buffer(&self, buffer: BufferPtr) {
let flush_cursor = unsafe { buffer.flush_cursor_mut() };
let len = core::mem::replace(flush_cursor, 0);
buffer
.write_cursor()
.store(len | WRITE_CURSOR_DONE, Ordering::Release);
let local = self.local.get_or_default();
local.chain.push(buffer);
}
pub fn mark_complete_buffer(buffer: BufferPtr, len: u32) {
let flush_cursor = unsafe { buffer.flush_cursor_mut() };
*flush_cursor = len;
}
pub fn get_complete_buffer_len(buffer: BufferPtr) -> u32 {
let flush_cursor = unsafe { buffer.flush_cursor_mut() };
*flush_cursor
}
}
struct WriterFlushQueueReceive<'a> {
shared: &'a Mutex<WriterFlushShared>,
}
impl core::future::Future for WriterFlushQueueReceive<'_> {
type Output = BufferPtr;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared = self.shared.lock();
if let Some((head, _tail)) = shared.head_tail.take() {
shared.waker = None;
Poll::Ready(head)
} else {
let new_waker = cx.waker();
if let Some(existing_waker) = &shared.waker {
if !existing_waker.will_wake(new_waker) {
shared.waker = Some(new_waker.clone());
}
} else {
shared.waker = Some(new_waker.clone());
}
Poll::Pending
}
}
}
impl Drop for WriterFlushQueueReceive<'_> {
fn drop(&mut self) {
let mut shared = self.shared.lock();
shared.waker = None;
}
}
impl Drop for WriterFlushSender {
fn drop(&mut self) {
self.flush();
}
}
pub struct WriterFlushReceiver {
shared: Arc<Mutex<WriterFlushShared>>,
}
pub struct Flush {
buffer: BufferPtr,
writer_id: usize,
offset: usize,
len: usize,
release_buffer: bool,
_not_send: core::marker::PhantomData<*const ()>,
}
impl WriterFlushReceiver {
fn new(shared: Arc<Mutex<WriterFlushShared>>) -> Self {
Self { shared }
}
pub async fn flush(&mut self) -> FlushIterator {
let recv_head = WriterFlushQueueReceive {
shared: &self.shared,
}
.await;
FlushIterator {
head: Some(recv_head),
}
}
}
pub struct FlushIterator {
head: Option<BufferPtr>,
}
impl core::iter::Iterator for FlushIterator {
type Item = Flush;
fn next(&mut self) -> Option<Self::Item> {
while let Some(buffer) = self.head {
self.head = unsafe { buffer.swap_next(None) };
let write_cursor = buffer
.write_cursor()
.fetch_or(WRITE_CURSOR_FLUSHED_FLAG, Ordering::AcqRel);
let writer_id = unsafe { buffer.writer_id() };
let flush_cursor = unsafe { buffer.flush_cursor_mut() };
let buffer_is_done = (write_cursor & WRITE_CURSOR_DONE) != 0;
let write_cursor = write_cursor & WRITE_CURSOR_MASK;
debug_assert!(write_cursor > 0);
if *flush_cursor < write_cursor {
let offset = *flush_cursor as usize;
let len = (write_cursor - *flush_cursor) as usize;
*flush_cursor = write_cursor;
return Some(Flush {
buffer,
writer_id,
offset,
len,
release_buffer: buffer_is_done,
_not_send: core::marker::PhantomData,
});
} else if buffer_is_done {
debug_assert_eq!(*flush_cursor, write_cursor);
*flush_cursor = 0;
unsafe {
buffer.receive(1);
}
unsafe {
buffer.release_ref(1);
}
}
}
None
}
}
impl Drop for FlushIterator {
fn drop(&mut self) {
while self.next().is_some() {}
}
}
impl Flush {
pub fn len(&self) -> usize {
self.len
}
pub fn writer_id(&self) -> usize {
self.writer_id
}
}
impl core::ops::Deref for Flush {
type Target = [u8];
fn deref(&self) -> &Self::Target {
unsafe { core::slice::from_raw_parts(self.buffer.data().add(self.offset), self.len) }
}
}
impl Drop for Flush {
fn drop(&mut self) {
if self.release_buffer {
unsafe {
*self.buffer.flush_cursor_mut() = 0;
self.buffer.receive(1);
self.buffer.release_ref(1);
}
}
}
}
impl From<Flush> for Packet {
fn from(flush: Flush) -> Self {
if flush.release_buffer {
unsafe {
*flush.buffer.flush_cursor_mut() = 0;
flush.buffer.receive(1);
}
} else {
unsafe {
flush.buffer.take_shared_ref(1);
}
}
let packet = unsafe { Self::new(flush.buffer, flush.offset as usize, flush.len as usize) };
core::mem::forget(flush);
packet
}
}