fluvio_async_tls/
server.rs1use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5
6use futures_core::ready;
7use futures_io::{AsyncRead, AsyncWrite};
8use rustls::{Certificate, ServerConnection};
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::{io, mem};
13
14pub struct TlsStream<IO> {
17 pub(crate) io: IO,
18 pub(crate) session: ServerConnection,
19 pub(crate) state: TlsState,
20}
21
22impl<IO> TlsStream<IO> {
23 pub fn client_certificates(&self) -> Option<Vec<Certificate>> {
28 self.session.peer_certificates().map(Vec::from)
29 }
30}
31
32pub(crate) enum MidHandshake<IO> {
33 Handshaking(TlsStream<IO>),
34 End,
35}
36
37impl<IO> Future for MidHandshake<IO>
38where
39 IO: AsyncRead + AsyncWrite + Unpin,
40{
41 type Output = io::Result<TlsStream<IO>>;
42
43 #[inline]
44 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45 let this = self.get_mut();
46
47 if let MidHandshake::Handshaking(stream) = this {
48 let eof = !stream.state.readable();
49 let (io, session) = (&mut stream.io, &mut stream.session);
50 let mut stream = Stream::new(io, session).set_eof(eof);
51
52 if stream.session.is_handshaking() {
53 ready!(stream.complete_io(cx))?;
54 }
55
56 if stream.session.wants_write() {
57 ready!(stream.complete_io(cx))?;
58 }
59 }
60
61 match mem::replace(this, MidHandshake::End) {
62 MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
63 MidHandshake::End => panic!(),
64 }
65 }
66}
67
68impl<IO> AsyncRead for TlsStream<IO>
69where
70 IO: AsyncRead + AsyncWrite + Unpin,
71{
72 fn poll_read(
73 self: Pin<&mut Self>,
74 cx: &mut Context<'_>,
75 buf: &mut [u8],
76 ) -> Poll<io::Result<usize>> {
77 let this = self.get_mut();
78 let mut stream =
79 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
80
81 match this.state {
82 TlsState::Stream | TlsState::WriteShutdown => {
83 match stream.as_mut_pin().poll_read(cx, buf) {
84 Poll::Ready(Ok(0)) => {
85 this.state.shutdown_read();
86 Poll::Ready(Ok(0))
87 }
88 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
89 Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
90 this.state.shutdown_read();
91 if this.state.writeable() {
92 stream.session.send_close_notify();
93 this.state.shutdown_write();
94 }
95 Poll::Ready(Ok(0))
96 }
97 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
98 Poll::Pending => Poll::Pending,
99 }
100 }
101 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
102 #[cfg(feature = "early-data")]
103 s => unreachable!("server TLS can not hit this state: {:?}", s),
104 }
105 }
106}
107
108impl<IO> AsyncWrite for TlsStream<IO>
109where
110 IO: AsyncRead + AsyncWrite + Unpin,
111{
112 fn poll_write(
113 self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 buf: &[u8],
116 ) -> Poll<io::Result<usize>> {
117 let this = self.get_mut();
118 let mut stream =
119 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120 stream.as_mut_pin().poll_write(cx, buf)
121 }
122
123 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
124 let this = self.get_mut();
125 let mut stream =
126 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
127 stream.as_mut_pin().poll_flush(cx)
128 }
129
130 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
131 if self.state.writeable() {
132 self.session.send_close_notify();
133 self.state.shutdown_write();
134 }
135
136 let this = self.get_mut();
137 let mut stream =
138 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
139 stream.as_mut_pin().poll_close(cx)
140 }
141}