1#![deny(missing_docs)]
2
3use std::{
4 future::Future,
5 pin::{Pin, pin},
6 sync::{Arc, Weak},
7 task::{Context, Poll},
8};
9
10#[cfg(feature = "loom")]
11use loom::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12#[cfg(not(feature = "loom"))]
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14
15use pin_project::pin_project;
16use selium_kernel::drivers::channel::FrameReadable;
17use tokio::io::{AsyncRead, ReadBuf};
18use tracing::{Span, debug, instrument};
19
20use crate::{Backpressure, Channel, ChannelError};
21
22#[pin_project(project = ReaderProj)]
24pub enum Reader {
25 Strong(#[pin] StrongReader),
27 Weak(#[pin] WeakReader),
29}
30
31pub struct StrongReader {
33 chan: Weak<Channel>,
35 pub(crate) pos: Arc<AtomicU64>,
39 pos_id: usize,
41 fuse: AtomicBool,
43 span: Span,
45}
46
47pub struct WeakReader {
49 chan: Weak<Channel>,
51 pub(crate) pos: u64,
53 fuse: AtomicBool,
55 span: Span,
57}
58
59impl Reader {
60 pub fn terminate(&self) {
62 match self {
63 Self::Strong(strong) => strong.terminate(),
64 Self::Weak(weak) => weak.terminate(),
65 }
66 }
67
68 pub fn into_strong(self) -> std::result::Result<StrongReader, Self> {
70 match self {
71 Self::Strong(strong) => Ok(strong),
72 Self::Weak(_) => Err(self),
73 }
74 }
75
76 pub fn into_weak(self) -> std::result::Result<WeakReader, Self> {
78 match self {
79 Self::Strong(_) => Err(self),
80 Self::Weak(weak) => Ok(weak),
81 }
82 }
83}
84
85impl AsyncRead for Reader {
86 fn poll_read(
87 self: Pin<&mut Self>,
88 cx: &mut Context<'_>,
89 buf: &mut ReadBuf,
90 ) -> Poll<std::io::Result<()>> {
91 match self.project() {
92 ReaderProj::Strong(strong) => pin!(strong).poll_read(cx, buf),
93 ReaderProj::Weak(weak) => pin!(weak).poll_read(cx, buf),
94 }
95 }
96}
97
98impl From<StrongReader> for Reader {
99 fn from(value: StrongReader) -> Self {
100 Self::Strong(value)
101 }
102}
103
104impl From<WeakReader> for Reader {
105 fn from(value: WeakReader) -> Self {
106 Self::Weak(value)
107 }
108}
109
110impl StrongReader {
111 #[instrument(name = "StrongReader", parent = &chan.span, skip_all, fields(position_id=pos_id))]
112 pub(crate) fn new(chan: Arc<Channel>, pos: Arc<AtomicU64>, pos_id: usize) -> Self {
113 debug!("create reader");
114
115 Self {
116 chan: Arc::downgrade(&chan),
117 pos,
118 pos_id,
119 fuse: AtomicBool::new(false),
120 span: Span::current(),
121 }
122 }
123
124 #[instrument(parent = &self.span, skip(self))]
128 pub fn terminate(&self) {
129 if let Some(chan) = self.chan.upgrade() {
130 debug!("terminate reader");
131
132 self.fuse.store(true, Ordering::Release);
133 chan.remove_head(self.pos_id);
134 }
135 }
136
137 pub async fn read_frame(&mut self, max_len: usize) -> std::io::Result<(u16, Vec<u8>)> {
139 futures::future::poll_fn(|cx| self.poll_read_frame(cx, max_len)).await
140 }
141
142 #[instrument(parent = &self.span, skip_all)]
143 fn poll_read_frame(
144 &mut self,
145 cx: &mut Context<'_>,
146 max_len: usize,
147 ) -> Poll<std::io::Result<(u16, Vec<u8>)>> {
148 let Some(chan) = self.chan.upgrade() else {
149 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
150 };
151
152 if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
153 return Poll::Ready(Err(std::io::Error::from(
154 std::io::ErrorKind::ConnectionAborted,
155 )));
156 }
157
158 let mut pos = self.pos.load(Ordering::Acquire);
159
160 let draining = chan.draining.load(Ordering::Acquire);
161
162 let frame = if let Some(frame) = chan.frame_for(pos) {
163 frame
164 } else if matches!(chan.backpressure, Backpressure::Drop)
165 && let Some(frame) = chan.frame_from(pos)
166 {
167 if frame.start > pos {
168 self.pos.store(frame.start, Ordering::Release);
169 pos = frame.start;
170 }
171 frame
172 } else if draining {
173 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
174 } else {
175 chan.enqueue(pos, cx.waker().to_owned());
176 debug!("frame metadata pending");
177 return Poll::Pending;
178 };
179
180 if frame.len as usize > max_len {
181 return Poll::Ready(Err(std::io::Error::new(
182 std::io::ErrorKind::InvalidData,
183 "frame exceeds requested length",
184 )));
185 }
186
187 let end = frame.start + frame.len;
188 if chan.get_tail() < end {
189 chan.enqueue(end, cx.waker().to_owned());
190 debug!("frame pending");
191 return Poll::Pending;
192 }
193
194 let mut payload = vec![0u8; frame.len as usize];
195 if frame.len > 0 {
196 unsafe { chan.read_unsafe(pos, &mut payload) };
197 }
198
199 self.pos.store(end, Ordering::Release);
200 chan.prune_frames();
201 debug!(len = payload.len(), "consumed frame");
202 chan.schedule_writers();
203
204 Poll::Ready(Ok((frame.writer_id, payload)))
205 }
206
207 #[instrument(parent = &self.span, skip(self))]
209 pub fn downgrade(self) -> WeakReader {
210 debug!("downgrade this reader");
211
212 if let Some(chan) = self.chan.upgrade() {
213 chan.remove_head(self.pos_id);
214 }
215
216 WeakReader::new_with_state(
217 self.chan.clone(),
218 self.pos.load(Ordering::Acquire),
219 self.fuse.load(Ordering::Acquire),
220 )
221 }
222}
223
224impl FrameReadable for StrongReader {
225 fn read_frame(
226 &mut self,
227 max_len: usize,
228 ) -> Pin<Box<dyn Future<Output = std::io::Result<(u16, Vec<u8>)>> + Send + '_>> {
229 Box::pin(StrongReader::read_frame(self, max_len))
230 }
231}
232
233impl Drop for StrongReader {
234 fn drop(&mut self) {
235 if let Some(chan) = self.chan.upgrade() {
236 chan.remove_head(self.pos_id);
237 }
238 }
239}
240
241impl AsyncRead for StrongReader {
242 #[instrument(parent = &self.span, skip_all)]
243 fn poll_read(
244 self: Pin<&mut Self>,
245 cx: &mut Context<'_>,
246 buf: &mut ReadBuf,
247 ) -> Poll<std::io::Result<()>> {
248 let Some(chan) = self.chan.upgrade() else {
249 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
250 };
251
252 if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
253 return Poll::Ready(Err(std::io::Error::from(
254 std::io::ErrorKind::ConnectionAborted,
255 )));
256 }
257
258 let pos = self.pos.load(Ordering::Acquire);
259
260 if chan.draining.load(Ordering::Acquire) {
262 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
263 }
264
265 let filled = buf.filled().len();
266 let read = unsafe { chan.read_unsafe(pos, &mut buf.initialized_mut()[filled..]) };
267 buf.advance(read);
268
269 if read == 0 {
270 chan.enqueue(pos, cx.waker().to_owned());
271 debug!("pending");
272 Poll::Pending
273 } else {
274 self.pos.store(pos + read as u64, Ordering::Release);
275 debug!(size = read, "consumed bytes");
276 chan.schedule_writers();
277 Poll::Ready(Ok(()))
278 }
279 }
280}
281
282impl WeakReader {
283 #[instrument(name = "WeakReader", parent = &chan.span, skip_all)]
284 pub(crate) fn new(chan: Arc<Channel>, pos: u64) -> Self {
285 debug!("create reader");
286
287 Self {
288 chan: Arc::downgrade(&chan),
289 pos,
290 fuse: AtomicBool::new(false),
291 span: Span::current(),
292 }
293 }
294
295 #[instrument(name = "WeakReader", parent = &chan.upgrade().expect("channel missing").span, skip_all)]
296 fn new_with_state(chan: Weak<Channel>, pos: u64, fuse_state: bool) -> Self {
297 let reader = Self {
298 chan,
299 pos,
300 fuse: AtomicBool::new(fuse_state),
301 span: Span::current(),
302 };
303 if fuse_state {
304 reader.terminate();
305 }
306 reader
307 }
308
309 #[instrument(parent = &self.span, skip(self))]
313 pub fn terminate(&self) {
314 debug!("terminate");
315
316 self.fuse.store(true, Ordering::Release);
317 }
318
319 pub async fn read_frame(&mut self, max_len: usize) -> std::io::Result<(u16, Vec<u8>)> {
321 futures::future::poll_fn(|cx| self.poll_read_frame(cx, max_len)).await
322 }
323
324 #[instrument(parent = &self.span, skip_all)]
325 fn poll_read_frame(
326 &mut self,
327 cx: &mut Context<'_>,
328 max_len: usize,
329 ) -> Poll<std::io::Result<(u16, Vec<u8>)>> {
330 let Some(chan) = self.chan.upgrade() else {
331 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
332 };
333
334 if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
335 return Poll::Ready(Err(std::io::Error::from(
336 std::io::ErrorKind::ConnectionAborted,
337 )));
338 }
339
340 let draining = chan.draining.load(Ordering::Acquire);
341
342 if let Err(ChannelError::ReaderBehind(pos)) = chan.read(self.pos, &mut []) {
343 if let Some(frame) = chan.frame_from(pos) {
344 self.pos = frame.start;
345 } else {
346 self.pos = pos;
347 }
348 return Poll::Ready(Err(std::io::Error::other(ChannelError::ReaderBehind(pos))));
349 }
350
351 let Some(frame) = chan.frame_from(self.pos) else {
352 if draining {
353 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
354 }
355 chan.enqueue(self.pos, cx.waker().to_owned());
356 debug!("frame metadata pending");
357 return Poll::Pending;
358 };
359 if frame.start > self.pos {
360 self.pos = frame.start;
361 }
362
363 if frame.len as usize > max_len {
364 return Poll::Ready(Err(std::io::Error::new(
365 std::io::ErrorKind::InvalidData,
366 "frame exceeds requested length",
367 )));
368 }
369
370 let end = frame.start + frame.len;
371 if chan.get_tail() < end {
372 chan.enqueue(end, cx.waker().to_owned());
373 debug!("weak reader frame pending");
374 return Poll::Pending;
375 }
376
377 let mut payload = vec![0u8; frame.len as usize];
378 match chan.read(self.pos, &mut payload) {
379 Ok(read) => {
380 self.pos = end;
381 chan.prune_frames();
382 debug!(len = payload.len(), read, "weak reader consumed frame");
383 chan.schedule_writers();
384 Poll::Ready(Ok((frame.writer_id, payload)))
385 }
386 Err(ChannelError::ReaderBehind(pos)) => {
387 if let Some(frame) = chan.frame_from(pos) {
388 self.pos = frame.start;
389 } else {
390 self.pos = pos;
391 }
392 Poll::Ready(Err(std::io::Error::other(ChannelError::ReaderBehind(pos))))
393 }
394 Err(_) => unreachable!(),
395 }
396 }
397}
398
399impl FrameReadable for WeakReader {
400 fn read_frame(
401 &mut self,
402 max_len: usize,
403 ) -> Pin<Box<dyn Future<Output = std::io::Result<(u16, Vec<u8>)>> + Send + '_>> {
404 Box::pin(WeakReader::read_frame(self, max_len))
405 }
406}
407
408impl AsyncRead for WeakReader {
409 fn poll_read(
410 mut self: Pin<&mut Self>,
411 cx: &mut Context<'_>,
412 buf: &mut ReadBuf,
413 ) -> Poll<std::io::Result<()>> {
414 let Some(chan) = self.chan.upgrade() else {
415 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
416 };
417
418 if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
419 return Poll::Ready(Err(std::io::Error::from(
420 std::io::ErrorKind::ConnectionAborted,
421 )));
422 }
423
424 if chan.draining.load(Ordering::Acquire) {
425 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
426 }
427
428 let filled = buf.filled().len();
429 match chan.read(self.pos, &mut buf.initialized_mut()[filled..]) {
430 Ok(read) if read > 0 => {
431 self.pos += read as u64;
432 buf.advance(read);
433 Poll::Ready(Ok(()))
434 }
435 Ok(_) => {
436 chan.enqueue(self.pos, cx.waker().to_owned());
437 Poll::Pending
438 }
439 Err(ChannelError::ReaderBehind(pos)) => {
440 if let Some(frame) = chan.frame_from(pos) {
441 self.pos = frame.start;
442 } else {
443 self.pos = pos;
444 }
445 Poll::Ready(Err(std::io::Error::other(ChannelError::ReaderBehind(pos))))
446 }
447 Err(_) => unreachable!(),
448 }
449 }
450}