Skip to main content

compio_rustls/stream/
split.rs

1use std::{
2    io::{
3        self,
4        Write as _,
5    },
6    ops::DerefMut,
7    rc::Rc,
8};
9
10use compio_buf::{
11    BufResult,
12    IoBuf,
13    IoBufMut,
14    bytes::BytesMut,
15};
16use compio_io::{
17    AsyncRead,
18    AsyncWrite,
19};
20use rustls::{
21    ConnectionCommon,
22    SideData,
23};
24use xiaoyong_value::unsync::async_mutex::Mutex;
25
26use crate::stream::util::{
27    flush_tls_writes,
28    process_tls_reads,
29    read_plaintext,
30};
31
32pub struct SharedTlsState<SW, C> {
33    os_writer:  SW,
34    connection: C,
35    write_buf:  Option<BytesMut>,
36}
37
38impl<SW, C, SD> SharedTlsState<SW, C>
39where
40    SW: AsyncWrite,
41    C: DerefMut<Target = ConnectionCommon<SD>>,
42    SD: SideData,
43{
44    async fn flush(&mut self) -> io::Result<()> {
45        flush_tls_writes(&mut self.connection, &mut self.os_writer, &mut self.write_buf).await
46    }
47}
48
49pub struct TlsReadHalf<SR, SW, C> {
50    os_reader: SR,
51    shared:    Rc<Mutex<SharedTlsState<SW, C>>>,
52    read_buf:  Option<BytesMut>,
53}
54
55impl<SR, SW, C, SD> AsyncRead for TlsReadHalf<SR, SW, C>
56where
57    SR: AsyncRead,
58    SW: AsyncWrite,
59    C: DerefMut<Target = ConnectionCommon<SD>>,
60    SD: SideData,
61{
62    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
63        loop {
64            {
65                // Lock to extract available plaintext
66                let mut state = self.shared.lock().await;
67
68                buf = match read_plaintext(&mut state.connection, buf) {
69                    | Ok(res) => return res,
70                    | Err(b) => b,
71                };
72
73                // Handshake/Side-effects might generate writes
74                if state.connection.wants_write() {
75                    if let Err(e) = state.flush().await {
76                        return BufResult(Err(e), buf);
77                    }
78                }
79
80                // If we're done handshaking/writing and rustls doesn't want to read anymore
81                // (e.g. gracefully closed), return clean EOF without blocking on os_reader.
82                if !state.connection.wants_read() {
83                    return BufResult(Ok(0), buf);
84                }
85            }
86
87            // Wait on OS Reader for ciphertext
88            let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
89            if rbuf.buf_len() == rbuf.buf_capacity() {
90                rbuf.reserve(4096);
91            }
92
93            let BufResult(res, b) = self.os_reader.read(rbuf).await;
94
95            match res {
96                | Ok(0) => {
97                    self.read_buf = Some(b);
98                    return BufResult(Err(io::Error::from(io::ErrorKind::UnexpectedEof)), buf);
99                },
100                | Ok(_) => {},
101                | Err(e) => {
102                    self.read_buf = Some(b);
103                    return BufResult(Err(e), buf);
104                },
105            };
106
107            // Re-lock to process the received ciphertext
108            let mut state = self.shared.lock().await;
109            if let Err(e) = process_tls_reads(&mut state.connection, b, &mut self.read_buf) {
110                return BufResult(Err(e), buf);
111            }
112        }
113    }
114}
115
116pub struct TlsWriteHalf<SW, C> {
117    shared: Rc<Mutex<SharedTlsState<SW, C>>>,
118}
119
120impl<SW, C, SD> AsyncWrite for TlsWriteHalf<SW, C>
121where
122    SW: AsyncWrite,
123    C: DerefMut<Target = ConnectionCommon<SD>>,
124    SD: SideData,
125{
126    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
127        let mut state = self.shared.lock().await;
128
129        let slice = buf.as_init();
130        let written = match state.connection.writer().write(slice) {
131            | Ok(n) => n,
132            | Err(e) => return BufResult(Err(e), buf),
133        };
134
135        if let Err(e) = state.flush().await {
136            return BufResult(Err(e), buf);
137        }
138
139        BufResult(Ok(written), buf)
140    }
141
142    async fn flush(&mut self) -> io::Result<()> {
143        let mut state = self.shared.lock().await;
144        state.connection.writer().flush()?;
145        state.flush().await?;
146        state.os_writer.flush().await
147    }
148
149    async fn shutdown(&mut self) -> io::Result<()> {
150        let mut state = self.shared.lock().await;
151        state.connection.send_close_notify();
152        state.flush().await?;
153        state.os_writer.shutdown().await
154    }
155}
156
157/// Constructs independent Read and Write halves from split OS streams and a TLS
158/// connection.
159pub fn split_tls_stream<SR, SW, C>(
160    os_reader: SR,
161    os_writer: SW,
162    connection: C,
163) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
164    let shared_state = Rc::new(Mutex::new(SharedTlsState {
165        os_writer,
166        connection,
167        write_buf: Some(BytesMut::with_capacity(4096)),
168    }));
169
170    let read_half = TlsReadHalf {
171        os_reader,
172        shared: Rc::clone(&shared_state),
173        read_buf: Some(BytesMut::with_capacity(4096)),
174    };
175
176    let write_half = TlsWriteHalf {
177        shared: shared_state
178    };
179
180    (read_half, write_half)
181}
182
183/// Constructs independent Read and Write halves from split OS streams and a TLS
184/// connection.
185pub fn split_tls_stream_with_capacity<SR, SW, C>(
186    os_reader: SR,
187    os_writer: SW,
188    connection: C,
189    capacity: usize,
190) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
191    let shared_state = Rc::new(Mutex::new(SharedTlsState {
192        os_writer,
193        connection,
194        write_buf: Some(BytesMut::with_capacity(capacity)),
195    }));
196
197    let read_half = TlsReadHalf {
198        os_reader,
199        shared: Rc::clone(&shared_state),
200        read_buf: Some(BytesMut::with_capacity(capacity)),
201    };
202
203    let write_half = TlsWriteHalf {
204        shared: shared_state
205    };
206
207    (read_half, write_half)
208}