#[macro_use] extern crate lazy_static;
extern crate libtls_sys;
#[macro_use] extern crate quick_error;
use libtls_sys::*;
use std::borrow::Cow;
use std::convert::AsRef;
use std::default::Default;
use std::ffi::{NulError, CString, CStr};
use std::io::{self, Read, Write};
use std::net::{self, ToSocketAddrs, TcpStream, SocketAddr};
use std::os::unix::io::{FromRawFd, IntoRawFd};
use std::os::unix::ffi::OsStrExt;
use std::ptr;
use std::path::Path;
#[derive(Copy, Clone)]
pub struct Init(());
lazy_static!{
static ref TLS_INIT: Init = unsafe { Init::new() };
}
impl Init {
unsafe fn new() -> Init {
if tls_init() == -1 {
panic!("TLS init failed.");
}
Init(())
}
pub fn init() -> Init {
*TLS_INIT
}
}
quick_error! {
#[derive(Debug)]
pub enum ConnectError {
InvalidHostError {}
IoError(err: io::Error) {
from()
cause(err)
}
}
}
pub trait ToNamedSocketAddrs: ToSocketAddrs {
fn host(&self) -> Cow<str>;
}
impl<'a> ToNamedSocketAddrs for (&'a str, u16) {
fn host(&self) -> Cow<str> {
Cow::Borrowed(&self.0)
}
}
impl<'a> ToNamedSocketAddrs for &'a str {
fn host(&self) -> Cow<str> {
if let Ok(addr) = self.parse::<SocketAddr>() {
return Cow::Owned(addr.ip().to_string());
}
let mut parts_it = self.rsplitn(2, ':');
if let (Some(_port), Some(host)) = (parts_it.next(), parts_it.next()) {
Cow::Borrowed(host)
} else {
Cow::Borrowed(self)
}
}
}
pub struct Config(*mut Struct_tls_config);
impl Config {
pub fn new(_: Init) -> Config {
let raw_config = unsafe { tls_config_new() };
if raw_config.is_null() {
panic!("TLS configuration failed to allocate.");
}
Config(raw_config)
}
pub fn connect<A: ToNamedSocketAddrs+Clone>(&self, addr: A)
-> Result<Stream, ConnectError> {
self.connect_name(addr.clone(), &*addr.host())
}
pub fn connect_name<A: ToSocketAddrs>(&self, addr: A, host: &str)
-> Result<Stream, ConnectError> {
self.connect_stream(try!(TcpStream::connect(addr)), host)
}
pub fn connect_stream(&self, stream: TcpStream, host: &str)
-> Result<Stream, ConnectError> {
unsafe { Stream::do_connect(self, stream, host) }
}
}
impl Default for Config {
fn default() -> Config {
Config::new(Init::init())
}
}
impl Drop for Config {
fn drop(&mut self) {
unsafe {
tls_config_free(self.0)
}
}
}
impl Config {
pub fn set_ca_path<P: AsRef<Path>>(&mut self, path: P) -> Result<(), ()> {
unsafe {
tls_config_set_ca_path(self.0, path.as_ref().to_path_c_str().as_ptr())
.as_tls_result_bare()
}
}
pub fn set_ca_file<P: AsRef<Path>>(&mut self, path: P) -> Result<(), ()> {
unsafe {
tls_config_set_ca_file(self.0, path.as_ref().to_path_c_str().as_ptr())
.as_tls_result_bare()
}
}
pub fn set_ca_mem(&mut self, cert: &[u8]) -> Result<(), ()> {
unsafe {
tls_config_set_ca_mem(self.0, &cert[0], cert.len())
.as_tls_result_bare()
}
}
pub fn set_cert_file<P: AsRef<Path>>(&mut self, path: P) -> Result<(), ()> {
unsafe {
tls_config_set_cert_file(self.0, path.as_ref().to_path_c_str().as_ptr())
.as_tls_result_bare()
}
}
pub fn set_cert_mem(&mut self, cert: &[u8]) -> Result<(), ()> {
unsafe {
tls_config_set_cert_mem(self.0, &cert[0], cert.len())
.as_tls_result_bare()
}
}
pub fn set_ciphers(&mut self, ciphers: Ciphers) -> Result<(), ()> {
unsafe {
tls_config_set_ciphers(
self.0,
ciphers.0.as_ptr()
).as_tls_result_bare()
}
}
pub fn set_key_file<P: AsRef<Path>>(&mut self, path: P) -> Result<(), ()> {
unsafe {
tls_config_set_key_file(self.0, path.as_ref().to_path_c_str().as_ptr())
.as_tls_result_bare()
}
}
pub fn set_key_mem(&mut self, cert: &[u8]) -> Result<(), ()> {
unsafe {
tls_config_set_key_mem(self.0, &cert[0], cert.len())
.as_tls_result_bare()
}
}
pub fn set_protocols(&mut self, protocols: Protocols) {
unsafe {
tls_config_set_protocols(self.0, protocols.0)
}
}
pub fn set_verify_depth(&mut self, verify_depth: u64) {
unsafe {
tls_config_set_verify_depth(self.0, verify_depth as libc::c_int);
}
}
pub fn prefer_ciphers_client(&mut self) {
unsafe {
tls_config_prefer_ciphers_client(self.0);
}
}
pub fn prefer_ciphers_server(&mut self) {
unsafe {
tls_config_prefer_ciphers_server(self.0);
}
}
pub fn clear_keys(&mut self) {
unsafe {
tls_config_clear_keys(self.0);
}
}
pub fn insecure_noverifycert(&mut self) {
unsafe {
tls_config_insecure_noverifycert(self.0);
}
}
pub fn insecure_noverifyname(&mut self) {
unsafe {
tls_config_insecure_noverifyname(self.0);
}
}
pub fn insecure_noverifytime(&mut self) {
unsafe {
tls_config_insecure_noverifytime(self.0);
}
}
pub fn verify(&mut self) {
unsafe {
tls_config_verify(self.0);
}
}
pub fn verify_client(&mut self) {
unsafe {
tls_config_verify_client(self.0);
}
}
pub fn verify_client_optional(&mut self) {
unsafe {
tls_config_verify_client_optional(self.0);
}
}
}
#[derive(Debug, Clone)]
pub struct Protocols(libc::uint32_t);
impl Protocols {
pub fn from_str(protocols: &str) -> Result<Protocols, ()> {
Init::init();
let mut ret_val = Protocols(0);
let protocols = try!(CString::new(protocols).map_err(|_| ()));
unsafe {
try!(
tls_config_parse_protocols(&mut ret_val.0, protocols.as_ptr())
.as_tls_result_bare()
)
}
Ok(ret_val)
}
pub fn secure() -> Protocols {
Protocols(libtls_sys::TLS_PROTOCOLS_DEFAULT)
}
pub fn all() -> Protocols {
Protocols(libtls_sys::TLS_PROTOCOLS_ALL)
}
pub fn legacy() -> Protocols {
Protocols::all()
}
}
impl Default for Protocols {
fn default() -> Protocols {
Protocols::secure()
}
}
#[derive(Debug, Clone)]
pub struct Ciphers(CString);
impl Ciphers {
pub fn from_str(ciphers: &str) -> Result<Ciphers, NulError> {
Ok(Ciphers(try!(CString::new(ciphers))))
}
pub fn secure() -> Ciphers {
Ciphers::from_str("secure").expect("'secure' does not contain a nul")
}
pub fn compat() -> Ciphers {
Ciphers::from_str("compat").expect("'compat' does not contain a nul")
}
pub fn legacy() -> Ciphers {
Ciphers::compat()
}
}
impl Default for Ciphers {
fn default() -> Ciphers {
Ciphers::secure()
}
}
pub struct Stream {
context: *mut Struct_tls,
stream: TcpStream,
}
impl Stream {
unsafe fn do_connect(config: &Config, stream: TcpStream, host: &str)
-> Result<Self, ConnectError> {
let host = try!(
CString::new(host).map_err(|_| ConnectError::InvalidHostError)
);
let context = tls_client();
if context.is_null() {
panic!("TLS client building failed to allocate.");
}
let fd = stream.into_raw_fd();
let mut result = Stream {
stream: TcpStream::from_raw_fd(fd),
context: context,
};
try!(tls_configure(context, config.0).as_tls_result_io(context));
try!(result.run_to_completion(|| {
tls_connect_socket(context, fd, host.as_ptr())
}));
Ok(result)
}
pub fn connect<A: ToNamedSocketAddrs+Clone>(addr: A)
-> Result<Stream, ConnectError> {
Config::default().connect(addr)
}
pub fn close(mut self) -> Result<(), io::Error> {
unsafe { self.do_close() }
}
pub fn get_ref(&self) -> &TcpStream {
&self.stream
}
pub fn get_mut(&mut self) -> &mut TcpStream {
&mut self.stream
}
unsafe fn do_close(&mut self) -> Result<(), io::Error> {
let context = self.context;
try!(self.run_to_completion(|| {
tls_close(context)
}));
tls_free(self.context);
self.context = ptr::null_mut();
try!(self.stream.shutdown(net::Shutdown::Both));
Ok(())
}
unsafe fn run_to_completion<C, T>(&mut self, mut c: C)
-> Result<T, io::Error>
where C: FnMut() -> T, T: Into<libc::c_int> + Clone {
loop {
let result = c();
match result.clone().into() {
TLS_WANT_POLLIN | TLS_WANT_POLLOUT => (),
-1 => {
return Err(io::Error::new(
io::ErrorKind::Other,
CStr::from_ptr(tls_error(self.context))
.to_string_lossy()
.into_owned()
))
}
_ => return Ok(result),
}
}
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
let len = buf.len();
let buf = &mut buf[0] as *mut u8 as *mut libc::c_void;
unsafe {
tls_read(self.context, buf, len)
.as_tls_result_io_number(self.context).map(|x| x as usize)
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
let len = buf.len();
let buf = &buf[0] as *const u8 as *const libc::c_void;
unsafe {
tls_write(self.context, buf, len)
.as_tls_result_io_number(self.context).map(|x| x as usize)
}
}
fn flush(&mut self) -> Result<(), io::Error> {
self.stream.flush()
}
}
impl Drop for Stream {
fn drop(&mut self) {
if !self.context.is_null() {
unsafe {
self.do_close()
.expect("TLS connection close unexpectedly failed!");
}
}
}
}
trait PathRawStr {
fn to_path_c_str(&self) -> Cow<CStr>;
}
impl PathRawStr for Path {
fn to_path_c_str(&self) -> Cow<CStr> {
unsafe {
Cow::Borrowed(
CStr::from_ptr(
&self.as_os_str().as_bytes()[0] as *const u8 as *const libc::c_char
)
)
}
}
}
trait AsTlsResult: Sized {
fn as_tls_result_bare(&self) -> Result<(), ()>;
fn as_tls_result_bare_number(&self) -> Result<Self, ()>;
unsafe fn as_tls_result_io(&self, context: *mut Struct_tls)
-> Result<(), io::Error>;
unsafe fn as_tls_result_io_number(&self, context: *mut Struct_tls)
-> Result<Self, io::Error>;
}
macro_rules! define_tls_result {
($t:ty) => (
impl AsTlsResult for $t {
fn as_tls_result_bare(&self) -> Result<(), ()> {
if *self < 0 {
Err(())
} else {
Ok(())
}
}
fn as_tls_result_bare_number(&self) -> Result<$t, ()> {
self.as_tls_result_bare().map(|_| *self)
}
unsafe fn as_tls_result_io(&self, context: *mut Struct_tls)
-> Result<(), io::Error> {
if *self < 0 {
Err(io::Error::new(
io::ErrorKind::Other,
CStr::from_ptr(tls_error(context))
.to_string_lossy()
.into_owned()
))
} else {
Ok(())
}
}
unsafe fn as_tls_result_io_number(&self, context: *mut Struct_tls)
-> Result<$t, io::Error> {
self.as_tls_result_io(context).map(|_| *self)
}
}
)
}
define_tls_result!(libc::c_int);
define_tls_result!(libc::ssize_t);
#[cfg(test)]
mod test;