libp2prs_mplex/connection/
stream.rs1use crate::{
22 connection::{Id, StreamCommand},
23 frame::{Frame, StreamID},
24};
25use bytes::{Buf, BufMut};
26use futures::channel::oneshot;
27use futures::lock::Mutex;
28use futures::task::{Context, Poll};
29use futures::{channel::mpsc, AsyncRead, AsyncWrite, FutureExt, Sink, SinkExt};
30use std::pin::Pin;
31use std::sync::Arc;
32use std::{fmt, io};
33
34#[derive(Copy, Clone, Debug, PartialEq, Eq)]
36pub enum State {
37 Open,
39 SendClosed,
41 RecvClosed,
43}
44
45impl State {
46 pub fn can_read(self) -> bool {
48 self != State::RecvClosed
49 }
50
51 pub fn can_write(self) -> bool {
53 self != State::SendClosed
54 }
55}
56
57pub struct Stream {
58 id: StreamID,
59 conn_id: Id,
60 read_buffer: bytes::BytesMut,
61 sender: mpsc::Sender<StreamCommand>,
62 receiver: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
63}
64
65impl fmt::Debug for Stream {
66 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67 write!(f, "(Stream {}/{})", self.conn_id, self.id.id())
68 }
69}
70
71impl fmt::Display for Stream {
72 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73 write!(f, "(Stream {}/{})", self.conn_id, self.id.id())
74 }
75}
76
77impl Clone for Stream {
78 fn clone(&self) -> Self {
80 Stream {
81 id: self.id,
82 conn_id: self.conn_id,
83 read_buffer: Default::default(),
84 sender: self.sender.clone(),
85 receiver: self.receiver.clone(),
86 }
87 }
88}
89
90impl Stream {
91 pub(crate) fn new(id: StreamID, conn_id: Id, sender: mpsc::Sender<StreamCommand>, receiver: mpsc::Receiver<Vec<u8>>) -> Self {
92 Stream {
93 id,
94 conn_id,
95 read_buffer: Default::default(),
96 sender,
97 receiver: Arc::new(Mutex::new(receiver)),
98 }
99 }
100
101 pub fn val(&self) -> u32 {
102 self.id.val()
103 }
104
105 pub fn id(&self) -> u32 {
107 self.id.id()
108 }
109
110 pub async fn reset(&mut self) -> io::Result<()> {
113 if self.sender.is_closed() {
114 return Ok(());
115 }
116
117 let (tx, rx) = oneshot::channel();
118 let frame = Frame::reset_frame(self.id);
119 let cmd = StreamCommand::ResetStream(frame, tx);
120 self.sender.send(cmd).await.map_err(|_| self.write_zero_err())?;
121 rx.await.map_err(|_| self.closed_err())?;
122
123 self.sender.close().await.map_err(|_| self.write_zero_err())?;
124
125 Ok(())
126 }
127
128 fn write_zero_err(&self) -> io::Error {
130 let msg = format!("{}/{}: connection is closed", self.conn_id, self.id);
131 io::Error::new(io::ErrorKind::WriteZero, msg)
132 }
133
134 fn closed_err(&self) -> io::Error {
136 let msg = format!("{}/{}: stream is closed / reset", self.conn_id, self.id);
137 io::Error::new(io::ErrorKind::WriteZero, msg)
138 }
139}
140
141impl AsyncRead for Stream {
142 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
143 if self.read_buffer.has_remaining() {
144 let len = std::cmp::min(self.read_buffer.remaining(), buf.len());
145 buf[..len].copy_from_slice(&self.read_buffer[..len]);
146 self.read_buffer.advance(len);
147 return Poll::Ready(Ok(len));
148 }
149
150 let this = self.get_mut();
151
152 let mut receiver = futures::ready!(this.receiver.lock().poll_unpin(cx));
153
154 let x = futures::Stream::poll_next(Pin::new(&mut *receiver), cx);
155 if let Some(data) = futures::ready!(x) {
156 let dlen = data.len();
157 let len = std::cmp::min(data.len(), buf.len());
158 buf[..len].copy_from_slice(&data[..len]);
159
160 if len < dlen {
161 this.read_buffer.reserve(dlen - len);
162 this.read_buffer.put(&data[len..dlen]);
163 }
164 return Poll::Ready(Ok(len));
165 }
166 Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
167 }
168}
169
170impl AsyncWrite for Stream {
171 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
172 if self.sender.is_closed() {
173 return Poll::Ready(Err(self.closed_err()));
174 }
175
176 futures::ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
177
178 let frame = Frame::message_frame(self.id, buf);
179 let n = buf.len();
180 log::trace!("{}/{}: write {} bytes", self.conn_id, self.id, n);
181
182 let cmd = StreamCommand::SendFrame(frame);
183 self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
184 Poll::Ready(Ok(n))
185 }
186
187 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
188 let this = self.get_mut();
189 Pin::new(&mut this.sender).poll_flush(cx).map_err(|_| this.write_zero_err())
190 }
191
192 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193 if self.sender.is_closed() {
194 return Poll::Ready(Ok(()));
195 }
196
197 futures::ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
198
199 let frame = Frame::close_frame(self.id);
200 let cmd = StreamCommand::CloseStream(frame);
201 self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
202
203 let this = self.get_mut();
204 Pin::new(&mut this.sender).poll_close(cx).map_err(|_| this.closed_err())
205 }
206}