1use std::{
2 io,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IoBuf, bytes::Bytes};
8use compio_io::AsyncWrite;
9use futures_util::{future::poll_fn, ready};
10use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, StoppedError};
14
15#[derive(Debug)]
33pub struct SendStream {
34 conn: Arc<ConnectionInner>,
35 stream: StreamId,
36 is_0rtt: bool,
37}
38
39impl SendStream {
40 pub(crate) fn new(conn: Arc<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
41 Self {
42 conn,
43 stream,
44 is_0rtt,
45 }
46 }
47
48 pub fn id(&self) -> StreamId {
50 self.stream
51 }
52
53 pub fn finish(&mut self) -> Result<(), ClosedStream> {
70 let mut state = self.conn.state();
71 match state.conn.send_stream(self.stream).finish() {
72 Ok(()) => {
73 state.wake();
74 Ok(())
75 }
76 Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
77 Err(FinishError::Stopped(_)) => Ok(()),
80 }
81 }
82
83 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
95 let mut state = self.conn.state();
96 if self.is_0rtt && !state.check_0rtt() {
97 return Ok(());
98 }
99 state.conn.send_stream(self.stream).reset(error_code)?;
100 state.wake();
101 Ok(())
102 }
103
104 pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
113 self.conn
114 .state()
115 .conn
116 .send_stream(self.stream)
117 .set_priority(priority)
118 }
119
120 pub fn priority(&self) -> Result<i32, ClosedStream> {
122 self.conn.state().conn.send_stream(self.stream).priority()
123 }
124
125 pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
139 poll_fn(|cx| {
140 let mut state = self.conn.state();
141 if self.is_0rtt && !state.check_0rtt() {
142 return Poll::Ready(Err(StoppedError::ZeroRttRejected));
143 }
144 match state.conn.send_stream(self.stream).stopped() {
145 Err(_) => Poll::Ready(Ok(None)),
146 Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
147 Ok(None) => {
148 if let Some(e) = &state.error {
149 return Poll::Ready(Err(e.clone().into()));
150 }
151 state.stopped.insert(self.stream, cx.waker().clone());
152 Poll::Pending
153 }
154 }
155 })
156 .await
157 }
158
159 fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
160 where
161 F: FnOnce(quinn_proto::SendStream) -> Result<R, quinn_proto::WriteError>,
162 {
163 let mut state = self.conn.try_state()?;
164 if self.is_0rtt && !state.check_0rtt() {
165 return Poll::Ready(Err(WriteError::ZeroRttRejected));
166 }
167 match f(state.conn.send_stream(self.stream)) {
168 Ok(r) => {
169 state.wake();
170 Poll::Ready(Ok(r))
171 }
172 Err(e) => match e.try_into() {
173 Ok(e) => Poll::Ready(Err(e)),
174 Err(()) => {
175 state.writable.insert(self.stream, cx.waker().clone());
176 Poll::Pending
177 }
178 },
179 }
180 }
181
182 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
190 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
191 }
192
193 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
197 let mut count = 0;
198 poll_fn(|cx| {
199 loop {
200 if count == buf.len() {
201 return Poll::Ready(Ok(()));
202 }
203 let n =
204 ready!(self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..])))?;
205 count += n;
206 }
207 })
208 .await
209 }
210
211 pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
219 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
220 }
221
222 pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
226 let mut chunks = 0;
227 poll_fn(|cx| {
228 loop {
229 if chunks == bufs.len() {
230 return Poll::Ready(Ok(()));
231 }
232 let written = ready!(self.execute_poll_write(cx, |mut stream| {
233 stream.write_chunks(&mut bufs[chunks..])
234 }))?;
235 chunks += written.chunks;
236 }
237 })
238 .await
239 }
240}
241
242impl Drop for SendStream {
243 fn drop(&mut self) {
244 let mut state = self.conn.state();
245
246 state.stopped.remove(&self.stream);
248 state.writable.remove(&self.stream);
249
250 if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
251 return;
252 }
253 match state.conn.send_stream(self.stream).finish() {
254 Ok(()) => state.wake(),
255 Err(FinishError::Stopped(reason)) => {
256 if state.conn.send_stream(self.stream).reset(reason).is_ok() {
257 state.wake();
258 }
259 }
260 Err(FinishError::ClosedStream) => {}
262 }
263 }
264}
265
266#[derive(Debug, Error, Clone, PartialEq, Eq)]
268pub enum WriteError {
269 #[error("sending stopped by peer: error {0}")]
273 Stopped(VarInt),
274 #[error("connection lost")]
276 ConnectionLost(#[from] ConnectionError),
277 #[error("closed stream")]
279 ClosedStream,
280 #[error("0-RTT rejected")]
287 ZeroRttRejected,
288 #[cfg(feature = "h3")]
291 #[error("stream not ready")]
292 NotReady,
293}
294
295impl TryFrom<quinn_proto::WriteError> for WriteError {
296 type Error = ();
297
298 fn try_from(value: quinn_proto::WriteError) -> Result<Self, Self::Error> {
299 use quinn_proto::WriteError::*;
300 match value {
301 Stopped(e) => Ok(Self::Stopped(e)),
302 ClosedStream => Ok(Self::ClosedStream),
303 Blocked => Err(()),
304 }
305 }
306}
307
308impl From<StoppedError> for WriteError {
309 fn from(x: StoppedError) -> Self {
310 match x {
311 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
312 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
313 }
314 }
315}
316
317impl From<WriteError> for io::Error {
318 fn from(x: WriteError) -> Self {
319 use WriteError::*;
320 let kind = match x {
321 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
322 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
323 #[cfg(feature = "h3")]
324 NotReady => io::ErrorKind::Other,
325 };
326 Self::new(kind, x)
327 }
328}
329
330impl AsyncWrite for SendStream {
331 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
332 let res = self.write(buf.as_slice()).await.map_err(Into::into);
333 BufResult(res, buf)
334 }
335
336 async fn flush(&mut self) -> io::Result<()> {
337 Ok(())
338 }
339
340 async fn shutdown(&mut self) -> io::Result<()> {
341 self.finish()?;
342 Ok(())
343 }
344}
345
346#[cfg(feature = "io-compat")]
347impl futures_util::AsyncWrite for SendStream {
348 fn poll_write(
349 self: std::pin::Pin<&mut Self>,
350 cx: &mut Context<'_>,
351 buf: &[u8],
352 ) -> Poll<io::Result<usize>> {
353 self.get_mut()
354 .execute_poll_write(cx, |mut stream| stream.write(buf))
355 .map_err(Into::into)
356 }
357
358 fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
359 Poll::Ready(Ok(()))
360 }
361
362 fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
363 self.get_mut().finish()?;
364 Poll::Ready(Ok(()))
365 }
366}
367
368#[cfg(feature = "h3")]
369pub(crate) mod h3_impl {
370 use compio_buf::bytes::Buf;
371 use h3::quic::{self, StreamErrorIncoming, WriteBuf};
372
373 use super::*;
374
375 impl From<WriteError> for StreamErrorIncoming {
376 fn from(e: WriteError) -> Self {
377 use WriteError::*;
378 match e {
379 Stopped(code) => Self::StreamTerminated {
380 error_code: code.into_inner(),
381 },
382 ConnectionLost(e) => Self::ConnectionErrorIncoming {
383 connection_error: e.into(),
384 },
385
386 e => Self::Unknown(Box::new(e)),
387 }
388 }
389 }
390
391 pub struct SendStream<B> {
394 inner: super::SendStream,
395 buf: Option<WriteBuf<B>>,
396 }
397
398 impl<B> SendStream<B> {
399 pub(crate) fn new(conn: Arc<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
400 Self {
401 inner: super::SendStream::new(conn, stream, is_0rtt),
402 buf: None,
403 }
404 }
405 }
406
407 impl<B> quic::SendStream<B> for SendStream<B>
408 where
409 B: Buf,
410 {
411 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
412 if let Some(data) = &mut self.buf {
413 while data.has_remaining() {
414 let n = ready!(
415 self.inner
416 .execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
417 )?;
418 data.advance(n);
419 }
420 }
421 self.buf = None;
422 Poll::Ready(Ok(()))
423 }
424
425 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
426 if self.buf.is_some() {
427 return Err(WriteError::NotReady.into());
428 }
429 self.buf = Some(data.into());
430 Ok(())
431 }
432
433 fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
434 Poll::Ready(
435 self.inner
436 .finish()
437 .map_err(|_| WriteError::ClosedStream.into()),
438 )
439 }
440
441 fn reset(&mut self, reset_code: u64) {
442 self.inner
443 .reset(reset_code.try_into().unwrap_or(VarInt::MAX))
444 .ok();
445 }
446
447 fn send_id(&self) -> quic::StreamId {
448 u64::from(self.inner.stream).try_into().unwrap()
449 }
450 }
451
452 impl<B> quic::SendStreamUnframed<B> for SendStream<B>
453 where
454 B: Buf,
455 {
456 fn poll_send<D: Buf>(
457 &mut self,
458 cx: &mut Context<'_>,
459 buf: &mut D,
460 ) -> Poll<Result<usize, StreamErrorIncoming>> {
461 debug_assert!(
463 self.buf.is_some(),
464 "poll_send called while send stream is not ready"
465 );
466
467 let n = ready!(
468 self.inner
469 .execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
470 )?;
471 buf.advance(n);
472 Poll::Ready(Ok(n))
473 }
474 }
475}