monoio_native_tls/
stream.rs1use std::io::{self, Read, Write};
2
3use monoio::{
4 buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut, RawBuf},
5 io::{AsyncReadRent, AsyncWriteRent, Split},
6 BufResult,
7};
8
9use crate::utils::{Buffers, IOWrapper};
10
11#[derive(Debug)]
19pub struct TlsStream<S> {
20 tls: native_tls::TlsStream<Buffers>,
21 io: IOWrapper<S>,
22}
23
24impl<S> TlsStream<S> {
25 pub(crate) fn new(tls_stream: native_tls::TlsStream<Buffers>, io: IOWrapper<S>) -> Self {
26 Self {
27 tls: tls_stream,
28 io,
29 }
30 }
31
32 pub fn into_inner(self) -> S {
33 self.io.into_parts().0
34 }
35
36 #[cfg(feature = "alpn")]
37 pub fn alpn_protocol(&self) -> Option<Vec<u8>> {
38 self.tls.negotiated_alpn().ok().flatten()
39 }
40}
41
42unsafe impl<S: Split> Split for TlsStream<S> {}
43
44impl<S: AsyncReadRent> AsyncReadRent for TlsStream<S> {
45 #[allow(clippy::await_holding_refcell_ref)]
46 async fn read<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<usize, T> {
47 let slice = unsafe { std::slice::from_raw_parts_mut(buf.write_ptr(), buf.bytes_total()) };
48
49 loop {
50 match self.tls.read(slice) {
52 Ok(n) => {
53 unsafe { buf.set_init(n) };
54 return (Ok(n), buf);
55 }
56 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
58 Err(e) => {
59 return (Err(e), buf);
60 }
61 }
62
63 match unsafe { self.io.do_read_io() }.await {
65 Ok(0) => {
66 return (Ok(0), buf);
67 }
68 Ok(_) => (),
69 Err(e) => {
70 return (Err(e), buf);
71 }
72 };
73 }
74 }
75
76 async fn readv<T: IoVecBufMut>(&mut self, mut buf: T) -> BufResult<usize, T> {
77 let n = match unsafe { RawBuf::new_from_iovec_mut(&mut buf) } {
78 Some(raw_buf) => self.read(raw_buf).await.0,
79 None => Ok(0),
80 };
81 if let Ok(n) = n {
82 unsafe { buf.set_init(n) };
83 }
84 (n, buf)
85 }
86}
87
88impl<S: AsyncWriteRent> AsyncWriteRent for TlsStream<S> {
89 #[allow(clippy::await_holding_refcell_ref)]
90 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
91 let slice = unsafe { std::slice::from_raw_parts(buf.read_ptr(), buf.bytes_init()) };
93
94 loop {
95 let maybe_n = match self.tls.write(slice) {
97 Ok(n) => Some(n),
98 Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
99 Err(e) => return (Err(e), buf),
100 };
101
102 if let Err(e) = unsafe { self.io.do_write_io() }.await {
104 return (Err(e), buf);
105 }
106
107 if let Some(n) = maybe_n {
108 return (Ok(n), buf);
109 }
110 }
111 }
112
113 async fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> BufResult<usize, T> {
115 let n = match unsafe { RawBuf::new_from_iovec(&buf_vec) } {
116 Some(raw_buf) => self.write(raw_buf).await.0,
117 None => Ok(0),
118 };
119 (n, buf_vec)
120 }
121
122 #[allow(clippy::await_holding_refcell_ref)]
123 async fn flush(&mut self) -> io::Result<()> {
124 loop {
125 match self.tls.flush() {
126 Ok(_) => {
127 unsafe { self.io.do_write_io() }.await?;
128 return Ok(());
129 }
130 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
131 unsafe { self.io.do_write_io() }.await?;
132 }
133 Err(e) => {
134 return Err(e);
135 }
136 }
137 }
138 }
139
140 async fn shutdown(&mut self) -> io::Result<()> {
141 self.tls.shutdown()?;
142 unsafe { self.io.do_write_io() }.await?;
143 Ok(())
144 }
145}