1use std::io::{self, ErrorKind, Write};
7use std::result::Result;
8
9use compio::BufResult;
10use compio::buf::{IoBuf, IoBufMut};
11use compio::io::compat::SyncStream;
12use compio::io::{AsyncRead, AsyncWrite};
13use openssl::error::ErrorStack;
14use openssl::ssl::{self, ErrorCode, ShutdownResult, ShutdownState, Ssl, SslRef};
15
16#[cfg(test)]
17mod test;
18
19#[derive(Debug)]
21pub struct SslStream<S> {
22 stream: ssl::SslStream<SyncStream<S>>,
23}
24
25impl<S: AsyncRead + AsyncWrite> SslStream<S> {
26 pub fn new(ssl: Ssl, stream: S) -> Result<SslStream<S>, ErrorStack> {
30 let stream = ssl::SslStream::new(ssl, SyncStream::new(stream))?;
31 Ok(SslStream { stream })
32 }
33
34 #[inline(always)]
40 pub fn get_mut(&mut self) -> &mut S {
41 self.stream.get_mut().get_mut()
42 }
43
44 #[inline(always)]
46 pub fn get_ref(&self) -> &S {
47 self.stream.get_ref().get_ref()
48 }
49
50 #[inline(always)]
52 pub fn ssl(&self) -> &SslRef {
53 self.stream.ssl()
54 }
55
56 pub async fn accept(&mut self) -> io::Result<()> {
60 self.ssl_async_do(|s| s.accept()).await
61 }
62
63 pub async fn connect(&mut self) -> io::Result<()> {
67 self.ssl_async_do(|s| s.connect()).await
68 }
69
70 #[cfg(any(ossl111, libressl340))]
78 pub async fn read_realy_data(&mut self, buf: &mut [u8]) -> io::Result<usize> {
79 self.ssl_async_do(|s| s.read_early_data(buf)).await
80 }
81
82 #[cfg(any(ossl111, libressl340))]
88 pub async fn write_realy_data(&mut self, buf: &[u8]) -> io::Result<usize> {
89 self.ssl_async_do(|s| s.write_early_data(buf)).await
90 }
91
92 pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
96 self.ssl_async_do(|s| s.ssl_peek(buf)).await
97 }
98
99 #[inline(always)]
101 pub fn get_shutdown(&mut self) -> ShutdownState {
102 self.stream.get_shutdown()
103 }
104
105 #[inline(always)]
109 pub fn set_shutdown(&mut self, state: ShutdownState) {
110 self.stream.set_shutdown(state)
111 }
112
113 #[inline(always)]
124 #[cfg(ossl111)]
125 pub async fn stateless(&mut self) -> Result<bool, ErrorStack> {
126 self.stream.stateless()
127 }
128
129 async fn ssl_async_do<R, F>(&mut self, mut f: F) -> io::Result<R>
130 where
131 F: FnMut(&mut ssl::SslStream<SyncStream<S>>) -> Result<R, ssl::Error>,
132 {
133 loop {
134 match f(&mut self.stream) {
135 Ok(n) => return Ok(n),
136 Err(e) => match e.code() {
137 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
138 if self.stream.get_mut().flush_write_buf().await? == 0 {
139 self.stream.get_mut().fill_read_buf().await?;
140 }
141 }
142 _ => return Err(ssl_err_into_io(e)),
143 },
144 }
145 }
146 }
147}
148
149impl<S> From<ssl::SslStream<SyncStream<S>>> for SslStream<S> {
150 fn from(value: ssl::SslStream<SyncStream<S>>) -> Self {
151 SslStream { stream: value }
152 }
153}
154
155#[inline]
156fn ssl_err_into_io(err: openssl::ssl::Error) -> io::Error {
157 err.into_io_error().unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))
158}
159
160impl<S: AsyncRead> AsyncRead for SslStream<S> {
161 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
162 let read_buf = buf.as_mut_slice();
163 loop {
164 let ret = self.stream.ssl_read_uninit(read_buf);
165 match ret {
166 Ok(n) => {
167 unsafe { buf.set_buf_init(n) };
169 return BufResult(Ok(n), buf);
170 }
171 Err(e) if e.code() == ErrorCode::ZERO_RETURN => {
172 return BufResult(Ok(0), buf);
173 }
174 Err(e) if e.code() == ErrorCode::WANT_READ => {
175 match self.stream.get_mut().fill_read_buf().await {
176 Ok(_) => continue,
177 Err(e) => return BufResult(Err(e), buf),
178 }
179 }
180 Err(e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {}
181 Err(e) => return BufResult(Err(ssl_err_into_io(e)), buf),
182 }
183 }
184 }
185
186 }
188
189impl<S: AsyncWrite + AsyncRead> AsyncWrite for SslStream<S> {
191 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
192 let slice = buf.as_slice();
193 loop {
194 let ret = self.stream.ssl_write(slice);
195 match ret {
196 Ok(n) => {
197 let ret = self.stream.get_mut().flush_write_buf().await;
198 return BufResult(ret.map(|_| n), buf);
199 }
200 Err(e) if e.code() == ErrorCode::WANT_WRITE => {
201 match self.stream.get_mut().flush_write_buf().await {
202 Ok(_) => continue,
203 Err(e) => return BufResult(Err(e), buf),
204 }
205 }
206 Err(e) => return BufResult(Err(ssl_err_into_io(e)), buf),
207 }
208 }
209 }
210
211 async fn flush(&mut self) -> io::Result<()> {
214 loop {
215 match self.stream.flush() {
216 Ok(_) => {
217 self.stream.get_mut().flush_write_buf().await?;
218 return Ok(());
219 }
220 Err(e) if e.kind() == ErrorKind::WouldBlock => {
221 self.stream.get_mut().flush_write_buf().await?;
222 }
223 e => return e,
224 }
225 }
226 }
227
228 async fn shutdown(&mut self) -> io::Result<()> {
229 loop {
230 let ret = self.stream.shutdown();
231 match ret {
232 Ok(ShutdownResult::Sent) => {
233 self.stream.get_mut().flush_write_buf().await?;
234 }
235 Ok(ShutdownResult::Received) => {
236 break;
237 }
238 Err(e) if e.code() == ErrorCode::WANT_WRITE => {
239 self.stream.get_mut().flush_write_buf().await?;
240 }
241 Err(e) if e.code() == ErrorCode::WANT_READ => {
242 self.stream.get_mut().fill_read_buf().await?;
243 }
244 Err(e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {
245 break;
246 }
247 Err(e) => return Err(ssl_err_into_io(e)),
248 }
249 }
250 self.stream.get_mut().get_mut().shutdown().await
251 }
252}