fluvio_async_tls/
server.rs

1//! The server end of a TLS connection.
2
3use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5
6use futures_core::ready;
7use futures_io::{AsyncRead, AsyncWrite};
8use rustls::{Certificate, ServerConnection};
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::{io, mem};
13
14/// The server end of a TLS connection. Can be used like any other bidirectional IO stream.
15/// Wraps the underlying TCP stream.
16pub struct TlsStream<IO> {
17    pub(crate) io: IO,
18    pub(crate) session: ServerConnection,
19    pub(crate) state: TlsState,
20}
21
22impl<IO> TlsStream<IO> {
23    /// Retrieves the certificate chain used by the client,
24    /// if client authentication was completed.
25    ///
26    /// The return value is None until this value is available.
27    pub fn client_certificates(&self) -> Option<Vec<Certificate>> {
28        self.session.peer_certificates().map(Vec::from)
29    }
30}
31
32pub(crate) enum MidHandshake<IO> {
33    Handshaking(TlsStream<IO>),
34    End,
35}
36
37impl<IO> Future for MidHandshake<IO>
38where
39    IO: AsyncRead + AsyncWrite + Unpin,
40{
41    type Output = io::Result<TlsStream<IO>>;
42
43    #[inline]
44    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45        let this = self.get_mut();
46
47        if let MidHandshake::Handshaking(stream) = this {
48            let eof = !stream.state.readable();
49            let (io, session) = (&mut stream.io, &mut stream.session);
50            let mut stream = Stream::new(io, session).set_eof(eof);
51
52            if stream.session.is_handshaking() {
53                ready!(stream.complete_io(cx))?;
54            }
55
56            if stream.session.wants_write() {
57                ready!(stream.complete_io(cx))?;
58            }
59        }
60
61        match mem::replace(this, MidHandshake::End) {
62            MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
63            MidHandshake::End => panic!(),
64        }
65    }
66}
67
68impl<IO> AsyncRead for TlsStream<IO>
69where
70    IO: AsyncRead + AsyncWrite + Unpin,
71{
72    fn poll_read(
73        self: Pin<&mut Self>,
74        cx: &mut Context<'_>,
75        buf: &mut [u8],
76    ) -> Poll<io::Result<usize>> {
77        let this = self.get_mut();
78        let mut stream =
79            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
80
81        match this.state {
82            TlsState::Stream | TlsState::WriteShutdown => {
83                match stream.as_mut_pin().poll_read(cx, buf) {
84                    Poll::Ready(Ok(0)) => {
85                        this.state.shutdown_read();
86                        Poll::Ready(Ok(0))
87                    }
88                    Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
89                    Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
90                        this.state.shutdown_read();
91                        if this.state.writeable() {
92                            stream.session.send_close_notify();
93                            this.state.shutdown_write();
94                        }
95                        Poll::Ready(Ok(0))
96                    }
97                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
98                    Poll::Pending => Poll::Pending,
99                }
100            }
101            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
102            #[cfg(feature = "early-data")]
103            s => unreachable!("server TLS can not hit this state: {:?}", s),
104        }
105    }
106}
107
108impl<IO> AsyncWrite for TlsStream<IO>
109where
110    IO: AsyncRead + AsyncWrite + Unpin,
111{
112    fn poll_write(
113        self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115        buf: &[u8],
116    ) -> Poll<io::Result<usize>> {
117        let this = self.get_mut();
118        let mut stream =
119            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120        stream.as_mut_pin().poll_write(cx, buf)
121    }
122
123    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
124        let this = self.get_mut();
125        let mut stream =
126            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
127        stream.as_mut_pin().poll_flush(cx)
128    }
129
130    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
131        if self.state.writeable() {
132            self.session.send_close_notify();
133            self.state.shutdown_write();
134        }
135
136        let this = self.get_mut();
137        let mut stream =
138            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
139        stream.as_mut_pin().poll_close(cx)
140    }
141}