1use std::{
2 io::{
3 self,
4 Write as _,
5 },
6 ops::DerefMut,
7 rc::Rc,
8};
9
10use compio_buf::{
11 BufResult, IntoInner as _, IoBuf, IoBufMut, bytes::BytesMut
12};
13use compio_io::{
14 AsyncRead,
15 AsyncWrite,
16};
17use rustls::{
18 ConnectionCommon,
19 SideData,
20};
21use xiaoyong_value::unsync::async_mutex::Mutex;
22
23use crate::stream::util::{
24 flush_tls_writes,
25 process_tls_reads,
26 read_plaintext,
27};
28
29pub struct SharedTlsState<SW, C> {
30 os_writer: SW,
31 connection: C,
32 write_buf: Option<BytesMut>,
33}
34
35impl<SW, C, SD> SharedTlsState<SW, C>
36where
37 SW: AsyncWrite,
38 C: DerefMut<Target = ConnectionCommon<SD>>,
39 SD: SideData,
40{
41 async fn flush(&mut self) -> io::Result<()> {
42 flush_tls_writes(&mut self.connection, &mut self.os_writer, &mut self.write_buf).await
43 }
44}
45
46pub struct TlsReadHalf<SR, SW, C> {
47 os_reader: SR,
48 shared: Rc<Mutex<SharedTlsState<SW, C>>>,
49 read_buf: Option<BytesMut>,
50}
51
52impl<SR, SW, C, SD> AsyncRead for TlsReadHalf<SR, SW, C>
53where
54 SR: AsyncRead,
55 SW: AsyncWrite,
56 C: DerefMut<Target = ConnectionCommon<SD>>,
57 SD: SideData,
58{
59 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
60 loop {
61 {
62 let mut state = self.shared.lock().await;
64
65 buf = match read_plaintext(&mut state.connection, buf) {
66 | Ok(res) => return res,
67 | Err(b) => b,
68 };
69
70 if state.connection.wants_write() {
72 if let Err(e) = state.flush().await {
73 return BufResult(Err(e), buf);
74 }
75 }
76
77 if !state.connection.wants_read() {
80 return BufResult(Ok(0), buf);
81 }
82 }
83
84 let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
86 if rbuf.buf_len() == rbuf.buf_capacity() {
87 rbuf.reserve(4096);
88 }
89
90 let init_len = rbuf.buf_len();
91 let BufResult(res, slice) = self.os_reader.read(rbuf.slice(init_len..)).await;
92 let mut b = slice.into_inner();
93
94 match res {
95 | Ok(0) => {
96 self.read_buf = Some(b);
97 return BufResult(Err(io::Error::from(io::ErrorKind::UnexpectedEof)), buf);
98 },
99 | Ok(n) => {
100 unsafe { b.set_len(init_len + n) };
101 },
102 | Err(e) => {
103 self.read_buf = Some(b);
104 return BufResult(Err(e), buf);
105 },
106 };
107
108 let mut state = self.shared.lock().await;
110 if let Err(e) = process_tls_reads(&mut state.connection, b, &mut self.read_buf) {
111 return BufResult(Err(e), buf);
112 }
113 }
114 }
115}
116
117pub struct TlsWriteHalf<SW, C> {
118 shared: Rc<Mutex<SharedTlsState<SW, C>>>,
119}
120
121impl<SW, C, SD> AsyncWrite for TlsWriteHalf<SW, C>
122where
123 SW: AsyncWrite,
124 C: DerefMut<Target = ConnectionCommon<SD>>,
125 SD: SideData,
126{
127 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
128 let mut state = self.shared.lock().await;
129
130 let slice = buf.as_init();
131 let written = match state.connection.writer().write(slice) {
132 | Ok(n) => n,
133 | Err(e) => return BufResult(Err(e), buf),
134 };
135
136 if let Err(e) = state.flush().await {
137 return BufResult(Err(e), buf);
138 }
139
140 BufResult(Ok(written), buf)
141 }
142
143 async fn flush(&mut self) -> io::Result<()> {
144 let mut state = self.shared.lock().await;
145 state.connection.writer().flush()?;
146 state.flush().await?;
147 state.os_writer.flush().await
148 }
149
150 async fn shutdown(&mut self) -> io::Result<()> {
151 let mut state = self.shared.lock().await;
152 state.connection.send_close_notify();
153 state.flush().await?;
154 state.os_writer.shutdown().await
155 }
156}
157
158pub fn split_tls_stream<SR, SW, C>(
161 os_reader: SR,
162 os_writer: SW,
163 connection: C,
164) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
165 let shared_state = Rc::new(Mutex::new(SharedTlsState {
166 os_writer,
167 connection,
168 write_buf: Some(BytesMut::with_capacity(4096)),
169 }));
170
171 let read_half = TlsReadHalf {
172 os_reader,
173 shared: Rc::clone(&shared_state),
174 read_buf: Some(BytesMut::with_capacity(4096)),
175 };
176
177 let write_half = TlsWriteHalf {
178 shared: shared_state
179 };
180
181 (read_half, write_half)
182}
183
184pub fn split_tls_stream_with_capacity<SR, SW, C>(
187 os_reader: SR,
188 os_writer: SW,
189 connection: C,
190 capacity: usize,
191) -> (TlsReadHalf<SR, SW, C>, TlsWriteHalf<SW, C>) {
192 let shared_state = Rc::new(Mutex::new(SharedTlsState {
193 os_writer,
194 connection,
195 write_buf: Some(BytesMut::with_capacity(capacity)),
196 }));
197
198 let read_half = TlsReadHalf {
199 os_reader,
200 shared: Rc::clone(&shared_state),
201 read_buf: Some(BytesMut::with_capacity(capacity)),
202 };
203
204 let write_half = TlsWriteHalf {
205 shared: shared_state
206 };
207
208 (read_half, write_half)
209}