em_app/
mbedtls_hyper.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6use hyper::net::{NetworkStream, SslClient, SslServer};
7use std::fmt;
8use std::io;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use std::time::Duration;
12use std::io::{Error as IoError, ErrorKind as IoErrorKind};
13
14use mbedtls::ssl::{Config, Context};
15
16// Native TLS compatibility - to move to native tls client in the future
17#[derive(Clone)]
18pub struct TlsStream<T> {
19    context: Arc<Mutex<Context<T>>>,
20}
21
22impl<T: 'static> TlsStream<T> {
23    pub fn new(context: Arc<Mutex<Context<T>>>) -> io::Result<Self> {
24        if context.lock().unwrap().io_mut().is_none() {
25            return Err(IoError::new(IoErrorKind::InvalidInput, "Peer set in context is not of expected type"));
26        }
27
28        Ok(TlsStream {
29            context,
30        })
31    }
32}
33
34impl<T: 'static + io::Read + io::Write> io::Read for TlsStream<T>
35{
36    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37        self.context.lock().unwrap().read(buf)
38    }
39}
40
41impl<T: 'static + io::Read + io::Write> io::Write for TlsStream<T>
42{
43    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
44        self.context.lock().unwrap().write(buf)
45    }
46
47    fn flush(&mut self) -> io::Result<()> {
48        self.context.lock().unwrap().flush()
49    }
50}
51
52impl<T: 'static> NetworkStream for TlsStream<T>
53    where T: NetworkStream
54{
55    fn peer_addr(&mut self) -> io::Result<SocketAddr> {
56        self.context.lock().unwrap().io_mut()
57            .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))?
58            .peer_addr()
59    }
60
61    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
62        self.context.lock().unwrap().io_mut()
63            .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))?
64            .set_read_timeout(dur)
65    }
66
67    fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
68        self.context.lock().unwrap().io_mut()
69            .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))?
70            .set_write_timeout(dur)
71    }
72 }
73
74
75#[derive(Clone)]
76pub struct MbedSSLServer {
77    rc_config: Arc<Config>,
78}
79
80impl MbedSSLServer {
81    pub fn new(rc_config: Arc<Config>) -> Self {
82        MbedSSLServer {
83            rc_config,
84        }
85    }
86}
87
88/// An abstraction to allow any SSL implementation to be used with server-side HttpsStreams.
89impl<T> SslServer<T> for MbedSSLServer
90    where T: NetworkStream + Send + Clone + fmt::Debug + Sync
91{
92    /// The protected stream.
93    type Stream = TlsStream<T>;
94
95    /// Wrap a server stream with SSL.
96    fn wrap_server(&self, stream: T) -> Result<Self::Stream, hyper::Error> {
97        let mut ctx = Context::new(self.rc_config.clone());
98        ctx.establish(stream, None).map_err(|e| hyper::error::Error::Ssl(e.into()))?;
99
100        Ok(TlsStream::new(Arc::new(Mutex::new(ctx))).expect("Software error creating TlsStream"))
101    }
102}
103
104#[derive(Clone)]
105pub struct MbedSSLClient {
106    rc_config: Arc<Config>,
107    verify_hostname: bool,
108
109    // This can be used when verify_hostname is set to true.
110    // It will force ssl client to send this specific SNI on all established connections disregarding any host provided by hyper.
111    override_sni: Option<String>,
112}
113
114impl MbedSSLClient {
115    #[allow(dead_code)]
116    pub fn new(rc_config: Arc<Config>, verify_hostname: bool) -> Self {
117        MbedSSLClient {
118            rc_config,
119            verify_hostname,
120            override_sni: None,
121        }
122    }
123
124    #[allow(dead_code)]
125    pub fn new_with_sni(rc_config: Arc<Config>, verify_hostname: bool, override_sni: Option<String>) -> Self {
126        MbedSSLClient {
127            rc_config,
128            verify_hostname,
129            override_sni,
130        }
131    }
132}
133
134impl<T> SslClient<T> for MbedSSLClient
135    where T: NetworkStream + Send + Clone + fmt::Debug + Sync
136{
137    type Stream = TlsStream<T>;
138
139    fn wrap_client(&self, stream: T, host: &str) -> hyper::Result<TlsStream<T>> {
140        let mut context = Context::new(self.rc_config.clone());
141
142        let verify_hostname = match self.verify_hostname {
143            true => Some(self.override_sni.as_ref().map(|v| v.as_str()).unwrap_or(host)),
144            false => None,
145        };
146
147        match context.establish(stream, verify_hostname) {
148            Ok(()) => Ok(TlsStream::new(Arc::new(Mutex::new(context))).expect("Software error creating TlsStream")),
149            Err(e) => Err(hyper::Error::Ssl(Box::new(e))),
150        }
151    }
152}
153