use std::convert::AsRef;
use std::net::TcpStream;
use std::path::Path;
use std::error;
use std::fmt;
use std::str;
use std::io;
use http::{HttpScheme, ALPN_PROTOCOLS};
use super::{ClientStream, write_preface, HttpConnect, HttpConnectError};
use openssl::ssl::{Ssl, SslStream, SslContext};
use openssl::ssl::{SSL_VERIFY_PEER, SSL_VERIFY_FAIL_IF_NO_PEER_CERT};
use openssl::ssl::SSL_OP_NO_COMPRESSION;
use openssl::ssl::error::SslError;
use openssl::ssl::SslMethod;
pub struct TlsConnector<'a, 'ctx> {
pub host: &'a str,
context: Http2TlsContext<'ctx>,
}
enum Http2TlsContext<'a> {
Wrapped(&'a SslContext),
CertPath(&'a Path),
}
pub enum TlsConnectError {
IoError(io::Error),
SslError(SslError),
Http2NotSupported(SslStream<TcpStream>),
}
impl fmt::Debug for TlsConnectError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
try!(write!(fmt, "TlsConnectError::{}", match *self {
TlsConnectError::IoError(_) => "IoError",
TlsConnectError::SslError(_) => "SslError",
TlsConnectError::Http2NotSupported(_) => "Http2NotSupported",
}));
match *self {
TlsConnectError::IoError(ref err) => try!(write!(fmt, "({:?})", err)),
TlsConnectError::SslError(ref err) => try!(write!(fmt, "({:?})", err)),
TlsConnectError::Http2NotSupported(_) => try!(write!(fmt, "(...)")),
};
Ok(())
}
}
impl fmt::Display for TlsConnectError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "TLS HTTP/2 connect error: {}", (self as &error::Error).description())
}
}
impl error::Error for TlsConnectError {
fn description(&self) -> &str {
match *self {
TlsConnectError::IoError(ref err) => err.description(),
TlsConnectError::SslError(ref err) => err.description(),
TlsConnectError::Http2NotSupported(_) => "HTTP/2 not supported by the server",
}
}
fn cause(&self) -> Option<&error::Error> {
match *self {
TlsConnectError::IoError(ref err) => Some(err),
TlsConnectError::SslError(ref err) => Some(err),
TlsConnectError::Http2NotSupported(_) => None,
}
}
}
impl From<io::Error> for TlsConnectError {
fn from(err: io::Error) -> TlsConnectError {
TlsConnectError::IoError(err)
}
}
impl From<SslError> for TlsConnectError {
fn from(err: SslError) -> TlsConnectError {
TlsConnectError::SslError(err)
}
}
impl HttpConnectError for TlsConnectError {}
impl<'a, 'ctx> TlsConnector<'a, 'ctx> {
pub fn new<P: AsRef<Path>>(host: &'a str, ca_file_path: &'ctx P) -> TlsConnector<'a, 'ctx> {
TlsConnector {
host: host,
context: Http2TlsContext::CertPath(ca_file_path.as_ref()),
}
}
pub fn with_context(host: &'a str, context: &'ctx SslContext) -> TlsConnector<'a, 'ctx> {
TlsConnector {
host: host,
context: Http2TlsContext::Wrapped(context),
}
}
pub fn build_default_context(ca_file_path: &Path) -> Result<SslContext, TlsConnectError> {
let mut context = try!(SslContext::new(SslMethod::Tlsv1_2));
context.set_verify(SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, None);
try!(context.set_CA_file(ca_file_path));
context.set_options(SSL_OP_NO_COMPRESSION);
context.set_npn_protocols(ALPN_PROTOCOLS);
Ok(context)
}
}
impl<'a, 'ctx> HttpConnect for TlsConnector<'a, 'ctx> {
type Stream = SslStream<TcpStream>;
type Err = TlsConnectError;
fn connect(self) -> Result<ClientStream<SslStream<TcpStream>>, TlsConnectError> {
let raw_tcp = try!(TcpStream::connect(&(self.host, 443)));
let ssl = match self.context {
Http2TlsContext::CertPath(path) => {
let ctx = try!(TlsConnector::build_default_context(&path));
try!(Ssl::new(&ctx))
},
Http2TlsContext::Wrapped(ctx) => try!(Ssl::new(ctx)),
};
try!(ssl.set_hostname(self.host));
let mut ssl_stream = try!(SslStream::new_from(ssl, raw_tcp));
let fail = match ssl_stream.get_selected_npn_protocol() {
None => true,
Some(proto) => {
debug!("Selected protocol -> {:?}", str::from_utf8(proto));
let found = ALPN_PROTOCOLS.iter().any(|&http2_proto| http2_proto == proto);
!found
}
};
if fail {
return Err(TlsConnectError::Http2NotSupported(ssl_stream));
}
try!(write_preface(&mut ssl_stream));
Ok(ClientStream(ssl_stream, HttpScheme::Https, self.host.into()))
}
}