foctet_mux/
stream.rs

1use anyhow::Result;
2use bytes::Bytes;
3use foctet_core::connection::SessionId;
4use foctet_core::frame::{Frame, FrameBuilder, FrameFlags, FrameType};
5use foctet_core::stream::StreamId;
6use std::io;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::sync::mpsc::{Receiver, Sender};
12
13/// The stream state
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum StreamState {
16    /// Just created
17    Init,
18    /// We sent a open request message
19    OpenSent,
20    /// We received a open request message
21    OpenReceived,
22    /// Stream established
23    Established,
24    /// We closed the stream
25    LocalClosing,
26    /// Remote closed the stream
27    RemoteClosing,
28    /// Both side of the stream closed
29    Closed,
30    /// Stream rejected by remote
31    Reset,
32}
33
34// Stream event
35#[derive(Debug, Eq, PartialEq)]
36pub enum StreamEvent {
37    Frame(Frame),
38    Closed(StreamId),
39    Error,
40}
41
42/// Writer half of LogicalStream
43#[derive(Debug)]
44pub struct LogicalStreamWriter {
45    session_id: SessionId,
46    stream_id: StreamId,
47    state: Arc<Mutex<StreamState>>,
48    frame_sender: Sender<StreamEvent>,
49}
50
51impl LogicalStreamWriter {
52    pub fn set_state(&mut self, state: StreamState) {
53        match self.state.lock() {
54            Ok(mut state_guard) => {
55                *state_guard = state;
56            }
57            Err(_) => (),
58        }
59    }
60    /// Get the session id
61    pub fn session_id(&self) -> SessionId {
62        self.session_id
63    }
64    /// Get the stream id
65    pub fn stream_id(&self) -> StreamId {
66        self.stream_id
67    }
68    /// Get the stream state
69    pub fn state(&self) -> StreamState {
70        match self.state.lock() {
71            Ok(state_guard) => *state_guard,
72            Err(_) => StreamState::Closed,
73        }
74    }
75    pub async fn send_event(&self, event: StreamEvent) -> Result<()> {
76        match self.frame_sender.send(event).await {
77            Ok(_) => Ok(()),
78            Err(_) => anyhow::bail!(io::Error::new(
79                io::ErrorKind::BrokenPipe,
80                "Failed to send event"
81            )),
82        }
83    }
84    /// Send a frame
85    pub async fn send_frame(&self, frame: Frame) -> Result<()> {
86        self.frame_sender
87            .send(StreamEvent::Frame(frame))
88            .await
89            .map_err(|_| {
90                anyhow::anyhow!(io::Error::new(
91                    io::ErrorKind::BrokenPipe,
92                    "Failed to send frame"
93                ))
94            })
95    }
96
97    /// Send raw bytes as a data frame
98    pub async fn send_bytes(&self, bytes: Bytes) -> Result<()> {
99        let frame = FrameBuilder::new()
100            .with_stream_id(self.stream_id)
101            .with_frame_type(FrameType::Data)
102            .with_payload(bytes)
103            .build();
104        self.send_frame(frame).await
105    }
106
107    async fn send_close_request(&mut self) -> Result<()> {
108        let frame_flags = FrameFlags::close_request();
109        let close_frame: Frame = Frame::builder()
110            .with_stream_id(self.stream_id)
111            .with_flags(frame_flags)
112            .build();
113        match self
114            .frame_sender
115            .send(StreamEvent::Frame(close_frame))
116            .await
117        {
118            Ok(_) => Ok(()),
119            Err(_) => anyhow::bail!(io::Error::new(
120                io::ErrorKind::BrokenPipe,
121                "Failed to send close frame"
122            )),
123        }
124    }
125
126    /// Close the stream
127    pub async fn close(&mut self) -> Result<()> {
128        let state = match self.state.lock() {
129            Ok(state_guard) => *state_guard,
130            Err(_) => StreamState::Closed,
131        };
132        match state {
133            StreamState::OpenSent
134            | StreamState::OpenReceived
135            | StreamState::Established
136            | StreamState::Init => {
137                self.set_state(StreamState::LocalClosing);
138                self.send_close_request().await?;
139            }
140            StreamState::RemoteClosing => {
141                self.set_state(StreamState::Closed);
142                self.send_close_request().await?;
143                let event = StreamEvent::Closed(self.stream_id);
144                self.send_event(event).await?;
145            }
146            StreamState::Reset | StreamState::Closed => {
147                self.set_state(StreamState::Closed);
148                let event = StreamEvent::Closed(self.stream_id);
149                self.send_event(event).await?;
150            }
151            StreamState::LocalClosing => {
152                self.set_state(StreamState::Closed);
153                let event = StreamEvent::Closed(self.stream_id);
154                self.send_event(event).await?;
155            }
156        }
157        Ok(())
158    }
159}
160
161impl AsyncWrite for LogicalStreamWriter {
162    fn poll_write(
163        self: Pin<&mut Self>,
164        _cx: &mut Context<'_>,
165        buf: &[u8],
166    ) -> Poll<std::io::Result<usize>> {
167        let payload = Bytes::copy_from_slice(buf);
168        let frame = FrameBuilder::new()
169            .with_stream_id(self.stream_id)
170            .with_frame_type(FrameType::Data)
171            .with_payload(payload)
172            .build();
173
174        match self.frame_sender.try_send(StreamEvent::Frame(frame)) {
175            Ok(_) => {
176                Poll::Ready(Ok(buf.len()))
177            }
178            Err(_) => {
179                tracing::error!("Failed to send frame");
180                return Poll::Ready(Err(io::Error::new(
181                    io::ErrorKind::BrokenPipe,
182                    "Failed to send frame",
183                )));
184            }
185        }
186    }
187
188    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
189        Poll::Ready(Ok(()))
190    }
191
192    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
193        // TODO: closing the logical stream here.
194        Poll::Ready(Ok(()))
195    }
196}
197
198/// Reader half of LogicalStream
199#[derive(Debug)]
200pub struct LogicalStreamReader {
201    session_id: SessionId,
202    stream_id: StreamId,
203    state: Arc<Mutex<StreamState>>,
204    frame_receiver: Receiver<Frame>,
205}
206
207impl LogicalStreamReader {
208    pub fn set_state(&mut self, state: StreamState) {
209        match self.state.lock() {
210            Ok(mut state_guard) => {
211                *state_guard = state;
212            }
213            Err(_) => (),
214        }
215    }
216    /// Get the session id
217    pub fn session_id(&self) -> SessionId {
218        self.session_id
219    }
220    /// Get the stream id
221    pub fn stream_id(&self) -> StreamId {
222        self.stream_id
223    }
224    /// Get the stream state
225    pub fn state(&self) -> StreamState {
226        match self.state.lock() {
227            Ok(state_guard) => *state_guard,
228            Err(_) => StreamState::Closed,
229        }
230    }
231    /// Receive a frame
232    pub async fn recv_frame(&mut self) -> Result<Frame> {
233        match self.frame_receiver.recv().await {
234            Some(frame) => {
235                self.set_state_from_flags(frame.header.flags);
236                Ok(frame)
237            }
238            None => {
239                if self.frame_receiver.is_closed() {
240                    anyhow::bail!(io::Error::new(io::ErrorKind::BrokenPipe, "Channel closed"))
241                } else {
242                    anyhow::bail!(io::Error::new(
243                        io::ErrorKind::BrokenPipe,
244                        "Failed to receive frame"
245                    ))
246                }
247            }
248        }
249    }
250
251    /// Receive raw bytes from a data frame
252    pub async fn recv_bytes(&mut self) -> Result<Bytes> {
253        let frame = self.recv_frame().await?;
254        Ok(frame.payload)
255    }
256
257    fn set_state_from_flags(&mut self, flags: FrameFlags) {
258        if flags.is_open_request() {
259            self.set_state(StreamState::OpenReceived);
260        } else if flags.is_open_response() {
261            self.set_state(StreamState::Established);
262        } else if flags.is_open_reset() {
263            self.set_state(StreamState::Reset);
264        }
265    }
266}
267
268impl AsyncRead for LogicalStreamReader {
269    fn poll_read(
270        mut self: Pin<&mut Self>,
271        _cx: &mut Context<'_>,
272        buf: &mut ReadBuf<'_>,
273    ) -> Poll<io::Result<()>> {
274        match self.frame_receiver.try_recv() {
275            Ok(frame) => {
276                self.set_state_from_flags(frame.header.flags);
277                buf.put_slice(&frame.payload);
278                Poll::Ready(Ok(()))
279            }
280            Err(tokio::sync::mpsc::error::TryRecvError::Empty) => Poll::Pending,
281            Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
282                // Session closed, mark as EOF
283                self.set_state(StreamState::Closed);
284                Poll::Ready(Ok(()))
285            }
286        }
287    }
288}
289
290#[derive(Debug)]
291pub struct LogicalStream {
292    /// Session ID
293    session_id: SessionId,
294    /// Stream ID
295    stream_id: StreamId,
296    /// Stream state
297    state: StreamState,
298    // Send frame to parent session
299    frame_sender: Sender<StreamEvent>,
300    // Receive frame of current stream from parent session
301    // (if the sender closed means session closed the stream should close too)
302    frame_receiver: Receiver<Frame>,
303}
304
305impl LogicalStream {
306    pub fn new(
307        session_id: SessionId,
308        stream_id: StreamId,
309        state: StreamState,
310        frame_sender: Sender<StreamEvent>,
311        frame_receiver: Receiver<Frame>,
312    ) -> Self {
313        Self {
314            session_id,
315            stream_id,
316            state,
317            frame_sender,
318            frame_receiver,
319        }
320    }
321
322    pub async fn send_frame(&self, frame: Frame) -> Result<()> {
323        match self.frame_sender.send(StreamEvent::Frame(frame)).await {
324            Ok(_) => Ok(()),
325            Err(_) => anyhow::bail!(io::Error::new(
326                io::ErrorKind::BrokenPipe,
327                "Failed to send frame"
328            )),
329        }
330    }
331
332    pub async fn recv_frame(&mut self) -> Result<Frame> {
333        match self.frame_receiver.recv().await {
334            Some(frame) => {
335                self.set_state_from_flags(frame.header.flags);
336                Ok(frame)
337            }
338            None => anyhow::bail!(io::Error::new(
339                io::ErrorKind::BrokenPipe,
340                "Failed to receive frame"
341            )),
342        }
343    }
344
345    async fn send_event(&self, event: StreamEvent) -> Result<()> {
346        match self.frame_sender.send(event).await {
347            Ok(_) => Ok(()),
348            Err(_) => anyhow::bail!(io::Error::new(
349                io::ErrorKind::BrokenPipe,
350                "Failed to send event"
351            )),
352        }
353    }
354
355    async fn send_close_request(&mut self) -> Result<()> {
356        let frame_flags = FrameFlags::close_request();
357        let close_frame: Frame = Frame::builder()
358            .with_stream_id(self.stream_id)
359            .with_flags(frame_flags)
360            .build();
361        match self
362            .frame_sender
363            .send(StreamEvent::Frame(close_frame))
364            .await
365        {
366            Ok(_) => Ok(()),
367            Err(_) => anyhow::bail!(io::Error::new(
368                io::ErrorKind::BrokenPipe,
369                "Failed to send close frame"
370            )),
371        }
372    }
373
374    /// Close the stream
375    pub async fn close(&mut self) -> Result<()> {
376        match self.state {
377            StreamState::OpenSent
378            | StreamState::OpenReceived
379            | StreamState::Established
380            | StreamState::Init => {
381                self.state = StreamState::LocalClosing;
382                self.send_close_request().await?;
383            }
384            StreamState::RemoteClosing => {
385                self.state = StreamState::Closed;
386                self.send_close_request().await?;
387                let event = StreamEvent::Closed(self.stream_id);
388                self.send_event(event).await?;
389            }
390            StreamState::Reset | StreamState::Closed => {
391                self.state = StreamState::Closed;
392                let event = StreamEvent::Closed(self.stream_id);
393                self.send_event(event).await?;
394            }
395            StreamState::LocalClosing => {
396                self.state = StreamState::Closed;
397                let event = StreamEvent::Closed(self.stream_id);
398                self.send_event(event).await?;
399            }
400        }
401        Ok(())
402    }
403
404    fn set_state_from_flags(&mut self, flags: FrameFlags) {
405        if flags.is_open_request() {
406            self.state = StreamState::OpenReceived;
407        } else if flags.is_open_response() {
408            self.state = StreamState::Established;
409        } else if flags.is_open_reset() {
410            self.state = StreamState::Reset;
411        } else if flags.is_close_request() {
412            self.state = StreamState::RemoteClosing;
413        } else if flags.is_close_response() {
414            self.state = StreamState::Closed;
415        }
416    }
417
418    /// Split this stream into a reader and writer
419    pub fn split(self) -> (LogicalStreamWriter, LogicalStreamReader) {
420        let state = Arc::new(Mutex::new(self.state));
421        let writer = LogicalStreamWriter {
422            session_id: self.session_id,
423            stream_id: self.stream_id,
424            state: Arc::clone(&state),
425            frame_sender: self.frame_sender,
426        };
427        let reader = LogicalStreamReader {
428            session_id: self.session_id,
429            stream_id: self.stream_id,
430            state,
431            frame_receiver: self.frame_receiver,
432        };
433        (writer, reader)
434    }
435
436    pub fn set_state(&mut self, state: StreamState) {
437        self.state = state;
438    }
439
440    pub fn stream_id(&self) -> StreamId {
441        self.stream_id
442    }
443
444    pub fn state(&self) -> StreamState {
445        self.state
446    }
447
448    pub fn session_id(&self) -> SessionId {
449        self.session_id
450    }
451}