1use std::{
73 any::Any,
74 cmp, fmt,
75 io::{
76 prelude::*,
77 {self},
78 },
79 marker::PhantomData,
80 os::raw::c_void,
81 panic::{
82 AssertUnwindSafe, {self},
83 },
84 ptr, result, slice,
85};
86
87#[allow(unused_imports)]
88use core_foundation::array::CFArray;
89#[allow(unused_imports)]
90use core_foundation::array::CFArrayRef;
91use core_foundation::base::{Boolean, TCFType};
92#[cfg(feature = "alpn")]
93use core_foundation::string::CFString;
94use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
95#[allow(unused_imports)]
96use security_framework_sys::base::errSecBadReq;
97#[allow(unused_imports)]
98use security_framework_sys::base::errSecIO;
99#[allow(unused_imports)]
100use security_framework_sys::base::errSecNotTrusted;
101#[allow(unused_imports)]
102use security_framework_sys::base::errSecSuccess;
103#[allow(unused_imports)]
104use security_framework_sys::base::errSecTrustSettingDeny;
105#[allow(unused_imports)]
106use security_framework_sys::base::errSecUnimplemented;
107use security_framework_sys::{base::errSecParam, secure_transport::*};
108
109use crate::{
110 base::{Error, Result},
111 certificate::SecCertificate,
112 cipher_suite::CipherSuite,
113 cvt,
114 identity::SecIdentity,
115 import_export::Pkcs12ImportOptions,
116 policy::SecPolicy,
117 trust::SecTrust,
118 AsInner,
119};
120
121#[derive(Debug, Copy, Clone, PartialEq, Eq)]
123pub struct SslProtocolSide(SSLProtocolSide);
124
125impl SslProtocolSide {
126 pub const SERVER: Self = Self(kSSLServerSide);
128
129 pub const CLIENT: Self = Self(kSSLClientSide);
131}
132
133#[derive(Debug, Copy, Clone)]
135pub struct SslConnectionType(SSLConnectionType);
136
137impl SslConnectionType {
138 pub const STREAM: Self = Self(kSSLStreamType);
140
141 pub const DATAGRAM: Self = Self(kSSLDatagramType);
143}
144
145#[derive(Debug)]
147pub enum HandshakeError<S> {
148 Failure(Error),
150
151 Interrupted(MidHandshakeSslStream<S>),
153}
154
155impl<S> From<Error> for HandshakeError<S> {
156 #[inline(always)]
157 fn from(err: Error) -> Self {
158 Self::Failure(err)
159 }
160}
161
162#[derive(Debug)]
164pub enum ClientHandshakeError<S> {
165 Failure(Error),
167
168 Interrupted(MidHandshakeClientBuilder<S>),
170}
171
172impl<S> From<Error> for ClientHandshakeError<S> {
173 #[inline(always)]
174 fn from(err: Error) -> Self {
175 Self::Failure(err)
176 }
177}
178
179#[derive(Debug)]
181pub struct MidHandshakeSslStream<S> {
182 stream: SslStream<S>,
183 error: Error,
184}
185
186impl<S> MidHandshakeSslStream<S> {
187 #[inline(always)]
189 #[must_use]
190 pub fn get_ref(&self) -> &S {
191 self.stream.get_ref()
192 }
193
194 #[inline(always)]
196 pub fn get_mut(&mut self) -> &mut S {
197 self.stream.get_mut()
198 }
199
200 #[inline(always)]
202 #[must_use]
203 pub fn context(&self) -> &SslContext {
204 self.stream.context()
205 }
206
207 #[inline(always)]
209 pub fn context_mut(&mut self) -> &mut SslContext {
210 self.stream.context_mut()
211 }
212
213 #[inline(always)]
216 #[must_use]
217 pub fn server_auth_completed(&self) -> bool {
218 self.error.code() == errSSLPeerAuthCompleted
219 }
220
221 #[inline(always)]
224 #[must_use]
225 pub fn client_cert_requested(&self) -> bool {
226 self.error.code() == errSSLClientCertRequested
227 }
228
229 #[inline(always)]
232 #[must_use]
233 pub fn would_block(&self) -> bool {
234 self.error.code() == errSSLWouldBlock
235 }
236
237 #[inline(always)]
239 #[must_use]
240 pub fn error(&self) -> &Error {
241 &self.error
242 }
243
244 #[inline(always)]
246 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
247 self.stream.handshake()
248 }
249}
250
251#[derive(Debug)]
253pub struct MidHandshakeClientBuilder<S> {
254 stream: MidHandshakeSslStream<S>,
255 domain: Option<String>,
256 certs: Vec<SecCertificate>,
257 trust_certs_only: bool,
258 danger_accept_invalid_certs: bool,
259}
260
261impl<S> MidHandshakeClientBuilder<S> {
262 #[inline(always)]
264 #[must_use]
265 pub fn get_ref(&self) -> &S {
266 self.stream.get_ref()
267 }
268
269 #[inline(always)]
271 pub fn get_mut(&mut self) -> &mut S {
272 self.stream.get_mut()
273 }
274
275 #[inline(always)]
277 #[must_use]
278 pub fn error(&self) -> &Error {
279 self.stream.error()
280 }
281
282 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
284 let MidHandshakeClientBuilder {
285 stream,
286 domain,
287 certs,
288 trust_certs_only,
289 danger_accept_invalid_certs,
290 } = self;
291
292 let mut result = stream.handshake();
293 loop {
294 let stream = match result {
295 Ok(stream) => return Ok(stream),
296 Err(HandshakeError::Interrupted(stream)) => stream,
297 Err(HandshakeError::Failure(err)) => {
298 return Err(ClientHandshakeError::Failure(err))
299 }
300 };
301
302 if stream.would_block() {
303 let ret = MidHandshakeClientBuilder {
304 stream,
305 domain,
306 certs,
307 trust_certs_only,
308 danger_accept_invalid_certs,
309 };
310 return Err(ClientHandshakeError::Interrupted(ret));
311 }
312
313 if stream.server_auth_completed() {
314 if danger_accept_invalid_certs {
315 result = stream.handshake();
316 continue;
317 }
318 let mut trust = match stream.context().peer_trust2()? {
319 Some(trust) => trust,
320 None => {
321 result = stream.handshake();
322 continue;
323 }
324 };
325 trust.set_anchor_certificates(&certs)?;
326 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
327 let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
328 trust.set_policy(&policy)?;
329 trust.evaluate_with_error().map_err(|error| {
330 #[cfg(feature = "log")]
331 log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
332 Error::from_code(error.code() as _)
333 })?;
334 result = stream.handshake();
335 continue;
336 }
337
338 let err = Error::from_code(stream.error().code());
339 return Err(ClientHandshakeError::Failure(err));
340 }
341 }
342}
343
344#[derive(Debug, PartialEq, Eq)]
346pub struct SessionState(SSLSessionState);
347
348impl SessionState {
349 pub const IDLE: Self = Self(kSSLIdle);
351
352 pub const HANDSHAKE: Self = Self(kSSLHandshake);
354
355 pub const CONNECTED: Self = Self(kSSLConnected);
357
358 pub const CLOSED: Self = Self(kSSLClosed);
360
361 pub const ABORTED: Self = Self(kSSLAborted);
363}
364
365#[derive(Debug, Copy, Clone, PartialEq, Eq)]
367pub struct SslAuthenticate(SSLAuthenticate);
368
369impl SslAuthenticate {
370 pub const NEVER: Self = Self(kNeverAuthenticate);
372
373 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
375
376 pub const TRY: Self = Self(kTryAuthenticate);
378}
379
380#[derive(Debug, Copy, Clone, PartialEq, Eq)]
382pub struct SslClientCertificateState(SSLClientCertificateState);
383
384impl SslClientCertificateState {
385 pub const NONE: Self = Self(kSSLClientCertNone);
387
388 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
390 pub const SENT: Self = Self(kSSLClientCertSent);
392
393 pub const REJECTED: Self = Self(kSSLClientCertRejected);
395}
396
397#[derive(Debug, Copy, Clone, PartialEq, Eq)]
399pub struct SslProtocol(SSLProtocol);
400
401impl SslProtocol {
402 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
404
405 pub const SSL3: Self = Self(kSSLProtocol3);
408
409 pub const TLS1: Self = Self(kTLSProtocol1);
412
413 pub const TLS11: Self = Self(kTLSProtocol11);
416
417 pub const TLS12: Self = Self(kTLSProtocol12);
420
421 pub const TLS13: Self = Self(kTLSProtocol13);
424
425 pub const SSL2: Self = Self(kSSLProtocol2);
427
428 pub const DTLS1: Self = Self(kDTLSProtocol1);
430
431 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
433
434 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
436
437 pub const ALL: Self = Self(kSSLProtocolAll);
439}
440
441declare_TCFType! {
442 SslContext, SSLContextRef
444}
445
446impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
447
448impl fmt::Debug for SslContext {
449 #[cold]
450 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
451 let mut builder = fmt.debug_struct("SslContext");
452 if let Ok(state) = self.state() {
453 builder.field("state", &state);
454 }
455 builder.finish()
456 }
457}
458
459unsafe impl Sync for SslContext {}
460unsafe impl Send for SslContext {}
461
462impl AsInner for SslContext {
463 type Inner = SSLContextRef;
464
465 #[inline(always)]
466 fn as_inner(&self) -> SSLContextRef {
467 self.0
468 }
469}
470
471macro_rules! impl_options {
472 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
473 $(
474 $(#[$a])*
475 #[inline(always)]
476 pub fn $set(&mut self, value: bool) -> Result<()> {
477 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
478 }
479
480 $(#[$a])*
481 #[inline]
482 pub fn $get(&self) -> Result<bool> {
483 let mut value = 0;
484 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
485 Ok(value != 0)
486 }
487 )*
488 }
489}
490
491impl SslContext {
492 #[inline]
495 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
496 unsafe {
497 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
498 Ok(Self(ctx))
499 }
500 }
501
502 #[inline]
511 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
512 unsafe {
513 cvt(SSLSetPeerDomainName(
515 self.0,
516 peer_name.as_ptr().cast(),
517 peer_name.len(),
518 ))
519 }
520 }
521
522 pub fn peer_domain_name(&self) -> Result<String> {
524 unsafe {
525 let mut len = 0;
526 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
527 let mut buf = vec![0; len];
528 cvt(SSLGetPeerDomainName(
529 self.0,
530 buf.as_mut_ptr().cast(),
531 &mut len,
532 ))?;
533 Ok(String::from_utf8(buf).unwrap())
534 }
535 }
536
537 pub fn set_certificate(
545 &mut self,
546 identity: &SecIdentity,
547 certs: &[SecCertificate],
548 ) -> Result<()> {
549 let mut arr = vec![identity.as_CFType()];
550 arr.extend(certs.iter().map(|c| c.as_CFType()));
551 let certs = CFArray::from_CFTypes(&arr);
552
553 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
554 }
555
556 #[inline]
563 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
564 unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
565 }
566
567 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
569 unsafe {
570 let mut ptr = ptr::null();
571 let mut len = 0;
572 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
573 if ptr.is_null() {
574 Ok(None)
575 } else {
576 Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
577 }
578 }
579 }
580
581 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
583 unsafe {
584 let mut num_ciphers = 0;
585 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
586 let mut ciphers = vec![0; num_ciphers];
587 cvt(SSLGetSupportedCiphers(
588 self.0,
589 ciphers.as_mut_ptr(),
590 &mut num_ciphers,
591 ))?;
592 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
593 }
594 }
595
596 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
599 unsafe {
600 let mut num_ciphers = 0;
601 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
602 let mut ciphers = vec![0; num_ciphers];
603 cvt(SSLGetEnabledCiphers(
604 self.0,
605 ciphers.as_mut_ptr(),
606 &mut num_ciphers,
607 ))?;
608 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
609 }
610 }
611
612 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
614 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
615 unsafe {
616 cvt(SSLSetEnabledCiphers(
617 self.0,
618 ciphers.as_ptr(),
619 ciphers.len(),
620 ))
621 }
622 }
623
624 #[inline]
626 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
627 unsafe {
628 let mut cipher = 0;
629 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
630 Ok(CipherSuite::from_raw(cipher))
631 }
632 }
633
634 #[inline]
638 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
639 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
640 }
641
642 #[inline]
644 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
645 let mut state = 0;
646
647 unsafe {
648 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
649 }
650 Ok(SslClientCertificateState(state))
651 }
652
653 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
658 if self.state()? == SessionState::IDLE {
661 return Err(Error::from_code(errSecBadReq));
662 }
663
664 unsafe {
665 let mut trust = ptr::null_mut();
666 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
667 if trust.is_null() {
668 Ok(None)
669 } else {
670 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
671 }
672 }
673 }
674
675 #[inline]
677 pub fn state(&self) -> Result<SessionState> {
678 unsafe {
679 let mut state = 0;
680 cvt(SSLGetSessionState(self.0, &mut state))?;
681 Ok(SessionState(state))
682 }
683 }
684
685 #[inline]
687 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
688 unsafe {
689 let mut version = 0;
690 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
691 Ok(SslProtocol(version))
692 }
693 }
694
695 #[inline]
697 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
698 unsafe {
699 let mut version = 0;
700 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
701 Ok(SslProtocol(version))
702 }
703 }
704
705 #[inline]
707 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
708 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
709 }
710
711 #[inline]
713 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
714 unsafe {
715 let mut version = 0;
716 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
717 Ok(SslProtocol(version))
718 }
719 }
720
721 #[inline]
723 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
724 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
725 }
726
727 #[cfg(feature = "alpn")]
729 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
730 let mut array: CFArrayRef = ptr::null();
731 unsafe {
732 #[cfg(feature = "OSX_10_13")]
733 {
734 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
735 }
736
737 #[cfg(not(feature = "OSX_10_13"))]
738 {
739 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
740 if let Some(f) = SSLCopyALPNProtocols.get() {
741 cvt(f(self.0, &mut array))?;
742 } else {
743 return Err(Error::from_code(errSecUnimplemented));
744 }
745 }
746
747 if array.is_null() {
748 return Ok(vec![]);
749 }
750
751 let array = CFArray::<CFString>::wrap_under_create_rule(array);
752 Ok(array.into_iter().map(|p| p.to_string()).collect())
753 }
754 }
755
756 #[cfg(feature = "alpn")]
760 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
761 let protocols = CFArray::from_CFTypes(
765 &protocols
766 .iter()
767 .map(|proto| CFString::new(proto))
768 .collect::<Vec<_>>(),
769 );
770
771 #[cfg(feature = "OSX_10_13")]
772 {
773 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
774 }
775 #[cfg(not(feature = "OSX_10_13"))]
776 {
777 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
778 if let Some(f) = SSLSetALPNProtocols.get() {
779 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
780 } else {
781 Err(Error::from_code(errSecUnimplemented))
782 }
783 }
784 }
785
786 #[cfg(feature = "session-tickets")]
794 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
795 #[cfg(feature = "OSX_10_13")]
796 {
797 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
798 }
799 #[cfg(not(feature = "OSX_10_13"))]
800 {
801 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
802 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
803 unsafe { cvt(f(self.0, enabled as Boolean)) }
804 } else {
805 Err(Error::from_code(errSecUnimplemented))
806 }
807 }
808 }
809
810 #[cfg(target_os = "macos")]
818 #[deprecated(note = "use `set_protocol_version_max`")]
819 pub fn set_protocol_version_enabled(
820 &mut self,
821 protocol: SslProtocol,
822 enabled: bool,
823 ) -> Result<()> {
824 unsafe {
825 cvt(SSLSetProtocolVersionEnabled(
826 self.0,
827 protocol.0,
828 enabled as Boolean,
829 ))
830 }
831 }
832
833 #[inline]
836 pub fn buffered_read_size(&self) -> Result<usize> {
837 unsafe {
838 let mut size = 0;
839 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
840 Ok(size)
841 }
842 }
843
844 impl_options! {
845 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
848 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
851 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
854 #[cfg(feature = "OSX_10_9")]
859 const kSSLSessionOptionFalseStart: false_start & set_false_start,
860 #[cfg(feature = "OSX_10_9")]
865 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
866 }
867
868 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
869 where
870 S: Read + Write,
871 {
872 unsafe {
873 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
874 if ret != errSecSuccess {
875 return Err(Error::from_code(ret));
876 }
877
878 let stream = Connection {
879 stream,
880 err: None,
881 panic: None,
882 };
883 let stream = Box::into_raw(Box::new(stream));
884 let ret = SSLSetConnection(self.0, stream.cast());
885 if ret != errSecSuccess {
886 let _conn = Box::from_raw(stream);
887 return Err(Error::from_code(ret));
888 }
889
890 Ok(SslStream {
891 ctx: self,
892 _m: PhantomData,
893 })
894 }
895 }
896
897 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
899 where
900 S: Read + Write,
901 {
902 self.into_stream(stream)
903 .map_err(HandshakeError::Failure)
904 .and_then(SslStream::handshake)
905 }
906}
907
908struct Connection<S> {
909 stream: S,
910 err: Option<io::Error>,
911 panic: Option<Box<dyn Any + Send>>,
912}
913
914#[cold]
916fn translate_err(e: &io::Error) -> OSStatus {
917 match e.kind() {
918 io::ErrorKind::NotFound => errSSLClosedGraceful,
919 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
920 io::ErrorKind::WouldBlock | io::ErrorKind::NotConnected => errSSLWouldBlock,
921 _ => errSecIO,
922 }
923}
924
925unsafe extern "C" fn read_func<S>(
926 connection: SSLConnectionRef,
927 data: *mut c_void,
928 data_length: *mut usize,
929) -> OSStatus
930where
931 S: Read,
932{
933 let conn: &mut Connection<S> = &mut *(connection as *mut _);
934 let data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
935 let mut start = 0;
936 let mut ret = errSecSuccess;
937
938 while start < data.len() {
939 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
940 Ok(Ok(0)) => {
941 ret = errSSLClosedNoNotify;
942 break;
943 }
944 Ok(Ok(len)) => start += len,
945 Ok(Err(e)) => {
946 ret = translate_err(&e);
947 conn.err = Some(e);
948 break;
949 }
950 Err(e) => {
951 ret = errSecIO;
952 conn.panic = Some(e);
953 break;
954 }
955 }
956 }
957
958 *data_length = start;
959 ret
960}
961
962unsafe extern "C" fn write_func<S>(
963 connection: SSLConnectionRef,
964 data: *const c_void,
965 data_length: *mut usize,
966) -> OSStatus
967where
968 S: Write,
969{
970 let conn: &mut Connection<S> = &mut *(connection as *mut _);
971 let data = slice::from_raw_parts(data as *mut u8, *data_length);
972 let mut start = 0;
973 let mut ret = errSecSuccess;
974
975 while start < data.len() {
976 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
977 Ok(Ok(0)) => {
978 ret = errSSLClosedNoNotify;
979 break;
980 }
981 Ok(Ok(len)) => start += len,
982 Ok(Err(e)) => {
983 ret = translate_err(&e);
984 conn.err = Some(e);
985 break;
986 }
987 Err(e) => {
988 ret = errSecIO;
989 conn.panic = Some(e);
990 break;
991 }
992 }
993 }
994
995 *data_length = start;
996 ret
997}
998
999pub struct SslStream<S> {
1001 ctx: SslContext,
1002 _m: PhantomData<S>,
1003}
1004
1005impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
1006 #[cold]
1007 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1008 fmt.debug_struct("SslStream")
1009 .field("context", &self.ctx)
1010 .field("stream", self.get_ref())
1011 .finish()
1012 }
1013}
1014
1015impl<S> Drop for SslStream<S> {
1016 fn drop(&mut self) {
1017 unsafe {
1018 let mut conn = ptr::null();
1019 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1020 assert!(ret == errSecSuccess);
1021 let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
1022 }
1023 }
1024}
1025
1026impl<S> SslStream<S> {
1027 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1028 match unsafe { SSLHandshake(self.ctx.0) } {
1029 errSecSuccess => Ok(self),
1030 reason @ errSSLPeerAuthCompleted
1031 | reason @ errSSLClientCertRequested
1032 | reason @ errSSLWouldBlock
1033 | reason @ errSSLClientHelloReceived => {
1034 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1035 stream: self,
1036 error: Error::from_code(reason),
1037 }))
1038 }
1039 err => {
1040 self.check_panic();
1041 Err(HandshakeError::Failure(Error::from_code(err)))
1042 }
1043 }
1044 }
1045
1046 #[inline(always)]
1048 #[must_use]
1049 pub fn get_ref(&self) -> &S {
1050 &self.connection().stream
1051 }
1052
1053 #[inline(always)]
1055 pub fn get_mut(&mut self) -> &mut S {
1056 &mut self.connection_mut().stream
1057 }
1058
1059 #[inline(always)]
1061 #[must_use]
1062 pub fn context(&self) -> &SslContext {
1063 &self.ctx
1064 }
1065
1066 #[inline(always)]
1068 pub fn context_mut(&mut self) -> &mut SslContext {
1069 &mut self.ctx
1070 }
1071
1072 pub fn close(&mut self) -> result::Result<(), io::Error> {
1074 unsafe {
1075 let ret = SSLClose(self.ctx.0);
1076 if ret == errSecSuccess {
1077 Ok(())
1078 } else {
1079 Err(self.get_error(ret))
1080 }
1081 }
1082 }
1083
1084 fn connection(&self) -> &Connection<S> {
1085 unsafe {
1086 let mut conn = ptr::null();
1087 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1088 assert!(ret == errSecSuccess);
1089
1090 &mut *(conn as *mut Connection<S>)
1091 }
1092 }
1093
1094 fn connection_mut(&mut self) -> &mut Connection<S> {
1095 unsafe {
1096 let mut conn = ptr::null();
1097 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1098 assert!(ret == errSecSuccess);
1099
1100 &mut *(conn as *mut Connection<S>)
1101 }
1102 }
1103
1104 #[cold]
1105 fn check_panic(&mut self) {
1106 let conn = self.connection_mut();
1107 if let Some(err) = conn.panic.take() {
1108 panic::resume_unwind(err);
1109 }
1110 }
1111
1112 #[cold]
1113 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1114 self.check_panic();
1115
1116 if let Some(err) = self.connection_mut().err.take() {
1117 err
1118 } else {
1119 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1120 }
1121 }
1122}
1123
1124impl<S: Read + Write> Read for SslStream<S> {
1125 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1126 if buf.is_empty() {
1131 return Ok(0);
1132 }
1133
1134 let buffered = self.context().buffered_read_size().unwrap_or(0);
1139 let to_read = if buffered > 0 {
1140 cmp::min(buffered, buf.len())
1141 } else {
1142 buf.len()
1143 };
1144
1145 unsafe {
1146 let mut n_read = 0;
1147 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut n_read);
1148 if n_read > 0 {
1151 return Ok(n_read);
1152 }
1153
1154 match ret {
1155 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1156 errSSLPeerAuthCompleted => self.read(buf),
1158 _ => Err(self.get_error(ret)),
1159 }
1160 }
1161 }
1162}
1163
1164impl<S: Read + Write> Write for SslStream<S> {
1165 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1166 if buf.is_empty() {
1168 return Ok(0);
1169 }
1170 unsafe {
1171 let mut n_written = 0;
1172 let ret = SSLWrite(self.ctx.0, buf.as_ptr().cast(), buf.len(), &mut n_written);
1173 if n_written > 0 {
1176 Ok(n_written)
1177 } else {
1178 Err(self.get_error(ret))
1179 }
1180 }
1181 }
1182
1183 fn flush(&mut self) -> io::Result<()> {
1184 self.connection_mut().stream.flush()
1185 }
1186}
1187
1188#[derive(Debug)]
1190pub struct ClientBuilder {
1191 identity: Option<SecIdentity>,
1192 certs: Vec<SecCertificate>,
1193 chain: Vec<SecCertificate>,
1194 protocol_min: Option<SslProtocol>,
1195 protocol_max: Option<SslProtocol>,
1196 trust_certs_only: bool,
1197 use_sni: bool,
1198 danger_accept_invalid_certs: bool,
1199 danger_accept_invalid_hostnames: bool,
1200 whitelisted_ciphers: Vec<CipherSuite>,
1201 blacklisted_ciphers: Vec<CipherSuite>,
1202 #[cfg(feature = "alpn")]
1203 alpn: Option<Vec<String>>,
1204 #[cfg(feature = "session-tickets")]
1205 enable_session_tickets: bool,
1206}
1207
1208impl Default for ClientBuilder {
1209 #[inline(always)]
1210 fn default() -> Self {
1211 Self::new()
1212 }
1213}
1214
1215impl ClientBuilder {
1216 #[inline]
1218 #[must_use]
1219 pub fn new() -> Self {
1220 Self {
1221 identity: None,
1222 certs: Vec::new(),
1223 chain: Vec::new(),
1224 protocol_min: None,
1225 protocol_max: None,
1226 trust_certs_only: false,
1227 use_sni: true,
1228 danger_accept_invalid_certs: false,
1229 danger_accept_invalid_hostnames: false,
1230 whitelisted_ciphers: Vec::new(),
1231 blacklisted_ciphers: Vec::new(),
1232 #[cfg(feature = "alpn")]
1233 alpn: None,
1234 #[cfg(feature = "session-tickets")]
1235 enable_session_tickets: false,
1236 }
1237 }
1238
1239 #[inline]
1242 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1243 self.certs = certs.to_owned();
1244 self
1245 }
1246
1247 #[inline]
1250 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1251 self.certs.push(certs.to_owned());
1252 self
1253 }
1254
1255 #[inline(always)]
1258 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1259 self.trust_certs_only = only;
1260 self
1261 }
1262
1263 #[inline(always)]
1272 pub fn danger_accept_invalid_certs(&mut self, no_verify: bool) -> &mut Self {
1273 self.danger_accept_invalid_certs = no_verify;
1274 self
1275 }
1276
1277 #[inline(always)]
1279 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1280 self.use_sni = use_sni;
1281 self
1282 }
1283
1284 #[inline(always)]
1292 pub fn danger_accept_invalid_hostnames(
1293 &mut self,
1294 danger_accept_invalid_hostnames: bool,
1295 ) -> &mut Self {
1296 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1297 self
1298 }
1299
1300 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1302 self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1303 self
1304 }
1305
1306 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1308 self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1309 self
1310 }
1311
1312 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1314 self.identity = Some(identity.clone());
1315 self.chain = chain.to_owned();
1316 self
1317 }
1318
1319 #[inline(always)]
1321 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1322 self.protocol_min = Some(min);
1323 self
1324 }
1325
1326 #[inline(always)]
1328 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1329 self.protocol_max = Some(max);
1330 self
1331 }
1332
1333 #[cfg(feature = "alpn")]
1335 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1336 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1337 self
1338 }
1339
1340 #[cfg(feature = "session-tickets")]
1344 #[inline(always)]
1345 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1346 self.enable_session_tickets = enable;
1347 self
1348 }
1349
1350 pub fn handshake<S>(
1354 &self,
1355 domain: &str,
1356 stream: S,
1357 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1358 where
1359 S: Read + Write,
1360 {
1361 let stream = MidHandshakeSslStream {
1364 stream: self.ctx_into_stream(domain, stream)?,
1365 error: Error::from(errSecSuccess),
1366 };
1367
1368 let certs = self.certs.clone();
1369 let stream = MidHandshakeClientBuilder {
1370 stream,
1371 domain: if self.danger_accept_invalid_hostnames {
1372 None
1373 } else {
1374 Some(domain.to_string())
1375 },
1376 certs,
1377 trust_certs_only: self.trust_certs_only,
1378 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1379 };
1380 stream.handshake()
1381 }
1382
1383 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1384 where
1385 S: Read + Write,
1386 {
1387 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1388
1389 if self.use_sni {
1390 ctx.set_peer_domain_name(domain)?;
1391 }
1392 if let Some(ref identity) = self.identity {
1393 ctx.set_certificate(identity, &self.chain)?;
1394 }
1395 #[cfg(feature = "alpn")]
1396 {
1397 if let Some(ref alpn) = self.alpn {
1398 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1399 }
1400 }
1401 #[cfg(feature = "session-tickets")]
1402 {
1403 if self.enable_session_tickets {
1404 ctx.set_peer_id(domain.as_bytes())?;
1407 ctx.set_session_tickets_enabled(true)?;
1408 }
1409 }
1410 ctx.set_break_on_server_auth(true)?;
1411 self.configure_protocols(&mut ctx)?;
1412 self.configure_ciphers(&mut ctx)?;
1413
1414 ctx.into_stream(stream)
1415 }
1416
1417 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1418 if let Some(min) = self.protocol_min {
1419 ctx.set_protocol_version_min(min)?;
1420 }
1421 if let Some(max) = self.protocol_max {
1422 ctx.set_protocol_version_max(max)?;
1423 }
1424 Ok(())
1425 }
1426
1427 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1428 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1429 ctx.enabled_ciphers()?
1430 } else {
1431 self.whitelisted_ciphers.clone()
1432 };
1433
1434 if !self.blacklisted_ciphers.is_empty() {
1435 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1436 }
1437
1438 ctx.set_enabled_ciphers(&ciphers)?;
1439 Ok(())
1440 }
1441}
1442
1443#[derive(Debug)]
1445pub struct ServerBuilder {
1446 identity: SecIdentity,
1447 certs: Vec<SecCertificate>,
1448}
1449
1450impl ServerBuilder {
1451 #[must_use]
1454 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1455 Self {
1456 identity: identity.clone(),
1457 certs: certs.to_owned(),
1458 }
1459 }
1460
1461 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1468 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1469 .passphrase(passphrase)
1470 .import(pkcs12_der)?
1471 .into_iter()
1472 .filter_map(|identity| {
1473 let certs = identity.cert_chain.unwrap_or_default();
1474 identity.identity.map(|identity| (identity, certs))
1475 })
1476 .collect();
1477 if identities.len() == 1 {
1478 let (identity, certs) = identities.pop().unwrap();
1479 Ok(ServerBuilder::new(&identity, &certs))
1480 } else {
1481 Err(Error::from_code(errSecParam))
1483 }
1484 }
1485
1486 pub fn new_ssl_context(&self) -> Result<SslContext> {
1488 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1489 ctx.set_certificate(&self.identity, &self.certs)?;
1490 Ok(ctx)
1491 }
1492
1493 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1495 where
1496 S: Read + Write,
1497 {
1498 match self.new_ssl_context()?.handshake(stream) {
1499 Ok(stream) => Ok(stream),
1500 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1501 Err(HandshakeError::Failure(err)) => Err(err),
1502 }
1503 }
1504}
1505
1506#[cfg(test)]
1507mod test {
1508 use std::{io, io::prelude::*, net::TcpStream};
1509
1510 use super::*;
1511
1512 #[test]
1513 fn server_builder_from_pkcs12() {
1514 let pkcs12_der = include_bytes!("../test/server.p12");
1515 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1516 }
1517
1518 #[test]
1519 fn connect() {
1520 let mut ctx = p!(SslContext::new(
1521 SslProtocolSide::CLIENT,
1522 SslConnectionType::STREAM
1523 ));
1524 p!(ctx.set_peer_domain_name("google.com"));
1525 let stream = p!(TcpStream::connect("google.com:443"));
1526 p!(ctx.handshake(stream));
1527 }
1528
1529 #[test]
1530 fn connect_bad_domain() {
1531 let mut ctx = p!(SslContext::new(
1532 SslProtocolSide::CLIENT,
1533 SslConnectionType::STREAM
1534 ));
1535 p!(ctx.set_peer_domain_name("foobar.com"));
1536 let stream = p!(TcpStream::connect("google.com:443"));
1537 if ctx.handshake(stream).is_ok() {
1538 panic!("expected failure")
1539 }
1540 }
1541
1542 #[test]
1543 fn load_page() {
1544 let mut ctx = p!(SslContext::new(
1545 SslProtocolSide::CLIENT,
1546 SslConnectionType::STREAM
1547 ));
1548 p!(ctx.set_peer_domain_name("google.com"));
1549 let stream = p!(TcpStream::connect("google.com:443"));
1550 let mut stream = p!(ctx.handshake(stream));
1551 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1552 p!(stream.flush());
1553 let mut buf = vec![];
1554 p!(stream.read_to_end(&mut buf));
1555 println!("{}", String::from_utf8_lossy(&buf));
1556 }
1557
1558 #[test]
1559 fn client_no_session_ticket_resumption() {
1560 for _ in 0..2 {
1561 let stream = p!(TcpStream::connect("google.com:443"));
1562
1563 let stream = MidHandshakeSslStream {
1565 stream: ClientBuilder::new()
1566 .ctx_into_stream("google.com", stream)
1567 .unwrap(),
1568 error: Error::from(errSecSuccess),
1569 };
1570
1571 let mut result = stream.handshake();
1572
1573 if let Err(HandshakeError::Interrupted(stream)) = result {
1574 assert!(stream.server_auth_completed());
1575 result = stream.handshake();
1576 } else {
1577 panic!("Unexpectedly skipped server auth");
1578 }
1579
1580 assert!(result.is_ok());
1581 }
1582 }
1583
1584 #[test]
1585 #[cfg(feature = "session-tickets")]
1586 fn client_session_ticket_resumption() {
1587 for i in 0..2 {
1590 let stream = p!(TcpStream::connect("google.com:443"));
1591 let mut builder = ClientBuilder::new();
1592 builder.enable_session_tickets(true);
1593
1594 let stream = MidHandshakeSslStream {
1596 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1597 error: Error::from(errSecSuccess),
1598 };
1599
1600 let mut result = stream.handshake();
1601
1602 if let Err(HandshakeError::Interrupted(stream)) = result {
1603 assert!(stream.server_auth_completed());
1604 assert_eq!(
1605 i, 0,
1606 "Session ticket resumption did not work, server auth was not skipped"
1607 );
1608 result = stream.handshake();
1609 } else {
1610 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1611 }
1612
1613 assert!(result.is_ok());
1614 }
1615 }
1616
1617 #[test]
1618 #[cfg(feature = "alpn")]
1619 fn client_alpn_accept() {
1620 let mut ctx = p!(SslContext::new(
1621 SslProtocolSide::CLIENT,
1622 SslConnectionType::STREAM
1623 ));
1624 p!(ctx.set_peer_domain_name("google.com"));
1625 p!(ctx.set_alpn_protocols(&["h2"]));
1626 let stream = p!(TcpStream::connect("google.com:443"));
1627 let stream = ctx.handshake(stream).unwrap();
1628 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1629 }
1630
1631 #[test]
1632 #[cfg(feature = "alpn")]
1633 fn client_alpn_reject() {
1634 let mut ctx = p!(SslContext::new(
1635 SslProtocolSide::CLIENT,
1636 SslConnectionType::STREAM
1637 ));
1638 p!(ctx.set_peer_domain_name("google.com"));
1639 p!(ctx.set_alpn_protocols(&["h2c"]));
1640 let stream = p!(TcpStream::connect("google.com:443"));
1641 let stream = ctx.handshake(stream).unwrap();
1642 assert!(stream.context().alpn_protocols().is_err());
1643 }
1644
1645 #[test]
1646 fn client_no_anchor_certs() {
1647 let stream = p!(TcpStream::connect("google.com:443"));
1648 assert!(ClientBuilder::new()
1649 .trust_anchor_certificates_only(true)
1650 .handshake("google.com", stream)
1651 .is_err());
1652 }
1653
1654 #[test]
1655 fn client_bad_domain() {
1656 let stream = p!(TcpStream::connect("google.com:443"));
1657 assert!(ClientBuilder::new()
1658 .handshake("foobar.com", stream)
1659 .is_err());
1660 }
1661
1662 #[test]
1663 fn client_bad_domain_ignored() {
1664 let stream = p!(TcpStream::connect("google.com:443"));
1665 ClientBuilder::new()
1666 .danger_accept_invalid_hostnames(true)
1667 .handshake("foobar.com", stream)
1668 .unwrap();
1669 }
1670
1671 #[test]
1672 fn connect_no_verify_ssl() {
1673 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1674 let mut builder = ClientBuilder::new();
1675 builder.danger_accept_invalid_certs(true);
1676 builder.handshake("expired.badssl.com", stream).unwrap();
1677 }
1678
1679 #[test]
1680 fn load_page_client() {
1681 let stream = p!(TcpStream::connect("google.com:443"));
1682 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1683 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1684 p!(stream.flush());
1685 let mut buf = vec![];
1686 p!(stream.read_to_end(&mut buf));
1687 println!("{}", String::from_utf8_lossy(&buf));
1688 }
1689
1690 #[test]
1691 #[cfg_attr(target_os = "ios", ignore)] fn cipher_configuration() {
1693 let mut ctx = p!(SslContext::new(
1694 SslProtocolSide::SERVER,
1695 SslConnectionType::STREAM
1696 ));
1697 let ciphers = p!(ctx.enabled_ciphers());
1698 let ciphers = ciphers
1699 .iter()
1700 .enumerate()
1701 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1702 .collect::<Vec<_>>();
1703 p!(ctx.set_enabled_ciphers(&ciphers));
1704 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1705 }
1706
1707 #[test]
1708 fn test_builder_whitelist_ciphers() {
1709 let stream = p!(TcpStream::connect("google.com:443"));
1710
1711 let ctx = p!(SslContext::new(
1712 SslProtocolSide::CLIENT,
1713 SslConnectionType::STREAM
1714 ));
1715 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1716
1717 let ciphers = p!(ctx.enabled_ciphers());
1718 let cipher = ciphers.first().unwrap();
1719 let stream = p!(ClientBuilder::new()
1720 .whitelist_ciphers(&[*cipher])
1721 .ctx_into_stream("google.com", stream));
1722
1723 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1724 }
1725
1726 #[test]
1727 #[cfg_attr(target_os = "ios", ignore)] fn test_builder_blacklist_ciphers() {
1729 let stream = p!(TcpStream::connect("google.com:443"));
1730
1731 let ctx = p!(SslContext::new(
1732 SslProtocolSide::CLIENT,
1733 SslConnectionType::STREAM
1734 ));
1735 let num = p!(ctx.enabled_ciphers()).len();
1736 assert!(num > 1);
1737
1738 let ciphers = p!(ctx.enabled_ciphers());
1739 let cipher = ciphers.first().unwrap();
1740 let stream = p!(ClientBuilder::new()
1741 .blacklist_ciphers(&[*cipher])
1742 .ctx_into_stream("google.com", stream));
1743
1744 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1745 }
1746
1747 #[test]
1748 fn idle_context_peer_trust() {
1749 let ctx = p!(SslContext::new(
1750 SslProtocolSide::SERVER,
1751 SslConnectionType::STREAM
1752 ));
1753 assert!(ctx.peer_trust2().is_err());
1754 }
1755
1756 #[test]
1757 fn peer_id() {
1758 let mut ctx = p!(SslContext::new(
1759 SslProtocolSide::SERVER,
1760 SslConnectionType::STREAM
1761 ));
1762 assert!(p!(ctx.peer_id()).is_none());
1763 p!(ctx.set_peer_id(b"foobar"));
1764 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1765 }
1766
1767 #[test]
1768 fn peer_domain_name() {
1769 let mut ctx = p!(SslContext::new(
1770 SslProtocolSide::CLIENT,
1771 SslConnectionType::STREAM
1772 ));
1773 assert_eq!("", p!(ctx.peer_domain_name()));
1774 p!(ctx.set_peer_domain_name("foobar.com"));
1775 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1776 }
1777
1778 #[test]
1779 #[should_panic(expected = "blammo")]
1780 fn write_panic() {
1781 struct ExplodingStream(TcpStream);
1782
1783 impl Read for ExplodingStream {
1784 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1785 self.0.read(buf)
1786 }
1787 }
1788
1789 impl Write for ExplodingStream {
1790 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1791 panic!("blammo");
1792 }
1793
1794 fn flush(&mut self) -> io::Result<()> {
1795 self.0.flush()
1796 }
1797 }
1798
1799 let mut ctx = p!(SslContext::new(
1800 SslProtocolSide::CLIENT,
1801 SslConnectionType::STREAM
1802 ));
1803 p!(ctx.set_peer_domain_name("google.com"));
1804 let stream = p!(TcpStream::connect("google.com:443"));
1805 let _ = ctx.handshake(ExplodingStream(stream));
1806 }
1807
1808 #[test]
1809 #[should_panic(expected = "blammo")]
1810 fn read_panic() {
1811 struct ExplodingStream(TcpStream);
1812
1813 impl Read for ExplodingStream {
1814 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1815 panic!("blammo");
1816 }
1817 }
1818
1819 impl Write for ExplodingStream {
1820 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1821 self.0.write(buf)
1822 }
1823
1824 fn flush(&mut self) -> io::Result<()> {
1825 self.0.flush()
1826 }
1827 }
1828
1829 let mut ctx = p!(SslContext::new(
1830 SslProtocolSide::CLIENT,
1831 SslConnectionType::STREAM
1832 ));
1833 p!(ctx.set_peer_domain_name("google.com"));
1834 let stream = p!(TcpStream::connect("google.com:443"));
1835 let _ = ctx.handshake(ExplodingStream(stream));
1836 }
1837
1838 #[test]
1839 fn zero_length_buffers() {
1840 let mut ctx = p!(SslContext::new(
1841 SslProtocolSide::CLIENT,
1842 SslConnectionType::STREAM
1843 ));
1844 p!(ctx.set_peer_domain_name("google.com"));
1845 let stream = p!(TcpStream::connect("google.com:443"));
1846 let mut stream = ctx.handshake(stream).unwrap();
1847 assert_eq!(stream.write(b"").unwrap(), 0);
1848 assert_eq!(stream.read(&mut []).unwrap(), 0);
1849 }
1850}