1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, RawSocket};
5
6use super::*;
7use crate::common::IoSession;
8
9#[derive(Debug)]
12pub struct TlsStream<IO> {
13 pub(crate) io: IO,
14 pub(crate) session: ServerConnection,
15 pub(crate) state: TlsState,
16}
17
18impl<IO> TlsStream<IO> {
19 #[inline]
20 pub fn get_ref(&self) -> (&IO, &ServerConnection) {
21 (&self.io, &self.session)
22 }
23
24 #[inline]
25 pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
26 (&mut self.io, &mut self.session)
27 }
28
29 #[inline]
30 pub fn into_inner(self) -> (IO, ServerConnection) {
31 (self.io, self.session)
32 }
33}
34
35impl<IO> IoSession for TlsStream<IO> {
36 type Io = IO;
37 type Session = ServerConnection;
38
39 #[inline]
40 fn skip_handshake(&self) -> bool {
41 false
42 }
43
44 #[inline]
45 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
46 (&mut self.state, &mut self.io, &mut self.session)
47 }
48
49 #[inline]
50 fn into_io(self) -> Self::Io {
51 self.io
52 }
53}
54
55impl<IO> AsyncRead for TlsStream<IO>
56where
57 IO: AsyncRead + AsyncWrite + Unpin,
58{
59 fn poll_read(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 buf: &mut [u8],
63 ) -> Poll<io::Result<usize>> {
64 let this = self.get_mut();
65 let mut stream =
66 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
67
68 match &this.state {
69 TlsState::Stream | TlsState::WriteShutdown => {
70 match stream.as_mut_pin().poll_read(cx, buf) {
71 Poll::Ready(Ok(n)) => {
72 if n == 0 || stream.eof {
73 this.state.shutdown_read();
74 }
75
76 Poll::Ready(Ok(n))
77 }
78 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
79 this.state.shutdown_read();
80 Poll::Ready(Err(err))
81 }
82 output => output,
83 }
84 }
85 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
86 #[cfg(feature = "early-data")]
87 s => unreachable!("server TLS can not hit this state: {:?}", s),
88 }
89 }
90}
91
92impl<IO> AsyncWrite for TlsStream<IO>
93where
94 IO: AsyncRead + AsyncWrite + Unpin,
95{
96 fn poll_write(
99 self: Pin<&mut Self>,
100 cx: &mut Context<'_>,
101 buf: &[u8],
102 ) -> Poll<io::Result<usize>> {
103 let this = self.get_mut();
104 let mut stream =
105 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
106 stream.as_mut_pin().poll_write(cx, buf)
107 }
108
109 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110 let this = self.get_mut();
111 let mut stream =
112 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
113 stream.as_mut_pin().poll_flush(cx)
114 }
115
116 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
117 if self.state.writeable() {
118 self.session.send_close_notify();
119 self.state.shutdown_write();
120 }
121
122 let this = self.get_mut();
123 let mut stream =
124 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
125 stream.as_mut_pin().poll_close(cx)
126 }
127}
128
129#[cfg(unix)]
130impl<IO> AsRawFd for TlsStream<IO>
131where
132 IO: AsRawFd,
133{
134 #[inline]
135 fn as_raw_fd(&self) -> RawFd {
136 self.get_ref().0.as_raw_fd()
137 }
138}
139
140#[cfg(windows)]
141impl<IO> AsRawSocket for TlsStream<IO>
142where
143 IO: AsRawSocket,
144{
145 #[inline]
146 fn as_raw_socket(&self) -> RawSocket {
147 self.get_ref().0.as_raw_socket()
148 }
149}