apple_security/
secure_transport.rs

1//! SSL/TLS encryption support using Secure Transport.
2//!
3//! # Examples
4//!
5//! To connect as a client to a server with a certificate trusted by the system:
6//!
7//! ```rust
8//! use std::io::prelude::*;
9//! use std::net::TcpStream;
10//! use apple_security::secure_transport::ClientBuilder;
11//!
12//! let stream = TcpStream::connect("google.com:443").unwrap();
13//! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14//!
15//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16//! let mut page = vec![];
17//! stream.read_to_end(&mut page).unwrap();
18//! println!("{}", String::from_utf8_lossy(&page));
19//! ```
20//!
21//! To connect to a server with a certificate that's *not* trusted by the
22//! system, specify the root certificates for the server's chain to the
23//! `ClientBuilder`:
24//!
25//! ```rust,no_run
26//! use std::io::prelude::*;
27//! use std::net::TcpStream;
28//! use apple_security::secure_transport::ClientBuilder;
29//!
30//! # let root_cert = unsafe { std::mem::zeroed() };
31//! let stream = TcpStream::connect("my_server.com:443").unwrap();
32//! let mut stream = ClientBuilder::new()
33//!                      .anchor_certificates(&[root_cert])
34//!                      .handshake("my_server.com", stream)
35//!                      .unwrap();
36//!
37//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38//! let mut page = vec![];
39//! stream.read_to_end(&mut page).unwrap();
40//! println!("{}", String::from_utf8_lossy(&page));
41//! ```
42//!
43//! For more advanced configuration, the `SslContext` type can be used directly.
44//!
45//! To run a server:
46//!
47//! ```rust,no_run
48//! use std::net::TcpListener;
49//! use std::thread;
50//! use apple_security::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51//!
52//! // Create a TCP listener and start accepting on it.
53//! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54//!
55//! for stream in listener.incoming() {
56//!     let stream = stream.unwrap();
57//!     thread::spawn(move || {
58//!         // Create a new context configured to operate on the server side of
59//!         // a traditional SSL/TLS session.
60//!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61//!                           .unwrap();
62//!
63//!         // Install the certificate chain that we will be using.
64//!         # let identity = unsafe { std::mem::zeroed() };
65//!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66//!         # let root_cert = unsafe { std::mem::zeroed() };
67//!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68//!
69//!         // Perform the SSL/TLS handshake and get our stream.
70//!         let mut stream = ctx.handshake(stream).unwrap();
71//!     });
72//! }
73//!
74//! ```
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77
78use core_foundation::base::{Boolean, TCFType};
79#[cfg(feature = "alpn")]
80use core_foundation::string::CFString;
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use apple_security_sys::base::{
86    errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87    errSecUnimplemented,
88};
89
90use apple_security_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::identity::SecIdentity;
106use crate::import_export::Pkcs12ImportOptions;
107use crate::policy::SecPolicy;
108use crate::trust::SecTrust;
109use crate::{cvt, AsInner};
110use apple_security_sys::base::errSecParam;
111
112/// Specifies a side of a TLS session.
113#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117    /// The server side of the session.
118    pub const SERVER: Self = Self(kSSLServerSide);
119
120    /// The client side of the session.
121    pub const CLIENT: Self = Self(kSSLClientSide);
122}
123
124/// Specifies the type of TLS session.
125#[derive(Debug, Copy, Clone)]
126pub struct SslConnectionType(SSLConnectionType);
127
128impl SslConnectionType {
129    /// A traditional TLS stream.
130    pub const STREAM: Self = Self(kSSLStreamType);
131
132    /// A DTLS session.
133    pub const DATAGRAM: Self = Self(kSSLDatagramType);
134}
135
136/// An error or intermediate state after a TLS handshake attempt.
137#[derive(Debug)]
138pub enum HandshakeError<S> {
139    /// The handshake failed.
140    Failure(Error),
141    /// The handshake was interrupted midway through.
142    Interrupted(MidHandshakeSslStream<S>),
143}
144
145impl<S> From<Error> for HandshakeError<S> {
146    #[inline(always)]
147    fn from(err: Error) -> Self {
148        Self::Failure(err)
149    }
150}
151
152/// An error or intermediate state after a TLS handshake attempt.
153#[derive(Debug)]
154pub enum ClientHandshakeError<S> {
155    /// The handshake failed.
156    Failure(Error),
157    /// The handshake was interrupted midway through.
158    Interrupted(MidHandshakeClientBuilder<S>),
159}
160
161impl<S> From<Error> for ClientHandshakeError<S> {
162    #[inline(always)]
163    fn from(err: Error) -> Self {
164        Self::Failure(err)
165    }
166}
167
168/// An SSL stream midway through the handshake process.
169#[derive(Debug)]
170pub struct MidHandshakeSslStream<S> {
171    stream: SslStream<S>,
172    error: Error,
173}
174
175impl<S> MidHandshakeSslStream<S> {
176    /// Returns a shared reference to the inner stream.
177    #[inline(always)]
178    #[must_use]
179    pub fn get_ref(&self) -> &S {
180        self.stream.get_ref()
181    }
182
183    /// Returns a mutable reference to the inner stream.
184    #[inline(always)]
185    pub fn get_mut(&mut self) -> &mut S {
186        self.stream.get_mut()
187    }
188
189    /// Returns a shared reference to the `SslContext` of the stream.
190    #[inline(always)]
191    #[must_use]
192    pub fn context(&self) -> &SslContext {
193        self.stream.context()
194    }
195
196    /// Returns a mutable reference to the `SslContext` of the stream.
197    #[inline(always)]
198    pub fn context_mut(&mut self) -> &mut SslContext {
199        self.stream.context_mut()
200    }
201
202    /// Returns `true` iff `break_on_server_auth` was set and the handshake has
203    /// progressed to that point.
204    #[inline(always)]
205    #[must_use]
206    pub fn server_auth_completed(&self) -> bool {
207        self.error.code() == errSSLPeerAuthCompleted
208    }
209
210    /// Returns `true` iff `break_on_cert_requested` was set and the handshake
211    /// has progressed to that point.
212    #[inline(always)]
213    #[must_use]
214    pub fn client_cert_requested(&self) -> bool {
215        self.error.code() == errSSLClientCertRequested
216    }
217
218    /// Returns `true` iff the underlying stream returned an error with the
219    /// `WouldBlock` kind.
220    #[inline(always)]
221    #[must_use]
222    pub fn would_block(&self) -> bool {
223        self.error.code() == errSSLWouldBlock
224    }
225
226    /// Returns the error which caused the handshake interruption.
227    #[inline(always)]
228    #[must_use]
229    pub fn error(&self) -> &Error {
230        &self.error
231    }
232
233    /// Restarts the handshake process.
234    #[inline(always)]
235    pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
236        self.stream.handshake()
237    }
238}
239
240/// An SSL stream midway through the handshake process.
241#[derive(Debug)]
242pub struct MidHandshakeClientBuilder<S> {
243    stream: MidHandshakeSslStream<S>,
244    domain: Option<String>,
245    certs: Vec<SecCertificate>,
246    trust_certs_only: bool,
247    danger_accept_invalid_certs: bool,
248}
249
250impl<S> MidHandshakeClientBuilder<S> {
251    /// Returns a shared reference to the inner stream.
252    #[inline(always)]
253    #[must_use]
254    pub fn get_ref(&self) -> &S {
255        self.stream.get_ref()
256    }
257
258    /// Returns a mutable reference to the inner stream.
259    #[inline(always)]
260    pub fn get_mut(&mut self) -> &mut S {
261        self.stream.get_mut()
262    }
263
264    /// Returns the error which caused the handshake interruption.
265    #[inline(always)]
266    #[must_use]
267    pub fn error(&self) -> &Error {
268        self.stream.error()
269    }
270
271    /// Restarts the handshake process.
272    pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
273        let MidHandshakeClientBuilder {
274            stream,
275            domain,
276            certs,
277            trust_certs_only,
278            danger_accept_invalid_certs,
279        } = self;
280
281        let mut result = stream.handshake();
282        loop {
283            let stream = match result {
284                Ok(stream) => return Ok(stream),
285                Err(HandshakeError::Interrupted(stream)) => stream,
286                Err(HandshakeError::Failure(err)) => {
287                    return Err(ClientHandshakeError::Failure(err))
288                }
289            };
290
291            if stream.would_block() {
292                let ret = MidHandshakeClientBuilder {
293                    stream,
294                    domain,
295                    certs,
296                    trust_certs_only,
297                    danger_accept_invalid_certs,
298                };
299                return Err(ClientHandshakeError::Interrupted(ret));
300            }
301
302            if stream.server_auth_completed() {
303                if danger_accept_invalid_certs {
304                    result = stream.handshake();
305                    continue;
306                }
307                let mut trust = match stream.context().peer_trust2()? {
308                    Some(trust) => trust,
309                    None => {
310                        result = stream.handshake();
311                        continue;
312                    }
313                };
314                trust.set_anchor_certificates(&certs)?;
315                trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
316                let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
317                trust.set_policy(&policy)?;
318                trust.evaluate_with_error().map_err(|error| {
319                    #[cfg(feature = "log")]
320                    log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
321                    Error::from_code(error.code() as _)
322                })?;
323                result = stream.handshake();
324                continue;
325            }
326
327            let err = Error::from_code(stream.error().code());
328            return Err(ClientHandshakeError::Failure(err));
329        }
330    }
331}
332
333/// Specifies the state of a TLS session.
334#[derive(Debug, PartialEq, Eq)]
335pub struct SessionState(SSLSessionState);
336
337impl SessionState {
338    /// The session has not yet started.
339    pub const IDLE: Self = Self(kSSLIdle);
340
341    /// The session is in the handshake process.
342    pub const HANDSHAKE: Self = Self(kSSLHandshake);
343
344    /// The session is connected.
345    pub const CONNECTED: Self = Self(kSSLConnected);
346
347    /// The session has been terminated.
348    pub const CLOSED: Self = Self(kSSLClosed);
349
350    /// The session has been aborted due to an error.
351    pub const ABORTED: Self = Self(kSSLAborted);
352}
353
354/// Specifies a server's requirement for client certificates.
355#[derive(Debug, Copy, Clone, PartialEq, Eq)]
356pub struct SslAuthenticate(SSLAuthenticate);
357
358impl SslAuthenticate {
359    /// Do not request a client certificate.
360    pub const NEVER: Self = Self(kNeverAuthenticate);
361
362    /// Require a client certificate.
363    pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
364
365    /// Request but do not require a client certificate.
366    pub const TRY: Self = Self(kTryAuthenticate);
367}
368
369/// Specifies the state of client certificate processing.
370#[derive(Debug, Copy, Clone, PartialEq, Eq)]
371pub struct SslClientCertificateState(SSLClientCertificateState);
372
373impl SslClientCertificateState {
374    /// A client certificate has not been requested or sent.
375    pub const NONE: Self = Self(kSSLClientCertNone);
376
377    /// A client certificate has been requested but not recieved.
378    pub const REQUESTED: Self = Self(kSSLClientCertRequested);
379    /// A client certificate has been received and successfully validated.
380    pub const SENT: Self = Self(kSSLClientCertSent);
381
382    /// A client certificate has been received but has failed to validate.
383    pub const REJECTED: Self = Self(kSSLClientCertRejected);
384}
385
386/// Specifies protocol versions.
387#[derive(Debug, Copy, Clone, PartialEq, Eq)]
388pub struct SslProtocol(SSLProtocol);
389
390impl SslProtocol {
391    /// No protocol has been or should be negotiated or specified; use the default.
392    pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
393
394    /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
395    /// SSL 3.0.
396    pub const SSL3: Self = Self(kSSLProtocol3);
397
398    /// The TLS 1.0 protocol is preferred, though lower versions may be used
399    /// if the peer does not support TLS 1.0.
400    pub const TLS1: Self = Self(kTLSProtocol1);
401
402    /// The TLS 1.1 protocol is preferred, though lower versions may be used
403    /// if the peer does not support TLS 1.1.
404    pub const TLS11: Self = Self(kTLSProtocol11);
405
406    /// The TLS 1.2 protocol is preferred, though lower versions may be used
407    /// if the peer does not support TLS 1.2.
408    pub const TLS12: Self = Self(kTLSProtocol12);
409
410    /// The TLS 1.3 protocol is preferred, though lower versions may be used
411    /// if the peer does not support TLS 1.3.
412    pub const TLS13: Self = Self(kTLSProtocol13);
413
414    /// Only the SSL 2.0 protocol is accepted.
415    pub const SSL2: Self = Self(kSSLProtocol2);
416
417    /// The `DTLSv1` protocol is preferred.
418    pub const DTLS1: Self = Self(kDTLSProtocol1);
419
420    /// Only the SSL 3.0 protocol is accepted.
421    pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
422
423    /// Only the TLS 1.0 protocol is accepted.
424    pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
425
426    /// All supported TLS/SSL versions are accepted.
427    pub const ALL: Self = Self(kSSLProtocolAll);
428}
429
430declare_TCFType! {
431    /// A Secure Transport SSL/TLS context object.
432    SslContext, SSLContextRef
433}
434
435impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
436
437impl fmt::Debug for SslContext {
438    #[cold]
439    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
440        let mut builder = fmt.debug_struct("SslContext");
441        if let Ok(state) = self.state() {
442            builder.field("state", &state);
443        }
444        builder.finish()
445    }
446}
447
448unsafe impl Sync for SslContext {}
449unsafe impl Send for SslContext {}
450
451impl AsInner for SslContext {
452    type Inner = SSLContextRef;
453
454    #[inline(always)]
455    fn as_inner(&self) -> SSLContextRef {
456        self.0
457    }
458}
459
460macro_rules! impl_options {
461    ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
462        $(
463            $(#[$a])*
464            #[inline(always)]
465            pub fn $set(&mut self, value: bool) -> Result<()> {
466                unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
467            }
468
469            $(#[$a])*
470            #[inline]
471            pub fn $get(&self) -> Result<bool> {
472                let mut value = 0;
473                unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
474                Ok(value != 0)
475            }
476        )*
477    }
478}
479
480impl SslContext {
481    /// Creates a new `SslContext` for the specified side and type of SSL
482    /// connection.
483    #[inline]
484    pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
485        unsafe {
486            let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
487            Ok(Self(ctx))
488        }
489    }
490
491    /// Sets the fully qualified domain name of the peer.
492    ///
493    /// This will be used on the client side of a session to validate the
494    /// common name field of the server's certificate. It has no effect if
495    /// called on a server-side `SslContext`.
496    ///
497    /// It is *highly* recommended to call this method before starting the
498    /// handshake process.
499    #[inline]
500    pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
501        unsafe {
502            // SSLSetPeerDomainName doesn't need a null terminated string
503            cvt(SSLSetPeerDomainName(
504                self.0,
505                peer_name.as_ptr().cast(),
506                peer_name.len(),
507            ))
508        }
509    }
510
511    /// Returns the peer domain name set by `set_peer_domain_name`.
512    pub fn peer_domain_name(&self) -> Result<String> {
513        unsafe {
514            let mut len = 0;
515            cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
516            let mut buf = vec![0; len];
517            cvt(SSLGetPeerDomainName(
518                self.0,
519                buf.as_mut_ptr().cast(),
520                &mut len,
521            ))?;
522            Ok(String::from_utf8(buf).unwrap())
523        }
524    }
525
526    /// Sets the certificate to be used by this side of the SSL session.
527    ///
528    /// This must be called before the handshake for server-side connections,
529    /// and can be used on the client-side to specify a client certificate.
530    ///
531    /// The `identity` corresponds to the leaf certificate and private
532    /// key, and the `certs` correspond to extra certificates in the chain.
533    pub fn set_certificate(
534        &mut self,
535        identity: &SecIdentity,
536        certs: &[SecCertificate],
537    ) -> Result<()> {
538        let mut arr = vec![identity.as_CFType()];
539        arr.extend(certs.iter().map(|c| c.as_CFType()));
540        let certs = CFArray::from_CFTypes(&arr);
541
542        unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
543    }
544
545    /// Sets the peer ID of this session.
546    ///
547    /// A peer ID is an opaque sequence of bytes that will be used by Secure
548    /// Transport to identify the peer of an SSL session. If the peer ID of
549    /// this session matches that of a previously terminated session, the
550    /// previous session can be resumed without requiring a full handshake.
551    #[inline]
552    pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
553        unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
554    }
555
556    /// Returns the peer ID of this session.
557    pub fn peer_id(&self) -> Result<Option<&[u8]>> {
558        unsafe {
559            let mut ptr = ptr::null();
560            let mut len = 0;
561            cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
562            if ptr.is_null() {
563                Ok(None)
564            } else {
565                Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
566            }
567        }
568    }
569
570    /// Returns the list of ciphers that are supported by Secure Transport.
571    pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
572        unsafe {
573            let mut num_ciphers = 0;
574            cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
575            let mut ciphers = vec![0; num_ciphers];
576            cvt(SSLGetSupportedCiphers(
577                self.0,
578                ciphers.as_mut_ptr(),
579                &mut num_ciphers,
580            ))?;
581            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
582        }
583    }
584
585    /// Returns the list of ciphers that are eligible to be used for
586    /// negotiation.
587    pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
588        unsafe {
589            let mut num_ciphers = 0;
590            cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
591            let mut ciphers = vec![0; num_ciphers];
592            cvt(SSLGetEnabledCiphers(
593                self.0,
594                ciphers.as_mut_ptr(),
595                &mut num_ciphers,
596            ))?;
597            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
598        }
599    }
600
601    /// Sets the list of ciphers that are eligible to be used for negotiation.
602    pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
603        let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
604        unsafe {
605            cvt(SSLSetEnabledCiphers(
606                self.0,
607                ciphers.as_ptr(),
608                ciphers.len(),
609            ))
610        }
611    }
612
613    /// Returns the cipher being used by the session.
614    #[inline]
615    pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
616        unsafe {
617            let mut cipher = 0;
618            cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
619            Ok(CipherSuite::from_raw(cipher))
620        }
621    }
622
623    /// Sets the requirements for client certificates.
624    ///
625    /// Should only be called on server-side sessions.
626    #[inline]
627    pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
628        unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
629    }
630
631    /// Returns the state of client certificate processing.
632    #[inline]
633    pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
634        let mut state = 0;
635
636        unsafe {
637            cvt(SSLGetClientCertificateState(self.0, &mut state))?;
638        }
639        Ok(SslClientCertificateState(state))
640    }
641
642    /// Returns the `SecTrust` object corresponding to the peer.
643    ///
644    /// This can be used in conjunction with `set_break_on_server_auth` to
645    /// validate certificates which do not have roots in the default set.
646    pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
647        // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
648        // so explicitly check for that
649        if self.state()? == SessionState::IDLE {
650            return Err(Error::from_code(errSecBadReq));
651        }
652
653        unsafe {
654            let mut trust = ptr::null_mut();
655            cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
656            if trust.is_null() {
657                Ok(None)
658            } else {
659                Ok(Some(SecTrust::wrap_under_create_rule(trust)))
660            }
661        }
662    }
663
664    /// Returns the state of the session.
665    #[inline]
666    pub fn state(&self) -> Result<SessionState> {
667        unsafe {
668            let mut state = 0;
669            cvt(SSLGetSessionState(self.0, &mut state))?;
670            Ok(SessionState(state))
671        }
672    }
673
674    /// Returns the protocol version being used by the session.
675    #[inline]
676    pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
677        unsafe {
678            let mut version = 0;
679            cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
680            Ok(SslProtocol(version))
681        }
682    }
683
684    /// Returns the maximum protocol version allowed by the session.
685    #[inline]
686    pub fn protocol_version_max(&self) -> Result<SslProtocol> {
687        unsafe {
688            let mut version = 0;
689            cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
690            Ok(SslProtocol(version))
691        }
692    }
693
694    /// Sets the maximum protocol version allowed by the session.
695    #[inline]
696    pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
697        unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
698    }
699
700    /// Returns the minimum protocol version allowed by the session.
701    #[inline]
702    pub fn protocol_version_min(&self) -> Result<SslProtocol> {
703        unsafe {
704            let mut version = 0;
705            cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
706            Ok(SslProtocol(version))
707        }
708    }
709
710    /// Sets the minimum protocol version allowed by the session.
711    #[inline]
712    pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
713        unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
714    }
715
716    /// Returns the set of protocols selected via ALPN if it succeeded.
717    #[cfg(feature = "alpn")]
718    pub fn alpn_protocols(&self) -> Result<Vec<String>> {
719        let mut array: CFArrayRef = ptr::null();
720        unsafe {
721            #[cfg(feature = "OSX_10_13")]
722            {
723                cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
724            }
725
726            #[cfg(not(feature = "OSX_10_13"))]
727            {
728                dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
729                if let Some(f) = SSLCopyALPNProtocols.get() {
730                    cvt(f(self.0, &mut array))?;
731                } else {
732                    return Err(Error::from_code(errSecUnimplemented));
733                }
734            }
735
736            if array.is_null() {
737                return Ok(vec![]);
738            }
739
740            let array = CFArray::<CFString>::wrap_under_create_rule(array);
741            Ok(array.into_iter().map(|p| p.to_string()).collect())
742        }
743    }
744
745    /// Configures the set of protocols use for ALPN.
746    ///
747    /// This is only used for client-side connections.
748    #[cfg(feature = "alpn")]
749    pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
750        // When CFMutableArray is added to core-foundation and IntoIterator trait
751        // is implemented for CFMutableArray, the code below should directly collect
752        // into a CFMutableArray.
753        let protocols = CFArray::from_CFTypes(
754            &protocols
755                .iter()
756                .map(|proto| CFString::new(proto))
757                .collect::<Vec<_>>(),
758        );
759
760        #[cfg(feature = "OSX_10_13")]
761        {
762            unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
763        }
764        #[cfg(not(feature = "OSX_10_13"))]
765        {
766            dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
767            if let Some(f) = SSLSetALPNProtocols.get() {
768                unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
769            } else {
770                Err(Error::from_code(errSecUnimplemented))
771            }
772        }
773    }
774
775    /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
776    ///
777    /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
778    /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
779    /// ticket returned by the server.
780    ///
781    /// [`SslContext::set_peer_id`]: #method.set_peer_id
782    #[cfg(feature = "session-tickets")]
783    pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
784        #[cfg(feature = "OSX_10_13")]
785        {
786            unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
787        }
788        #[cfg(not(feature = "OSX_10_13"))]
789        {
790            dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
791            if let Some(f) = SSLSetSessionTicketsEnabled.get() {
792                unsafe { cvt(f(self.0, enabled as Boolean)) }
793            } else {
794                Err(Error::from_code(errSecUnimplemented))
795            }
796        }
797    }
798
799    /// Sets whether a protocol is enabled or not.
800    ///
801    /// # Note
802    ///
803    /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
804    /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
805    /// to use this API instead.
806    #[cfg(target_os = "macos")]
807    #[deprecated(note = "use `set_protocol_version_max`")]
808    pub fn set_protocol_version_enabled(
809        &mut self,
810        protocol: SslProtocol,
811        enabled: bool,
812    ) -> Result<()> {
813        unsafe {
814            cvt(SSLSetProtocolVersionEnabled(
815                self.0,
816                protocol.0,
817                enabled as Boolean,
818            ))
819        }
820    }
821
822    /// Returns the number of bytes which can be read without triggering a
823    /// `read` call in the underlying stream.
824    #[inline]
825    pub fn buffered_read_size(&self) -> Result<usize> {
826        unsafe {
827            let mut size = 0;
828            cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
829            Ok(size)
830        }
831    }
832
833    impl_options! {
834        /// If enabled, the handshake process will pause and return instead of
835        /// automatically validating a server's certificate.
836        const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
837        /// If enabled, the handshake process will pause and return after
838        /// the server requests a certificate from the client.
839        const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
840        /// If enabled, the handshake process will pause and return instead of
841        /// automatically validating a client's certificate.
842        const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
843        /// If enabled, TLS false start will be performed if an appropriate
844        /// cipher suite is negotiated.
845        ///
846        /// Requires the `OSX_10_9` (or greater) feature.
847        #[cfg(feature = "OSX_10_9")]
848        const kSSLSessionOptionFalseStart: false_start & set_false_start,
849        /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
850        /// connections using block ciphers to mitigate the BEAST attack.
851        ///
852        /// Requires the `OSX_10_9` (or greater) feature.
853        #[cfg(feature = "OSX_10_9")]
854        const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
855    }
856
857    fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
858    where
859        S: Read + Write,
860    {
861        unsafe {
862            let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
863            if ret != errSecSuccess {
864                return Err(Error::from_code(ret));
865            }
866
867            let stream = Connection {
868                stream,
869                err: None,
870                panic: None,
871            };
872            let stream = Box::into_raw(Box::new(stream));
873            let ret = SSLSetConnection(self.0, stream.cast());
874            if ret != errSecSuccess {
875                let _conn = Box::from_raw(stream);
876                return Err(Error::from_code(ret));
877            }
878
879            Ok(SslStream {
880                ctx: self,
881                _m: PhantomData,
882            })
883        }
884    }
885
886    /// Performs the SSL/TLS handshake.
887    pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
888    where
889        S: Read + Write,
890    {
891        self.into_stream(stream)
892            .map_err(HandshakeError::Failure)
893            .and_then(SslStream::handshake)
894    }
895}
896
897struct Connection<S> {
898    stream: S,
899    err: Option<io::Error>,
900    panic: Option<Box<dyn Any + Send>>,
901}
902
903// the logic here is based off of libcurl's
904#[cold]
905fn translate_err(e: &io::Error) -> OSStatus {
906    match e.kind() {
907        io::ErrorKind::NotFound => errSSLClosedGraceful,
908        io::ErrorKind::ConnectionReset => errSSLClosedAbort,
909        io::ErrorKind::WouldBlock |
910        io::ErrorKind::NotConnected => errSSLWouldBlock,
911        _ => errSecIO,
912    }
913}
914
915unsafe extern "C" fn read_func<S>(
916    connection: SSLConnectionRef,
917    data: *mut c_void,
918    data_length: *mut usize,
919) -> OSStatus
920where
921    S: Read,
922{
923    let conn: &mut Connection<S> = &mut *(connection as *mut _);
924    let data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
925    let mut start = 0;
926    let mut ret = errSecSuccess;
927
928    while start < data.len() {
929        match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
930            Ok(Ok(0)) => {
931                ret = errSSLClosedNoNotify;
932                break;
933            }
934            Ok(Ok(len)) => start += len,
935            Ok(Err(e)) => {
936                ret = translate_err(&e);
937                conn.err = Some(e);
938                break;
939            }
940            Err(e) => {
941                ret = errSecIO;
942                conn.panic = Some(e);
943                break;
944            }
945        }
946    }
947
948    *data_length = start;
949    ret
950}
951
952unsafe extern "C" fn write_func<S>(
953    connection: SSLConnectionRef,
954    data: *const c_void,
955    data_length: *mut usize,
956) -> OSStatus
957where
958    S: Write,
959{
960    let conn: &mut Connection<S> = &mut *(connection as *mut _);
961    let data = slice::from_raw_parts(data as *mut u8, *data_length);
962    let mut start = 0;
963    let mut ret = errSecSuccess;
964
965    while start < data.len() {
966        match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
967            Ok(Ok(0)) => {
968                ret = errSSLClosedNoNotify;
969                break;
970            }
971            Ok(Ok(len)) => start += len,
972            Ok(Err(e)) => {
973                ret = translate_err(&e);
974                conn.err = Some(e);
975                break;
976            }
977            Err(e) => {
978                ret = errSecIO;
979                conn.panic = Some(e);
980                break;
981            }
982        }
983    }
984
985    *data_length = start;
986    ret
987}
988
989/// A type implementing SSL/TLS encryption over an underlying stream.
990pub struct SslStream<S> {
991    ctx: SslContext,
992    _m: PhantomData<S>,
993}
994
995impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
996    #[cold]
997    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
998        fmt.debug_struct("SslStream")
999            .field("context", &self.ctx)
1000            .field("stream", self.get_ref())
1001            .finish()
1002    }
1003}
1004
1005impl<S> Drop for SslStream<S> {
1006    fn drop(&mut self) {
1007        unsafe {
1008            let mut conn = ptr::null();
1009            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1010            assert!(ret == errSecSuccess);
1011            let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
1012        }
1013    }
1014}
1015
1016impl<S> SslStream<S> {
1017    fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1018        match unsafe { SSLHandshake(self.ctx.0) } {
1019            errSecSuccess => Ok(self),
1020            reason @ errSSLPeerAuthCompleted
1021            | reason @ errSSLClientCertRequested
1022            | reason @ errSSLWouldBlock
1023            | reason @ errSSLClientHelloReceived => {
1024                Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1025                    stream: self,
1026                    error: Error::from_code(reason),
1027                }))
1028            }
1029            err => {
1030                self.check_panic();
1031                Err(HandshakeError::Failure(Error::from_code(err)))
1032            }
1033        }
1034    }
1035
1036    /// Returns a shared reference to the inner stream.
1037    #[inline(always)]
1038    #[must_use]
1039    pub fn get_ref(&self) -> &S {
1040        &self.connection().stream
1041    }
1042
1043    /// Returns a mutable reference to the underlying stream.
1044    #[inline(always)]
1045    pub fn get_mut(&mut self) -> &mut S {
1046        &mut self.connection_mut().stream
1047    }
1048
1049    /// Returns a shared reference to the `SslContext` of the stream.
1050    #[inline(always)]
1051    #[must_use]
1052    pub fn context(&self) -> &SslContext {
1053        &self.ctx
1054    }
1055
1056    /// Returns a mutable reference to the `SslContext` of the stream.
1057    #[inline(always)]
1058    pub fn context_mut(&mut self) -> &mut SslContext {
1059        &mut self.ctx
1060    }
1061
1062    /// Shuts down the connection.
1063    pub fn close(&mut self) -> result::Result<(), io::Error> {
1064        unsafe {
1065            let ret = SSLClose(self.ctx.0);
1066            if ret == errSecSuccess {
1067                Ok(())
1068            } else {
1069                Err(self.get_error(ret))
1070            }
1071        }
1072    }
1073
1074    fn connection(&self) -> &Connection<S> {
1075        unsafe {
1076            let mut conn = ptr::null();
1077            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1078            assert!(ret == errSecSuccess);
1079
1080            &mut *(conn as *mut Connection<S>)
1081        }
1082    }
1083
1084    fn connection_mut(&mut self) -> &mut Connection<S> {
1085        unsafe {
1086            let mut conn = ptr::null();
1087            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1088            assert!(ret == errSecSuccess);
1089
1090            &mut *(conn as *mut Connection<S>)
1091        }
1092    }
1093
1094    #[cold]
1095    fn check_panic(&mut self) {
1096        let conn = self.connection_mut();
1097        if let Some(err) = conn.panic.take() {
1098            panic::resume_unwind(err);
1099        }
1100    }
1101
1102    #[cold]
1103    fn get_error(&mut self, ret: OSStatus) -> io::Error {
1104        self.check_panic();
1105
1106        if let Some(err) = self.connection_mut().err.take() {
1107            err
1108        } else {
1109            io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1110        }
1111    }
1112}
1113
1114impl<S: Read + Write> Read for SslStream<S> {
1115    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1116        // Below we base our return value off the amount of data read, so a
1117        // zero-length buffer might cause us to erroneously interpret this
1118        // request as an error. Instead short-circuit that logic and return
1119        // `Ok(0)` instead.
1120        if buf.is_empty() {
1121            return Ok(0);
1122        }
1123
1124        // If some data was buffered but not enough to fill `buf`, SSLRead
1125        // will try to read a new packet. This is bad because there may be
1126        // no more data but the socket is remaining open (e.g HTTPS with
1127        // Connection: keep-alive).
1128        let buffered = self.context().buffered_read_size().unwrap_or(0);
1129        let to_read = if buffered > 0 {
1130            cmp::min(buffered, buf.len())
1131        } else {
1132            buf.len()
1133        };
1134
1135        unsafe {
1136            let mut nread = 0;
1137            let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1138            // SSLRead can return an error at the same time it returns the last
1139            // chunk of data (!)
1140            if nread > 0 {
1141                return Ok(nread);
1142            }
1143
1144            match ret {
1145                errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1146                // this error isn't fatal
1147                errSSLPeerAuthCompleted => self.read(buf),
1148                _ => Err(self.get_error(ret)),
1149            }
1150        }
1151    }
1152}
1153
1154impl<S: Read + Write> Write for SslStream<S> {
1155    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1156        // Like above in read, short circuit a 0-length write
1157        if buf.is_empty() {
1158            return Ok(0);
1159        }
1160        unsafe {
1161            let mut nwritten = 0;
1162            let ret = SSLWrite(
1163                self.ctx.0,
1164                buf.as_ptr().cast(),
1165                buf.len(),
1166                &mut nwritten,
1167            );
1168            // just to be safe, base success off of nwritten rather than ret
1169            // for the same reason as in read
1170            if nwritten > 0 {
1171                Ok(nwritten)
1172            } else {
1173                Err(self.get_error(ret))
1174            }
1175        }
1176    }
1177
1178    fn flush(&mut self) -> io::Result<()> {
1179        self.connection_mut().stream.flush()
1180    }
1181}
1182
1183/// A builder type to simplify the creation of client side `SslStream`s.
1184#[derive(Debug)]
1185pub struct ClientBuilder {
1186    identity: Option<SecIdentity>,
1187    certs: Vec<SecCertificate>,
1188    chain: Vec<SecCertificate>,
1189    protocol_min: Option<SslProtocol>,
1190    protocol_max: Option<SslProtocol>,
1191    trust_certs_only: bool,
1192    use_sni: bool,
1193    danger_accept_invalid_certs: bool,
1194    danger_accept_invalid_hostnames: bool,
1195    whitelisted_ciphers: Vec<CipherSuite>,
1196    blacklisted_ciphers: Vec<CipherSuite>,
1197    #[cfg(feature = "alpn")]
1198    alpn: Option<Vec<String>>,
1199    #[cfg(feature = "session-tickets")]
1200    enable_session_tickets: bool,
1201}
1202
1203impl Default for ClientBuilder {
1204    #[inline(always)]
1205    fn default() -> Self {
1206        Self::new()
1207    }
1208}
1209
1210impl ClientBuilder {
1211    /// Creates a new builder with default options.
1212    #[inline]
1213    #[must_use]
1214    pub fn new() -> Self {
1215        Self {
1216            identity: None,
1217            certs: Vec::new(),
1218            chain: Vec::new(),
1219            protocol_min: None,
1220            protocol_max: None,
1221            trust_certs_only: false,
1222            use_sni: true,
1223            danger_accept_invalid_certs: false,
1224            danger_accept_invalid_hostnames: false,
1225            whitelisted_ciphers: Vec::new(),
1226            blacklisted_ciphers: Vec::new(),
1227            #[cfg(feature = "alpn")]
1228            alpn: None,
1229            #[cfg(feature = "session-tickets")]
1230            enable_session_tickets: false,
1231        }
1232    }
1233
1234    /// Specifies the set of root certificates to trust when
1235    /// verifying the server's certificate.
1236    #[inline]
1237    pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1238        self.certs = certs.to_owned();
1239        self
1240    }
1241
1242    /// Add the certificate the set of root certificates to trust
1243    /// when verifying the server's certificate.
1244    #[inline]
1245    pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1246        self.certs.push(certs.to_owned());
1247        self
1248    }
1249
1250    /// Specifies whether to trust the built-in certificates in addition
1251    /// to specified anchor certificates.
1252    #[inline(always)]
1253    pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1254        self.trust_certs_only = only;
1255        self
1256    }
1257
1258    /// Specifies whether to trust invalid certificates.
1259    ///
1260    /// # Warning
1261    ///
1262    /// You should think very carefully before using this method. If invalid
1263    /// certificates are trusted, *any* certificate for *any* site will be
1264    /// trusted for use. This includes expired certificates. This introduces
1265    /// significant vulnerabilities, and should only be used as a last resort.
1266    #[inline(always)]
1267    pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1268        self.danger_accept_invalid_certs = noverify;
1269        self
1270    }
1271
1272    /// Specifies whether to use Server Name Indication (SNI).
1273    #[inline(always)]
1274    pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1275        self.use_sni = use_sni;
1276        self
1277    }
1278
1279    /// Specifies whether to verify that the server's hostname matches its certificate.
1280    ///
1281    /// # Warning
1282    ///
1283    /// You should think very carefully before using this method. If hostnames are not verified,
1284    /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1285    /// vulnerabilities, and should only be used as a last resort.
1286    #[inline(always)]
1287    pub fn danger_accept_invalid_hostnames(
1288        &mut self,
1289        danger_accept_invalid_hostnames: bool,
1290    ) -> &mut Self {
1291        self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1292        self
1293    }
1294
1295    /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
1296    pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1297        self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1298        self
1299    }
1300
1301    /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
1302    pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1303        self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1304        self
1305    }
1306
1307    /// Use the specified identity as a SSL/TLS client certificate.
1308    pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1309        self.identity = Some(identity.clone());
1310        self.chain = chain.to_owned();
1311        self
1312    }
1313
1314    /// Configure the minimum protocol that this client will support.
1315    #[inline(always)]
1316    pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1317        self.protocol_min = Some(min);
1318        self
1319    }
1320
1321    /// Configure the minimum protocol that this client will support.
1322    #[inline(always)]
1323    pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1324        self.protocol_max = Some(max);
1325        self
1326    }
1327
1328    /// Configures the set of protocols used for ALPN.
1329    #[cfg(feature = "alpn")]
1330    pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1331        self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1332        self
1333    }
1334
1335    /// Configures the use of the RFC 5077 `SessionTicket` extension.
1336    ///
1337    /// Defaults to `false`.
1338    #[cfg(feature = "session-tickets")]
1339    #[inline(always)]
1340    pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1341        self.enable_session_tickets = enable;
1342        self
1343    }
1344
1345    /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1346    ///
1347    /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
1348    pub fn handshake<S>(
1349        &self,
1350        domain: &str,
1351        stream: S,
1352    ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1353    where
1354        S: Read + Write,
1355    {
1356        // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1357        // of the handshake logic through that.
1358        let stream = MidHandshakeSslStream {
1359            stream: self.ctx_into_stream(domain, stream)?,
1360            error: Error::from(errSecSuccess),
1361        };
1362
1363        let certs = self.certs.clone();
1364        let stream = MidHandshakeClientBuilder {
1365            stream,
1366            domain: if self.danger_accept_invalid_hostnames {
1367                None
1368            } else {
1369                Some(domain.to_string())
1370            },
1371            certs,
1372            trust_certs_only: self.trust_certs_only,
1373            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1374        };
1375        stream.handshake()
1376    }
1377
1378    fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1379    where
1380        S: Read + Write,
1381    {
1382        let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1383
1384        if self.use_sni {
1385            ctx.set_peer_domain_name(domain)?;
1386        }
1387        if let Some(ref identity) = self.identity {
1388            ctx.set_certificate(identity, &self.chain)?;
1389        }
1390        #[cfg(feature = "alpn")]
1391        {
1392            if let Some(ref alpn) = self.alpn {
1393                ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1394            }
1395        }
1396        #[cfg(feature = "session-tickets")]
1397        {
1398            if self.enable_session_tickets {
1399                // We must use the domain here to ensure that we go through certificate validation
1400                // again rather than resuming the session if the domain changes.
1401                ctx.set_peer_id(domain.as_bytes())?;
1402                ctx.set_session_tickets_enabled(true)?;
1403            }
1404        }
1405        ctx.set_break_on_server_auth(true)?;
1406        self.configure_protocols(&mut ctx)?;
1407        self.configure_ciphers(&mut ctx)?;
1408
1409        ctx.into_stream(stream)
1410    }
1411
1412    fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1413        if let Some(min) = self.protocol_min {
1414            ctx.set_protocol_version_min(min)?;
1415        }
1416        if let Some(max) = self.protocol_max {
1417            ctx.set_protocol_version_max(max)?;
1418        }
1419        Ok(())
1420    }
1421
1422    fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1423        let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1424            ctx.enabled_ciphers()?
1425        } else {
1426            self.whitelisted_ciphers.clone()
1427        };
1428
1429        if !self.blacklisted_ciphers.is_empty() {
1430            ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1431        }
1432
1433        ctx.set_enabled_ciphers(&ciphers)?;
1434        Ok(())
1435    }
1436}
1437
1438/// A builder type to simplify the creation of server-side `SslStream`s.
1439#[derive(Debug)]
1440pub struct ServerBuilder {
1441    identity: SecIdentity,
1442    certs: Vec<SecCertificate>,
1443}
1444
1445impl ServerBuilder {
1446    /// Creates a new `ServerBuilder` which will use the specified identity
1447    /// and certificate chain for handshakes.
1448    #[must_use]
1449    pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1450        Self {
1451            identity: identity.clone(),
1452            certs: certs.to_owned(),
1453        }
1454    }
1455
1456    /// Creates a new `ServerBuilder` which will use the identity
1457    /// from the given PKCS #12 data.
1458    ///
1459    /// This operation fails if PKCS #12 file contains zero or more than one identity.
1460    ///
1461    /// This is a shortcut for the most common operation.
1462    pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1463        let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1464            .passphrase(passphrase)
1465            .import(pkcs12_der)?
1466            .into_iter()
1467            .filter_map(|idendity| {
1468                let certs = idendity.cert_chain.unwrap_or_default();
1469                idendity.identity.map(|identity| (identity, certs))
1470            })
1471            .collect();
1472        if identities.len() == 1 {
1473            let (identity, certs) = identities.pop().unwrap();
1474            Ok(ServerBuilder::new(&identity, &certs))
1475        } else {
1476            // This error code is not really helpful
1477            Err(Error::from_code(errSecParam))
1478        }
1479    }
1480
1481    /// Create a SSL context for lower-level stream initialization.
1482    pub fn new_ssl_context(&self) -> Result<SslContext> {
1483        let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1484        ctx.set_certificate(&self.identity, &self.certs)?;
1485        Ok(ctx)
1486    }
1487
1488    /// Initiates a new SSL/TLS session over a stream.
1489    pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1490    where
1491        S: Read + Write,
1492    {
1493        match self.new_ssl_context()?.handshake(stream) {
1494            Ok(stream) => Ok(stream),
1495            Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1496            Err(HandshakeError::Failure(err)) => Err(err),
1497        }
1498    }
1499}
1500
1501#[cfg(test)]
1502mod test {
1503    use std::io;
1504    use std::io::prelude::*;
1505    use std::net::TcpStream;
1506
1507    use super::*;
1508
1509    #[test]
1510    fn server_builder_from_pkcs12() {
1511        let pkcs12_der = include_bytes!("../test/server.p12");
1512        ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1513    }
1514
1515    #[test]
1516    fn connect() {
1517        let mut ctx = p!(SslContext::new(
1518            SslProtocolSide::CLIENT,
1519            SslConnectionType::STREAM
1520        ));
1521        p!(ctx.set_peer_domain_name("google.com"));
1522        let stream = p!(TcpStream::connect("google.com:443"));
1523        p!(ctx.handshake(stream));
1524    }
1525
1526    #[test]
1527    fn connect_bad_domain() {
1528        let mut ctx = p!(SslContext::new(
1529            SslProtocolSide::CLIENT,
1530            SslConnectionType::STREAM
1531        ));
1532        p!(ctx.set_peer_domain_name("foobar.com"));
1533        let stream = p!(TcpStream::connect("google.com:443"));
1534        match ctx.handshake(stream) {
1535            Ok(_) => panic!("expected failure"),
1536            Err(_) => {}
1537        }
1538    }
1539
1540    #[test]
1541    fn load_page() {
1542        let mut ctx = p!(SslContext::new(
1543            SslProtocolSide::CLIENT,
1544            SslConnectionType::STREAM
1545        ));
1546        p!(ctx.set_peer_domain_name("google.com"));
1547        let stream = p!(TcpStream::connect("google.com:443"));
1548        let mut stream = p!(ctx.handshake(stream));
1549        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1550        p!(stream.flush());
1551        let mut buf = vec![];
1552        p!(stream.read_to_end(&mut buf));
1553        println!("{}", String::from_utf8_lossy(&buf));
1554    }
1555
1556    #[test]
1557    fn client_no_session_ticket_resumption() {
1558        for _ in 0..2 {
1559            let stream = p!(TcpStream::connect("google.com:443"));
1560
1561            // Manually handshake here.
1562            let stream = MidHandshakeSslStream {
1563                stream: ClientBuilder::new()
1564                    .ctx_into_stream("google.com", stream)
1565                    .unwrap(),
1566                error: Error::from(errSecSuccess),
1567            };
1568
1569            let mut result = stream.handshake();
1570
1571            if let Err(HandshakeError::Interrupted(stream)) = result {
1572                assert!(stream.server_auth_completed());
1573                result = stream.handshake();
1574            } else {
1575                panic!("Unexpectedly skipped server auth");
1576            }
1577
1578            assert!(result.is_ok());
1579        }
1580    }
1581
1582    #[test]
1583    #[cfg(feature = "session-tickets")]
1584    fn client_session_ticket_resumption() {
1585        // The first time through this loop, we should do a full handshake. The second time, we
1586        // should immediately finish the handshake without breaking on server auth.
1587        for i in 0..2 {
1588            let stream = p!(TcpStream::connect("google.com:443"));
1589            let mut builder = ClientBuilder::new();
1590            builder.enable_session_tickets(true);
1591
1592            // Manually handshake here.
1593            let stream = MidHandshakeSslStream {
1594                stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1595                error: Error::from(errSecSuccess),
1596            };
1597
1598            let mut result = stream.handshake();
1599
1600            if let Err(HandshakeError::Interrupted(stream)) = result {
1601                assert!(stream.server_auth_completed());
1602                assert_eq!(
1603                    i, 0,
1604                    "Session ticket resumption did not work, server auth was not skipped"
1605                );
1606                result = stream.handshake();
1607            } else {
1608                assert_eq!(i, 1, "Unexpectedly skipped server auth");
1609            }
1610
1611            assert!(result.is_ok());
1612        }
1613    }
1614
1615    #[test]
1616    #[cfg(feature = "alpn")]
1617    fn client_alpn_accept() {
1618        let mut ctx = p!(SslContext::new(
1619            SslProtocolSide::CLIENT,
1620            SslConnectionType::STREAM
1621        ));
1622        p!(ctx.set_peer_domain_name("google.com"));
1623        p!(ctx.set_alpn_protocols(&vec!["h2"]));
1624        let stream = p!(TcpStream::connect("google.com:443"));
1625        let stream = ctx.handshake(stream).unwrap();
1626        assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1627    }
1628
1629    #[test]
1630    #[cfg(feature = "alpn")]
1631    fn client_alpn_reject() {
1632        let mut ctx = p!(SslContext::new(
1633            SslProtocolSide::CLIENT,
1634            SslConnectionType::STREAM
1635        ));
1636        p!(ctx.set_peer_domain_name("google.com"));
1637        p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1638        let stream = p!(TcpStream::connect("google.com:443"));
1639        let stream = ctx.handshake(stream).unwrap();
1640        assert!(stream.context().alpn_protocols().is_err());
1641    }
1642
1643    #[test]
1644    fn client_no_anchor_certs() {
1645        let stream = p!(TcpStream::connect("google.com:443"));
1646        assert!(ClientBuilder::new()
1647            .trust_anchor_certificates_only(true)
1648            .handshake("google.com", stream)
1649            .is_err());
1650    }
1651
1652    #[test]
1653    fn client_bad_domain() {
1654        let stream = p!(TcpStream::connect("google.com:443"));
1655        assert!(ClientBuilder::new()
1656            .handshake("foobar.com", stream)
1657            .is_err());
1658    }
1659
1660    #[test]
1661    fn client_bad_domain_ignored() {
1662        let stream = p!(TcpStream::connect("google.com:443"));
1663        ClientBuilder::new()
1664            .danger_accept_invalid_hostnames(true)
1665            .handshake("foobar.com", stream)
1666            .unwrap();
1667    }
1668
1669    #[test]
1670    fn connect_no_verify_ssl() {
1671        let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1672        let mut builder = ClientBuilder::new();
1673        builder.danger_accept_invalid_certs(true);
1674        builder.handshake("expired.badssl.com", stream).unwrap();
1675    }
1676
1677    #[test]
1678    fn load_page_client() {
1679        let stream = p!(TcpStream::connect("google.com:443"));
1680        let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1681        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1682        p!(stream.flush());
1683        let mut buf = vec![];
1684        p!(stream.read_to_end(&mut buf));
1685        println!("{}", String::from_utf8_lossy(&buf));
1686    }
1687
1688    #[test]
1689    #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
1690    fn cipher_configuration() {
1691        let mut ctx = p!(SslContext::new(
1692            SslProtocolSide::SERVER,
1693            SslConnectionType::STREAM
1694        ));
1695        let ciphers = p!(ctx.enabled_ciphers());
1696        let ciphers = ciphers
1697            .iter()
1698            .enumerate()
1699            .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1700            .collect::<Vec<_>>();
1701        p!(ctx.set_enabled_ciphers(&ciphers));
1702        assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1703    }
1704
1705    #[test]
1706    fn test_builder_whitelist_ciphers() {
1707        let stream = p!(TcpStream::connect("google.com:443"));
1708
1709        let ctx = p!(SslContext::new(
1710            SslProtocolSide::CLIENT,
1711            SslConnectionType::STREAM
1712        ));
1713        assert!(p!(ctx.enabled_ciphers()).len() > 1);
1714
1715        let ciphers = p!(ctx.enabled_ciphers());
1716        let cipher = ciphers.first().unwrap();
1717        let stream = p!(ClientBuilder::new()
1718            .whitelist_ciphers(&[*cipher])
1719            .ctx_into_stream("google.com", stream));
1720
1721        assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1722    }
1723
1724    #[test]
1725    #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
1726    fn test_builder_blacklist_ciphers() {
1727        let stream = p!(TcpStream::connect("google.com:443"));
1728
1729        let ctx = p!(SslContext::new(
1730            SslProtocolSide::CLIENT,
1731            SslConnectionType::STREAM
1732        ));
1733        let num = p!(ctx.enabled_ciphers()).len();
1734        assert!(num > 1);
1735
1736        let ciphers = p!(ctx.enabled_ciphers());
1737        let cipher = ciphers.first().unwrap();
1738        let stream = p!(ClientBuilder::new()
1739            .blacklist_ciphers(&[*cipher])
1740            .ctx_into_stream("google.com", stream));
1741
1742        assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1743    }
1744
1745    #[test]
1746    fn idle_context_peer_trust() {
1747        let ctx = p!(SslContext::new(
1748            SslProtocolSide::SERVER,
1749            SslConnectionType::STREAM
1750        ));
1751        assert!(ctx.peer_trust2().is_err());
1752    }
1753
1754    #[test]
1755    fn peer_id() {
1756        let mut ctx = p!(SslContext::new(
1757            SslProtocolSide::SERVER,
1758            SslConnectionType::STREAM
1759        ));
1760        assert!(p!(ctx.peer_id()).is_none());
1761        p!(ctx.set_peer_id(b"foobar"));
1762        assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1763    }
1764
1765    #[test]
1766    fn peer_domain_name() {
1767        let mut ctx = p!(SslContext::new(
1768            SslProtocolSide::CLIENT,
1769            SslConnectionType::STREAM
1770        ));
1771        assert_eq!("", p!(ctx.peer_domain_name()));
1772        p!(ctx.set_peer_domain_name("foobar.com"));
1773        assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1774    }
1775
1776    #[test]
1777    #[should_panic(expected = "blammo")]
1778    fn write_panic() {
1779        struct ExplodingStream(TcpStream);
1780
1781        impl Read for ExplodingStream {
1782            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1783                self.0.read(buf)
1784            }
1785        }
1786
1787        impl Write for ExplodingStream {
1788            fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1789                panic!("blammo");
1790            }
1791
1792            fn flush(&mut self) -> io::Result<()> {
1793                self.0.flush()
1794            }
1795        }
1796
1797        let mut ctx = p!(SslContext::new(
1798            SslProtocolSide::CLIENT,
1799            SslConnectionType::STREAM
1800        ));
1801        p!(ctx.set_peer_domain_name("google.com"));
1802        let stream = p!(TcpStream::connect("google.com:443"));
1803        let _ = ctx.handshake(ExplodingStream(stream));
1804    }
1805
1806    #[test]
1807    #[should_panic(expected = "blammo")]
1808    fn read_panic() {
1809        struct ExplodingStream(TcpStream);
1810
1811        impl Read for ExplodingStream {
1812            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1813                panic!("blammo");
1814            }
1815        }
1816
1817        impl Write for ExplodingStream {
1818            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1819                self.0.write(buf)
1820            }
1821
1822            fn flush(&mut self) -> io::Result<()> {
1823                self.0.flush()
1824            }
1825        }
1826
1827        let mut ctx = p!(SslContext::new(
1828            SslProtocolSide::CLIENT,
1829            SslConnectionType::STREAM
1830        ));
1831        p!(ctx.set_peer_domain_name("google.com"));
1832        let stream = p!(TcpStream::connect("google.com:443"));
1833        let _ = ctx.handshake(ExplodingStream(stream));
1834    }
1835
1836    #[test]
1837    fn zero_length_buffers() {
1838        let mut ctx = p!(SslContext::new(
1839            SslProtocolSide::CLIENT,
1840            SslConnectionType::STREAM
1841        ));
1842        p!(ctx.set_peer_domain_name("google.com"));
1843        let stream = p!(TcpStream::connect("google.com:443"));
1844        let mut stream = ctx.handshake(stream).unwrap();
1845        assert_eq!(stream.write(b"").unwrap(), 0);
1846        assert_eq!(stream.read(&mut []).unwrap(), 0);
1847    }
1848}