1use std::{
2 io::{
3 self,
4 Write as _,
5 },
6 ops::DerefMut,
7 rc::Rc,
8};
9
10use compio_buf::{
11 BufResult,
12 IoBuf,
13 IoBufMut,
14 bytes::BytesMut,
15};
16use compio_io::{
17 AsyncRead,
18 AsyncWrite,
19};
20use rustls::{
21 ConnectionCommon,
22 SideData,
23};
24use xiaoyong_value::unsync::async_mutex::Mutex;
25
26use crate::stream::util::{
27 flush_tls_writes,
28 process_tls_reads,
29 read_plaintext,
30};
31
32pub struct SharedTlsState<SW, C> {
33 os_writer: SW,
34 connection: C,
35 write_buf: Option<BytesMut>,
36}
37
38impl<SW, C, SD> SharedTlsState<SW, C>
39where
40 SW: AsyncWrite,
41 C: DerefMut<Target = ConnectionCommon<SD>>,
42 SD: SideData,
43{
44 async fn flush(&mut self) -> io::Result<()> {
45 flush_tls_writes(&mut self.connection, &mut self.os_writer, &mut self.write_buf).await
46 }
47}
48
49pub struct TlsReadHalf<SR, SW, C> {
50 os_reader: SR,
51 shared: Rc<Mutex<SharedTlsState<SW, C>>>,
52 read_buf: Option<BytesMut>,
53}
54
55impl<SR, SW, C, SD> AsyncRead for TlsReadHalf<SR, SW, C>
56where
57 SR: AsyncRead,
58 SW: AsyncWrite,
59 C: DerefMut<Target = ConnectionCommon<SD>>,
60 SD: SideData,
61{
62 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
63 loop {
64 {
65 let mut state = self.shared.lock().await;
67
68 buf = match read_plaintext(&mut state.connection, buf) {
69 | Ok(res) => return res,
70 | Err(b) => b,
71 };
72
73 if state.connection.wants_write() {
75 if let Err(e) = state.flush().await {
76 return BufResult(Err(e), buf);
77 }
78 }
79
80 if !state.connection.wants_read() {
83 return BufResult(Ok(0), buf);
84 }
85 }
86
87 let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
89 if rbuf.buf_len() == rbuf.buf_capacity() {
90 rbuf.reserve(4096);
91 }
92
93 let BufResult(res, b) = self.os_reader.read(rbuf).await;
94
95 match res {
96 | Ok(0) => {
97 self.read_buf = Some(b);
98 return BufResult(Err(io::Error::from(io::ErrorKind::UnexpectedEof)), buf);
99 },
100 | Ok(_) => {},
101 | Err(e) => {
102 self.read_buf = Some(b);
103 return BufResult(Err(e), buf);
104 },
105 };
106
107 let mut state = self.shared.lock().await;
109 if let Err(e) = process_tls_reads(&mut state.connection, b, &mut self.read_buf) {
110 return BufResult(Err(e), buf);
111 }
112 }
113 }
114}
115
116pub struct TlsWriteHalf<SW, C> {
117 shared: Rc<Mutex<SharedTlsState<SW, C>>>,
118}
119
120impl<SW, C, SD> AsyncWrite for TlsWriteHalf<SW, C>
121where
122 SW: AsyncWrite,
123 C: DerefMut<Target = ConnectionCommon<SD>>,
124 SD: SideData,
125{
126 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
127 let mut state = self.shared.lock().await;
128
129 let slice = buf.as_init();
130 let written = match state.connection.writer().write(slice) {
131 | Ok(n) => n,
132 | Err(e) => return BufResult(Err(e), buf),
133 };
134
135 if let Err(e) = state.flush().await {
136 return BufResult(Err(e), buf);
137 }
138
139 BufResult(Ok(written), buf)
140 }
141
142 async fn flush(&mut self) -> io::Result<()> {
143 let mut state = self.shared.lock().await;
144 state.connection.writer().flush()?;
145 state.flush().await?;
146 state.os_writer.flush().await
147 }
148
149 async fn shutdown(&mut self) -> io::Result<()> {
150 let mut state = self.shared.lock().await;
151 state.connection.send_close_notify();
152 state.flush().await?;
153 state.os_writer.shutdown().await
154 }
155}
156
157pub fn split_tls_stream<SR, SW, C>(
160 os_reader: SR,
161 os_writer: SW,
162 connection: C,
163) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
164 let shared_state = Rc::new(Mutex::new(SharedTlsState {
165 os_writer,
166 connection,
167 write_buf: Some(BytesMut::with_capacity(4096)),
168 }));
169
170 let read_half = TlsReadHalf {
171 os_reader,
172 shared: Rc::clone(&shared_state),
173 read_buf: Some(BytesMut::with_capacity(4096)),
174 };
175
176 let write_half = TlsWriteHalf {
177 shared: shared_state
178 };
179
180 (read_half, write_half)
181}
182
183pub fn split_tls_stream_with_capacity<SR, SW, C>(
186 os_reader: SR,
187 os_writer: SW,
188 connection: C,
189 capacity: usize,
190) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
191 let shared_state = Rc::new(Mutex::new(SharedTlsState {
192 os_writer,
193 connection,
194 write_buf: Some(BytesMut::with_capacity(capacity)),
195 }));
196
197 let read_half = TlsReadHalf {
198 os_reader,
199 shared: Rc::clone(&shared_state),
200 read_buf: Some(BytesMut::with_capacity(capacity)),
201 };
202
203 let write_half = TlsWriteHalf {
204 shared: shared_state
205 };
206
207 (read_half, write_half)
208}