mbedtls/ssl/
io.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
4 * https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
5 * 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
6 * option. This file may not be copied, modified, or distributed except
7 * according to those terms. */
8//! Various I/O abstractions for use with MbedTLS's TLS sessions.
9//!
10//! If you are using `std::net::TcpStream` or any `std::io::Read` and
11//! `std::io::Write` streams, you probably don't need to look at any of this.
12//! Just pass your stream directly to `Context::establish`. If you want to use
13//! a `std::net::UdpSocket` with DTLS, take a look at `ConnectedUdpSocket`. If
14//! you are implementing your own communication types or traits, consider
15//! implementing `Io` for them. If all else fails, implement `IoCallback`.
16
17#[cfg(feature = "std")]
18use std::{
19    io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult, Write},
20    net::UdpSocket,
21    result::Result as StdResult,
22};
23
24use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
25use mbedtls_sys::types::size_t;
26
27use super::context::Context;
28use crate::error::Result;
29#[cfg(feature = "std")]
30use crate::error::{codes, Error};
31
32/// A direct representation of the `mbedtls_ssl_send_t` and `mbedtls_ssl_recv_t`
33/// callback function pointers.
34///
35/// You probably want to implement `IoCallback` instead.
36pub trait IoCallbackUnsafe<T> {
37    unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int
38    where
39        Self: Sized;
40    unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int
41    where
42        Self: Sized;
43    fn data_ptr(&mut self) -> *mut c_void;
44}
45
46/// A safe representation of the `mbedtls_ssl_send_t` and `mbedtls_ssl_recv_t`
47/// callback function pointers.
48///
49/// `T` specifies whether this abstracts an implementation of `std::io::Read`
50/// and `std::io::Write` or the more generic `Io` type. See the `Stream` and
51/// `AnyIo` types in this module.
52pub trait IoCallback<T> {
53    fn recv(&mut self, buf: &mut [u8]) -> Result<usize>;
54    fn send(&mut self, buf: &[u8]) -> Result<usize>;
55}
56
57impl<IO: IoCallback<T>, T> IoCallbackUnsafe<T> for IO {
58    unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
59        let len = if len > (c_int::max_value() as size_t) {
60            c_int::max_value() as size_t
61        } else {
62            len
63        };
64        match (&mut *(user_data as *mut IO)).recv(::core::slice::from_raw_parts_mut(data, len)) {
65            Ok(i) => i as c_int,
66            Err(e) => e.to_int(),
67        }
68    }
69
70    unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
71        let len = if len > (c_int::max_value() as size_t) {
72            c_int::max_value() as size_t
73        } else {
74            len
75        };
76        match (&mut *(user_data as *mut IO)).send(::core::slice::from_raw_parts(data, len)) {
77            Ok(i) => i as c_int,
78            Err(e) => e.to_int(),
79        }
80    }
81
82    fn data_ptr(&mut self) -> *mut c_void {
83        self as *mut IO as *mut _
84    }
85}
86
87/// Marker type for an IO implementation that doesn't implement `std::io::Read`
88/// and `std::io::Write`.
89pub enum AnyIo {}
90#[cfg(feature = "std")]
91/// Marker type for an IO implementation that implements both `std::io::Read`
92/// and `std::io::Write`.
93pub enum Stream {}
94
95/// Read and write bytes or packets.
96///
97/// Implementors represent a duplex socket or file descriptor that can be read
98/// from or written to.
99///
100/// You can wrap any type of `Io` with `Context::establish` to protect that
101/// communication channel with (D)TLS. That `Context` then also implements `Io`
102/// so you can use it interchangeably.
103///
104/// If you are using byte streams and are using `std`, you don't need this trait
105/// and can rely on `std::io::Read` and `std::io::Write` instead.
106pub trait Io {
107    fn recv(&mut self, buf: &mut [u8]) -> Result<usize>;
108    fn send(&mut self, buf: &[u8]) -> Result<usize>;
109}
110
111impl<IO: Io> IoCallback<AnyIo> for IO {
112    fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
113        Io::recv(self, buf)
114    }
115
116    fn send(&mut self, buf: &[u8]) -> Result<usize> {
117        Io::send(self, buf)
118    }
119}
120
121#[cfg(feature = "std")]
122impl<IO: Read + Write> IoCallback<Stream> for IO {
123    fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
124        self.read(buf).map_err(|e| match e {
125            ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantRead),
126            _ => Error::from(codes::NetRecvFailed),
127        })
128    }
129
130    fn send(&mut self, buf: &[u8]) -> Result<usize> {
131        self.write(buf).map_err(|e| match e {
132            ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantWrite),
133            _ => Error::from(codes::NetSendFailed),
134        })
135    }
136}
137
138#[cfg(feature = "std")]
139/// A `UdpSocket` on which `connect` was successfully called.
140///
141/// Construct this type using `ConnectedUdpSocket::connect`.
142pub struct ConnectedUdpSocket {
143    socket: UdpSocket,
144}
145
146#[cfg(feature = "std")]
147impl ConnectedUdpSocket {
148    pub fn connect<A: std::net::ToSocketAddrs>(socket: UdpSocket, addr: A) -> StdResult<Self, (IoError, UdpSocket)> {
149        match socket.connect(addr) {
150            Ok(_) => Ok(ConnectedUdpSocket { socket }),
151            Err(e) => Err((e, socket)),
152        }
153    }
154
155    pub fn into_socket(self) -> UdpSocket {
156        self.socket
157    }
158}
159
160#[cfg(feature = "std")]
161impl Io for ConnectedUdpSocket {
162    fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
163        match self.socket.recv(buf) {
164            Ok(i) => Ok(i),
165            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(codes::SslWantRead.into()),
166            Err(_) => Err(codes::NetRecvFailed.into()),
167        }
168    }
169
170    fn send(&mut self, buf: &[u8]) -> Result<usize> {
171        self.socket.send(buf).map_err(|_| codes::NetSendFailed.into())
172    }
173}
174
175impl<T: IoCallbackUnsafe<AnyIo>> Io for Context<T> {
176    fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
177        Context::recv(self, buf)
178    }
179
180    fn send(&mut self, buf: &[u8]) -> Result<usize> {
181        Context::send(self, buf)
182    }
183}
184
185#[cfg(feature = "std")]
186/// Implements [`std::io::Read`] whenever T implements `Read`, too. This ensures
187/// that `Read`, which is designated for byte-oriented sources, is only
188/// implemented when the underlying [`IoCallbackUnsafe`] is byte-oriented, too.
189/// Specifically, this means that it is implemented for `Context<TcpStream>`,
190/// i.e. TLS connections but not for DTLS connections.
191impl<T: IoCallbackUnsafe<Stream>> Read for Context<T> {
192    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
193        match self.recv(buf) {
194            Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0),
195            Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => {
196                Err(IoErrorKind::WouldBlock.into())
197            }
198            Err(e) => Err(crate::private::error_to_io_error(e)),
199            Ok(i) => Ok(i),
200        }
201    }
202}
203
204#[cfg(feature = "std")]
205/// Implements [`std::io::Write`] whenever T implements `Write`, too. This
206/// ensures that `Write`, which is designated for byte-oriented sinks, is only
207/// implemented when the underlying [`IoCallbackUnsafe`] is byte-oriented, too.
208/// Specifically, this means that it is implemented for `Context<TcpStream>`,
209/// i.e. TLS connections but not for DTLS connections.
210impl<T: IoCallbackUnsafe<Stream>> Write for Context<T> {
211    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
212        match self.send(buf) {
213            Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0),
214            Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => {
215                Err(IoErrorKind::WouldBlock.into())
216            }
217            Err(e) => Err(crate::private::error_to_io_error(e)),
218            Ok(i) => Ok(i),
219        }
220    }
221
222    fn flush(&mut self) -> IoResult<()> {
223        Ok(())
224    }
225}