1use hyper::net::{NetworkStream, SslClient, SslServer};
7use std::fmt;
8use std::io;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use std::time::Duration;
12use std::io::{Error as IoError, ErrorKind as IoErrorKind};
13
14use mbedtls::ssl::{Config, Context};
15
16#[derive(Clone)]
18pub struct TlsStream<T> {
19 context: Arc<Mutex<Context<T>>>,
20}
21
22impl<T: 'static> TlsStream<T> {
23 pub fn new(context: Arc<Mutex<Context<T>>>) -> io::Result<Self> {
24 if context.lock().unwrap().io_mut().is_none() {
25 return Err(IoError::new(IoErrorKind::InvalidInput, "Peer set in context is not of expected type"));
26 }
27
28 Ok(TlsStream {
29 context,
30 })
31 }
32}
33
34impl<T: 'static + io::Read + io::Write> io::Read for TlsStream<T>
35{
36 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37 self.context.lock().unwrap().read(buf)
38 }
39}
40
41impl<T: 'static + io::Read + io::Write> io::Write for TlsStream<T>
42{
43 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
44 self.context.lock().unwrap().write(buf)
45 }
46
47 fn flush(&mut self) -> io::Result<()> {
48 self.context.lock().unwrap().flush()
49 }
50}
51
52impl<T: 'static> NetworkStream for TlsStream<T>
53 where T: NetworkStream
54{
55 fn peer_addr(&mut self) -> io::Result<SocketAddr> {
56 self.context.lock().unwrap().io_mut()
57 .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))?
58 .peer_addr()
59 }
60
61 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
62 self.context.lock().unwrap().io_mut()
63 .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))?
64 .set_read_timeout(dur)
65 }
66
67 fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
68 self.context.lock().unwrap().io_mut()
69 .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))?
70 .set_write_timeout(dur)
71 }
72 }
73
74
75#[derive(Clone)]
76pub struct MbedSSLServer {
77 rc_config: Arc<Config>,
78}
79
80impl MbedSSLServer {
81 pub fn new(rc_config: Arc<Config>) -> Self {
82 MbedSSLServer {
83 rc_config,
84 }
85 }
86}
87
88impl<T> SslServer<T> for MbedSSLServer
90 where T: NetworkStream + Send + Clone + fmt::Debug + Sync
91{
92 type Stream = TlsStream<T>;
94
95 fn wrap_server(&self, stream: T) -> Result<Self::Stream, hyper::Error> {
97 let mut ctx = Context::new(self.rc_config.clone());
98 ctx.establish(stream, None).map_err(|e| hyper::error::Error::Ssl(e.into()))?;
99
100 Ok(TlsStream::new(Arc::new(Mutex::new(ctx))).expect("Software error creating TlsStream"))
101 }
102}
103
104#[derive(Clone)]
105pub struct MbedSSLClient {
106 rc_config: Arc<Config>,
107 verify_hostname: bool,
108
109 override_sni: Option<String>,
112}
113
114impl MbedSSLClient {
115 #[allow(dead_code)]
116 pub fn new(rc_config: Arc<Config>, verify_hostname: bool) -> Self {
117 MbedSSLClient {
118 rc_config,
119 verify_hostname,
120 override_sni: None,
121 }
122 }
123
124 #[allow(dead_code)]
125 pub fn new_with_sni(rc_config: Arc<Config>, verify_hostname: bool, override_sni: Option<String>) -> Self {
126 MbedSSLClient {
127 rc_config,
128 verify_hostname,
129 override_sni,
130 }
131 }
132}
133
134impl<T> SslClient<T> for MbedSSLClient
135 where T: NetworkStream + Send + Clone + fmt::Debug + Sync
136{
137 type Stream = TlsStream<T>;
138
139 fn wrap_client(&self, stream: T, host: &str) -> hyper::Result<TlsStream<T>> {
140 let mut context = Context::new(self.rc_config.clone());
141
142 let verify_hostname = match self.verify_hostname {
143 true => Some(self.override_sni.as_ref().map(|v| v.as_str()).unwrap_or(host)),
144 false => None,
145 };
146
147 match context.establish(stream, verify_hostname) {
148 Ok(()) => Ok(TlsStream::new(Arc::new(Mutex::new(context))).expect("Software error creating TlsStream")),
149 Err(e) => Err(hyper::Error::Ssl(Box::new(e))),
150 }
151 }
152}
153