net_mux/
stream.rs

1//! Stream abstraction for multiplexing
2//!
3//! This module provides the [`Stream`] struct for representing individual data streams
4//! in network multiplexing. Each stream has a unique stream ID and implements async
5//! read/write interfaces, supporting concurrent processing of multiple streams.
6//!
7//! # Features
8//!
9//! - **Async I/O**: Implements [`AsyncRead`] and [`AsyncWrite`] traits
10//! - **State Management**: Uses bit flags to track stream read/write state
11//! - **Auto Cleanup**: Automatically notifies stream manager when stream closes
12//! - **Thread Safety**: Supports safe cross-thread transfer
13
14use std::{
15    cmp,
16    future::Future,
17    pin::Pin,
18    task::{Context, Poll},
19};
20
21use bitflags::bitflags;
22use parking_lot::RwLock;
23use tokio::{
24    io::{AsyncRead, AsyncWrite, ReadBuf},
25    sync::{broadcast, mpsc, oneshot},
26};
27use tokio_util::bytes::{Buf, Bytes};
28
29use crate::{
30    alloc::StreamId,
31    error::Error,
32    frame::Frame,
33    msg::{self, Message},
34};
35
36// Async Future type for writing frames
37type WriteFrameFuture = Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + Sync>>;
38
39/// Multiplexed stream
40///
41/// Represents a data stream in network multiplexing, supporting async read/write operations.
42/// Each stream has a unique stream ID and communicates with the stream manager through
43/// message channels.
44pub struct Stream {
45    stream_id: StreamId,
46    status: RwLock<StreamFlags>,
47    read_buf: Bytes,
48    cur_write_fut: Option<WriteFrameFuture>,
49
50    _shutdown_rx: broadcast::Receiver<()>,
51
52    msg_tx: mpsc::Sender<Message>,
53    frame_rx: mpsc::Receiver<Frame>,
54    remote_fin_rx: oneshot::Receiver<()>,
55    close_tx: mpsc::UnboundedSender<StreamId>,
56}
57
58bitflags! {
59    // Stream state flags
60    //
61    // Used to track stream read/write state, supporting half-close operations.
62    struct StreamFlags: u8 {
63        // Read permission flag
64        const R = 1 << 0;
65        // Write permission flag
66        const W = 1 << 1;
67
68        // Read/Write permission flags (R | W)
69        const V = Self::R.bits() | Self::W.bits();
70    }
71}
72
73impl Stream {
74    /// Close the stream
75    ///
76    /// Sends a FIN message to the remote peer and disables write operations.
77    /// The stream will be automatically cleaned up when both read and write
78    /// operations are disabled.
79    pub async fn close(&self) {
80        let _ = msg::send_fin(self.msg_tx.clone(), self.stream_id).await;
81        self.deny_rw(StreamFlags::W);
82    }
83
84    // Create a new stream and listen remote fin signal.
85    pub(crate) fn new(
86        stream_id: StreamId,
87        shutdown_rx: broadcast::Receiver<()>,
88        msg_tx: mpsc::Sender<Message>,
89        frame_rx: mpsc::Receiver<Frame>,
90        close_tx: mpsc::UnboundedSender<StreamId>,
91        remote_fin_rx: oneshot::Receiver<()>,
92    ) -> Self {
93        Self {
94            stream_id,
95            status: RwLock::new(StreamFlags::V),
96            read_buf: Bytes::new(),
97            cur_write_fut: None,
98            _shutdown_rx: shutdown_rx,
99            msg_tx,
100            frame_rx,
101            close_tx,
102            remote_fin_rx,
103        }
104    }
105
106    // Deny read/write permissions for the stream
107    //
108    // Removes the specified flags from the stream's status. If all permissions
109    // are removed (both read and write), the stream will be automatically
110    // closed and cleaned up.
111    fn deny_rw(&self, flags: StreamFlags) {
112        let mut status_guard = self.status.write();
113        *status_guard -= flags & StreamFlags::V;
114
115        if !status_guard.contains(StreamFlags::V) {
116            let _ = self.close_tx.send(self.stream_id);
117        }
118    }
119}
120
121impl Drop for Stream {
122    fn drop(&mut self) {
123        self.deny_rw(StreamFlags::V);
124        let msg_tx = self.msg_tx.clone();
125        let stream_id = self.stream_id;
126        tokio::spawn(async move {
127            let _ = msg::send_fin(msg_tx, stream_id).await;
128        });
129    }
130}
131
132impl AsyncRead for Stream {
133    fn poll_read(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136        buf: &mut ReadBuf<'_>,
137    ) -> Poll<std::io::Result<()>> {
138        let this = self.get_mut();
139        loop {
140            if !this.status.read().contains(StreamFlags::R) {
141                return Poll::Ready(Err(std::io::Error::new(
142                    std::io::ErrorKind::BrokenPipe,
143                    "stream has been closed",
144                )));
145            }
146
147            if !this.read_buf.is_empty() {
148                let to_copy = cmp::min(this.read_buf.len(), buf.remaining());
149                buf.put_slice(&this.read_buf[..to_copy]);
150                this.read_buf.advance(to_copy);
151                return Poll::Ready(Ok(()));
152            }
153
154            match Pin::new(&mut this.remote_fin_rx).poll(cx) {
155                Poll::Ready(_) => {
156                    this.deny_rw(StreamFlags::R);
157                    return Poll::Ready(Ok(()));
158                }
159                Poll::Pending => {}
160            }
161
162            match Pin::new(&mut this.frame_rx).poll_recv(cx) {
163                Poll::Ready(Some(frame)) => {
164                    this.read_buf = Bytes::from(frame.data);
165                    continue;
166                }
167                Poll::Pending => {
168                    return Poll::Pending;
169                }
170                Poll::Ready(None) => {
171                    unreachable!()
172                }
173            }
174        }
175    }
176}
177
178impl AsyncWrite for Stream {
179    fn poll_write(
180        mut self: Pin<&mut Self>,
181        cx: &mut Context<'_>,
182        buf: &[u8],
183    ) -> Poll<Result<usize, std::io::Error>> {
184        if !self.status.read().contains(StreamFlags::W) {
185            return Poll::Ready(Err(std::io::Error::new(
186                std::io::ErrorKind::BrokenPipe,
187                "stream is closed for writing",
188            )));
189        }
190
191        if self.cur_write_fut.is_none() {
192            let msg_tx = self.msg_tx.clone();
193            let stream_id = self.stream_id;
194            let data = buf.to_vec();
195
196            self.cur_write_fut =
197                Some(
198                    Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
199                        as WriteFrameFuture,
200                );
201            return Poll::Ready(Ok(buf.len()));
202        }
203
204        match self.cur_write_fut.as_mut().unwrap().as_mut().poll(cx) {
205            Poll::Pending => Poll::Pending,
206            Poll::Ready(Ok(_)) => {
207                let msg_tx = self.msg_tx.clone();
208                let stream_id = self.stream_id;
209                let data = buf.to_vec();
210
211                self.cur_write_fut =
212                    Some(
213                        Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
214                            as WriteFrameFuture,
215                    );
216                Poll::Ready(Ok(buf.len()))
217            }
218            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
219        }
220    }
221
222    fn poll_flush(
223        mut self: Pin<&mut Self>,
224        cx: &mut Context<'_>,
225    ) -> Poll<Result<(), std::io::Error>> {
226        if let Some(fut) = self.cur_write_fut.as_mut() {
227            match fut.as_mut().poll(cx) {
228                Poll::Pending => Poll::Pending,
229                Poll::Ready(Ok(_)) => {
230                    self.cur_write_fut = None;
231                    Poll::Ready(Ok(()))
232                }
233                Poll::Ready(Err(e)) => {
234                    self.cur_write_fut = None;
235                    Poll::Ready(Err(std::io::Error::other(e.to_string())))
236                }
237            }
238        } else {
239            Poll::Ready(Ok(()))
240        }
241    }
242
243    fn poll_shutdown(
244        self: Pin<&mut Self>,
245        _cx: &mut Context<'_>,
246    ) -> Poll<Result<(), std::io::Error>> {
247        Poll::Ready(Ok(()))
248    }
249}