async_tls/
server.rs

1//! The server end of a TLS connection.
2
3use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5
6use futures_core::ready;
7use futures_io::{AsyncRead, AsyncWrite};
8use rustls::ServerConnection;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::{io, mem};
13
14/// The server end of a TLS connection. Can be used like any other bidirectional IO stream.
15/// Wraps the underlying TCP stream.
16#[derive(Debug)]
17pub struct TlsStream<IO> {
18    pub(crate) io: IO,
19    pub(crate) conn: ServerConnection,
20    pub(crate) state: TlsState,
21}
22
23pub(crate) enum MidHandshake<IO> {
24    Handshaking(TlsStream<IO>),
25    End,
26}
27
28impl<IO> Future for MidHandshake<IO>
29where
30    IO: AsyncRead + AsyncWrite + Unpin,
31{
32    type Output = io::Result<TlsStream<IO>>;
33
34    #[inline]
35    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36        let this = self.get_mut();
37
38        if let MidHandshake::Handshaking(stream) = this {
39            let eof = !stream.state.readable();
40            let (io, session) = (&mut stream.io, &mut stream.conn);
41            let mut stream = Stream::new(io, session).set_eof(eof);
42
43            if stream.conn.is_handshaking() {
44                ready!(stream.complete_io(cx))?;
45            }
46
47            if stream.conn.wants_write() {
48                ready!(stream.complete_io(cx))?;
49            }
50        }
51
52        match mem::replace(this, MidHandshake::End) {
53            MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
54            MidHandshake::End => panic!(),
55        }
56    }
57}
58
59impl<IO> AsyncRead for TlsStream<IO>
60where
61    IO: AsyncRead + AsyncWrite + Unpin,
62{
63    fn poll_read(
64        self: Pin<&mut Self>,
65        cx: &mut Context<'_>,
66        buf: &mut [u8],
67    ) -> Poll<io::Result<usize>> {
68        let this = self.get_mut();
69        let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
70
71        match this.state {
72            TlsState::Stream | TlsState::WriteShutdown => {
73                match stream.as_mut_pin().poll_read(cx, buf) {
74                    Poll::Ready(Ok(0)) => {
75                        this.state.shutdown_read();
76                        Poll::Ready(Ok(0))
77                    }
78                    Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
79                    Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
80                        this.state.shutdown_read();
81                        if this.state.writeable() {
82                            stream.conn.send_close_notify();
83                            this.state.shutdown_write();
84                        }
85                        Poll::Ready(Ok(0))
86                    }
87                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
88                    Poll::Pending => Poll::Pending,
89                }
90            }
91            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
92            #[cfg(feature = "early-data")]
93            s => unreachable!("server TLS can not hit this state: {:?}", s),
94        }
95    }
96}
97
98impl<IO> AsyncWrite for TlsStream<IO>
99where
100    IO: AsyncRead + AsyncWrite + Unpin,
101{
102    fn poll_write(
103        self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105        buf: &[u8],
106    ) -> Poll<io::Result<usize>> {
107        let this = self.get_mut();
108        let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
109        stream.as_mut_pin().poll_write(cx, buf)
110    }
111
112    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113        let this = self.get_mut();
114        let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
115        stream.as_mut_pin().poll_flush(cx)
116    }
117
118    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119        if self.state.writeable() {
120            self.conn.send_close_notify();
121            self.state.shutdown_write();
122        }
123
124        let this = self.get_mut();
125        let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
126        stream.as_mut_pin().poll_close(cx)
127    }
128}