1#[cfg(feature = "std")]
18use std::{
19 io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult, Write},
20 net::UdpSocket,
21 result::Result as StdResult,
22};
23
24use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
25use mbedtls_sys::types::size_t;
26
27use super::context::Context;
28use crate::error::Result;
29#[cfg(feature = "std")]
30use crate::error::{codes, Error};
31
32pub trait IoCallbackUnsafe<T> {
37 unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int
38 where
39 Self: Sized;
40 unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int
41 where
42 Self: Sized;
43 fn data_ptr(&mut self) -> *mut c_void;
44}
45
46pub trait IoCallback<T> {
53 fn recv(&mut self, buf: &mut [u8]) -> Result<usize>;
54 fn send(&mut self, buf: &[u8]) -> Result<usize>;
55}
56
57impl<IO: IoCallback<T>, T> IoCallbackUnsafe<T> for IO {
58 unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
59 let len = if len > (c_int::max_value() as size_t) {
60 c_int::max_value() as size_t
61 } else {
62 len
63 };
64 match (&mut *(user_data as *mut IO)).recv(::core::slice::from_raw_parts_mut(data, len)) {
65 Ok(i) => i as c_int,
66 Err(e) => e.to_int(),
67 }
68 }
69
70 unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
71 let len = if len > (c_int::max_value() as size_t) {
72 c_int::max_value() as size_t
73 } else {
74 len
75 };
76 match (&mut *(user_data as *mut IO)).send(::core::slice::from_raw_parts(data, len)) {
77 Ok(i) => i as c_int,
78 Err(e) => e.to_int(),
79 }
80 }
81
82 fn data_ptr(&mut self) -> *mut c_void {
83 self as *mut IO as *mut _
84 }
85}
86
87pub enum AnyIo {}
90#[cfg(feature = "std")]
91pub enum Stream {}
94
95pub trait Io {
107 fn recv(&mut self, buf: &mut [u8]) -> Result<usize>;
108 fn send(&mut self, buf: &[u8]) -> Result<usize>;
109}
110
111impl<IO: Io> IoCallback<AnyIo> for IO {
112 fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
113 Io::recv(self, buf)
114 }
115
116 fn send(&mut self, buf: &[u8]) -> Result<usize> {
117 Io::send(self, buf)
118 }
119}
120
121#[cfg(feature = "std")]
122impl<IO: Read + Write> IoCallback<Stream> for IO {
123 fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
124 self.read(buf).map_err(|e| match e {
125 ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantRead),
126 _ => Error::from(codes::NetRecvFailed),
127 })
128 }
129
130 fn send(&mut self, buf: &[u8]) -> Result<usize> {
131 self.write(buf).map_err(|e| match e {
132 ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantWrite),
133 _ => Error::from(codes::NetSendFailed),
134 })
135 }
136}
137
138#[cfg(feature = "std")]
139pub struct ConnectedUdpSocket {
143 socket: UdpSocket,
144}
145
146#[cfg(feature = "std")]
147impl ConnectedUdpSocket {
148 pub fn connect<A: std::net::ToSocketAddrs>(socket: UdpSocket, addr: A) -> StdResult<Self, (IoError, UdpSocket)> {
149 match socket.connect(addr) {
150 Ok(_) => Ok(ConnectedUdpSocket { socket }),
151 Err(e) => Err((e, socket)),
152 }
153 }
154
155 pub fn into_socket(self) -> UdpSocket {
156 self.socket
157 }
158}
159
160#[cfg(feature = "std")]
161impl Io for ConnectedUdpSocket {
162 fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
163 match self.socket.recv(buf) {
164 Ok(i) => Ok(i),
165 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(codes::SslWantRead.into()),
166 Err(_) => Err(codes::NetRecvFailed.into()),
167 }
168 }
169
170 fn send(&mut self, buf: &[u8]) -> Result<usize> {
171 self.socket.send(buf).map_err(|_| codes::NetSendFailed.into())
172 }
173}
174
175impl<T: IoCallbackUnsafe<AnyIo>> Io for Context<T> {
176 fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
177 Context::recv(self, buf)
178 }
179
180 fn send(&mut self, buf: &[u8]) -> Result<usize> {
181 Context::send(self, buf)
182 }
183}
184
185#[cfg(feature = "std")]
186impl<T: IoCallbackUnsafe<Stream>> Read for Context<T> {
192 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
193 match self.recv(buf) {
194 Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0),
195 Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => {
196 Err(IoErrorKind::WouldBlock.into())
197 }
198 Err(e) => Err(crate::private::error_to_io_error(e)),
199 Ok(i) => Ok(i),
200 }
201 }
202}
203
204#[cfg(feature = "std")]
205impl<T: IoCallbackUnsafe<Stream>> Write for Context<T> {
211 fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
212 match self.send(buf) {
213 Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0),
214 Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => {
215 Err(IoErrorKind::WouldBlock.into())
216 }
217 Err(e) => Err(crate::private::error_to_io_error(e)),
218 Ok(i) => Ok(i),
219 }
220 }
221
222 fn flush(&mut self) -> IoResult<()> {
223 Ok(())
224 }
225}