Skip to main content

selium_messaging/
reader.rs

1#![deny(missing_docs)]
2
3use std::{
4    future::Future,
5    pin::{Pin, pin},
6    sync::{Arc, Weak},
7    task::{Context, Poll},
8};
9
10#[cfg(feature = "loom")]
11use loom::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12#[cfg(not(feature = "loom"))]
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14
15use pin_project::pin_project;
16use selium_kernel::drivers::channel::FrameReadable;
17use tokio::io::{AsyncRead, ReadBuf};
18use tracing::{Span, debug, instrument};
19
20use crate::{Backpressure, Channel, ChannelError};
21
22/// Convenience type for implementors to treat both reader types as one.
23#[pin_project(project = ReaderProj)]
24pub enum Reader {
25    /// Strong reader variant that maintains backpressure guarantees.
26    Strong(#[pin] StrongReader),
27    /// Weak reader variant that may drop data when lagging.
28    Weak(#[pin] WeakReader),
29}
30
31/// Reader that prevents overwriting unread bytes.
32pub struct StrongReader {
33    /// Ptr to `Channel`
34    chan: Weak<Channel>,
35    /// Start position being read
36    ///
37    /// The position is tracked by the `Channel` to prevent overwrites
38    pub(crate) pos: Arc<AtomicU64>,
39    /// The ID of the position in the `Channel` map
40    pos_id: usize,
41    /// Termination fuse, which when 'lit' (=true), will safely terminate the reader
42    fuse: AtomicBool,
43    /// Tracing span for instrumentation
44    span: Span,
45}
46
47/// Reader that sacrifices retention to avoid blocking writers.
48pub struct WeakReader {
49    /// Ptr to `Channel`
50    chan: Weak<Channel>,
51    /// Start position being read
52    pub(crate) pos: u64,
53    /// Termination fuse, which when 'lit' (=true), will safely terminate the reader
54    fuse: AtomicBool,
55    /// Tracing span for instrumentation
56    span: Span,
57}
58
59impl Reader {
60    /// Signal termination to the underlying reader variant.
61    pub fn terminate(&self) {
62        match self {
63            Self::Strong(strong) => strong.terminate(),
64            Self::Weak(weak) => weak.terminate(),
65        }
66    }
67
68    /// Extract the strong reader variant, returning `self` on mismatch.
69    pub fn into_strong(self) -> std::result::Result<StrongReader, Self> {
70        match self {
71            Self::Strong(strong) => Ok(strong),
72            Self::Weak(_) => Err(self),
73        }
74    }
75
76    /// Extract the weak reader variant, returning `self` on mismatch.
77    pub fn into_weak(self) -> std::result::Result<WeakReader, Self> {
78        match self {
79            Self::Strong(_) => Err(self),
80            Self::Weak(weak) => Ok(weak),
81        }
82    }
83}
84
85impl AsyncRead for Reader {
86    fn poll_read(
87        self: Pin<&mut Self>,
88        cx: &mut Context<'_>,
89        buf: &mut ReadBuf,
90    ) -> Poll<std::io::Result<()>> {
91        match self.project() {
92            ReaderProj::Strong(strong) => pin!(strong).poll_read(cx, buf),
93            ReaderProj::Weak(weak) => pin!(weak).poll_read(cx, buf),
94        }
95    }
96}
97
98impl From<StrongReader> for Reader {
99    fn from(value: StrongReader) -> Self {
100        Self::Strong(value)
101    }
102}
103
104impl From<WeakReader> for Reader {
105    fn from(value: WeakReader) -> Self {
106        Self::Weak(value)
107    }
108}
109
110impl StrongReader {
111    #[instrument(name = "StrongReader", parent = &chan.span, skip_all, fields(position_id=pos_id))]
112    pub(crate) fn new(chan: Arc<Channel>, pos: Arc<AtomicU64>, pos_id: usize) -> Self {
113        debug!("create reader");
114
115        Self {
116            chan: Arc::downgrade(&chan),
117            pos,
118            pos_id,
119            fuse: AtomicBool::new(false),
120            span: Span::current(),
121        }
122    }
123
124    /// Safely terminate this reader.
125    ///
126    /// This will cause `poll_read` to error with [std::io::ErrorKind::ConnectionAborted].
127    #[instrument(parent = &self.span, skip(self))]
128    pub fn terminate(&self) {
129        if let Some(chan) = self.chan.upgrade() {
130            debug!("terminate reader");
131
132            self.fuse.store(true, Ordering::Release);
133            chan.remove_head(self.pos_id);
134        }
135    }
136
137    /// Read a complete frame, returning the writer identifier and payload bytes.
138    pub async fn read_frame(&mut self, max_len: usize) -> std::io::Result<(u16, Vec<u8>)> {
139        futures::future::poll_fn(|cx| self.poll_read_frame(cx, max_len)).await
140    }
141
142    #[instrument(parent = &self.span, skip_all)]
143    fn poll_read_frame(
144        &mut self,
145        cx: &mut Context<'_>,
146        max_len: usize,
147    ) -> Poll<std::io::Result<(u16, Vec<u8>)>> {
148        let Some(chan) = self.chan.upgrade() else {
149            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
150        };
151
152        if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
153            return Poll::Ready(Err(std::io::Error::from(
154                std::io::ErrorKind::ConnectionAborted,
155            )));
156        }
157
158        let mut pos = self.pos.load(Ordering::Acquire);
159
160        let draining = chan.draining.load(Ordering::Acquire);
161
162        let frame = if let Some(frame) = chan.frame_for(pos) {
163            frame
164        } else if matches!(chan.backpressure, Backpressure::Drop)
165            && let Some(frame) = chan.frame_from(pos)
166        {
167            if frame.start > pos {
168                self.pos.store(frame.start, Ordering::Release);
169                pos = frame.start;
170            }
171            frame
172        } else if draining {
173            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
174        } else {
175            chan.enqueue(pos, cx.waker().to_owned());
176            debug!("frame metadata pending");
177            return Poll::Pending;
178        };
179
180        if frame.len as usize > max_len {
181            return Poll::Ready(Err(std::io::Error::new(
182                std::io::ErrorKind::InvalidData,
183                "frame exceeds requested length",
184            )));
185        }
186
187        let end = frame.start + frame.len;
188        if chan.get_tail() < end {
189            chan.enqueue(end, cx.waker().to_owned());
190            debug!("frame pending");
191            return Poll::Pending;
192        }
193
194        let mut payload = vec![0u8; frame.len as usize];
195        if frame.len > 0 {
196            unsafe { chan.read_unsafe(pos, &mut payload) };
197        }
198
199        self.pos.store(end, Ordering::Release);
200        chan.prune_frames();
201        debug!(len = payload.len(), "consumed frame");
202        chan.schedule_writers();
203
204        Poll::Ready(Ok((frame.writer_id, payload)))
205    }
206
207    /// Convert this strong reader into a weak reader that relinquishes its head slot when idle.
208    #[instrument(parent = &self.span, skip(self))]
209    pub fn downgrade(self) -> WeakReader {
210        debug!("downgrade this reader");
211
212        if let Some(chan) = self.chan.upgrade() {
213            chan.remove_head(self.pos_id);
214        }
215
216        WeakReader::new_with_state(
217            self.chan.clone(),
218            self.pos.load(Ordering::Acquire),
219            self.fuse.load(Ordering::Acquire),
220        )
221    }
222}
223
224impl FrameReadable for StrongReader {
225    fn read_frame(
226        &mut self,
227        max_len: usize,
228    ) -> Pin<Box<dyn Future<Output = std::io::Result<(u16, Vec<u8>)>> + Send + '_>> {
229        Box::pin(StrongReader::read_frame(self, max_len))
230    }
231}
232
233impl Drop for StrongReader {
234    fn drop(&mut self) {
235        if let Some(chan) = self.chan.upgrade() {
236            chan.remove_head(self.pos_id);
237        }
238    }
239}
240
241impl AsyncRead for StrongReader {
242    #[instrument(parent = &self.span, skip_all)]
243    fn poll_read(
244        self: Pin<&mut Self>,
245        cx: &mut Context<'_>,
246        buf: &mut ReadBuf,
247    ) -> Poll<std::io::Result<()>> {
248        let Some(chan) = self.chan.upgrade() else {
249            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
250        };
251
252        if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
253            return Poll::Ready(Err(std::io::Error::from(
254                std::io::ErrorKind::ConnectionAborted,
255            )));
256        }
257
258        let pos = self.pos.load(Ordering::Acquire);
259
260        // If the channel is draining, only allow unfinished reads to proceed
261        if chan.draining.load(Ordering::Acquire) {
262            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
263        }
264
265        let filled = buf.filled().len();
266        let read = unsafe { chan.read_unsafe(pos, &mut buf.initialized_mut()[filled..]) };
267        buf.advance(read);
268
269        if read == 0 {
270            chan.enqueue(pos, cx.waker().to_owned());
271            debug!("pending");
272            Poll::Pending
273        } else {
274            self.pos.store(pos + read as u64, Ordering::Release);
275            debug!(size = read, "consumed bytes");
276            chan.schedule_writers();
277            Poll::Ready(Ok(()))
278        }
279    }
280}
281
282impl WeakReader {
283    #[instrument(name = "WeakReader", parent = &chan.span, skip_all)]
284    pub(crate) fn new(chan: Arc<Channel>, pos: u64) -> Self {
285        debug!("create reader");
286
287        Self {
288            chan: Arc::downgrade(&chan),
289            pos,
290            fuse: AtomicBool::new(false),
291            span: Span::current(),
292        }
293    }
294
295    #[instrument(name = "WeakReader", parent = &chan.upgrade().expect("channel missing").span, skip_all)]
296    fn new_with_state(chan: Weak<Channel>, pos: u64, fuse_state: bool) -> Self {
297        let reader = Self {
298            chan,
299            pos,
300            fuse: AtomicBool::new(fuse_state),
301            span: Span::current(),
302        };
303        if fuse_state {
304            reader.terminate();
305        }
306        reader
307    }
308
309    /// Safely terminate this reader.
310    ///
311    /// This will cause `poll_read` to error with [std::io::ErrorKind::ConnectionAborted].
312    #[instrument(parent = &self.span, skip(self))]
313    pub fn terminate(&self) {
314        debug!("terminate");
315
316        self.fuse.store(true, Ordering::Release);
317    }
318
319    /// Read a complete frame, returning the writer identifier and payload bytes.
320    pub async fn read_frame(&mut self, max_len: usize) -> std::io::Result<(u16, Vec<u8>)> {
321        futures::future::poll_fn(|cx| self.poll_read_frame(cx, max_len)).await
322    }
323
324    #[instrument(parent = &self.span, skip_all)]
325    fn poll_read_frame(
326        &mut self,
327        cx: &mut Context<'_>,
328        max_len: usize,
329    ) -> Poll<std::io::Result<(u16, Vec<u8>)>> {
330        let Some(chan) = self.chan.upgrade() else {
331            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
332        };
333
334        if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
335            return Poll::Ready(Err(std::io::Error::from(
336                std::io::ErrorKind::ConnectionAborted,
337            )));
338        }
339
340        let draining = chan.draining.load(Ordering::Acquire);
341
342        if let Err(ChannelError::ReaderBehind(pos)) = chan.read(self.pos, &mut []) {
343            if let Some(frame) = chan.frame_from(pos) {
344                self.pos = frame.start;
345            } else {
346                self.pos = pos;
347            }
348            return Poll::Ready(Err(std::io::Error::other(ChannelError::ReaderBehind(pos))));
349        }
350
351        let Some(frame) = chan.frame_from(self.pos) else {
352            if draining {
353                return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
354            }
355            chan.enqueue(self.pos, cx.waker().to_owned());
356            debug!("frame metadata pending");
357            return Poll::Pending;
358        };
359        if frame.start > self.pos {
360            self.pos = frame.start;
361        }
362
363        if frame.len as usize > max_len {
364            return Poll::Ready(Err(std::io::Error::new(
365                std::io::ErrorKind::InvalidData,
366                "frame exceeds requested length",
367            )));
368        }
369
370        let end = frame.start + frame.len;
371        if chan.get_tail() < end {
372            chan.enqueue(end, cx.waker().to_owned());
373            debug!("weak reader frame pending");
374            return Poll::Pending;
375        }
376
377        let mut payload = vec![0u8; frame.len as usize];
378        match chan.read(self.pos, &mut payload) {
379            Ok(read) => {
380                self.pos = end;
381                chan.prune_frames();
382                debug!(len = payload.len(), read, "weak reader consumed frame");
383                chan.schedule_writers();
384                Poll::Ready(Ok((frame.writer_id, payload)))
385            }
386            Err(ChannelError::ReaderBehind(pos)) => {
387                if let Some(frame) = chan.frame_from(pos) {
388                    self.pos = frame.start;
389                } else {
390                    self.pos = pos;
391                }
392                Poll::Ready(Err(std::io::Error::other(ChannelError::ReaderBehind(pos))))
393            }
394            Err(_) => unreachable!(),
395        }
396    }
397}
398
399impl FrameReadable for WeakReader {
400    fn read_frame(
401        &mut self,
402        max_len: usize,
403    ) -> Pin<Box<dyn Future<Output = std::io::Result<(u16, Vec<u8>)>> + Send + '_>> {
404        Box::pin(WeakReader::read_frame(self, max_len))
405    }
406}
407
408impl AsyncRead for WeakReader {
409    fn poll_read(
410        mut self: Pin<&mut Self>,
411        cx: &mut Context<'_>,
412        buf: &mut ReadBuf,
413    ) -> Poll<std::io::Result<()>> {
414        let Some(chan) = self.chan.upgrade() else {
415            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
416        };
417
418        if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
419            return Poll::Ready(Err(std::io::Error::from(
420                std::io::ErrorKind::ConnectionAborted,
421            )));
422        }
423
424        if chan.draining.load(Ordering::Acquire) {
425            return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
426        }
427
428        let filled = buf.filled().len();
429        match chan.read(self.pos, &mut buf.initialized_mut()[filled..]) {
430            Ok(read) if read > 0 => {
431                self.pos += read as u64;
432                buf.advance(read);
433                Poll::Ready(Ok(()))
434            }
435            Ok(_) => {
436                chan.enqueue(self.pos, cx.waker().to_owned());
437                Poll::Pending
438            }
439            Err(ChannelError::ReaderBehind(pos)) => {
440                if let Some(frame) = chan.frame_from(pos) {
441                    self.pos = frame.start;
442                } else {
443                    self.pos = pos;
444                }
445                Poll::Ready(Err(std::io::Error::other(ChannelError::ReaderBehind(pos))))
446            }
447            Err(_) => unreachable!(),
448        }
449    }
450}