Skip to main content

selium_messaging/
writer.rs

1#![deny(missing_docs)]
2//! In-memory asynchronous channel implementation with configurable backpressure.
3
4#[cfg(not(feature = "loom"))]
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6
7use std::{
8    pin::Pin,
9    sync::{Arc, Weak},
10    task::{Context, Poll},
11};
12
13#[cfg(feature = "loom")]
14use loom::sync::atomic::{AtomicBool, AtomicU64, Ordering};
15use pin_project::{pin_project, pinned_drop};
16use tokio::io::AsyncWrite;
17use tracing::{Span, debug, instrument};
18
19use crate::{Backpressure, Channel, id_factory::Id};
20
21/// Writable handle into a [`Channel`].
22#[pin_project(project = WriterProj)]
23pub enum Writer {
24    /// Strong writer variant that participates in backpressure accounting.
25    Strong(#[pin] StrongWriter),
26    /// Weak writer variant that does not hold a persistent tail slot.
27    Weak(#[pin] WeakWriter),
28}
29
30/// Writer that maintains a persistent tail slot and participates in backpressure accounting.
31pub struct StrongWriter {
32    /// Identifier for writer attribution
33    id: Id,
34    /// Ptr to `Channel`
35    chan: Weak<Channel>,
36    /// Start position being written
37    ///
38    /// The position is tracked by the `Channel` to trigger reads
39    pub(crate) pos: Arc<AtomicU64>,
40    /// The ID of the position in the `Channel` map
41    pos_id: Option<usize>,
42    /// Number of bytes left to write
43    pub(crate) rem: usize,
44    /// Termination fuse, which when 'lit' (=true), will safely terminate the writer
45    fuse: AtomicBool,
46    /// Tracing span for instrumentation
47    span: Span,
48}
49
50/// Writer that relinquishes its tail slot when idle.
51#[pin_project(PinnedDrop)]
52pub struct WeakWriter {
53    /// Identifier for writer attribution
54    id: Id,
55    /// Ptr to `Channel`
56    chan: Weak<Channel>,
57    /// Termination fuse, which when 'lit' (=true), will safely terminate the writer
58    fuse: AtomicBool,
59    /// Writes are channeled through a temporary `StrongWriter` which is dropped after
60    /// writing a full frame.
61    #[pin]
62    current: Option<StrongWriter>,
63    /// Tracing span for instrumentation
64    span: Span,
65}
66
67impl StrongWriter {
68    #[instrument(name = "StrongWriter", parent = &chan.span, skip_all, fields(id = id.get()))]
69    pub(crate) fn new(
70        id: Id,
71        chan: Arc<Channel>,
72        pos: Arc<AtomicU64>,
73        pos_id: Option<usize>,
74    ) -> Self {
75        Self {
76            id,
77            chan: Arc::downgrade(&chan),
78            pos,
79            pos_id,
80            rem: 0,
81            fuse: AtomicBool::new(false),
82            span: Span::current(),
83        }
84    }
85
86    fn release_tail(&mut self) {
87        if let Some(id) = self.pos_id.take()
88            && let Some(chan) = self.chan.upgrade()
89        {
90            chan.remove_tail(id);
91        }
92    }
93
94    fn is_idle(&self) -> bool {
95        self.rem == 0
96    }
97
98    /// Current number of bytes the writer can append without blocking.
99    pub fn writable_size(&self) -> u64 {
100        if let Some(chan) = self.chan.upgrade() {
101            chan.writable_size(self.pos.load(Ordering::Acquire))
102        } else {
103            0
104        }
105    }
106
107    /// Safely terminate this writer.
108    ///
109    /// This will cause `poll_write` to error with [std::io::ErrorKind::ConnectionAborted].
110    #[instrument(parent = &self.span, skip(self))]
111    pub fn terminate(&mut self) {
112        debug!("terminate");
113
114        self.fuse.store(true, Ordering::Release);
115        self.release_tail();
116    }
117
118    /// Convert this strong writer into a weak writer that relinquishes its tail slot when idle.
119    #[instrument(parent = &self.span, skip(self))]
120    pub fn downgrade(mut self) -> WeakWriter {
121        debug!("downgrade this writer");
122
123        self.release_tail();
124
125        WeakWriter::new_with_state(
126            self.id.clone(),
127            self.chan.clone(),
128            self.fuse.load(Ordering::Acquire),
129            None,
130        )
131    }
132}
133
134impl Drop for StrongWriter {
135    fn drop(&mut self) {
136        self.release_tail();
137    }
138}
139
140impl AsyncWrite for StrongWriter {
141    #[instrument(parent = &self.span, skip_all)]
142    fn poll_write(
143        mut self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        buf: &[u8],
146    ) -> Poll<std::io::Result<usize>> {
147        let Some(chan) = self.chan.upgrade() else {
148            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
149        };
150
151        // If the channel or writer is terminating, abort the write
152        if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
153            return Poll::Ready(Err(std::io::Error::from(
154                std::io::ErrorKind::ConnectionAborted,
155            )));
156        }
157
158        // If the channel is draining, only allow unfinished writes to proceed
159        if chan.draining.load(Ordering::Acquire) && self.rem == 0 {
160            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
161        }
162
163        let len = buf.len();
164
165        // If no bytes remain in the current slice, reserve a new slice
166        let pos = if self.rem == 0 {
167            // In Drop mode, if no space is available, drop immediately without reserving
168            if matches!(chan.backpressure, Backpressure::Drop) {
169                let avail = chan.writable_size(self.pos.load(Ordering::Acquire)) as usize;
170                if avail < len {
171                    return Poll::Ready(Ok(0));
172                }
173                let reserve = len as u64;
174                let pos = chan.reserve_slice(reserve);
175                self.pos.store(pos, Ordering::Release);
176                self.rem = len;
177                chan.register_frame(pos, reserve, self.id.get());
178                pos
179            } else {
180                let pos = chan.reserve_slice(len as u64);
181                self.pos.store(pos, Ordering::Release);
182                self.rem = len;
183                chan.register_frame(pos, len as u64, self.id.get());
184                pos
185            }
186        } else {
187            self.pos.load(Ordering::Acquire)
188        };
189
190        let written = chan.write(pos, &buf[..self.rem]);
191
192        if written == 0 {
193            debug!("writer poll_write made no progress");
194            if matches!(chan.backpressure, Backpressure::Drop) {
195                // Do not enqueue; signal that nothing was written.
196                return Poll::Ready(Ok(0));
197            }
198            chan.enqueue(pos, cx.waker().to_owned());
199            Poll::Pending
200        } else {
201            self.pos.store(pos + written as u64, Ordering::Release);
202            self.rem -= written;
203            debug!(
204                pos = self.pos.load(Ordering::Acquire),
205                rem = self.rem,
206                written,
207                "writer poll_write committed bytes"
208            );
209            chan.schedule_readers();
210            Poll::Ready(Ok(written))
211        }
212    }
213
214    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
215        Poll::Ready(Ok(()))
216    }
217
218    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
219        Poll::Ready(Ok(()))
220    }
221}
222
223impl WeakWriter {
224    #[instrument(name = "WeakWriter", parent = &chan.span, skip_all, fields(id = id.get()))]
225    pub(crate) fn new(id: Id, chan: Arc<Channel>) -> Self {
226        Self {
227            id,
228            chan: Arc::downgrade(&chan),
229            fuse: AtomicBool::new(false),
230            current: None,
231            span: Span::current(),
232        }
233    }
234
235    #[instrument(name = "WeakWriter", parent = &chan.upgrade().expect("channel missing").span, skip_all, fields(id = id.get()))]
236    fn new_with_state(
237        id: Id,
238        chan: Weak<Channel>,
239        fuse_state: bool,
240        current: Option<StrongWriter>,
241    ) -> Self {
242        let mut writer = Self {
243            id,
244            chan,
245            fuse: AtomicBool::new(fuse_state),
246            current,
247            span: Span::current(),
248        };
249        if fuse_state {
250            writer.terminate();
251        }
252        writer
253    }
254
255    fn ensure_strong(self: Pin<&mut Self>) -> std::io::Result<Pin<&mut StrongWriter>> {
256        let Some(chan) = self.chan.upgrade() else {
257            return Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe));
258        };
259
260        let mut this = self.project();
261        if this.fuse.load(Ordering::Acquire) {
262            return Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted));
263        }
264        if chan.draining.load(Ordering::Acquire) {
265            return Err(std::io::Error::from(std::io::ErrorKind::Interrupted));
266        }
267        if this.current.is_none() {
268            let strong = chan.new_strong_writer_with_id(this.id.clone());
269            this.current.set(Some(strong));
270        }
271        Ok(this.current.as_pin_mut().expect("strong writer present"))
272    }
273
274    fn release_if_idle(self: Pin<&mut Self>) {
275        let mut this = self.project();
276        if let Some(mut strong) = this.current.as_mut().as_pin_mut()
277            && strong.is_idle()
278        {
279            strong.as_mut().get_mut().release_tail();
280            if let Some(chan) = this.chan.upgrade() {
281                chan.schedule_readers();
282            }
283            this.current.set(None);
284        }
285    }
286
287    #[instrument(parent = &self.span, skip(self))]
288    fn terminate(&mut self) {
289        debug!("terminate");
290
291        self.fuse.store(true, Ordering::Release);
292        if let Some(mut strong) = self.current.take() {
293            strong.terminate();
294        }
295    }
296}
297
298impl AsyncWrite for WeakWriter {
299    fn poll_write(
300        self: Pin<&mut Self>,
301        cx: &mut Context<'_>,
302        buf: &[u8],
303    ) -> Poll<std::io::Result<usize>> {
304        let mut this = self;
305        match this.as_mut().ensure_strong() {
306            Ok(mut strong) => {
307                let result = strong.as_mut().poll_write(cx, buf);
308                if matches!(result, Poll::Ready(Ok(_))) {
309                    this.release_if_idle();
310                }
311                result
312            }
313            Err(err) => Poll::Ready(Err(err)),
314        }
315    }
316
317    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
318        let mut this = self;
319        match this.as_mut().ensure_strong() {
320            Ok(mut strong) => {
321                let result = strong.as_mut().poll_flush(cx);
322                this.release_if_idle();
323                result
324            }
325            Err(err) => Poll::Ready(Err(err)),
326        }
327    }
328
329    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
330        let mut this = self;
331        match this.as_mut().ensure_strong() {
332            Ok(mut strong) => {
333                let result = strong.as_mut().poll_shutdown(cx);
334                this.release_if_idle();
335                result
336            }
337            Err(err) => Poll::Ready(Err(err)),
338        }
339    }
340}
341
342#[pinned_drop]
343impl PinnedDrop for WeakWriter {
344    fn drop(self: Pin<&mut Self>) {
345        let mut this = self.project();
346        if let Some(mut strong) = this.current.take() {
347            strong.release_tail();
348            if let Some(chan) = strong.chan.upgrade() {
349                chan.schedule_readers();
350            }
351        }
352    }
353}
354
355impl Writer {
356    /// Signal termination to the underlying writer variant.
357    pub fn terminate(&mut self) {
358        match self {
359            Writer::Strong(writer) => writer.terminate(),
360            Writer::Weak(writer) => writer.terminate(),
361        }
362    }
363
364    /// Downgrade the writer to its weak form.
365    pub fn downgrade(self) -> Writer {
366        match self {
367            Writer::Strong(writer) => Writer::Weak(writer.downgrade()),
368            Writer::Weak(writer) => Writer::Weak(writer),
369        }
370    }
371
372    /// Extract the strong writer variant, returning `self` on mismatch.
373    #[allow(clippy::result_large_err)]
374    pub fn into_strong(self) -> std::result::Result<StrongWriter, Self> {
375        match self {
376            Writer::Strong(strong) => Ok(strong),
377            Writer::Weak(_) => Err(self),
378        }
379    }
380
381    /// Extract the weak writer variant, returning a downgraded strong writer when required.
382    #[allow(clippy::result_large_err)]
383    pub fn into_weak(self) -> std::result::Result<WeakWriter, Self> {
384        match self {
385            Writer::Weak(weak) => Ok(weak),
386            Writer::Strong(strong) => Ok(strong.downgrade()),
387        }
388    }
389}
390
391impl AsyncWrite for Writer {
392    fn poll_write(
393        self: Pin<&mut Self>,
394        cx: &mut Context<'_>,
395        buf: &[u8],
396    ) -> Poll<std::io::Result<usize>> {
397        match self.project() {
398            WriterProj::Strong(strong) => strong.poll_write(cx, buf),
399            WriterProj::Weak(weak) => weak.poll_write(cx, buf),
400        }
401    }
402
403    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
404        match self.project() {
405            WriterProj::Strong(strong) => strong.poll_flush(cx),
406            WriterProj::Weak(weak) => weak.poll_flush(cx),
407        }
408    }
409
410    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
411        match self.project() {
412            WriterProj::Strong(strong) => strong.poll_shutdown(cx),
413            WriterProj::Weak(weak) => weak.poll_shutdown(cx),
414        }
415    }
416}