use libc::{size_t, c_void};
use core_foundation::array::CFArray;
use core_foundation::base::{TCFType, Boolean};
use core_foundation_sys::base::OSStatus;
#[cfg(any(feature = "OSX_10_8", target_os = "ios"))]
use core_foundation_sys::base::{kCFAllocatorDefault, CFRelease};
use security_framework_sys::base::{errSecSuccess, errSecIO, errSecBadReq, errSecTrustSettingDeny,
errSecNotTrusted};
use security_framework_sys::secure_transport::*;
use std::any::Any;
use std::io;
use std::io::prelude::*;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use std::slice;
use std::result;
use {cvt, ErrorNew, CipherSuiteInternals, AsInner};
use base::{Result, Error};
use certificate::SecCertificate;
use cipher_suite::CipherSuite;
use identity::SecIdentity;
use trust::{SecTrust, TrustResult};
#[derive(Debug, Copy, Clone)]
pub enum ProtocolSide {
Server,
Client,
}
#[derive(Debug, Copy, Clone)]
pub enum ConnectionType {
Stream,
#[cfg(feature = "OSX_10_8")]
Datagram,
}
#[derive(Debug)]
pub enum HandshakeError<S> {
Failure(Error),
Interrupted(MidHandshakeSslStream<S>),
}
#[derive(Debug)]
pub struct MidHandshakeSslStream<S> {
stream: SslStream<S>,
reason: OSStatus,
}
impl<S> MidHandshakeSslStream<S> {
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
}
pub fn context(&self) -> &SslContext {
self.stream.context()
}
pub fn context_mut(&mut self) -> &mut SslContext {
self.stream.context_mut()
}
pub fn server_auth_completed(&self) -> bool {
self.reason == errSSLPeerAuthCompleted
}
pub fn client_cert_requested(&self) -> bool {
self.reason == errSSLClientCertRequested
}
pub fn would_block(&self) -> bool {
self.reason == errSSLWouldBlock
}
pub fn reason(&self) -> OSStatus {
self.reason
}
pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
self.stream.handshake()
}
}
#[derive(Debug)]
pub enum SessionState {
Idle,
Handshake,
Connected,
Closed,
Aborted,
}
impl SessionState {
fn from_raw(raw: SSLSessionState) -> SessionState {
match raw {
kSSLIdle => SessionState::Idle,
kSSLHandshake => SessionState::Handshake,
kSSLConnected => SessionState::Connected,
kSSLClosed => SessionState::Closed,
kSSLAborted => SessionState::Aborted,
_ => panic!("bad session state value {}", raw),
}
}
}
#[derive(Debug)]
pub enum SslAuthenticate {
Never,
Always,
Try,
}
#[derive(Debug)]
pub enum SslClientCertificateState {
None,
Requested,
Sent,
Rejected,
}
macro_rules! ssl_protocol {
($($(#[$a:meta])* const $variant:ident = $value:ident,)+) => {
pub enum SslProtocol {
$($(#[$a])* $variant,)+
}
impl fmt::Debug for SslProtocol {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use self::SslProtocol::*;
let s = match *self {
$($(#[$a])* $variant => stringify!($variant),)+
};
fmt.write_str(s)
}
}
impl SslProtocol {
fn from_raw(raw: SSLProtocol) -> SslProtocol {
use self::SslProtocol::*;
match raw {
$($(#[$a])* $value => $variant,)+
_ => panic!("invalid ssl protocol {}", raw),
}
}
#[cfg(feature = "OSX_10_8")]
fn to_raw(&self) -> SSLProtocol {
use self::SslProtocol::*;
match *self {
$($(#[$a])* $variant => $value,)+
}
}
}
}
}
ssl_protocol! {
const Unknown = kSSLProtocolUnknown,
const Ssl3 = kSSLProtocol3,
const Tls1 = kTLSProtocol1,
#[cfg(feature = "OSX_10_8")]
const Tls11 = kTLSProtocol11,
#[cfg(feature = "OSX_10_8")]
const Tls12 = kTLSProtocol12,
const Ssl2 = kSSLProtocol2,
#[cfg(feature = "OSX_10_8")]
const Dtls1 = kDTLSProtocol1,
const Ssl3Only = kSSLProtocol3Only,
const Tls1Only = kTLSProtocol1Only,
const All = kSSLProtocolAll,
}
pub struct SslContext(SSLContextRef);
impl Drop for SslContext {
#[cfg(not(any(feature = "OSX_10_8", target_os = "ios")))]
fn drop(&mut self) {
unsafe {
SSLDisposeContext(self.0);
}
}
#[cfg(any(feature = "OSX_10_8", target_os = "ios"))]
fn drop(&mut self) {
unsafe {
CFRelease(self.as_CFTypeRef());
}
}
}
#[cfg(feature = "OSX_10_8")]
impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
impl fmt::Debug for SslContext {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let mut builder = fmt.debug_struct("SslContext");
if let Ok(state) = self.state() {
builder.field("state", &state);
}
builder.finish()
}
}
unsafe impl Send for SslContext {}
impl AsInner for SslContext {
type Inner = SSLContextRef;
fn as_inner(&self) -> SSLContextRef {
self.0
}
}
macro_rules! impl_options {
($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
$(
$(#[$a])*
pub fn $set(&mut self, value: bool) -> Result<()> {
unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
}
$(#[$a])*
pub fn $get(&self) -> Result<bool> {
let mut value = 0;
unsafe { try!(cvt(SSLGetSessionOption(self.0, $opt, &mut value))); }
Ok(value != 0)
}
)*
}
}
impl SslContext {
pub fn new(side: ProtocolSide, type_: ConnectionType) -> Result<SslContext> {
SslContext::new_inner(side, type_)
}
#[cfg(not(any(feature = "OSX_10_8", target_os = "ios")))]
fn new_inner(side: ProtocolSide, _: ConnectionType) -> Result<SslContext> {
unsafe {
let is_server = match side {
ProtocolSide::Server => 1,
ProtocolSide::Client => 0,
};
let mut ctx = ptr::null_mut();
try!(cvt(SSLNewContext(is_server, &mut ctx)));
Ok(SslContext(ctx))
}
}
#[cfg(any(feature = "OSX_10_8", target_os = "ios"))]
fn new_inner(side: ProtocolSide, type_: ConnectionType) -> Result<SslContext> {
let side = match side {
ProtocolSide::Server => kSSLServerSide,
ProtocolSide::Client => kSSLClientSide,
};
let type_ = match type_ {
ConnectionType::Stream => kSSLStreamType,
ConnectionType::Datagram => kSSLDatagramType,
};
unsafe {
let ctx = SSLCreateContext(kCFAllocatorDefault, side, type_);
Ok(SslContext(ctx))
}
}
pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
unsafe {
cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr() as *const _, peer_name.len()))
}
}
pub fn peer_domain_name(&self) -> Result<String> {
unsafe {
let mut len = 0;
try!(cvt(SSLGetPeerDomainNameLength(self.0, &mut len)));
let mut buf = vec![0; len];
try!(cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr() as *mut _, &mut len)));
Ok(String::from_utf8(buf).unwrap())
}
}
pub fn set_certificate(&mut self,
identity: &SecIdentity,
certs: &[SecCertificate])
-> Result<()> {
let mut arr = vec![identity.as_CFType()];
arr.extend(certs.iter().map(|c| c.as_CFType()));
let certs = CFArray::from_CFTypes(&arr);
unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
}
pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr() as *const _, peer_id.len())) }
}
pub fn peer_id(&self) -> Result<Option<&[u8]>> {
unsafe {
let mut ptr = ptr::null();
let mut len = 0;
try!(cvt(SSLGetPeerID(self.0, &mut ptr, &mut len)));
if ptr.is_null() {
Ok(None)
} else {
Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
}
}
}
pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
unsafe {
let mut num_ciphers = 0;
try!(cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers)));
let mut ciphers = vec![0; num_ciphers];
try!(cvt(SSLGetSupportedCiphers(self.0, ciphers.as_mut_ptr(), &mut num_ciphers)));
Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c).unwrap()).collect())
}
}
pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
unsafe {
let mut num_ciphers = 0;
try!(cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers)));
let mut ciphers = vec![0; num_ciphers];
try!(cvt(SSLGetEnabledCiphers(self.0, ciphers.as_mut_ptr(), &mut num_ciphers)));
Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c).unwrap()).collect())
}
}
pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
unsafe { cvt(SSLSetEnabledCiphers(self.0, ciphers.as_ptr(), ciphers.len())) }
}
pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
unsafe {
let mut cipher = 0;
try!(cvt(SSLGetNegotiatedCipher(self.0, &mut cipher)));
Ok(CipherSuite::from_raw(cipher).unwrap())
}
}
pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
let auth = match auth {
SslAuthenticate::Never => kNeverAuthenticate,
SslAuthenticate::Always => kAlwaysAuthenticate,
SslAuthenticate::Try => kTryAuthenticate,
};
unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth)) }
}
pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
let mut state = 0;
unsafe {
try!(cvt(SSLGetClientCertificateState(self.0, &mut state)));
}
let state = match state {
kSSLClientCertNone => SslClientCertificateState::None,
kSSLClientCertRequested => SslClientCertificateState::Requested,
kSSLClientCertSent => SslClientCertificateState::Sent,
kSSLClientCertRejected => SslClientCertificateState::Rejected,
_ => panic!("got invalid client cert state {}", state),
};
Ok(state)
}
pub fn peer_trust(&self) -> Result<SecTrust> {
if let SessionState::Idle = try!(self.state()) {
return Err(Error::new(errSecBadReq));
}
unsafe {
let mut trust = ptr::null_mut();
try!(cvt(SSLCopyPeerTrust(self.0, &mut trust)));
Ok(SecTrust::wrap_under_create_rule(trust))
}
}
pub fn state(&self) -> Result<SessionState> {
unsafe {
let mut state = 0;
try!(cvt(SSLGetSessionState(self.0, &mut state)));
Ok(SessionState::from_raw(state))
}
}
pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
unsafe {
let mut version = 0;
try!(cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version)));
Ok(SslProtocol::from_raw(version))
}
}
#[cfg(feature = "OSX_10_8")]
pub fn protocol_version_max(&self) -> Result<SslProtocol> {
unsafe {
let mut version = 0;
try!(cvt(SSLGetProtocolVersionMax(self.0, &mut version)));
Ok(SslProtocol::from_raw(version))
}
}
#[cfg(feature = "OSX_10_8")]
pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.to_raw())) }
}
#[cfg(feature = "OSX_10_8")]
pub fn protocol_version_min(&self) -> Result<SslProtocol> {
unsafe {
let mut version = 0;
try!(cvt(SSLGetProtocolVersionMin(self.0, &mut version)));
Ok(SslProtocol::from_raw(version))
}
}
#[cfg(feature = "OSX_10_8")]
pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.to_raw())) }
}
pub fn buffered_read_size(&self) -> Result<usize> {
unsafe {
let mut size = 0;
try!(cvt(SSLGetBufferedReadSize(self.0, &mut size)));
Ok(size)
}
}
impl_options! {
const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
#[cfg(feature = "OSX_10_8")]
const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
#[cfg(feature = "OSX_10_9")]
const kSSLSessionOptionFalseStart: false_start & set_false_start,
#[cfg(feature = "OSX_10_9")]
const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
}
pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
where S: Read + Write
{
unsafe {
let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
if ret != errSecSuccess {
return Err(HandshakeError::Failure(Error::new(ret)));
}
let stream = Connection {
stream: stream,
err: None,
panic: None,
};
let stream = Box::into_raw(Box::new(stream));
let ret = SSLSetConnection(self.0, stream as *mut _);
if ret != errSecSuccess {
let _conn = Box::from_raw(stream);
return Err(HandshakeError::Failure(Error::new(ret)));
}
let stream = SslStream {
ctx: self,
_m: PhantomData,
};
stream.handshake()
}
}
}
struct Connection<S> {
stream: S,
err: Option<io::Error>,
panic: Option<Box<Any + Send>>,
}
#[cfg(feature = "nightly")]
fn recover<F, T>(f: F) -> ::std::result::Result<T, Box<Any + Send>> where F: FnOnce() -> T + ::std::panic::RecoverSafe {
::std::panic::recover(f)
}
#[cfg(not(feature = "nightly"))]
fn recover<F, T>(f: F) -> ::std::result::Result<T, Box<Any + Send>> where F: FnOnce() -> T {
Ok(f())
}
#[cfg(feature = "nightly")]
use std::panic::AssertRecoverSafe;
#[cfg(not(feature = "nightly"))]
struct AssertRecoverSafe<T>(T);
#[cfg(not(feature = "nightly"))]
impl<T> AssertRecoverSafe<T> {
fn new(t: T) -> Self {
AssertRecoverSafe(t)
}
}
#[cfg(not(feature = "nightly"))]
impl<T> ::std::ops::Deref for AssertRecoverSafe<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
#[cfg(not(feature = "nightly"))]
impl<T> ::std::ops::DerefMut for AssertRecoverSafe<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
fn translate_err(e: &io::Error) -> OSStatus {
match e.kind() {
io::ErrorKind::NotFound => errSSLClosedGraceful,
io::ErrorKind::ConnectionReset => errSSLClosedAbort,
io::ErrorKind::WouldBlock => errSSLWouldBlock,
_ => errSecIO,
}
}
unsafe extern "C" fn read_func<S: Read>(connection: SSLConnectionRef,
data: *mut c_void,
data_length: *mut size_t)
-> OSStatus {
let mut conn: &mut Connection<S> = mem::transmute(connection);
let mut data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
let mut start = 0;
let mut ret = errSecSuccess;
while start < data.len() {
let result = {
let mut conn = AssertRecoverSafe::new(&mut *conn);
let mut data = AssertRecoverSafe::new(&mut *data);
recover(move || conn.stream.read(&mut data[start..]))
};
match result {
Ok(Ok(0)) => {
ret = errSSLClosedNoNotify;
break;
}
Ok(Ok(len)) => start += len,
Ok(Err(e)) => {
ret = translate_err(&e);
conn.err = Some(e);
break;
}
Err(e) => {
ret = errSecIO;
conn.panic = Some(e);
break;
}
}
}
*data_length = start;
ret
}
unsafe extern "C" fn write_func<S: Write>(connection: SSLConnectionRef,
data: *const c_void,
data_length: *mut size_t)
-> OSStatus {
let mut conn: &mut Connection<S> = mem::transmute(connection);
let data = slice::from_raw_parts(data as *mut u8, *data_length);
let mut start = 0;
let mut ret = errSecSuccess;
while start < data.len() {
let result = {
let mut conn = AssertRecoverSafe::new(&mut *conn);
recover(move || conn.stream.write(&data[start..]))
};
match result {
Ok(Ok(0)) => {
ret = errSSLClosedNoNotify;
break;
}
Ok(Ok(len)) => start += len,
Ok(Err(e)) => {
ret = translate_err(&e);
conn.err = Some(e);
break;
}
Err(e) => {
ret = errSecIO;
conn.panic = Some(e);
break;
}
}
}
*data_length = start;
ret
}
#[cfg(feature = "nightly")]
use std::panic::propagate;
#[cfg(not(feature = "nightly"))]
use std::mem::drop as propagate;
pub struct SslStream<S> {
ctx: SslContext,
_m: PhantomData<S>,
}
impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("SslStream")
.field("context", &self.ctx)
.field("stream", self.get_ref())
.finish()
}
}
impl<S> Drop for SslStream<S> {
fn drop(&mut self) {
unsafe {
SSLClose(self.ctx.0);
let mut conn = ptr::null();
let ret = SSLGetConnection(self.ctx.0, &mut conn);
assert!(ret == errSecSuccess);
Box::<Connection<S>>::from_raw(conn as *mut _);
}
}
}
impl<S> SslStream<S> {
fn handshake(mut self) -> result::Result<SslStream<S>, HandshakeError<S>> {
match unsafe { SSLHandshake(self.ctx.0) } {
errSecSuccess => Ok(self),
reason @ errSSLPeerAuthCompleted |
reason @ errSSLClientCertRequested |
reason @ errSSLWouldBlock |
reason @ errSSLClientHelloReceived => {
Err(HandshakeError::Interrupted(MidHandshakeSslStream {
stream: self,
reason: reason,
}))
}
err => {
self.check_panic();
Err(HandshakeError::Failure(Error::new(err)))
}
}
}
pub fn get_ref(&self) -> &S {
&self.connection().stream
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.connection_mut().stream
}
pub fn context(&self) -> &SslContext {
&self.ctx
}
pub fn context_mut(&mut self) -> &mut SslContext {
&mut self.ctx
}
fn connection(&self) -> &Connection<S> {
unsafe {
let mut conn = ptr::null();
let ret = SSLGetConnection(self.ctx.0, &mut conn);
assert!(ret == errSecSuccess);
mem::transmute(conn)
}
}
fn connection_mut(&mut self) -> &mut Connection<S> {
unsafe {
let mut conn = ptr::null();
let ret = SSLGetConnection(self.ctx.0, &mut conn);
assert!(ret == errSecSuccess);
mem::transmute(conn)
}
}
fn check_panic(&mut self) {
let conn = self.connection_mut();
if let Some(err) = conn.panic.take() {
propagate(err);
}
}
fn get_error(&mut self, ret: OSStatus) -> io::Error {
self.check_panic();
if let Some(err) = self.connection_mut().err.take() {
err
} else {
io::Error::new(io::ErrorKind::Other, Error::new(ret))
}
}
}
impl<S: Read + Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
unsafe {
let mut nread = 0;
let ret = SSLRead(self.ctx.0,
buf.as_mut_ptr() as *mut _,
buf.len(),
&mut nread);
match ret {
errSecSuccess => Ok(nread as usize),
errSSLClosedGraceful |
errSSLClosedAbort |
errSSLClosedNoNotify => Ok(0),
_ => Err(self.get_error(ret)),
}
}
}
}
impl<S: Read + Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
unsafe {
let mut nwritten = 0;
let ret = SSLWrite(self.ctx.0,
buf.as_ptr() as *const _,
buf.len(),
&mut nwritten);
if ret == errSecSuccess {
Ok(nwritten as usize)
} else {
Err(self.get_error(ret))
}
}
}
fn flush(&mut self) -> io::Result<()> {
self.connection_mut().stream.flush()
}
}
#[derive(Debug)]
pub struct ClientBuilder {
certs: Option<Vec<SecCertificate>>,
}
impl ClientBuilder {
pub fn new() -> Self {
ClientBuilder { certs: None }
}
pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
self.certs = Some(certs.to_owned());
self
}
pub fn handshake<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
where S: Read + Write
{
let mut ctx = try!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
try!(ctx.set_peer_domain_name(domain));
if self.certs.is_some() {
try!(ctx.set_break_on_server_auth(true));
}
let mut result = ctx.handshake(stream);
loop {
match result {
Ok(stream) => return Ok(stream),
Err(HandshakeError::Interrupted(stream)) => {
if stream.server_auth_completed() {
if let Some(ref certs) = self.certs {
let mut trust = try!(stream.context().peer_trust());
try!(trust.set_anchor_certificates(certs));
let trusted = try!(trust.evaluate());
match trusted {
TrustResult::Invalid |
TrustResult::OtherError => return Err(Error::new(errSecBadReq)),
TrustResult::Proceed | TrustResult::Unspecified => {}
TrustResult::Deny => return Err(Error::new(errSecTrustSettingDeny)),
TrustResult::RecoverableTrustFailure |
TrustResult::FatalTrustFailure => {
return Err(Error::new(errSecNotTrusted));
}
}
} else {
return Err(Error::new(stream.reason()));
}
} else {
return Err(Error::new(stream.reason()));
}
result = stream.handshake();
}
Err(HandshakeError::Failure(err)) => return Err(err),
}
}
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "nightly")]
use std::io;
use std::io::prelude::*;
use std::net::TcpStream;
use super::*;
#[test]
fn connect() {
let mut ctx = p!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
p!(ctx.handshake(stream));
}
#[test]
fn connect_bad_domain() {
let mut ctx = p!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
p!(ctx.set_peer_domain_name("foobar.com"));
let stream = p!(TcpStream::connect("google.com:443"));
match ctx.handshake(stream) {
Ok(_) => panic!("expected failure"),
Err(_) => {}
}
}
#[test]
fn load_page() {
let mut ctx = p!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let mut stream = p!(ctx.handshake(stream));
p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
p!(stream.flush());
let mut buf = vec![];
p!(stream.read_to_end(&mut buf));
assert!(buf.starts_with(b"HTTP/1.0 200 OK"));
assert!(buf.ends_with(b"</html>"));
}
#[test]
fn client_bad_domain() {
let stream = p!(TcpStream::connect("google.com:443"));
assert!(ClientBuilder::new().handshake("foobar.com", stream).is_err());
}
#[test]
fn load_page_client() {
let stream = p!(TcpStream::connect("google.com:443"));
let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
p!(stream.flush());
let mut buf = vec![];
p!(stream.read_to_end(&mut buf));
assert!(buf.starts_with(b"HTTP/1.0 200 OK"));
assert!(buf.ends_with(b"</html>"));
}
#[test]
fn cipher_configuration() {
let mut ctx = p!(SslContext::new(ProtocolSide::Server, ConnectionType::Stream));
let ciphers = p!(ctx.enabled_ciphers());
let ciphers = ciphers.iter()
.enumerate()
.filter_map(|(i, c)| {
if i % 2 == 0 {
Some(*c)
} else {
None
}
})
.collect::<Vec<_>>();
p!(ctx.set_enabled_ciphers(&ciphers));
assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
}
#[test]
fn idle_context_peer_trust() {
let ctx = p!(SslContext::new(ProtocolSide::Server, ConnectionType::Stream));
assert!(ctx.peer_trust().is_err());
}
#[test]
fn peer_id() {
let mut ctx = p!(SslContext::new(ProtocolSide::Server, ConnectionType::Stream));
assert!(p!(ctx.peer_id()).is_none());
p!(ctx.set_peer_id(b"foobar"));
assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
}
#[test]
fn peer_domain_name() {
let mut ctx = p!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
assert_eq!("", p!(ctx.peer_domain_name()));
p!(ctx.set_peer_domain_name("foobar.com"));
assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
}
#[test]
#[should_panic(expected = "blammo")]
#[cfg(feature = "nightly")]
fn write_panic() {
struct ExplodingStream(TcpStream);
impl Read for ExplodingStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for ExplodingStream {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
panic!("blammo");
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let mut ctx = p!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let _ = ctx.handshake(ExplodingStream(stream));
}
#[test]
#[should_panic(expected = "blammo")]
#[cfg(feature = "nightly")]
fn read_panic() {
struct ExplodingStream(TcpStream);
impl Read for ExplodingStream {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
panic!("blammo");
}
}
impl Write for ExplodingStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let mut ctx = p!(SslContext::new(ProtocolSide::Client, ConnectionType::Stream));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let _ = ctx.handshake(ExplodingStream(stream));
}
}