rocket_community/data/
io_stream.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use hyper::upgrade::Upgraded;
6use hyper_util::rt::TokioIo;
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9/// A bidirectional, raw stream to the client.
10///
11/// An instance of `IoStream` is passed to an [`IoHandler`] in response to a
12/// successful upgrade request initiated by responders via
13/// [`Response::add_upgrade()`] or the equivalent builder method
14/// [`Builder::upgrade()`]. For details on upgrade connections, see
15/// [`Response`#upgrading].
16///
17/// An `IoStream` is guaranteed to be [`AsyncRead`], [`AsyncWrite`], and
18/// `Unpin`. Bytes written to the stream are sent directly to the client. Bytes
19/// read from the stream are those sent directly _by_ the client. See
20/// [`IoHandler`] for one example of how values of this type are used.
21///
22/// [`Response::add_upgrade()`]: crate::Response::add_upgrade()
23/// [`Builder::upgrade()`]: crate::response::Builder::upgrade()
24/// [`Response`#upgrading]: crate::response::Response#upgrading
25pub struct IoStream {
26    kind: IoStreamKind,
27}
28
29/// Just in case we want to add stream kinds in the future.
30enum IoStreamKind {
31    Upgraded(TokioIo<Upgraded>),
32}
33
34/// An upgraded connection I/O handler.
35///
36/// An I/O handler performs raw I/O via the passed in [`IoStream`], which is
37/// [`AsyncRead`], [`AsyncWrite`], and `Unpin`.
38///
39/// # Example
40///
41/// The example below implements an `EchoHandler` that echos the raw bytes back
42/// to the client.
43///
44/// ```rust
45/// # extern crate rocket_community as rocket;
46/// use std::pin::Pin;
47///
48/// use rocket::tokio::io;
49/// use rocket::data::{IoHandler, IoStream};
50///
51/// struct EchoHandler;
52///
53/// #[rocket::async_trait]
54/// impl IoHandler for EchoHandler {
55///     async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
56///         let (mut reader, mut writer) = io::split(io);
57///         io::copy(&mut reader, &mut writer).await?;
58///         Ok(())
59///     }
60/// }
61///
62/// # use rocket::Response;
63/// # rocket::async_test(async {
64/// # let mut response = Response::new();
65/// # response.add_upgrade("raw-echo", EchoHandler);
66/// # assert!(response.upgrade("raw-echo").is_some());
67/// # })
68/// ```
69#[crate::async_trait]
70pub trait IoHandler: Send {
71    /// Performs the raw I/O.
72    async fn io(self: Box<Self>, io: IoStream) -> io::Result<()>;
73}
74
75#[crate::async_trait]
76impl IoHandler for () {
77    async fn io(self: Box<Self>, _: IoStream) -> io::Result<()> {
78        Ok(())
79    }
80}
81
82#[doc(hidden)]
83impl From<Upgraded> for IoStream {
84    fn from(io: Upgraded) -> Self {
85        IoStream {
86            kind: IoStreamKind::Upgraded(TokioIo::new(io)),
87        }
88    }
89}
90
91/// A "trait alias" of sorts so we can use `AsyncRead + AsyncWrite + Unpin` in `dyn`.
92pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
93
94/// Implemented for all `AsyncRead + AsyncWrite + Unpin`, of course.
95impl<T: AsyncRead + AsyncWrite + Unpin> AsyncReadWrite for T {}
96
97impl IoStream {
98    /// Returns the internal I/O stream.
99    fn inner_mut(&mut self) -> Pin<&mut dyn AsyncReadWrite> {
100        match self.kind {
101            IoStreamKind::Upgraded(ref mut io) => Pin::new(io),
102        }
103    }
104
105    /// Returns `true` if the inner I/O stream is write vectored.
106    fn inner_is_write_vectored(&self) -> bool {
107        match self.kind {
108            IoStreamKind::Upgraded(ref io) => io.is_write_vectored(),
109        }
110    }
111}
112
113impl AsyncRead for IoStream {
114    fn poll_read(
115        self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117        buf: &mut ReadBuf<'_>,
118    ) -> Poll<io::Result<()>> {
119        self.get_mut().inner_mut().poll_read(cx, buf)
120    }
121}
122
123impl AsyncWrite for IoStream {
124    fn poll_write(
125        self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127        buf: &[u8],
128    ) -> Poll<io::Result<usize>> {
129        self.get_mut().inner_mut().poll_write(cx, buf)
130    }
131
132    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
133        self.get_mut().inner_mut().poll_flush(cx)
134    }
135
136    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137        self.get_mut().inner_mut().poll_shutdown(cx)
138    }
139
140    fn poll_write_vectored(
141        self: Pin<&mut Self>,
142        cx: &mut Context<'_>,
143        bufs: &[io::IoSlice<'_>],
144    ) -> Poll<io::Result<usize>> {
145        self.get_mut().inner_mut().poll_write_vectored(cx, bufs)
146    }
147
148    fn is_write_vectored(&self) -> bool {
149        self.inner_is_write_vectored()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn is_unpin() {
159        fn check_traits<T: AsyncRead + AsyncWrite + Unpin + Send>() {}
160        check_traits::<IoStream>();
161    }
162}