1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use proto::{ConnectionError, FinishError, StreamId, Written};
10use thiserror::Error;
11use tokio::sync::oneshot;
12
13use crate::{
14 connection::{ConnectionRef, UnknownStream},
15 VarInt,
16};
17
18#[derive(Debug)]
25pub struct SendStream {
26 conn: ConnectionRef,
27 stream: StreamId,
28 is_0rtt: bool,
29 finishing: Option<oneshot::Receiver<Option<WriteError>>>,
30}
31
32impl SendStream {
33 pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
34 Self {
35 conn,
36 stream,
37 is_0rtt,
38 finishing: None,
39 }
40 }
41
42 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
47 Write { stream: self, buf }.await
48 }
49
50 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
52 WriteAll { stream: self, buf }.await
53 }
54
55 pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
61 WriteChunks { stream: self, bufs }.await
62 }
63
64 pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
66 WriteChunk {
67 stream: self,
68 buf: [buf],
69 }
70 .await
71 }
72
73 pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
75 WriteAllChunks {
76 stream: self,
77 bufs,
78 offset: 0,
79 }
80 .await
81 }
82
83 fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
84 where
85 F: FnOnce(&mut proto::SendStream) -> Result<R, proto::WriteError>,
86 {
87 use proto::WriteError::*;
88 let mut conn = self.conn.state.lock("SendStream::poll_write");
89 if self.is_0rtt {
90 conn.check_0rtt()
91 .map_err(|()| WriteError::ZeroRttRejected)?;
92 }
93 if let Some(ref x) = conn.error {
94 return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
95 }
96
97 let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
98 Ok(result) => result,
99 Err(Blocked) => {
100 conn.blocked_writers.insert(self.stream, cx.waker().clone());
101 return Poll::Pending;
102 }
103 Err(Stopped(error_code)) => {
104 return Poll::Ready(Err(WriteError::Stopped(error_code)));
105 }
106 Err(UnknownStream) => {
107 return Poll::Ready(Err(WriteError::UnknownStream));
108 }
109 };
110
111 conn.wake();
112 Poll::Ready(Ok(result))
113 }
114
115 pub async fn finish(&mut self) -> Result<(), WriteError> {
120 Finish { stream: self }.await
121 }
122
123 #[doc(hidden)]
124 pub fn poll_finish(&mut self, cx: &mut Context) -> Poll<Result<(), WriteError>> {
125 let mut conn = self.conn.state.lock("poll_finish");
126 if self.is_0rtt {
127 conn.check_0rtt()
128 .map_err(|()| WriteError::ZeroRttRejected)?;
129 }
130 if self.finishing.is_none() {
131 conn.inner
132 .send_stream(self.stream)
133 .finish()
134 .map_err(|e| match e {
135 FinishError::UnknownStream => WriteError::UnknownStream,
136 FinishError::Stopped(error_code) => WriteError::Stopped(error_code),
137 })?;
138 let (send, recv) = oneshot::channel();
139 self.finishing = Some(recv);
140 conn.finishing.insert(self.stream, send);
141 conn.wake();
142 }
143 match Pin::new(self.finishing.as_mut().unwrap())
144 .poll(cx)
145 .map(|x| x.unwrap())
146 {
147 Poll::Ready(x) => {
148 self.finishing = None;
149 Poll::Ready(x.map_or(Ok(()), Err))
150 }
151 Poll::Pending => {
152 if let Some(ref x) = conn.error {
158 return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
159 }
160 Poll::Pending
161 }
162 }
163 }
164
165 pub fn reset(&mut self, error_code: VarInt) -> Result<(), UnknownStream> {
171 let mut conn = self.conn.state.lock("SendStream::reset");
172 if self.is_0rtt && conn.check_0rtt().is_err() {
173 return Ok(());
174 }
175 conn.inner.send_stream(self.stream).reset(error_code)?;
176 conn.wake();
177 Ok(())
178 }
179
180 pub fn set_priority(&self, priority: i32) -> Result<(), UnknownStream> {
188 let mut conn = self.conn.state.lock("SendStream::set_priority");
189 conn.inner.send_stream(self.stream).set_priority(priority)?;
190 Ok(())
191 }
192
193 pub fn priority(&self) -> Result<i32, UnknownStream> {
195 let mut conn = self.conn.state.lock("SendStream::priority");
196 Ok(conn.inner.send_stream(self.stream).priority()?)
197 }
198
199 pub async fn stopped(&mut self) -> Result<VarInt, StoppedError> {
201 Stopped { stream: self }.await
202 }
203
204 #[doc(hidden)]
205 pub fn poll_stopped(&mut self, cx: &mut Context) -> Poll<Result<VarInt, StoppedError>> {
206 let mut conn = self.conn.state.lock("SendStream::poll_stopped");
207
208 if self.is_0rtt {
209 conn.check_0rtt()
210 .map_err(|()| StoppedError::ZeroRttRejected)?;
211 }
212
213 match conn.inner.send_stream(self.stream).stopped() {
214 Err(_) => Poll::Ready(Err(StoppedError::UnknownStream)),
215 Ok(Some(error_code)) => Poll::Ready(Ok(error_code)),
216 Ok(None) => {
217 conn.stopped.insert(self.stream, cx.waker().clone());
218 Poll::Pending
219 }
220 }
221 }
222
223 pub fn id(&self) -> StreamId {
225 self.stream
226 }
227}
228
229#[cfg(feature = "futures-io")]
230impl futures_io::AsyncWrite for SendStream {
231 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
232 SendStream::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
233 }
234
235 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
236 Poll::Ready(Ok(()))
237 }
238
239 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
240 self.get_mut().poll_finish(cx).map_err(Into::into)
241 }
242}
243
244#[cfg(feature = "runtime-tokio")]
245impl tokio::io::AsyncWrite for SendStream {
246 fn poll_write(
247 self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 buf: &[u8],
250 ) -> Poll<io::Result<usize>> {
251 Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
252 }
253
254 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
255 Poll::Ready(Ok(()))
256 }
257
258 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
259 self.get_mut().poll_finish(cx).map_err(Into::into)
260 }
261}
262
263impl Drop for SendStream {
264 fn drop(&mut self) {
265 let mut conn = self.conn.state.lock("SendStream::drop");
266
267 conn.finishing.remove(&self.stream);
269 conn.stopped.remove(&self.stream);
270 conn.blocked_writers.remove(&self.stream);
271
272 if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
273 return;
274 }
275 if self.finishing.is_none() {
276 match conn.inner.send_stream(self.stream).finish() {
277 Ok(()) => conn.wake(),
278 Err(FinishError::Stopped(reason)) => {
279 if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
280 conn.wake();
281 }
282 }
283 Err(FinishError::UnknownStream) => {}
285 }
286 }
287 }
288}
289
290#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
292struct Finish<'a> {
293 stream: &'a mut SendStream,
294}
295
296impl Future for Finish<'_> {
297 type Output = Result<(), WriteError>;
298
299 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
300 self.get_mut().stream.poll_finish(cx)
301 }
302}
303
304#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
306struct Stopped<'a> {
307 stream: &'a mut SendStream,
308}
309
310impl Future for Stopped<'_> {
311 type Output = Result<VarInt, StoppedError>;
312
313 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
314 self.get_mut().stream.poll_stopped(cx)
315 }
316}
317
318#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
322struct Write<'a> {
323 stream: &'a mut SendStream,
324 buf: &'a [u8],
325}
326
327impl<'a> Future for Write<'a> {
328 type Output = Result<usize, WriteError>;
329 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
330 let this = self.get_mut();
331 let buf = this.buf;
332 this.stream.execute_poll(cx, |s| s.write(buf))
333 }
334}
335
336#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
340struct WriteAll<'a> {
341 stream: &'a mut SendStream,
342 buf: &'a [u8],
343}
344
345impl<'a> Future for WriteAll<'a> {
346 type Output = Result<(), WriteError>;
347 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
348 let this = self.get_mut();
349 loop {
350 if this.buf.is_empty() {
351 return Poll::Ready(Ok(()));
352 }
353 let buf = this.buf;
354 let n = ready!(this.stream.execute_poll(cx, |s| s.write(buf)))?;
355 this.buf = &this.buf[n..];
356 }
357 }
358}
359
360#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
364struct WriteChunks<'a> {
365 stream: &'a mut SendStream,
366 bufs: &'a mut [Bytes],
367}
368
369impl<'a> Future for WriteChunks<'a> {
370 type Output = Result<Written, WriteError>;
371 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
372 let this = self.get_mut();
373 let bufs = &mut *this.bufs;
374 this.stream.execute_poll(cx, |s| s.write_chunks(bufs))
375 }
376}
377
378#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
382struct WriteChunk<'a> {
383 stream: &'a mut SendStream,
384 buf: [Bytes; 1],
385}
386
387impl<'a> Future for WriteChunk<'a> {
388 type Output = Result<(), WriteError>;
389 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
390 let this = self.get_mut();
391 loop {
392 if this.buf[0].is_empty() {
393 return Poll::Ready(Ok(()));
394 }
395 let bufs = &mut this.buf[..];
396 ready!(this.stream.execute_poll(cx, |s| s.write_chunks(bufs)))?;
397 }
398 }
399}
400
401#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
405struct WriteAllChunks<'a> {
406 stream: &'a mut SendStream,
407 bufs: &'a mut [Bytes],
408 offset: usize,
409}
410
411impl<'a> Future for WriteAllChunks<'a> {
412 type Output = Result<(), WriteError>;
413 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
414 let this = self.get_mut();
415 loop {
416 if this.offset == this.bufs.len() {
417 return Poll::Ready(Ok(()));
418 }
419 let bufs = &mut this.bufs[this.offset..];
420 let written = ready!(this.stream.execute_poll(cx, |s| s.write_chunks(bufs)))?;
421 this.offset += written.chunks;
422 }
423 }
424}
425
426#[derive(Debug, Error, Clone, PartialEq, Eq)]
428pub enum WriteError {
429 #[error("sending stopped by peer: error {0}")]
433 Stopped(VarInt),
434 #[error("connection lost")]
436 ConnectionLost(#[from] ConnectionError),
437 #[error("unknown stream")]
439 UnknownStream,
440 #[error("0-RTT rejected")]
447 ZeroRttRejected,
448}
449
450#[derive(Debug, Error, Clone, PartialEq, Eq)]
452pub enum StoppedError {
453 #[error("connection lost")]
455 ConnectionLost(#[from] ConnectionError),
456 #[error("unknown stream")]
458 UnknownStream,
459 #[error("0-RTT rejected")]
466 ZeroRttRejected,
467}
468
469impl From<WriteError> for io::Error {
470 fn from(x: WriteError) -> Self {
471 use self::WriteError::*;
472 let kind = match x {
473 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
474 ConnectionLost(_) | UnknownStream => io::ErrorKind::NotConnected,
475 };
476 Self::new(kind, x)
477 }
478}