Skip to main content

compio_rustls/stream/
split.rs

1use std::{
2    io::{
3        self,
4        Write as _,
5    },
6    ops::DerefMut,
7    rc::Rc,
8};
9
10use compio_buf::{
11    BufResult, IntoInner as _, IoBuf, IoBufMut, bytes::BytesMut
12};
13use compio_io::{
14    AsyncRead,
15    AsyncWrite,
16};
17use rustls::{
18    ConnectionCommon,
19    SideData,
20};
21use xiaoyong_value::unsync::async_mutex::Mutex;
22
23use crate::stream::util::{
24    flush_tls_writes,
25    process_tls_reads,
26    read_plaintext,
27};
28
29pub struct SharedTlsState<SW, C> {
30    os_writer:  SW,
31    connection: C,
32    write_buf:  Option<BytesMut>,
33}
34
35impl<SW, C, SD> SharedTlsState<SW, C>
36where
37    SW: AsyncWrite,
38    C: DerefMut<Target = ConnectionCommon<SD>>,
39    SD: SideData,
40{
41    async fn flush(&mut self) -> io::Result<()> {
42        flush_tls_writes(&mut self.connection, &mut self.os_writer, &mut self.write_buf).await
43    }
44}
45
46pub struct TlsReadHalf<SR, SW, C> {
47    os_reader: SR,
48    shared:    Rc<Mutex<SharedTlsState<SW, C>>>,
49    read_buf:  Option<BytesMut>,
50}
51
52impl<SR, SW, C, SD> AsyncRead for TlsReadHalf<SR, SW, C>
53where
54    SR: AsyncRead,
55    SW: AsyncWrite,
56    C: DerefMut<Target = ConnectionCommon<SD>>,
57    SD: SideData,
58{
59    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
60        loop {
61            {
62                // Lock to extract available plaintext
63                let mut state = self.shared.lock().await;
64
65                buf = match read_plaintext(&mut state.connection, buf) {
66                    | Ok(res) => return res,
67                    | Err(b) => b,
68                };
69
70                // Handshake/Side-effects might generate writes
71                if state.connection.wants_write() {
72                    if let Err(e) = state.flush().await {
73                        return BufResult(Err(e), buf);
74                    }
75                }
76
77                // If we're done handshaking/writing and rustls doesn't want to read anymore
78                // (e.g. gracefully closed), return clean EOF without blocking on os_reader.
79                if !state.connection.wants_read() {
80                    return BufResult(Ok(0), buf);
81                }
82            }
83
84            // Wait on OS Reader for ciphertext
85            let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
86            if rbuf.buf_len() == rbuf.buf_capacity() {
87                rbuf.reserve(4096);
88            }
89
90            let init_len = rbuf.buf_len();
91            let BufResult(res, slice) = self.os_reader.read(rbuf.slice(init_len..)).await;
92            let mut b = slice.into_inner();
93
94            match res {
95                | Ok(0) => {
96                    self.read_buf = Some(b);
97                    return BufResult(Err(io::Error::from(io::ErrorKind::UnexpectedEof)), buf);
98                },
99                | Ok(n) => {
100                    unsafe { b.set_len(init_len + n) };
101                },
102                | Err(e) => {
103                    self.read_buf = Some(b);
104                    return BufResult(Err(e), buf);
105                },
106            };
107
108            // Re-lock to process the received ciphertext
109            let mut state = self.shared.lock().await;
110            if let Err(e) = process_tls_reads(&mut state.connection, b, &mut self.read_buf) {
111                return BufResult(Err(e), buf);
112            }
113        }
114    }
115}
116
117pub struct TlsWriteHalf<SW, C> {
118    shared: Rc<Mutex<SharedTlsState<SW, C>>>,
119}
120
121impl<SW, C, SD> AsyncWrite for TlsWriteHalf<SW, C>
122where
123    SW: AsyncWrite,
124    C: DerefMut<Target = ConnectionCommon<SD>>,
125    SD: SideData,
126{
127    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
128        let mut state = self.shared.lock().await;
129
130        let slice = buf.as_init();
131        let written = match state.connection.writer().write(slice) {
132            | Ok(n) => n,
133            | Err(e) => return BufResult(Err(e), buf),
134        };
135
136        if let Err(e) = state.flush().await {
137            return BufResult(Err(e), buf);
138        }
139
140        BufResult(Ok(written), buf)
141    }
142
143    async fn flush(&mut self) -> io::Result<()> {
144        let mut state = self.shared.lock().await;
145        state.connection.writer().flush()?;
146        state.flush().await?;
147        state.os_writer.flush().await
148    }
149
150    async fn shutdown(&mut self) -> io::Result<()> {
151        let mut state = self.shared.lock().await;
152        state.connection.send_close_notify();
153        state.flush().await?;
154        state.os_writer.shutdown().await
155    }
156}
157
158/// Constructs independent Read and Write halves from split OS streams and a TLS
159/// connection.
160pub fn split_tls_stream<SR, SW, C>(
161    os_reader: SR,
162    os_writer: SW,
163    connection: C,
164) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
165    let shared_state = Rc::new(Mutex::new(SharedTlsState {
166        os_writer,
167        connection,
168        write_buf: Some(BytesMut::with_capacity(4096)),
169    }));
170
171    let read_half = TlsReadHalf {
172        os_reader,
173        shared: Rc::clone(&shared_state),
174        read_buf: Some(BytesMut::with_capacity(4096)),
175    };
176
177    let write_half = TlsWriteHalf {
178        shared: shared_state
179    };
180
181    (read_half, write_half)
182}
183
184/// Constructs independent Read and Write halves from split OS streams and a TLS
185/// connection.
186pub fn split_tls_stream_with_capacity<SR, SW, C>(
187    os_reader: SR,
188    os_writer: SW,
189    connection: C,
190    capacity: usize,
191) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
192    let shared_state = Rc::new(Mutex::new(SharedTlsState {
193        os_writer,
194        connection,
195        write_buf: Some(BytesMut::with_capacity(capacity)),
196    }));
197
198    let read_half = TlsReadHalf {
199        os_reader,
200        shared: Rc::clone(&shared_state),
201        read_buf: Some(BytesMut::with_capacity(capacity)),
202    };
203
204    let write_half = TlsWriteHalf {
205        shared: shared_state
206    };
207
208    (read_half, write_half)
209}