1#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77
78use core_foundation::base::{Boolean, TCFType};
79#[cfg(feature = "alpn")]
80use core_foundation::string::CFString;
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use apple_security_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88};
89
90use apple_security_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::identity::SecIdentity;
106use crate::import_export::Pkcs12ImportOptions;
107use crate::policy::SecPolicy;
108use crate::trust::SecTrust;
109use crate::{cvt, AsInner};
110use apple_security_sys::base::errSecParam;
111
112#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117 pub const SERVER: Self = Self(kSSLServerSide);
119
120 pub const CLIENT: Self = Self(kSSLClientSide);
122}
123
124#[derive(Debug, Copy, Clone)]
126pub struct SslConnectionType(SSLConnectionType);
127
128impl SslConnectionType {
129 pub const STREAM: Self = Self(kSSLStreamType);
131
132 pub const DATAGRAM: Self = Self(kSSLDatagramType);
134}
135
136#[derive(Debug)]
138pub enum HandshakeError<S> {
139 Failure(Error),
141 Interrupted(MidHandshakeSslStream<S>),
143}
144
145impl<S> From<Error> for HandshakeError<S> {
146 #[inline(always)]
147 fn from(err: Error) -> Self {
148 Self::Failure(err)
149 }
150}
151
152#[derive(Debug)]
154pub enum ClientHandshakeError<S> {
155 Failure(Error),
157 Interrupted(MidHandshakeClientBuilder<S>),
159}
160
161impl<S> From<Error> for ClientHandshakeError<S> {
162 #[inline(always)]
163 fn from(err: Error) -> Self {
164 Self::Failure(err)
165 }
166}
167
168#[derive(Debug)]
170pub struct MidHandshakeSslStream<S> {
171 stream: SslStream<S>,
172 error: Error,
173}
174
175impl<S> MidHandshakeSslStream<S> {
176 #[inline(always)]
178 #[must_use]
179 pub fn get_ref(&self) -> &S {
180 self.stream.get_ref()
181 }
182
183 #[inline(always)]
185 pub fn get_mut(&mut self) -> &mut S {
186 self.stream.get_mut()
187 }
188
189 #[inline(always)]
191 #[must_use]
192 pub fn context(&self) -> &SslContext {
193 self.stream.context()
194 }
195
196 #[inline(always)]
198 pub fn context_mut(&mut self) -> &mut SslContext {
199 self.stream.context_mut()
200 }
201
202 #[inline(always)]
205 #[must_use]
206 pub fn server_auth_completed(&self) -> bool {
207 self.error.code() == errSSLPeerAuthCompleted
208 }
209
210 #[inline(always)]
213 #[must_use]
214 pub fn client_cert_requested(&self) -> bool {
215 self.error.code() == errSSLClientCertRequested
216 }
217
218 #[inline(always)]
221 #[must_use]
222 pub fn would_block(&self) -> bool {
223 self.error.code() == errSSLWouldBlock
224 }
225
226 #[inline(always)]
228 #[must_use]
229 pub fn error(&self) -> &Error {
230 &self.error
231 }
232
233 #[inline(always)]
235 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
236 self.stream.handshake()
237 }
238}
239
240#[derive(Debug)]
242pub struct MidHandshakeClientBuilder<S> {
243 stream: MidHandshakeSslStream<S>,
244 domain: Option<String>,
245 certs: Vec<SecCertificate>,
246 trust_certs_only: bool,
247 danger_accept_invalid_certs: bool,
248}
249
250impl<S> MidHandshakeClientBuilder<S> {
251 #[inline(always)]
253 #[must_use]
254 pub fn get_ref(&self) -> &S {
255 self.stream.get_ref()
256 }
257
258 #[inline(always)]
260 pub fn get_mut(&mut self) -> &mut S {
261 self.stream.get_mut()
262 }
263
264 #[inline(always)]
266 #[must_use]
267 pub fn error(&self) -> &Error {
268 self.stream.error()
269 }
270
271 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
273 let MidHandshakeClientBuilder {
274 stream,
275 domain,
276 certs,
277 trust_certs_only,
278 danger_accept_invalid_certs,
279 } = self;
280
281 let mut result = stream.handshake();
282 loop {
283 let stream = match result {
284 Ok(stream) => return Ok(stream),
285 Err(HandshakeError::Interrupted(stream)) => stream,
286 Err(HandshakeError::Failure(err)) => {
287 return Err(ClientHandshakeError::Failure(err))
288 }
289 };
290
291 if stream.would_block() {
292 let ret = MidHandshakeClientBuilder {
293 stream,
294 domain,
295 certs,
296 trust_certs_only,
297 danger_accept_invalid_certs,
298 };
299 return Err(ClientHandshakeError::Interrupted(ret));
300 }
301
302 if stream.server_auth_completed() {
303 if danger_accept_invalid_certs {
304 result = stream.handshake();
305 continue;
306 }
307 let mut trust = match stream.context().peer_trust2()? {
308 Some(trust) => trust,
309 None => {
310 result = stream.handshake();
311 continue;
312 }
313 };
314 trust.set_anchor_certificates(&certs)?;
315 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
316 let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
317 trust.set_policy(&policy)?;
318 trust.evaluate_with_error().map_err(|error| {
319 #[cfg(feature = "log")]
320 log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
321 Error::from_code(error.code() as _)
322 })?;
323 result = stream.handshake();
324 continue;
325 }
326
327 let err = Error::from_code(stream.error().code());
328 return Err(ClientHandshakeError::Failure(err));
329 }
330 }
331}
332
333#[derive(Debug, PartialEq, Eq)]
335pub struct SessionState(SSLSessionState);
336
337impl SessionState {
338 pub const IDLE: Self = Self(kSSLIdle);
340
341 pub const HANDSHAKE: Self = Self(kSSLHandshake);
343
344 pub const CONNECTED: Self = Self(kSSLConnected);
346
347 pub const CLOSED: Self = Self(kSSLClosed);
349
350 pub const ABORTED: Self = Self(kSSLAborted);
352}
353
354#[derive(Debug, Copy, Clone, PartialEq, Eq)]
356pub struct SslAuthenticate(SSLAuthenticate);
357
358impl SslAuthenticate {
359 pub const NEVER: Self = Self(kNeverAuthenticate);
361
362 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
364
365 pub const TRY: Self = Self(kTryAuthenticate);
367}
368
369#[derive(Debug, Copy, Clone, PartialEq, Eq)]
371pub struct SslClientCertificateState(SSLClientCertificateState);
372
373impl SslClientCertificateState {
374 pub const NONE: Self = Self(kSSLClientCertNone);
376
377 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
379 pub const SENT: Self = Self(kSSLClientCertSent);
381
382 pub const REJECTED: Self = Self(kSSLClientCertRejected);
384}
385
386#[derive(Debug, Copy, Clone, PartialEq, Eq)]
388pub struct SslProtocol(SSLProtocol);
389
390impl SslProtocol {
391 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
393
394 pub const SSL3: Self = Self(kSSLProtocol3);
397
398 pub const TLS1: Self = Self(kTLSProtocol1);
401
402 pub const TLS11: Self = Self(kTLSProtocol11);
405
406 pub const TLS12: Self = Self(kTLSProtocol12);
409
410 pub const TLS13: Self = Self(kTLSProtocol13);
413
414 pub const SSL2: Self = Self(kSSLProtocol2);
416
417 pub const DTLS1: Self = Self(kDTLSProtocol1);
419
420 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
422
423 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
425
426 pub const ALL: Self = Self(kSSLProtocolAll);
428}
429
430declare_TCFType! {
431 SslContext, SSLContextRef
433}
434
435impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
436
437impl fmt::Debug for SslContext {
438 #[cold]
439 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
440 let mut builder = fmt.debug_struct("SslContext");
441 if let Ok(state) = self.state() {
442 builder.field("state", &state);
443 }
444 builder.finish()
445 }
446}
447
448unsafe impl Sync for SslContext {}
449unsafe impl Send for SslContext {}
450
451impl AsInner for SslContext {
452 type Inner = SSLContextRef;
453
454 #[inline(always)]
455 fn as_inner(&self) -> SSLContextRef {
456 self.0
457 }
458}
459
460macro_rules! impl_options {
461 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
462 $(
463 $(#[$a])*
464 #[inline(always)]
465 pub fn $set(&mut self, value: bool) -> Result<()> {
466 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
467 }
468
469 $(#[$a])*
470 #[inline]
471 pub fn $get(&self) -> Result<bool> {
472 let mut value = 0;
473 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
474 Ok(value != 0)
475 }
476 )*
477 }
478}
479
480impl SslContext {
481 #[inline]
484 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
485 unsafe {
486 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
487 Ok(Self(ctx))
488 }
489 }
490
491 #[inline]
500 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
501 unsafe {
502 cvt(SSLSetPeerDomainName(
504 self.0,
505 peer_name.as_ptr().cast(),
506 peer_name.len(),
507 ))
508 }
509 }
510
511 pub fn peer_domain_name(&self) -> Result<String> {
513 unsafe {
514 let mut len = 0;
515 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
516 let mut buf = vec![0; len];
517 cvt(SSLGetPeerDomainName(
518 self.0,
519 buf.as_mut_ptr().cast(),
520 &mut len,
521 ))?;
522 Ok(String::from_utf8(buf).unwrap())
523 }
524 }
525
526 pub fn set_certificate(
534 &mut self,
535 identity: &SecIdentity,
536 certs: &[SecCertificate],
537 ) -> Result<()> {
538 let mut arr = vec![identity.as_CFType()];
539 arr.extend(certs.iter().map(|c| c.as_CFType()));
540 let certs = CFArray::from_CFTypes(&arr);
541
542 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
543 }
544
545 #[inline]
552 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
553 unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
554 }
555
556 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
558 unsafe {
559 let mut ptr = ptr::null();
560 let mut len = 0;
561 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
562 if ptr.is_null() {
563 Ok(None)
564 } else {
565 Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
566 }
567 }
568 }
569
570 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
572 unsafe {
573 let mut num_ciphers = 0;
574 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
575 let mut ciphers = vec![0; num_ciphers];
576 cvt(SSLGetSupportedCiphers(
577 self.0,
578 ciphers.as_mut_ptr(),
579 &mut num_ciphers,
580 ))?;
581 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
582 }
583 }
584
585 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
588 unsafe {
589 let mut num_ciphers = 0;
590 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
591 let mut ciphers = vec![0; num_ciphers];
592 cvt(SSLGetEnabledCiphers(
593 self.0,
594 ciphers.as_mut_ptr(),
595 &mut num_ciphers,
596 ))?;
597 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
598 }
599 }
600
601 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
603 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
604 unsafe {
605 cvt(SSLSetEnabledCiphers(
606 self.0,
607 ciphers.as_ptr(),
608 ciphers.len(),
609 ))
610 }
611 }
612
613 #[inline]
615 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
616 unsafe {
617 let mut cipher = 0;
618 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
619 Ok(CipherSuite::from_raw(cipher))
620 }
621 }
622
623 #[inline]
627 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
628 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
629 }
630
631 #[inline]
633 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
634 let mut state = 0;
635
636 unsafe {
637 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
638 }
639 Ok(SslClientCertificateState(state))
640 }
641
642 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
647 if self.state()? == SessionState::IDLE {
650 return Err(Error::from_code(errSecBadReq));
651 }
652
653 unsafe {
654 let mut trust = ptr::null_mut();
655 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
656 if trust.is_null() {
657 Ok(None)
658 } else {
659 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
660 }
661 }
662 }
663
664 #[inline]
666 pub fn state(&self) -> Result<SessionState> {
667 unsafe {
668 let mut state = 0;
669 cvt(SSLGetSessionState(self.0, &mut state))?;
670 Ok(SessionState(state))
671 }
672 }
673
674 #[inline]
676 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
677 unsafe {
678 let mut version = 0;
679 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
680 Ok(SslProtocol(version))
681 }
682 }
683
684 #[inline]
686 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
687 unsafe {
688 let mut version = 0;
689 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
690 Ok(SslProtocol(version))
691 }
692 }
693
694 #[inline]
696 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
697 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
698 }
699
700 #[inline]
702 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
703 unsafe {
704 let mut version = 0;
705 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
706 Ok(SslProtocol(version))
707 }
708 }
709
710 #[inline]
712 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
713 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
714 }
715
716 #[cfg(feature = "alpn")]
718 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
719 let mut array: CFArrayRef = ptr::null();
720 unsafe {
721 #[cfg(feature = "OSX_10_13")]
722 {
723 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
724 }
725
726 #[cfg(not(feature = "OSX_10_13"))]
727 {
728 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
729 if let Some(f) = SSLCopyALPNProtocols.get() {
730 cvt(f(self.0, &mut array))?;
731 } else {
732 return Err(Error::from_code(errSecUnimplemented));
733 }
734 }
735
736 if array.is_null() {
737 return Ok(vec![]);
738 }
739
740 let array = CFArray::<CFString>::wrap_under_create_rule(array);
741 Ok(array.into_iter().map(|p| p.to_string()).collect())
742 }
743 }
744
745 #[cfg(feature = "alpn")]
749 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
750 let protocols = CFArray::from_CFTypes(
754 &protocols
755 .iter()
756 .map(|proto| CFString::new(proto))
757 .collect::<Vec<_>>(),
758 );
759
760 #[cfg(feature = "OSX_10_13")]
761 {
762 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
763 }
764 #[cfg(not(feature = "OSX_10_13"))]
765 {
766 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
767 if let Some(f) = SSLSetALPNProtocols.get() {
768 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
769 } else {
770 Err(Error::from_code(errSecUnimplemented))
771 }
772 }
773 }
774
775 #[cfg(feature = "session-tickets")]
783 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
784 #[cfg(feature = "OSX_10_13")]
785 {
786 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
787 }
788 #[cfg(not(feature = "OSX_10_13"))]
789 {
790 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
791 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
792 unsafe { cvt(f(self.0, enabled as Boolean)) }
793 } else {
794 Err(Error::from_code(errSecUnimplemented))
795 }
796 }
797 }
798
799 #[cfg(target_os = "macos")]
807 #[deprecated(note = "use `set_protocol_version_max`")]
808 pub fn set_protocol_version_enabled(
809 &mut self,
810 protocol: SslProtocol,
811 enabled: bool,
812 ) -> Result<()> {
813 unsafe {
814 cvt(SSLSetProtocolVersionEnabled(
815 self.0,
816 protocol.0,
817 enabled as Boolean,
818 ))
819 }
820 }
821
822 #[inline]
825 pub fn buffered_read_size(&self) -> Result<usize> {
826 unsafe {
827 let mut size = 0;
828 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
829 Ok(size)
830 }
831 }
832
833 impl_options! {
834 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
837 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
840 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
843 #[cfg(feature = "OSX_10_9")]
848 const kSSLSessionOptionFalseStart: false_start & set_false_start,
849 #[cfg(feature = "OSX_10_9")]
854 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
855 }
856
857 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
858 where
859 S: Read + Write,
860 {
861 unsafe {
862 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
863 if ret != errSecSuccess {
864 return Err(Error::from_code(ret));
865 }
866
867 let stream = Connection {
868 stream,
869 err: None,
870 panic: None,
871 };
872 let stream = Box::into_raw(Box::new(stream));
873 let ret = SSLSetConnection(self.0, stream.cast());
874 if ret != errSecSuccess {
875 let _conn = Box::from_raw(stream);
876 return Err(Error::from_code(ret));
877 }
878
879 Ok(SslStream {
880 ctx: self,
881 _m: PhantomData,
882 })
883 }
884 }
885
886 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
888 where
889 S: Read + Write,
890 {
891 self.into_stream(stream)
892 .map_err(HandshakeError::Failure)
893 .and_then(SslStream::handshake)
894 }
895}
896
897struct Connection<S> {
898 stream: S,
899 err: Option<io::Error>,
900 panic: Option<Box<dyn Any + Send>>,
901}
902
903#[cold]
905fn translate_err(e: &io::Error) -> OSStatus {
906 match e.kind() {
907 io::ErrorKind::NotFound => errSSLClosedGraceful,
908 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
909 io::ErrorKind::WouldBlock |
910 io::ErrorKind::NotConnected => errSSLWouldBlock,
911 _ => errSecIO,
912 }
913}
914
915unsafe extern "C" fn read_func<S>(
916 connection: SSLConnectionRef,
917 data: *mut c_void,
918 data_length: *mut usize,
919) -> OSStatus
920where
921 S: Read,
922{
923 let conn: &mut Connection<S> = &mut *(connection as *mut _);
924 let data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
925 let mut start = 0;
926 let mut ret = errSecSuccess;
927
928 while start < data.len() {
929 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
930 Ok(Ok(0)) => {
931 ret = errSSLClosedNoNotify;
932 break;
933 }
934 Ok(Ok(len)) => start += len,
935 Ok(Err(e)) => {
936 ret = translate_err(&e);
937 conn.err = Some(e);
938 break;
939 }
940 Err(e) => {
941 ret = errSecIO;
942 conn.panic = Some(e);
943 break;
944 }
945 }
946 }
947
948 *data_length = start;
949 ret
950}
951
952unsafe extern "C" fn write_func<S>(
953 connection: SSLConnectionRef,
954 data: *const c_void,
955 data_length: *mut usize,
956) -> OSStatus
957where
958 S: Write,
959{
960 let conn: &mut Connection<S> = &mut *(connection as *mut _);
961 let data = slice::from_raw_parts(data as *mut u8, *data_length);
962 let mut start = 0;
963 let mut ret = errSecSuccess;
964
965 while start < data.len() {
966 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
967 Ok(Ok(0)) => {
968 ret = errSSLClosedNoNotify;
969 break;
970 }
971 Ok(Ok(len)) => start += len,
972 Ok(Err(e)) => {
973 ret = translate_err(&e);
974 conn.err = Some(e);
975 break;
976 }
977 Err(e) => {
978 ret = errSecIO;
979 conn.panic = Some(e);
980 break;
981 }
982 }
983 }
984
985 *data_length = start;
986 ret
987}
988
989pub struct SslStream<S> {
991 ctx: SslContext,
992 _m: PhantomData<S>,
993}
994
995impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
996 #[cold]
997 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
998 fmt.debug_struct("SslStream")
999 .field("context", &self.ctx)
1000 .field("stream", self.get_ref())
1001 .finish()
1002 }
1003}
1004
1005impl<S> Drop for SslStream<S> {
1006 fn drop(&mut self) {
1007 unsafe {
1008 let mut conn = ptr::null();
1009 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1010 assert!(ret == errSecSuccess);
1011 let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
1012 }
1013 }
1014}
1015
1016impl<S> SslStream<S> {
1017 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1018 match unsafe { SSLHandshake(self.ctx.0) } {
1019 errSecSuccess => Ok(self),
1020 reason @ errSSLPeerAuthCompleted
1021 | reason @ errSSLClientCertRequested
1022 | reason @ errSSLWouldBlock
1023 | reason @ errSSLClientHelloReceived => {
1024 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1025 stream: self,
1026 error: Error::from_code(reason),
1027 }))
1028 }
1029 err => {
1030 self.check_panic();
1031 Err(HandshakeError::Failure(Error::from_code(err)))
1032 }
1033 }
1034 }
1035
1036 #[inline(always)]
1038 #[must_use]
1039 pub fn get_ref(&self) -> &S {
1040 &self.connection().stream
1041 }
1042
1043 #[inline(always)]
1045 pub fn get_mut(&mut self) -> &mut S {
1046 &mut self.connection_mut().stream
1047 }
1048
1049 #[inline(always)]
1051 #[must_use]
1052 pub fn context(&self) -> &SslContext {
1053 &self.ctx
1054 }
1055
1056 #[inline(always)]
1058 pub fn context_mut(&mut self) -> &mut SslContext {
1059 &mut self.ctx
1060 }
1061
1062 pub fn close(&mut self) -> result::Result<(), io::Error> {
1064 unsafe {
1065 let ret = SSLClose(self.ctx.0);
1066 if ret == errSecSuccess {
1067 Ok(())
1068 } else {
1069 Err(self.get_error(ret))
1070 }
1071 }
1072 }
1073
1074 fn connection(&self) -> &Connection<S> {
1075 unsafe {
1076 let mut conn = ptr::null();
1077 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1078 assert!(ret == errSecSuccess);
1079
1080 &mut *(conn as *mut Connection<S>)
1081 }
1082 }
1083
1084 fn connection_mut(&mut self) -> &mut Connection<S> {
1085 unsafe {
1086 let mut conn = ptr::null();
1087 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1088 assert!(ret == errSecSuccess);
1089
1090 &mut *(conn as *mut Connection<S>)
1091 }
1092 }
1093
1094 #[cold]
1095 fn check_panic(&mut self) {
1096 let conn = self.connection_mut();
1097 if let Some(err) = conn.panic.take() {
1098 panic::resume_unwind(err);
1099 }
1100 }
1101
1102 #[cold]
1103 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1104 self.check_panic();
1105
1106 if let Some(err) = self.connection_mut().err.take() {
1107 err
1108 } else {
1109 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1110 }
1111 }
1112}
1113
1114impl<S: Read + Write> Read for SslStream<S> {
1115 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1116 if buf.is_empty() {
1121 return Ok(0);
1122 }
1123
1124 let buffered = self.context().buffered_read_size().unwrap_or(0);
1129 let to_read = if buffered > 0 {
1130 cmp::min(buffered, buf.len())
1131 } else {
1132 buf.len()
1133 };
1134
1135 unsafe {
1136 let mut nread = 0;
1137 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1138 if nread > 0 {
1141 return Ok(nread);
1142 }
1143
1144 match ret {
1145 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1146 errSSLPeerAuthCompleted => self.read(buf),
1148 _ => Err(self.get_error(ret)),
1149 }
1150 }
1151 }
1152}
1153
1154impl<S: Read + Write> Write for SslStream<S> {
1155 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1156 if buf.is_empty() {
1158 return Ok(0);
1159 }
1160 unsafe {
1161 let mut nwritten = 0;
1162 let ret = SSLWrite(
1163 self.ctx.0,
1164 buf.as_ptr().cast(),
1165 buf.len(),
1166 &mut nwritten,
1167 );
1168 if nwritten > 0 {
1171 Ok(nwritten)
1172 } else {
1173 Err(self.get_error(ret))
1174 }
1175 }
1176 }
1177
1178 fn flush(&mut self) -> io::Result<()> {
1179 self.connection_mut().stream.flush()
1180 }
1181}
1182
1183#[derive(Debug)]
1185pub struct ClientBuilder {
1186 identity: Option<SecIdentity>,
1187 certs: Vec<SecCertificate>,
1188 chain: Vec<SecCertificate>,
1189 protocol_min: Option<SslProtocol>,
1190 protocol_max: Option<SslProtocol>,
1191 trust_certs_only: bool,
1192 use_sni: bool,
1193 danger_accept_invalid_certs: bool,
1194 danger_accept_invalid_hostnames: bool,
1195 whitelisted_ciphers: Vec<CipherSuite>,
1196 blacklisted_ciphers: Vec<CipherSuite>,
1197 #[cfg(feature = "alpn")]
1198 alpn: Option<Vec<String>>,
1199 #[cfg(feature = "session-tickets")]
1200 enable_session_tickets: bool,
1201}
1202
1203impl Default for ClientBuilder {
1204 #[inline(always)]
1205 fn default() -> Self {
1206 Self::new()
1207 }
1208}
1209
1210impl ClientBuilder {
1211 #[inline]
1213 #[must_use]
1214 pub fn new() -> Self {
1215 Self {
1216 identity: None,
1217 certs: Vec::new(),
1218 chain: Vec::new(),
1219 protocol_min: None,
1220 protocol_max: None,
1221 trust_certs_only: false,
1222 use_sni: true,
1223 danger_accept_invalid_certs: false,
1224 danger_accept_invalid_hostnames: false,
1225 whitelisted_ciphers: Vec::new(),
1226 blacklisted_ciphers: Vec::new(),
1227 #[cfg(feature = "alpn")]
1228 alpn: None,
1229 #[cfg(feature = "session-tickets")]
1230 enable_session_tickets: false,
1231 }
1232 }
1233
1234 #[inline]
1237 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1238 self.certs = certs.to_owned();
1239 self
1240 }
1241
1242 #[inline]
1245 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1246 self.certs.push(certs.to_owned());
1247 self
1248 }
1249
1250 #[inline(always)]
1253 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1254 self.trust_certs_only = only;
1255 self
1256 }
1257
1258 #[inline(always)]
1267 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1268 self.danger_accept_invalid_certs = noverify;
1269 self
1270 }
1271
1272 #[inline(always)]
1274 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1275 self.use_sni = use_sni;
1276 self
1277 }
1278
1279 #[inline(always)]
1287 pub fn danger_accept_invalid_hostnames(
1288 &mut self,
1289 danger_accept_invalid_hostnames: bool,
1290 ) -> &mut Self {
1291 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1292 self
1293 }
1294
1295 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1297 self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1298 self
1299 }
1300
1301 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1303 self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1304 self
1305 }
1306
1307 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1309 self.identity = Some(identity.clone());
1310 self.chain = chain.to_owned();
1311 self
1312 }
1313
1314 #[inline(always)]
1316 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1317 self.protocol_min = Some(min);
1318 self
1319 }
1320
1321 #[inline(always)]
1323 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1324 self.protocol_max = Some(max);
1325 self
1326 }
1327
1328 #[cfg(feature = "alpn")]
1330 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1331 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1332 self
1333 }
1334
1335 #[cfg(feature = "session-tickets")]
1339 #[inline(always)]
1340 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1341 self.enable_session_tickets = enable;
1342 self
1343 }
1344
1345 pub fn handshake<S>(
1349 &self,
1350 domain: &str,
1351 stream: S,
1352 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1353 where
1354 S: Read + Write,
1355 {
1356 let stream = MidHandshakeSslStream {
1359 stream: self.ctx_into_stream(domain, stream)?,
1360 error: Error::from(errSecSuccess),
1361 };
1362
1363 let certs = self.certs.clone();
1364 let stream = MidHandshakeClientBuilder {
1365 stream,
1366 domain: if self.danger_accept_invalid_hostnames {
1367 None
1368 } else {
1369 Some(domain.to_string())
1370 },
1371 certs,
1372 trust_certs_only: self.trust_certs_only,
1373 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1374 };
1375 stream.handshake()
1376 }
1377
1378 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1379 where
1380 S: Read + Write,
1381 {
1382 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1383
1384 if self.use_sni {
1385 ctx.set_peer_domain_name(domain)?;
1386 }
1387 if let Some(ref identity) = self.identity {
1388 ctx.set_certificate(identity, &self.chain)?;
1389 }
1390 #[cfg(feature = "alpn")]
1391 {
1392 if let Some(ref alpn) = self.alpn {
1393 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1394 }
1395 }
1396 #[cfg(feature = "session-tickets")]
1397 {
1398 if self.enable_session_tickets {
1399 ctx.set_peer_id(domain.as_bytes())?;
1402 ctx.set_session_tickets_enabled(true)?;
1403 }
1404 }
1405 ctx.set_break_on_server_auth(true)?;
1406 self.configure_protocols(&mut ctx)?;
1407 self.configure_ciphers(&mut ctx)?;
1408
1409 ctx.into_stream(stream)
1410 }
1411
1412 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1413 if let Some(min) = self.protocol_min {
1414 ctx.set_protocol_version_min(min)?;
1415 }
1416 if let Some(max) = self.protocol_max {
1417 ctx.set_protocol_version_max(max)?;
1418 }
1419 Ok(())
1420 }
1421
1422 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1423 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1424 ctx.enabled_ciphers()?
1425 } else {
1426 self.whitelisted_ciphers.clone()
1427 };
1428
1429 if !self.blacklisted_ciphers.is_empty() {
1430 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1431 }
1432
1433 ctx.set_enabled_ciphers(&ciphers)?;
1434 Ok(())
1435 }
1436}
1437
1438#[derive(Debug)]
1440pub struct ServerBuilder {
1441 identity: SecIdentity,
1442 certs: Vec<SecCertificate>,
1443}
1444
1445impl ServerBuilder {
1446 #[must_use]
1449 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1450 Self {
1451 identity: identity.clone(),
1452 certs: certs.to_owned(),
1453 }
1454 }
1455
1456 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1463 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1464 .passphrase(passphrase)
1465 .import(pkcs12_der)?
1466 .into_iter()
1467 .filter_map(|idendity| {
1468 let certs = idendity.cert_chain.unwrap_or_default();
1469 idendity.identity.map(|identity| (identity, certs))
1470 })
1471 .collect();
1472 if identities.len() == 1 {
1473 let (identity, certs) = identities.pop().unwrap();
1474 Ok(ServerBuilder::new(&identity, &certs))
1475 } else {
1476 Err(Error::from_code(errSecParam))
1478 }
1479 }
1480
1481 pub fn new_ssl_context(&self) -> Result<SslContext> {
1483 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1484 ctx.set_certificate(&self.identity, &self.certs)?;
1485 Ok(ctx)
1486 }
1487
1488 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1490 where
1491 S: Read + Write,
1492 {
1493 match self.new_ssl_context()?.handshake(stream) {
1494 Ok(stream) => Ok(stream),
1495 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1496 Err(HandshakeError::Failure(err)) => Err(err),
1497 }
1498 }
1499}
1500
1501#[cfg(test)]
1502mod test {
1503 use std::io;
1504 use std::io::prelude::*;
1505 use std::net::TcpStream;
1506
1507 use super::*;
1508
1509 #[test]
1510 fn server_builder_from_pkcs12() {
1511 let pkcs12_der = include_bytes!("../test/server.p12");
1512 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1513 }
1514
1515 #[test]
1516 fn connect() {
1517 let mut ctx = p!(SslContext::new(
1518 SslProtocolSide::CLIENT,
1519 SslConnectionType::STREAM
1520 ));
1521 p!(ctx.set_peer_domain_name("google.com"));
1522 let stream = p!(TcpStream::connect("google.com:443"));
1523 p!(ctx.handshake(stream));
1524 }
1525
1526 #[test]
1527 fn connect_bad_domain() {
1528 let mut ctx = p!(SslContext::new(
1529 SslProtocolSide::CLIENT,
1530 SslConnectionType::STREAM
1531 ));
1532 p!(ctx.set_peer_domain_name("foobar.com"));
1533 let stream = p!(TcpStream::connect("google.com:443"));
1534 match ctx.handshake(stream) {
1535 Ok(_) => panic!("expected failure"),
1536 Err(_) => {}
1537 }
1538 }
1539
1540 #[test]
1541 fn load_page() {
1542 let mut ctx = p!(SslContext::new(
1543 SslProtocolSide::CLIENT,
1544 SslConnectionType::STREAM
1545 ));
1546 p!(ctx.set_peer_domain_name("google.com"));
1547 let stream = p!(TcpStream::connect("google.com:443"));
1548 let mut stream = p!(ctx.handshake(stream));
1549 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1550 p!(stream.flush());
1551 let mut buf = vec![];
1552 p!(stream.read_to_end(&mut buf));
1553 println!("{}", String::from_utf8_lossy(&buf));
1554 }
1555
1556 #[test]
1557 fn client_no_session_ticket_resumption() {
1558 for _ in 0..2 {
1559 let stream = p!(TcpStream::connect("google.com:443"));
1560
1561 let stream = MidHandshakeSslStream {
1563 stream: ClientBuilder::new()
1564 .ctx_into_stream("google.com", stream)
1565 .unwrap(),
1566 error: Error::from(errSecSuccess),
1567 };
1568
1569 let mut result = stream.handshake();
1570
1571 if let Err(HandshakeError::Interrupted(stream)) = result {
1572 assert!(stream.server_auth_completed());
1573 result = stream.handshake();
1574 } else {
1575 panic!("Unexpectedly skipped server auth");
1576 }
1577
1578 assert!(result.is_ok());
1579 }
1580 }
1581
1582 #[test]
1583 #[cfg(feature = "session-tickets")]
1584 fn client_session_ticket_resumption() {
1585 for i in 0..2 {
1588 let stream = p!(TcpStream::connect("google.com:443"));
1589 let mut builder = ClientBuilder::new();
1590 builder.enable_session_tickets(true);
1591
1592 let stream = MidHandshakeSslStream {
1594 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1595 error: Error::from(errSecSuccess),
1596 };
1597
1598 let mut result = stream.handshake();
1599
1600 if let Err(HandshakeError::Interrupted(stream)) = result {
1601 assert!(stream.server_auth_completed());
1602 assert_eq!(
1603 i, 0,
1604 "Session ticket resumption did not work, server auth was not skipped"
1605 );
1606 result = stream.handshake();
1607 } else {
1608 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1609 }
1610
1611 assert!(result.is_ok());
1612 }
1613 }
1614
1615 #[test]
1616 #[cfg(feature = "alpn")]
1617 fn client_alpn_accept() {
1618 let mut ctx = p!(SslContext::new(
1619 SslProtocolSide::CLIENT,
1620 SslConnectionType::STREAM
1621 ));
1622 p!(ctx.set_peer_domain_name("google.com"));
1623 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1624 let stream = p!(TcpStream::connect("google.com:443"));
1625 let stream = ctx.handshake(stream).unwrap();
1626 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1627 }
1628
1629 #[test]
1630 #[cfg(feature = "alpn")]
1631 fn client_alpn_reject() {
1632 let mut ctx = p!(SslContext::new(
1633 SslProtocolSide::CLIENT,
1634 SslConnectionType::STREAM
1635 ));
1636 p!(ctx.set_peer_domain_name("google.com"));
1637 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1638 let stream = p!(TcpStream::connect("google.com:443"));
1639 let stream = ctx.handshake(stream).unwrap();
1640 assert!(stream.context().alpn_protocols().is_err());
1641 }
1642
1643 #[test]
1644 fn client_no_anchor_certs() {
1645 let stream = p!(TcpStream::connect("google.com:443"));
1646 assert!(ClientBuilder::new()
1647 .trust_anchor_certificates_only(true)
1648 .handshake("google.com", stream)
1649 .is_err());
1650 }
1651
1652 #[test]
1653 fn client_bad_domain() {
1654 let stream = p!(TcpStream::connect("google.com:443"));
1655 assert!(ClientBuilder::new()
1656 .handshake("foobar.com", stream)
1657 .is_err());
1658 }
1659
1660 #[test]
1661 fn client_bad_domain_ignored() {
1662 let stream = p!(TcpStream::connect("google.com:443"));
1663 ClientBuilder::new()
1664 .danger_accept_invalid_hostnames(true)
1665 .handshake("foobar.com", stream)
1666 .unwrap();
1667 }
1668
1669 #[test]
1670 fn connect_no_verify_ssl() {
1671 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1672 let mut builder = ClientBuilder::new();
1673 builder.danger_accept_invalid_certs(true);
1674 builder.handshake("expired.badssl.com", stream).unwrap();
1675 }
1676
1677 #[test]
1678 fn load_page_client() {
1679 let stream = p!(TcpStream::connect("google.com:443"));
1680 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1681 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1682 p!(stream.flush());
1683 let mut buf = vec![];
1684 p!(stream.read_to_end(&mut buf));
1685 println!("{}", String::from_utf8_lossy(&buf));
1686 }
1687
1688 #[test]
1689 #[cfg_attr(target_os = "ios", ignore)] fn cipher_configuration() {
1691 let mut ctx = p!(SslContext::new(
1692 SslProtocolSide::SERVER,
1693 SslConnectionType::STREAM
1694 ));
1695 let ciphers = p!(ctx.enabled_ciphers());
1696 let ciphers = ciphers
1697 .iter()
1698 .enumerate()
1699 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1700 .collect::<Vec<_>>();
1701 p!(ctx.set_enabled_ciphers(&ciphers));
1702 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1703 }
1704
1705 #[test]
1706 fn test_builder_whitelist_ciphers() {
1707 let stream = p!(TcpStream::connect("google.com:443"));
1708
1709 let ctx = p!(SslContext::new(
1710 SslProtocolSide::CLIENT,
1711 SslConnectionType::STREAM
1712 ));
1713 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1714
1715 let ciphers = p!(ctx.enabled_ciphers());
1716 let cipher = ciphers.first().unwrap();
1717 let stream = p!(ClientBuilder::new()
1718 .whitelist_ciphers(&[*cipher])
1719 .ctx_into_stream("google.com", stream));
1720
1721 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1722 }
1723
1724 #[test]
1725 #[cfg_attr(target_os = "ios", ignore)] fn test_builder_blacklist_ciphers() {
1727 let stream = p!(TcpStream::connect("google.com:443"));
1728
1729 let ctx = p!(SslContext::new(
1730 SslProtocolSide::CLIENT,
1731 SslConnectionType::STREAM
1732 ));
1733 let num = p!(ctx.enabled_ciphers()).len();
1734 assert!(num > 1);
1735
1736 let ciphers = p!(ctx.enabled_ciphers());
1737 let cipher = ciphers.first().unwrap();
1738 let stream = p!(ClientBuilder::new()
1739 .blacklist_ciphers(&[*cipher])
1740 .ctx_into_stream("google.com", stream));
1741
1742 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1743 }
1744
1745 #[test]
1746 fn idle_context_peer_trust() {
1747 let ctx = p!(SslContext::new(
1748 SslProtocolSide::SERVER,
1749 SslConnectionType::STREAM
1750 ));
1751 assert!(ctx.peer_trust2().is_err());
1752 }
1753
1754 #[test]
1755 fn peer_id() {
1756 let mut ctx = p!(SslContext::new(
1757 SslProtocolSide::SERVER,
1758 SslConnectionType::STREAM
1759 ));
1760 assert!(p!(ctx.peer_id()).is_none());
1761 p!(ctx.set_peer_id(b"foobar"));
1762 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1763 }
1764
1765 #[test]
1766 fn peer_domain_name() {
1767 let mut ctx = p!(SslContext::new(
1768 SslProtocolSide::CLIENT,
1769 SslConnectionType::STREAM
1770 ));
1771 assert_eq!("", p!(ctx.peer_domain_name()));
1772 p!(ctx.set_peer_domain_name("foobar.com"));
1773 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1774 }
1775
1776 #[test]
1777 #[should_panic(expected = "blammo")]
1778 fn write_panic() {
1779 struct ExplodingStream(TcpStream);
1780
1781 impl Read for ExplodingStream {
1782 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1783 self.0.read(buf)
1784 }
1785 }
1786
1787 impl Write for ExplodingStream {
1788 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1789 panic!("blammo");
1790 }
1791
1792 fn flush(&mut self) -> io::Result<()> {
1793 self.0.flush()
1794 }
1795 }
1796
1797 let mut ctx = p!(SslContext::new(
1798 SslProtocolSide::CLIENT,
1799 SslConnectionType::STREAM
1800 ));
1801 p!(ctx.set_peer_domain_name("google.com"));
1802 let stream = p!(TcpStream::connect("google.com:443"));
1803 let _ = ctx.handshake(ExplodingStream(stream));
1804 }
1805
1806 #[test]
1807 #[should_panic(expected = "blammo")]
1808 fn read_panic() {
1809 struct ExplodingStream(TcpStream);
1810
1811 impl Read for ExplodingStream {
1812 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1813 panic!("blammo");
1814 }
1815 }
1816
1817 impl Write for ExplodingStream {
1818 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1819 self.0.write(buf)
1820 }
1821
1822 fn flush(&mut self) -> io::Result<()> {
1823 self.0.flush()
1824 }
1825 }
1826
1827 let mut ctx = p!(SslContext::new(
1828 SslProtocolSide::CLIENT,
1829 SslConnectionType::STREAM
1830 ));
1831 p!(ctx.set_peer_domain_name("google.com"));
1832 let stream = p!(TcpStream::connect("google.com:443"));
1833 let _ = ctx.handshake(ExplodingStream(stream));
1834 }
1835
1836 #[test]
1837 fn zero_length_buffers() {
1838 let mut ctx = p!(SslContext::new(
1839 SslProtocolSide::CLIENT,
1840 SslConnectionType::STREAM
1841 ));
1842 p!(ctx.set_peer_domain_name("google.com"));
1843 let stream = p!(TcpStream::connect("google.com:443"));
1844 let mut stream = ctx.handshake(stream).unwrap();
1845 assert_eq!(stream.write(b"").unwrap(), 0);
1846 assert_eq!(stream.read(&mut []).unwrap(), 0);
1847 }
1848}