net_mux/
stream.rs

1//! Stream abstraction for multiplexing
2//!
3//! This module provides the [`Stream`] struct for representing individual data streams
4//! in network multiplexing. Each stream has a unique stream ID and implements async
5//! read/write interfaces, supporting concurrent processing of multiple streams.
6//!
7//! # Features
8//!
9//! - **Async I/O**: Implements [`AsyncRead`] and [`AsyncWrite`] traits
10//! - **State Management**: Uses bit flags to track stream read/write state
11//! - **Auto Cleanup**: Automatically notifies stream manager when stream closes
12//! - **Thread Safety**: Supports safe cross-thread transfer
13
14use std::{
15    cmp,
16    future::Future,
17    pin::Pin,
18    sync::OnceLock,
19    task::{Context, Poll},
20};
21
22use bitflags::bitflags;
23use parking_lot::RwLock;
24use tokio::{
25    io::{AsyncRead, AsyncWrite, ReadBuf},
26    sync::{broadcast, mpsc, oneshot},
27};
28use tokio_util::bytes::{Buf, Bytes};
29
30use crate::{
31    alloc::StreamId,
32    error::Error,
33    frame::Frame,
34    msg::{self, Message},
35};
36
37// Async Future type for writing frames
38type WriteFrameFuture = Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + Sync>>;
39
40/// Multiplexed stream
41///
42/// Represents a data stream in network multiplexing, supporting async read/write operations.
43/// Each stream has a unique stream ID and communicates with the stream manager through
44/// message channels.
45pub struct Stream {
46    stream_id: StreamId,
47    status: RwLock<StreamFlags>,
48    read_buf: Bytes,
49    cur_write_fut: Option<WriteFrameFuture>,
50
51    _shutdown_rx: broadcast::Receiver<()>,
52
53    msg_tx: mpsc::UnboundedSender<Message>,
54    frame_rx: mpsc::Receiver<Frame>,
55    remote_fin_rx: oneshot::Receiver<()>,
56    close_tx: mpsc::UnboundedSender<StreamId>,
57
58    close_once: OnceLock<()>,
59}
60
61bitflags! {
62    // Stream state flags
63    //
64    // Used to track stream read/write state, supporting half-close operations.
65    struct StreamFlags: u8 {
66        // Read permission flag
67        const R = 1 << 0;
68        // Write permission flag
69        const W = 1 << 1;
70
71        // Read/Write permission flags (R | W)
72        const V = Self::R.bits() | Self::W.bits();
73    }
74}
75
76impl Stream {
77    /// Close the stream
78    ///
79    /// Sends a FIN message to the remote peer and disables write operations.
80    /// The stream will be automatically cleaned up when both read and write
81    /// operations are disabled.
82    pub fn close(&self) {
83        self.close_once.get_or_init(|| {
84            self.deny_rw(StreamFlags::W);
85            let _ = msg::send_fin(self.msg_tx.clone(), self.stream_id);
86        });
87    }
88
89    // Create a new stream and listen remote fin signal.
90    pub(crate) fn new(
91        stream_id: StreamId,
92        shutdown_rx: broadcast::Receiver<()>,
93        msg_tx: mpsc::UnboundedSender<Message>,
94        frame_rx: mpsc::Receiver<Frame>,
95        close_tx: mpsc::UnboundedSender<StreamId>,
96        remote_fin_rx: oneshot::Receiver<()>,
97    ) -> Self {
98        Self {
99            stream_id,
100            status: RwLock::new(StreamFlags::V),
101            read_buf: Bytes::new(),
102            cur_write_fut: None,
103            _shutdown_rx: shutdown_rx,
104            msg_tx,
105            frame_rx,
106            close_tx,
107            remote_fin_rx,
108            close_once: OnceLock::new(),
109        }
110    }
111
112    // Deny read/write permissions for the stream
113    //
114    // Removes the specified flags from the stream's status. If all permissions
115    // are removed (both read and write), the stream will be automatically
116    // closed and cleaned up.
117    fn deny_rw(&self, flags: StreamFlags) {
118        let mut status_guard = self.status.write();
119        *status_guard -= flags & StreamFlags::V;
120
121        if !status_guard.contains(StreamFlags::V) {
122            let _ = self.close_tx.send(self.stream_id);
123        }
124    }
125}
126
127impl Drop for Stream {
128    fn drop(&mut self) {
129        if self.status.read().contains(StreamFlags::W) {
130            let _ = msg::send_fin(self.msg_tx.clone(), self.stream_id);
131        }
132        self.deny_rw(StreamFlags::V);
133    }
134}
135
136impl AsyncRead for Stream {
137    fn poll_read(
138        self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140        buf: &mut ReadBuf<'_>,
141    ) -> Poll<std::io::Result<()>> {
142        let this = self.get_mut();
143        loop {
144            if !this.status.read().contains(StreamFlags::R) {
145                return Poll::Ready(Err(std::io::Error::new(
146                    std::io::ErrorKind::BrokenPipe,
147                    "stream has been closed",
148                )));
149            }
150
151            if !this.read_buf.is_empty() {
152                let to_copy = cmp::min(this.read_buf.len(), buf.remaining());
153                buf.put_slice(&this.read_buf[..to_copy]);
154                this.read_buf.advance(to_copy);
155                return Poll::Ready(Ok(()));
156            }
157
158            match Pin::new(&mut this.frame_rx).poll_recv(cx) {
159                Poll::Ready(Some(frame)) => {
160                    this.read_buf = Bytes::from(frame.data);
161                    continue;
162                }
163                Poll::Pending => {
164                    match Pin::new(&mut this.remote_fin_rx).poll(cx) {
165                        Poll::Ready(_) => {
166                            this.deny_rw(StreamFlags::R);
167                            return Poll::Ready(Ok(()));
168                        }
169                        Poll::Pending => {}
170                    }
171                    return Poll::Pending;
172                }
173                Poll::Ready(None) => {
174                    return Poll::Ready(Ok(()));
175                }
176            }
177        }
178    }
179}
180
181impl AsyncWrite for Stream {
182    fn poll_write(
183        mut self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185        buf: &[u8],
186    ) -> Poll<Result<usize, std::io::Error>> {
187        if !self.status.read().contains(StreamFlags::W) {
188            return Poll::Ready(Err(std::io::Error::new(
189                std::io::ErrorKind::BrokenPipe,
190                "stream is closed for writing",
191            )));
192        }
193
194        if self.cur_write_fut.is_none() {
195            let msg_tx = self.msg_tx.clone();
196            let stream_id = self.stream_id;
197            let data = buf.to_vec();
198
199            self.cur_write_fut =
200                Some(
201                    Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
202                        as WriteFrameFuture,
203                );
204            return Poll::Ready(Ok(buf.len()));
205        }
206
207        match self.cur_write_fut.as_mut().unwrap().as_mut().poll(cx) {
208            Poll::Pending => Poll::Pending,
209            Poll::Ready(Ok(_)) => {
210                let msg_tx = self.msg_tx.clone();
211                let stream_id = self.stream_id;
212                let data = buf.to_vec();
213
214                self.cur_write_fut =
215                    Some(
216                        Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
217                            as WriteFrameFuture,
218                    );
219                Poll::Ready(Ok(buf.len()))
220            }
221            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
222        }
223    }
224
225    fn poll_flush(
226        mut self: Pin<&mut Self>,
227        cx: &mut Context<'_>,
228    ) -> Poll<Result<(), std::io::Error>> {
229        if let Some(fut) = self.cur_write_fut.as_mut() {
230            match fut.as_mut().poll(cx) {
231                Poll::Pending => Poll::Pending,
232                Poll::Ready(Ok(_)) => {
233                    self.cur_write_fut = None;
234                    Poll::Ready(Ok(()))
235                }
236                Poll::Ready(Err(e)) => {
237                    self.cur_write_fut = None;
238                    Poll::Ready(Err(std::io::Error::other(e.to_string())))
239                }
240            }
241        } else {
242            Poll::Ready(Ok(()))
243        }
244    }
245
246    fn poll_shutdown(
247        mut self: Pin<&mut Self>,
248        cx: &mut Context<'_>,
249    ) -> Poll<Result<(), std::io::Error>> {
250        let res = match &mut self.cur_write_fut {
251            Some(fut) => match fut.as_mut().poll(cx) {
252                Poll::Pending => Poll::Pending,
253                Poll::Ready(Ok(_)) => {
254                    self.cur_write_fut = None;
255                    Poll::Ready(Ok(()))
256                }
257                Poll::Ready(Err(e)) => {
258                    self.cur_write_fut = None;
259                    Poll::Ready(Err(std::io::Error::other(e.to_string())))
260                }
261            },
262            None => Poll::Ready(Ok(())),
263        };
264        if !res.is_pending() {
265            self.close();
266        }
267        res
268    }
269}