libp2prs_mplex/connection/
stream.rs

1// Copyright 2020 Netwarps Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    connection::{Id, StreamCommand},
23    frame::{Frame, StreamID},
24};
25use bytes::{Buf, BufMut};
26use futures::channel::oneshot;
27use futures::lock::Mutex;
28use futures::task::{Context, Poll};
29use futures::{channel::mpsc, AsyncRead, AsyncWrite, FutureExt, Sink, SinkExt};
30use std::pin::Pin;
31use std::sync::Arc;
32use std::{fmt, io};
33
34/// The state of a Yamux stream.
35#[derive(Copy, Clone, Debug, PartialEq, Eq)]
36pub enum State {
37    /// Open bidirectionally.
38    Open,
39    /// Open for incoming messages.
40    SendClosed,
41    /// Open for outgoing messages.
42    RecvClosed,
43}
44
45impl State {
46    /// Can we receive messages over this stream?
47    pub fn can_read(self) -> bool {
48        self != State::RecvClosed
49    }
50
51    /// Can we send messages over this stream?
52    pub fn can_write(self) -> bool {
53        self != State::SendClosed
54    }
55}
56
57pub struct Stream {
58    id: StreamID,
59    conn_id: Id,
60    read_buffer: bytes::BytesMut,
61    sender: mpsc::Sender<StreamCommand>,
62    receiver: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
63}
64
65impl fmt::Debug for Stream {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        write!(f, "(Stream {}/{})", self.conn_id, self.id.id())
68    }
69}
70
71impl fmt::Display for Stream {
72    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73        write!(f, "(Stream {}/{})", self.conn_id, self.id.id())
74    }
75}
76
77impl Clone for Stream {
78    /// impl [`Clone`] trait
79    fn clone(&self) -> Self {
80        Stream {
81            id: self.id,
82            conn_id: self.conn_id,
83            read_buffer: Default::default(),
84            sender: self.sender.clone(),
85            receiver: self.receiver.clone(),
86        }
87    }
88}
89
90impl Stream {
91    pub(crate) fn new(id: StreamID, conn_id: Id, sender: mpsc::Sender<StreamCommand>, receiver: mpsc::Receiver<Vec<u8>>) -> Self {
92        Stream {
93            id,
94            conn_id,
95            read_buffer: Default::default(),
96            sender,
97            receiver: Arc::new(Mutex::new(receiver)),
98        }
99    }
100
101    pub fn val(&self) -> u32 {
102        self.id.val()
103    }
104
105    /// Get this stream's identifier.
106    pub fn id(&self) -> u32 {
107        self.id.id()
108    }
109
110    /// reset stream, sender will be closed and state will turn to Closed
111    /// If stream has reset, return ()
112    pub async fn reset(&mut self) -> io::Result<()> {
113        if self.sender.is_closed() {
114            return Ok(());
115        }
116
117        let (tx, rx) = oneshot::channel();
118        let frame = Frame::reset_frame(self.id);
119        let cmd = StreamCommand::ResetStream(frame, tx);
120        self.sender.send(cmd).await.map_err(|_| self.write_zero_err())?;
121        rx.await.map_err(|_| self.closed_err())?;
122
123        self.sender.close().await.map_err(|_| self.write_zero_err())?;
124
125        Ok(())
126    }
127
128    /// connection is closed
129    fn write_zero_err(&self) -> io::Error {
130        let msg = format!("{}/{}: connection is closed", self.conn_id, self.id);
131        io::Error::new(io::ErrorKind::WriteZero, msg)
132    }
133
134    /// stream is closed or reset
135    fn closed_err(&self) -> io::Error {
136        let msg = format!("{}/{}: stream is closed / reset", self.conn_id, self.id);
137        io::Error::new(io::ErrorKind::WriteZero, msg)
138    }
139}
140
141impl AsyncRead for Stream {
142    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
143        if self.read_buffer.has_remaining() {
144            let len = std::cmp::min(self.read_buffer.remaining(), buf.len());
145            buf[..len].copy_from_slice(&self.read_buffer[..len]);
146            self.read_buffer.advance(len);
147            return Poll::Ready(Ok(len));
148        }
149
150        let this = self.get_mut();
151
152        let mut receiver = futures::ready!(this.receiver.lock().poll_unpin(cx));
153
154        let x = futures::Stream::poll_next(Pin::new(&mut *receiver), cx);
155        if let Some(data) = futures::ready!(x) {
156            let dlen = data.len();
157            let len = std::cmp::min(data.len(), buf.len());
158            buf[..len].copy_from_slice(&data[..len]);
159
160            if len < dlen {
161                this.read_buffer.reserve(dlen - len);
162                this.read_buffer.put(&data[len..dlen]);
163            }
164            return Poll::Ready(Ok(len));
165        }
166        Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
167    }
168}
169
170impl AsyncWrite for Stream {
171    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
172        if self.sender.is_closed() {
173            return Poll::Ready(Err(self.closed_err()));
174        }
175
176        futures::ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
177
178        let frame = Frame::message_frame(self.id, buf);
179        let n = buf.len();
180        log::trace!("{}/{}: write {} bytes", self.conn_id, self.id, n);
181
182        let cmd = StreamCommand::SendFrame(frame);
183        self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
184        Poll::Ready(Ok(n))
185    }
186
187    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
188        let this = self.get_mut();
189        Pin::new(&mut this.sender).poll_flush(cx).map_err(|_| this.write_zero_err())
190    }
191
192    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193        if self.sender.is_closed() {
194            return Poll::Ready(Ok(()));
195        }
196
197        futures::ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
198
199        let frame = Frame::close_frame(self.id);
200        let cmd = StreamCommand::CloseStream(frame);
201        self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
202
203        let this = self.get_mut();
204        Pin::new(&mut this.sender).poll_close(cx).map_err(|_| this.closed_err())
205    }
206}