async_rustls/
server.rs

1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, RawSocket};
5
6use super::*;
7use crate::common::IoSession;
8
9/// A wrapper around an underlying raw stream which implements the TLS or SSL
10/// protocol.
11#[derive(Debug)]
12pub struct TlsStream<IO> {
13    pub(crate) io: IO,
14    pub(crate) session: ServerConnection,
15    pub(crate) state: TlsState,
16}
17
18impl<IO> TlsStream<IO> {
19    #[inline]
20    pub fn get_ref(&self) -> (&IO, &ServerConnection) {
21        (&self.io, &self.session)
22    }
23
24    #[inline]
25    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
26        (&mut self.io, &mut self.session)
27    }
28
29    #[inline]
30    pub fn into_inner(self) -> (IO, ServerConnection) {
31        (self.io, self.session)
32    }
33}
34
35impl<IO> IoSession for TlsStream<IO> {
36    type Io = IO;
37    type Session = ServerConnection;
38
39    #[inline]
40    fn skip_handshake(&self) -> bool {
41        false
42    }
43
44    #[inline]
45    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
46        (&mut self.state, &mut self.io, &mut self.session)
47    }
48
49    #[inline]
50    fn into_io(self) -> Self::Io {
51        self.io
52    }
53}
54
55impl<IO> AsyncRead for TlsStream<IO>
56where
57    IO: AsyncRead + AsyncWrite + Unpin,
58{
59    fn poll_read(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        buf: &mut [u8],
63    ) -> Poll<io::Result<usize>> {
64        let this = self.get_mut();
65        let mut stream =
66            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
67
68        match &this.state {
69            TlsState::Stream | TlsState::WriteShutdown => {
70                match stream.as_mut_pin().poll_read(cx, buf) {
71                    Poll::Ready(Ok(n)) => {
72                        if n == 0 || stream.eof {
73                            this.state.shutdown_read();
74                        }
75
76                        Poll::Ready(Ok(n))
77                    }
78                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
79                        this.state.shutdown_read();
80                        Poll::Ready(Err(err))
81                    }
82                    output => output,
83                }
84            }
85            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
86            #[cfg(feature = "early-data")]
87            s => unreachable!("server TLS can not hit this state: {:?}", s),
88        }
89    }
90}
91
92impl<IO> AsyncWrite for TlsStream<IO>
93where
94    IO: AsyncRead + AsyncWrite + Unpin,
95{
96    /// Note: that it does not guarantee the final data to be sent.
97    /// To be cautious, you must manually call `flush`.
98    fn poll_write(
99        self: Pin<&mut Self>,
100        cx: &mut Context<'_>,
101        buf: &[u8],
102    ) -> Poll<io::Result<usize>> {
103        let this = self.get_mut();
104        let mut stream =
105            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
106        stream.as_mut_pin().poll_write(cx, buf)
107    }
108
109    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110        let this = self.get_mut();
111        let mut stream =
112            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
113        stream.as_mut_pin().poll_flush(cx)
114    }
115
116    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
117        if self.state.writeable() {
118            self.session.send_close_notify();
119            self.state.shutdown_write();
120        }
121
122        let this = self.get_mut();
123        let mut stream =
124            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
125        stream.as_mut_pin().poll_close(cx)
126    }
127}
128
129#[cfg(unix)]
130impl<IO> AsRawFd for TlsStream<IO>
131where
132    IO: AsRawFd,
133{
134    #[inline]
135    fn as_raw_fd(&self) -> RawFd {
136        self.get_ref().0.as_raw_fd()
137    }
138}
139
140#[cfg(windows)]
141impl<IO> AsRawSocket for TlsStream<IO>
142where
143    IO: AsRawSocket,
144{
145    #[inline]
146    fn as_raw_socket(&self) -> RawSocket {
147        self.get_ref().0.as_raw_socket()
148    }
149}