apple_security_framework/
secure_transport.rs

1//! SSL/TLS encryption support using Secure Transport.
2//!
3//! # Examples
4//!
5//! To connect as a client to a server with a certificate trusted by the system:
6//!
7//! ```rust
8//! use std::{net::TcpStream, io::prelude::*};
9//! use apple_security_framework::secure_transport::ClientBuilder;
10//!
11//! let stream = TcpStream::connect("google.com:443").unwrap();
12//! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
13//!
14//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
15//! let mut page = vec![];
16//! stream.read_to_end(&mut page).unwrap();
17//! println!("{}", String::from_utf8_lossy(&page));
18//! ```
19//!
20//! To connect to a server with a certificate that's *not* trusted by the
21//! system, specify the root certificates for the server's chain to the
22//! `ClientBuilder`:
23//!
24//! ```rust,no_run
25//! use std::{net::TcpStream, io::prelude::*};
26//! use apple_security_framework::secure_transport::ClientBuilder;
27//!
28//! # let root_cert = unsafe { std::mem::zeroed() };
29//! let stream = TcpStream::connect("my_server.com:443").unwrap();
30//! let mut stream = ClientBuilder::new()
31//!                      .anchor_certificates(&[root_cert])
32//!                      .handshake("my_server.com", stream)
33//!                      .unwrap();
34//!
35//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
36//! let mut page = vec![];
37//! stream.read_to_end(&mut page).unwrap();
38//! println!("{}", String::from_utf8_lossy(&page));
39//! ```
40//!
41//! For more advanced configuration, the `SslContext` type can be used directly.
42//!
43//! To run a server:
44//!
45//! ```rust,no_run
46//! use std::{net::TcpListener, thread};
47//! use apple_security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
48//!
49//! // Create a TCP listener and start accepting on it.
50//! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
51//!
52//! for stream in listener.incoming() {
53//!     let stream = stream.unwrap();
54//!     thread::spawn(move || {
55//!         // Create a new context configured to operate on the server side of
56//!         // a traditional SSL/TLS session.
57//!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
58//!                           .unwrap();
59//!
60//!         // Install the certificate chain that we will be using.
61//!         # let identity = unsafe { std::mem::zeroed() };
62//!         # let intermediate_cert = unsafe { std::mem::zeroed() };
63//!         # let root_cert = unsafe { std::mem::zeroed() };
64//!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
65//!
66//!         // Perform the SSL/TLS handshake and get our stream.
67//!         let mut stream = ctx.handshake(stream).unwrap();
68//!     });
69//! }
70//!
71//! ```
72use std::{
73    any::Any,
74    cmp, fmt,
75    io::{
76        prelude::*,
77        {self},
78    },
79    marker::PhantomData,
80    os::raw::c_void,
81    panic::{
82        AssertUnwindSafe, {self},
83    },
84    ptr, result, slice,
85};
86
87#[allow(unused_imports)]
88use core_foundation::array::CFArray;
89#[allow(unused_imports)]
90use core_foundation::array::CFArrayRef;
91use core_foundation::base::{Boolean, TCFType};
92#[cfg(feature = "alpn")]
93use core_foundation::string::CFString;
94use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
95#[allow(unused_imports)]
96use security_framework_sys::base::errSecBadReq;
97#[allow(unused_imports)]
98use security_framework_sys::base::errSecIO;
99#[allow(unused_imports)]
100use security_framework_sys::base::errSecNotTrusted;
101#[allow(unused_imports)]
102use security_framework_sys::base::errSecSuccess;
103#[allow(unused_imports)]
104use security_framework_sys::base::errSecTrustSettingDeny;
105#[allow(unused_imports)]
106use security_framework_sys::base::errSecUnimplemented;
107use security_framework_sys::{base::errSecParam, secure_transport::*};
108
109use crate::{
110    base::{Error, Result},
111    certificate::SecCertificate,
112    cipher_suite::CipherSuite,
113    cvt,
114    identity::SecIdentity,
115    import_export::Pkcs12ImportOptions,
116    policy::SecPolicy,
117    trust::SecTrust,
118    AsInner,
119};
120
121/// Specifies a side of a TLS session.
122#[derive(Debug, Copy, Clone, PartialEq, Eq)]
123pub struct SslProtocolSide(SSLProtocolSide);
124
125impl SslProtocolSide {
126    /// The server side of the session.
127    pub const SERVER: Self = Self(kSSLServerSide);
128
129    /// The client side of the session.
130    pub const CLIENT: Self = Self(kSSLClientSide);
131}
132
133/// Specifies the type of TLS session.
134#[derive(Debug, Copy, Clone)]
135pub struct SslConnectionType(SSLConnectionType);
136
137impl SslConnectionType {
138    /// A traditional TLS stream.
139    pub const STREAM: Self = Self(kSSLStreamType);
140
141    /// A DTLS session.
142    pub const DATAGRAM: Self = Self(kSSLDatagramType);
143}
144
145/// An error or intermediate state after a TLS handshake attempt.
146#[derive(Debug)]
147pub enum HandshakeError<S> {
148    /// The handshake failed.
149    Failure(Error),
150
151    /// The handshake was interrupted midway through.
152    Interrupted(MidHandshakeSslStream<S>),
153}
154
155impl<S> From<Error> for HandshakeError<S> {
156    #[inline(always)]
157    fn from(err: Error) -> Self {
158        Self::Failure(err)
159    }
160}
161
162/// An error or intermediate state after a TLS handshake attempt.
163#[derive(Debug)]
164pub enum ClientHandshakeError<S> {
165    /// The handshake failed.
166    Failure(Error),
167
168    /// The handshake was interrupted midway through.
169    Interrupted(MidHandshakeClientBuilder<S>),
170}
171
172impl<S> From<Error> for ClientHandshakeError<S> {
173    #[inline(always)]
174    fn from(err: Error) -> Self {
175        Self::Failure(err)
176    }
177}
178
179/// An SSL stream midway through the handshake process.
180#[derive(Debug)]
181pub struct MidHandshakeSslStream<S> {
182    stream: SslStream<S>,
183    error: Error,
184}
185
186impl<S> MidHandshakeSslStream<S> {
187    /// Returns a shared reference to the inner stream.
188    #[inline(always)]
189    #[must_use]
190    pub fn get_ref(&self) -> &S {
191        self.stream.get_ref()
192    }
193
194    /// Returns a mutable reference to the inner stream.
195    #[inline(always)]
196    pub fn get_mut(&mut self) -> &mut S {
197        self.stream.get_mut()
198    }
199
200    /// Returns a shared reference to the `SslContext` of the stream.
201    #[inline(always)]
202    #[must_use]
203    pub fn context(&self) -> &SslContext {
204        self.stream.context()
205    }
206
207    /// Returns a mutable reference to the `SslContext` of the stream.
208    #[inline(always)]
209    pub fn context_mut(&mut self) -> &mut SslContext {
210        self.stream.context_mut()
211    }
212
213    /// Returns `true` iff `break_on_server_auth` was set and the handshake has
214    /// progressed to that point.
215    #[inline(always)]
216    #[must_use]
217    pub fn server_auth_completed(&self) -> bool {
218        self.error.code() == errSSLPeerAuthCompleted
219    }
220
221    /// Returns `true` iff `break_on_cert_requested` was set and the handshake
222    /// has progressed to that point.
223    #[inline(always)]
224    #[must_use]
225    pub fn client_cert_requested(&self) -> bool {
226        self.error.code() == errSSLClientCertRequested
227    }
228
229    /// Returns `true` iff the underlying stream returned an error with the
230    /// `WouldBlock` kind.
231    #[inline(always)]
232    #[must_use]
233    pub fn would_block(&self) -> bool {
234        self.error.code() == errSSLWouldBlock
235    }
236
237    /// Returns the error which caused the handshake interruption.
238    #[inline(always)]
239    #[must_use]
240    pub fn error(&self) -> &Error {
241        &self.error
242    }
243
244    /// Restarts the handshake process.
245    #[inline(always)]
246    pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
247        self.stream.handshake()
248    }
249}
250
251/// An SSL stream midway through the handshake process.
252#[derive(Debug)]
253pub struct MidHandshakeClientBuilder<S> {
254    stream: MidHandshakeSslStream<S>,
255    domain: Option<String>,
256    certs: Vec<SecCertificate>,
257    trust_certs_only: bool,
258    danger_accept_invalid_certs: bool,
259}
260
261impl<S> MidHandshakeClientBuilder<S> {
262    /// Returns a shared reference to the inner stream.
263    #[inline(always)]
264    #[must_use]
265    pub fn get_ref(&self) -> &S {
266        self.stream.get_ref()
267    }
268
269    /// Returns a mutable reference to the inner stream.
270    #[inline(always)]
271    pub fn get_mut(&mut self) -> &mut S {
272        self.stream.get_mut()
273    }
274
275    /// Returns the error which caused the handshake interruption.
276    #[inline(always)]
277    #[must_use]
278    pub fn error(&self) -> &Error {
279        self.stream.error()
280    }
281
282    /// Restarts the handshake process.
283    pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
284        let MidHandshakeClientBuilder {
285            stream,
286            domain,
287            certs,
288            trust_certs_only,
289            danger_accept_invalid_certs,
290        } = self;
291
292        let mut result = stream.handshake();
293        loop {
294            let stream = match result {
295                Ok(stream) => return Ok(stream),
296                Err(HandshakeError::Interrupted(stream)) => stream,
297                Err(HandshakeError::Failure(err)) => {
298                    return Err(ClientHandshakeError::Failure(err))
299                }
300            };
301
302            if stream.would_block() {
303                let ret = MidHandshakeClientBuilder {
304                    stream,
305                    domain,
306                    certs,
307                    trust_certs_only,
308                    danger_accept_invalid_certs,
309                };
310                return Err(ClientHandshakeError::Interrupted(ret));
311            }
312
313            if stream.server_auth_completed() {
314                if danger_accept_invalid_certs {
315                    result = stream.handshake();
316                    continue;
317                }
318                let mut trust = match stream.context().peer_trust2()? {
319                    Some(trust) => trust,
320                    None => {
321                        result = stream.handshake();
322                        continue;
323                    }
324                };
325                trust.set_anchor_certificates(&certs)?;
326                trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
327                let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
328                trust.set_policy(&policy)?;
329                trust.evaluate_with_error().map_err(|error| {
330                    #[cfg(feature = "log")]
331                    log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
332                    Error::from_code(error.code() as _)
333                })?;
334                result = stream.handshake();
335                continue;
336            }
337
338            let err = Error::from_code(stream.error().code());
339            return Err(ClientHandshakeError::Failure(err));
340        }
341    }
342}
343
344/// Specifies the state of a TLS session.
345#[derive(Debug, PartialEq, Eq)]
346pub struct SessionState(SSLSessionState);
347
348impl SessionState {
349    /// The session has not yet started.
350    pub const IDLE: Self = Self(kSSLIdle);
351
352    /// The session is in the handshake process.
353    pub const HANDSHAKE: Self = Self(kSSLHandshake);
354
355    /// The session is connected.
356    pub const CONNECTED: Self = Self(kSSLConnected);
357
358    /// The session has been terminated.
359    pub const CLOSED: Self = Self(kSSLClosed);
360
361    /// The session has been aborted due to an error.
362    pub const ABORTED: Self = Self(kSSLAborted);
363}
364
365/// Specifies a server's requirement for client certificates.
366#[derive(Debug, Copy, Clone, PartialEq, Eq)]
367pub struct SslAuthenticate(SSLAuthenticate);
368
369impl SslAuthenticate {
370    /// Do not request a client certificate.
371    pub const NEVER: Self = Self(kNeverAuthenticate);
372
373    /// Require a client certificate.
374    pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
375
376    /// Request but do not require a client certificate.
377    pub const TRY: Self = Self(kTryAuthenticate);
378}
379
380/// Specifies the state of client certificate processing.
381#[derive(Debug, Copy, Clone, PartialEq, Eq)]
382pub struct SslClientCertificateState(SSLClientCertificateState);
383
384impl SslClientCertificateState {
385    /// A client certificate has not been requested or sent.
386    pub const NONE: Self = Self(kSSLClientCertNone);
387
388    /// A client certificate has been requested but not recieved.
389    pub const REQUESTED: Self = Self(kSSLClientCertRequested);
390    /// A client certificate has been received and successfully validated.
391    pub const SENT: Self = Self(kSSLClientCertSent);
392
393    /// A client certificate has been received but has failed to validate.
394    pub const REJECTED: Self = Self(kSSLClientCertRejected);
395}
396
397/// Specifies protocol versions.
398#[derive(Debug, Copy, Clone, PartialEq, Eq)]
399pub struct SslProtocol(SSLProtocol);
400
401impl SslProtocol {
402    /// No protocol has been or should be negotiated or specified; use the default.
403    pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
404
405    /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
406    /// SSL 3.0.
407    pub const SSL3: Self = Self(kSSLProtocol3);
408
409    /// The TLS 1.0 protocol is preferred, though lower versions may be used
410    /// if the peer does not support TLS 1.0.
411    pub const TLS1: Self = Self(kTLSProtocol1);
412
413    /// The TLS 1.1 protocol is preferred, though lower versions may be used
414    /// if the peer does not support TLS 1.1.
415    pub const TLS11: Self = Self(kTLSProtocol11);
416
417    /// The TLS 1.2 protocol is preferred, though lower versions may be used
418    /// if the peer does not support TLS 1.2.
419    pub const TLS12: Self = Self(kTLSProtocol12);
420
421    /// The TLS 1.3 protocol is preferred, though lower versions may be used
422    /// if the peer does not support TLS 1.3.
423    pub const TLS13: Self = Self(kTLSProtocol13);
424
425    /// Only the SSL 2.0 protocol is accepted.
426    pub const SSL2: Self = Self(kSSLProtocol2);
427
428    /// The `DTLSv1` protocol is preferred.
429    pub const DTLS1: Self = Self(kDTLSProtocol1);
430
431    /// Only the SSL 3.0 protocol is accepted.
432    pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
433
434    /// Only the TLS 1.0 protocol is accepted.
435    pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
436
437    /// All supported TLS/SSL versions are accepted.
438    pub const ALL: Self = Self(kSSLProtocolAll);
439}
440
441declare_TCFType! {
442    /// A Secure Transport SSL/TLS context object.
443    SslContext, SSLContextRef
444}
445
446impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
447
448impl fmt::Debug for SslContext {
449    #[cold]
450    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
451        let mut builder = fmt.debug_struct("SslContext");
452        if let Ok(state) = self.state() {
453            builder.field("state", &state);
454        }
455        builder.finish()
456    }
457}
458
459unsafe impl Sync for SslContext {}
460unsafe impl Send for SslContext {}
461
462impl AsInner for SslContext {
463    type Inner = SSLContextRef;
464
465    #[inline(always)]
466    fn as_inner(&self) -> SSLContextRef {
467        self.0
468    }
469}
470
471macro_rules! impl_options {
472    ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
473        $(
474            $(#[$a])*
475            #[inline(always)]
476            pub fn $set(&mut self, value: bool) -> Result<()> {
477                unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
478            }
479
480            $(#[$a])*
481            #[inline]
482            pub fn $get(&self) -> Result<bool> {
483                let mut value = 0;
484                unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
485                Ok(value != 0)
486            }
487        )*
488    }
489}
490
491impl SslContext {
492    /// Creates a new `SslContext` for the specified side and type of SSL
493    /// connection.
494    #[inline]
495    pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
496        unsafe {
497            let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
498            Ok(Self(ctx))
499        }
500    }
501
502    /// Sets the fully qualified domain name of the peer.
503    ///
504    /// This will be used on the client side of a session to validate the
505    /// common name field of the server's certificate. It has no effect if
506    /// called on a server-side `SslContext`.
507    ///
508    /// It is *highly* recommended to call this method before starting the
509    /// handshake process.
510    #[inline]
511    pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
512        unsafe {
513            // SSLSetPeerDomainName doesn't need a null terminated string
514            cvt(SSLSetPeerDomainName(
515                self.0,
516                peer_name.as_ptr().cast(),
517                peer_name.len(),
518            ))
519        }
520    }
521
522    /// Returns the peer domain name set by `set_peer_domain_name`.
523    pub fn peer_domain_name(&self) -> Result<String> {
524        unsafe {
525            let mut len = 0;
526            cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
527            let mut buf = vec![0; len];
528            cvt(SSLGetPeerDomainName(
529                self.0,
530                buf.as_mut_ptr().cast(),
531                &mut len,
532            ))?;
533            Ok(String::from_utf8(buf).unwrap())
534        }
535    }
536
537    /// Sets the certificate to be used by this side of the SSL session.
538    ///
539    /// This must be called before the handshake for server-side connections,
540    /// and can be used on the client-side to specify a client certificate.
541    ///
542    /// The `identity` corresponds to the leaf certificate and private
543    /// key, and the `certs` correspond to extra certificates in the chain.
544    pub fn set_certificate(
545        &mut self,
546        identity: &SecIdentity,
547        certs: &[SecCertificate],
548    ) -> Result<()> {
549        let mut arr = vec![identity.as_CFType()];
550        arr.extend(certs.iter().map(|c| c.as_CFType()));
551        let certs = CFArray::from_CFTypes(&arr);
552
553        unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
554    }
555
556    /// Sets the peer ID of this session.
557    ///
558    /// A peer ID is an opaque sequence of bytes that will be used by Secure
559    /// Transport to identify the peer of an SSL session. If the peer ID of
560    /// this session matches that of a previously terminated session, the
561    /// previous session can be resumed without requiring a full handshake.
562    #[inline]
563    pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
564        unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
565    }
566
567    /// Returns the peer ID of this session.
568    pub fn peer_id(&self) -> Result<Option<&[u8]>> {
569        unsafe {
570            let mut ptr = ptr::null();
571            let mut len = 0;
572            cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
573            if ptr.is_null() {
574                Ok(None)
575            } else {
576                Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
577            }
578        }
579    }
580
581    /// Returns the list of ciphers that are supported by Secure Transport.
582    pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
583        unsafe {
584            let mut num_ciphers = 0;
585            cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
586            let mut ciphers = vec![0; num_ciphers];
587            cvt(SSLGetSupportedCiphers(
588                self.0,
589                ciphers.as_mut_ptr(),
590                &mut num_ciphers,
591            ))?;
592            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
593        }
594    }
595
596    /// Returns the list of ciphers that are eligible to be used for
597    /// negotiation.
598    pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
599        unsafe {
600            let mut num_ciphers = 0;
601            cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
602            let mut ciphers = vec![0; num_ciphers];
603            cvt(SSLGetEnabledCiphers(
604                self.0,
605                ciphers.as_mut_ptr(),
606                &mut num_ciphers,
607            ))?;
608            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
609        }
610    }
611
612    /// Sets the list of ciphers that are eligible to be used for negotiation.
613    pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
614        let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
615        unsafe {
616            cvt(SSLSetEnabledCiphers(
617                self.0,
618                ciphers.as_ptr(),
619                ciphers.len(),
620            ))
621        }
622    }
623
624    /// Returns the cipher being used by the session.
625    #[inline]
626    pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
627        unsafe {
628            let mut cipher = 0;
629            cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
630            Ok(CipherSuite::from_raw(cipher))
631        }
632    }
633
634    /// Sets the requirements for client certificates.
635    ///
636    /// Should only be called on server-side sessions.
637    #[inline]
638    pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
639        unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
640    }
641
642    /// Returns the state of client certificate processing.
643    #[inline]
644    pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
645        let mut state = 0;
646
647        unsafe {
648            cvt(SSLGetClientCertificateState(self.0, &mut state))?;
649        }
650        Ok(SslClientCertificateState(state))
651    }
652
653    /// Returns the `SecTrust` object corresponding to the peer.
654    ///
655    /// This can be used in conjunction with `set_break_on_server_auth` to
656    /// validate certificates which do not have roots in the default set.
657    pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
658        // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
659        // so explicitly check for that
660        if self.state()? == SessionState::IDLE {
661            return Err(Error::from_code(errSecBadReq));
662        }
663
664        unsafe {
665            let mut trust = ptr::null_mut();
666            cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
667            if trust.is_null() {
668                Ok(None)
669            } else {
670                Ok(Some(SecTrust::wrap_under_create_rule(trust)))
671            }
672        }
673    }
674
675    /// Returns the state of the session.
676    #[inline]
677    pub fn state(&self) -> Result<SessionState> {
678        unsafe {
679            let mut state = 0;
680            cvt(SSLGetSessionState(self.0, &mut state))?;
681            Ok(SessionState(state))
682        }
683    }
684
685    /// Returns the protocol version being used by the session.
686    #[inline]
687    pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
688        unsafe {
689            let mut version = 0;
690            cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
691            Ok(SslProtocol(version))
692        }
693    }
694
695    /// Returns the maximum protocol version allowed by the session.
696    #[inline]
697    pub fn protocol_version_max(&self) -> Result<SslProtocol> {
698        unsafe {
699            let mut version = 0;
700            cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
701            Ok(SslProtocol(version))
702        }
703    }
704
705    /// Sets the maximum protocol version allowed by the session.
706    #[inline]
707    pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
708        unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
709    }
710
711    /// Returns the minimum protocol version allowed by the session.
712    #[inline]
713    pub fn protocol_version_min(&self) -> Result<SslProtocol> {
714        unsafe {
715            let mut version = 0;
716            cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
717            Ok(SslProtocol(version))
718        }
719    }
720
721    /// Sets the minimum protocol version allowed by the session.
722    #[inline]
723    pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
724        unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
725    }
726
727    /// Returns the set of protocols selected via ALPN if it succeeded.
728    #[cfg(feature = "alpn")]
729    pub fn alpn_protocols(&self) -> Result<Vec<String>> {
730        let mut array: CFArrayRef = ptr::null();
731        unsafe {
732            #[cfg(feature = "OSX_10_13")]
733            {
734                cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
735            }
736
737            #[cfg(not(feature = "OSX_10_13"))]
738            {
739                dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
740                if let Some(f) = SSLCopyALPNProtocols.get() {
741                    cvt(f(self.0, &mut array))?;
742                } else {
743                    return Err(Error::from_code(errSecUnimplemented));
744                }
745            }
746
747            if array.is_null() {
748                return Ok(vec![]);
749            }
750
751            let array = CFArray::<CFString>::wrap_under_create_rule(array);
752            Ok(array.into_iter().map(|p| p.to_string()).collect())
753        }
754    }
755
756    /// Configures the set of protocols use for ALPN.
757    ///
758    /// This is only used for client-side connections.
759    #[cfg(feature = "alpn")]
760    pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
761        // When CFMutableArray is added to core-foundation and IntoIterator trait
762        // is implemented for CFMutableArray, the code below should directly collect
763        // into a CFMutableArray.
764        let protocols = CFArray::from_CFTypes(
765            &protocols
766                .iter()
767                .map(|proto| CFString::new(proto))
768                .collect::<Vec<_>>(),
769        );
770
771        #[cfg(feature = "OSX_10_13")]
772        {
773            unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
774        }
775        #[cfg(not(feature = "OSX_10_13"))]
776        {
777            dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
778            if let Some(f) = SSLSetALPNProtocols.get() {
779                unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
780            } else {
781                Err(Error::from_code(errSecUnimplemented))
782            }
783        }
784    }
785
786    /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
787    ///
788    /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
789    /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
790    /// ticket returned by the server.
791    ///
792    /// [`SslContext::set_peer_id`]: #method.set_peer_id
793    #[cfg(feature = "session-tickets")]
794    pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
795        #[cfg(feature = "OSX_10_13")]
796        {
797            unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
798        }
799        #[cfg(not(feature = "OSX_10_13"))]
800        {
801            dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
802            if let Some(f) = SSLSetSessionTicketsEnabled.get() {
803                unsafe { cvt(f(self.0, enabled as Boolean)) }
804            } else {
805                Err(Error::from_code(errSecUnimplemented))
806            }
807        }
808    }
809
810    /// Sets whether a protocol is enabled or not.
811    ///
812    /// # Note
813    ///
814    /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
815    /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
816    /// to use this API instead.
817    #[cfg(target_os = "macos")]
818    #[deprecated(note = "use `set_protocol_version_max`")]
819    pub fn set_protocol_version_enabled(
820        &mut self,
821        protocol: SslProtocol,
822        enabled: bool,
823    ) -> Result<()> {
824        unsafe {
825            cvt(SSLSetProtocolVersionEnabled(
826                self.0,
827                protocol.0,
828                enabled as Boolean,
829            ))
830        }
831    }
832
833    /// Returns the number of bytes which can be read without triggering a
834    /// `read` call in the underlying stream.
835    #[inline]
836    pub fn buffered_read_size(&self) -> Result<usize> {
837        unsafe {
838            let mut size = 0;
839            cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
840            Ok(size)
841        }
842    }
843
844    impl_options! {
845        /// If enabled, the handshake process will pause and return instead of
846        /// automatically validating a server's certificate.
847        const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
848        /// If enabled, the handshake process will pause and return after
849        /// the server requests a certificate from the client.
850        const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
851        /// If enabled, the handshake process will pause and return instead of
852        /// automatically validating a client's certificate.
853        const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
854        /// If enabled, TLS false start will be performed if an appropriate
855        /// cipher suite is negotiated.
856        ///
857        /// Requires the `OSX_10_9` (or greater) feature.
858        #[cfg(feature = "OSX_10_9")]
859        const kSSLSessionOptionFalseStart: false_start & set_false_start,
860        /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
861        /// connections using block ciphers to mitigate the BEAST attack.
862        ///
863        /// Requires the `OSX_10_9` (or greater) feature.
864        #[cfg(feature = "OSX_10_9")]
865        const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
866    }
867
868    fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
869    where
870        S: Read + Write,
871    {
872        unsafe {
873            let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
874            if ret != errSecSuccess {
875                return Err(Error::from_code(ret));
876            }
877
878            let stream = Connection {
879                stream,
880                err: None,
881                panic: None,
882            };
883            let stream = Box::into_raw(Box::new(stream));
884            let ret = SSLSetConnection(self.0, stream.cast());
885            if ret != errSecSuccess {
886                let _conn = Box::from_raw(stream);
887                return Err(Error::from_code(ret));
888            }
889
890            Ok(SslStream {
891                ctx: self,
892                _m: PhantomData,
893            })
894        }
895    }
896
897    /// Performs the SSL/TLS handshake.
898    pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
899    where
900        S: Read + Write,
901    {
902        self.into_stream(stream)
903            .map_err(HandshakeError::Failure)
904            .and_then(SslStream::handshake)
905    }
906}
907
908struct Connection<S> {
909    stream: S,
910    err: Option<io::Error>,
911    panic: Option<Box<dyn Any + Send>>,
912}
913
914// the logic here is based off libcurl's
915#[cold]
916fn translate_err(e: &io::Error) -> OSStatus {
917    match e.kind() {
918        io::ErrorKind::NotFound => errSSLClosedGraceful,
919        io::ErrorKind::ConnectionReset => errSSLClosedAbort,
920        io::ErrorKind::WouldBlock | io::ErrorKind::NotConnected => errSSLWouldBlock,
921        _ => errSecIO,
922    }
923}
924
925unsafe extern "C" fn read_func<S>(
926    connection: SSLConnectionRef,
927    data: *mut c_void,
928    data_length: *mut usize,
929) -> OSStatus
930where
931    S: Read,
932{
933    let conn: &mut Connection<S> = &mut *(connection as *mut _);
934    let data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
935    let mut start = 0;
936    let mut ret = errSecSuccess;
937
938    while start < data.len() {
939        match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
940            Ok(Ok(0)) => {
941                ret = errSSLClosedNoNotify;
942                break;
943            }
944            Ok(Ok(len)) => start += len,
945            Ok(Err(e)) => {
946                ret = translate_err(&e);
947                conn.err = Some(e);
948                break;
949            }
950            Err(e) => {
951                ret = errSecIO;
952                conn.panic = Some(e);
953                break;
954            }
955        }
956    }
957
958    *data_length = start;
959    ret
960}
961
962unsafe extern "C" fn write_func<S>(
963    connection: SSLConnectionRef,
964    data: *const c_void,
965    data_length: *mut usize,
966) -> OSStatus
967where
968    S: Write,
969{
970    let conn: &mut Connection<S> = &mut *(connection as *mut _);
971    let data = slice::from_raw_parts(data as *mut u8, *data_length);
972    let mut start = 0;
973    let mut ret = errSecSuccess;
974
975    while start < data.len() {
976        match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
977            Ok(Ok(0)) => {
978                ret = errSSLClosedNoNotify;
979                break;
980            }
981            Ok(Ok(len)) => start += len,
982            Ok(Err(e)) => {
983                ret = translate_err(&e);
984                conn.err = Some(e);
985                break;
986            }
987            Err(e) => {
988                ret = errSecIO;
989                conn.panic = Some(e);
990                break;
991            }
992        }
993    }
994
995    *data_length = start;
996    ret
997}
998
999/// A type implementing SSL/TLS encryption over an underlying stream.
1000pub struct SslStream<S> {
1001    ctx: SslContext,
1002    _m: PhantomData<S>,
1003}
1004
1005impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
1006    #[cold]
1007    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1008        fmt.debug_struct("SslStream")
1009            .field("context", &self.ctx)
1010            .field("stream", self.get_ref())
1011            .finish()
1012    }
1013}
1014
1015impl<S> Drop for SslStream<S> {
1016    fn drop(&mut self) {
1017        unsafe {
1018            let mut conn = ptr::null();
1019            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1020            assert!(ret == errSecSuccess);
1021            let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
1022        }
1023    }
1024}
1025
1026impl<S> SslStream<S> {
1027    fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1028        match unsafe { SSLHandshake(self.ctx.0) } {
1029            errSecSuccess => Ok(self),
1030            reason @ errSSLPeerAuthCompleted
1031            | reason @ errSSLClientCertRequested
1032            | reason @ errSSLWouldBlock
1033            | reason @ errSSLClientHelloReceived => {
1034                Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1035                    stream: self,
1036                    error: Error::from_code(reason),
1037                }))
1038            }
1039            err => {
1040                self.check_panic();
1041                Err(HandshakeError::Failure(Error::from_code(err)))
1042            }
1043        }
1044    }
1045
1046    /// Returns a shared reference to the inner stream.
1047    #[inline(always)]
1048    #[must_use]
1049    pub fn get_ref(&self) -> &S {
1050        &self.connection().stream
1051    }
1052
1053    /// Returns a mutable reference to the underlying stream.
1054    #[inline(always)]
1055    pub fn get_mut(&mut self) -> &mut S {
1056        &mut self.connection_mut().stream
1057    }
1058
1059    /// Returns a shared reference to the `SslContext` of the stream.
1060    #[inline(always)]
1061    #[must_use]
1062    pub fn context(&self) -> &SslContext {
1063        &self.ctx
1064    }
1065
1066    /// Returns a mutable reference to the `SslContext` of the stream.
1067    #[inline(always)]
1068    pub fn context_mut(&mut self) -> &mut SslContext {
1069        &mut self.ctx
1070    }
1071
1072    /// Shuts down the connection.
1073    pub fn close(&mut self) -> result::Result<(), io::Error> {
1074        unsafe {
1075            let ret = SSLClose(self.ctx.0);
1076            if ret == errSecSuccess {
1077                Ok(())
1078            } else {
1079                Err(self.get_error(ret))
1080            }
1081        }
1082    }
1083
1084    fn connection(&self) -> &Connection<S> {
1085        unsafe {
1086            let mut conn = ptr::null();
1087            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1088            assert!(ret == errSecSuccess);
1089
1090            &mut *(conn as *mut Connection<S>)
1091        }
1092    }
1093
1094    fn connection_mut(&mut self) -> &mut Connection<S> {
1095        unsafe {
1096            let mut conn = ptr::null();
1097            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1098            assert!(ret == errSecSuccess);
1099
1100            &mut *(conn as *mut Connection<S>)
1101        }
1102    }
1103
1104    #[cold]
1105    fn check_panic(&mut self) {
1106        let conn = self.connection_mut();
1107        if let Some(err) = conn.panic.take() {
1108            panic::resume_unwind(err);
1109        }
1110    }
1111
1112    #[cold]
1113    fn get_error(&mut self, ret: OSStatus) -> io::Error {
1114        self.check_panic();
1115
1116        if let Some(err) = self.connection_mut().err.take() {
1117            err
1118        } else {
1119            io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1120        }
1121    }
1122}
1123
1124impl<S: Read + Write> Read for SslStream<S> {
1125    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1126        // Below we base our return value off the amount of data read, so a
1127        // zero-length buffer might cause us to erroneously interpret this
1128        // request as an error. Instead short-circuit that logic and return
1129        // `Ok(0)` instead.
1130        if buf.is_empty() {
1131            return Ok(0);
1132        }
1133
1134        // If some data was buffered but not enough to fill `buf`, SSLRead
1135        // will try to read a new packet. This is bad because there may be
1136        // no more data but the socket is remaining open (e.g HTTPS with
1137        // Connection: keep-alive).
1138        let buffered = self.context().buffered_read_size().unwrap_or(0);
1139        let to_read = if buffered > 0 {
1140            cmp::min(buffered, buf.len())
1141        } else {
1142            buf.len()
1143        };
1144
1145        unsafe {
1146            let mut n_read = 0;
1147            let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut n_read);
1148            // SSLRead can return an error at the same time it returns the last
1149            // chunk of data (!)
1150            if n_read > 0 {
1151                return Ok(n_read);
1152            }
1153
1154            match ret {
1155                errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1156                // this error isn't fatal
1157                errSSLPeerAuthCompleted => self.read(buf),
1158                _ => Err(self.get_error(ret)),
1159            }
1160        }
1161    }
1162}
1163
1164impl<S: Read + Write> Write for SslStream<S> {
1165    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1166        // Like above in read, short circuit a 0-length write
1167        if buf.is_empty() {
1168            return Ok(0);
1169        }
1170        unsafe {
1171            let mut n_written = 0;
1172            let ret = SSLWrite(self.ctx.0, buf.as_ptr().cast(), buf.len(), &mut n_written);
1173            // just to be safe, base success off of n_written rather than ret
1174            // for the same reason as in read
1175            if n_written > 0 {
1176                Ok(n_written)
1177            } else {
1178                Err(self.get_error(ret))
1179            }
1180        }
1181    }
1182
1183    fn flush(&mut self) -> io::Result<()> {
1184        self.connection_mut().stream.flush()
1185    }
1186}
1187
1188/// A builder type to simplify the creation of client side `SslStream`s.
1189#[derive(Debug)]
1190pub struct ClientBuilder {
1191    identity: Option<SecIdentity>,
1192    certs: Vec<SecCertificate>,
1193    chain: Vec<SecCertificate>,
1194    protocol_min: Option<SslProtocol>,
1195    protocol_max: Option<SslProtocol>,
1196    trust_certs_only: bool,
1197    use_sni: bool,
1198    danger_accept_invalid_certs: bool,
1199    danger_accept_invalid_hostnames: bool,
1200    whitelisted_ciphers: Vec<CipherSuite>,
1201    blacklisted_ciphers: Vec<CipherSuite>,
1202    #[cfg(feature = "alpn")]
1203    alpn: Option<Vec<String>>,
1204    #[cfg(feature = "session-tickets")]
1205    enable_session_tickets: bool,
1206}
1207
1208impl Default for ClientBuilder {
1209    #[inline(always)]
1210    fn default() -> Self {
1211        Self::new()
1212    }
1213}
1214
1215impl ClientBuilder {
1216    /// Creates a new builder with default options.
1217    #[inline]
1218    #[must_use]
1219    pub fn new() -> Self {
1220        Self {
1221            identity: None,
1222            certs: Vec::new(),
1223            chain: Vec::new(),
1224            protocol_min: None,
1225            protocol_max: None,
1226            trust_certs_only: false,
1227            use_sni: true,
1228            danger_accept_invalid_certs: false,
1229            danger_accept_invalid_hostnames: false,
1230            whitelisted_ciphers: Vec::new(),
1231            blacklisted_ciphers: Vec::new(),
1232            #[cfg(feature = "alpn")]
1233            alpn: None,
1234            #[cfg(feature = "session-tickets")]
1235            enable_session_tickets: false,
1236        }
1237    }
1238
1239    /// Specifies the set of root certificates to trust when
1240    /// verifying the server's certificate.
1241    #[inline]
1242    pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1243        self.certs = certs.to_owned();
1244        self
1245    }
1246
1247    /// Add the certificate the set of root certificates to trust
1248    /// when verifying the server's certificate.
1249    #[inline]
1250    pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1251        self.certs.push(certs.to_owned());
1252        self
1253    }
1254
1255    /// Specifies whether to trust the built-in certificates in addition
1256    /// to specified anchor certificates.
1257    #[inline(always)]
1258    pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1259        self.trust_certs_only = only;
1260        self
1261    }
1262
1263    /// Specifies whether to trust invalid certificates.
1264    ///
1265    /// # Warning
1266    ///
1267    /// You should think very carefully before using this method. If invalid
1268    /// certificates are trusted, *any* certificate for *any* site will be
1269    /// trusted for use. This includes expired certificates. This introduces
1270    /// significant vulnerabilities, and should only be used as a last resort.
1271    #[inline(always)]
1272    pub fn danger_accept_invalid_certs(&mut self, no_verify: bool) -> &mut Self {
1273        self.danger_accept_invalid_certs = no_verify;
1274        self
1275    }
1276
1277    /// Specifies whether to use Server Name Indication (SNI).
1278    #[inline(always)]
1279    pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1280        self.use_sni = use_sni;
1281        self
1282    }
1283
1284    /// Specifies whether to verify that the server's hostname matches its certificate.
1285    ///
1286    /// # Warning
1287    ///
1288    /// You should think very carefully before using this method. If hostnames are not verified,
1289    /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1290    /// vulnerabilities, and should only be used as a last resort.
1291    #[inline(always)]
1292    pub fn danger_accept_invalid_hostnames(
1293        &mut self,
1294        danger_accept_invalid_hostnames: bool,
1295    ) -> &mut Self {
1296        self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1297        self
1298    }
1299
1300    /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
1301    pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1302        self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1303        self
1304    }
1305
1306    /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
1307    pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1308        self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1309        self
1310    }
1311
1312    /// Use the specified identity as a SSL/TLS client certificate.
1313    pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1314        self.identity = Some(identity.clone());
1315        self.chain = chain.to_owned();
1316        self
1317    }
1318
1319    /// Configure the minimum protocol that this client will support.
1320    #[inline(always)]
1321    pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1322        self.protocol_min = Some(min);
1323        self
1324    }
1325
1326    /// Configure the minimum protocol that this client will support.
1327    #[inline(always)]
1328    pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1329        self.protocol_max = Some(max);
1330        self
1331    }
1332
1333    /// Configures the set of protocols used for ALPN.
1334    #[cfg(feature = "alpn")]
1335    pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1336        self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1337        self
1338    }
1339
1340    /// Configures the use of the RFC 5077 `SessionTicket` extension.
1341    ///
1342    /// Defaults to `false`.
1343    #[cfg(feature = "session-tickets")]
1344    #[inline(always)]
1345    pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1346        self.enable_session_tickets = enable;
1347        self
1348    }
1349
1350    /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1351    ///
1352    /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
1353    pub fn handshake<S>(
1354        &self,
1355        domain: &str,
1356        stream: S,
1357    ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1358    where
1359        S: Read + Write,
1360    {
1361        // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1362        // of the handshake logic through that.
1363        let stream = MidHandshakeSslStream {
1364            stream: self.ctx_into_stream(domain, stream)?,
1365            error: Error::from(errSecSuccess),
1366        };
1367
1368        let certs = self.certs.clone();
1369        let stream = MidHandshakeClientBuilder {
1370            stream,
1371            domain: if self.danger_accept_invalid_hostnames {
1372                None
1373            } else {
1374                Some(domain.to_string())
1375            },
1376            certs,
1377            trust_certs_only: self.trust_certs_only,
1378            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1379        };
1380        stream.handshake()
1381    }
1382
1383    fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1384    where
1385        S: Read + Write,
1386    {
1387        let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1388
1389        if self.use_sni {
1390            ctx.set_peer_domain_name(domain)?;
1391        }
1392        if let Some(ref identity) = self.identity {
1393            ctx.set_certificate(identity, &self.chain)?;
1394        }
1395        #[cfg(feature = "alpn")]
1396        {
1397            if let Some(ref alpn) = self.alpn {
1398                ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1399            }
1400        }
1401        #[cfg(feature = "session-tickets")]
1402        {
1403            if self.enable_session_tickets {
1404                // We must use the domain here to ensure that we go through certificate validation
1405                // again rather than resuming the session if the domain changes.
1406                ctx.set_peer_id(domain.as_bytes())?;
1407                ctx.set_session_tickets_enabled(true)?;
1408            }
1409        }
1410        ctx.set_break_on_server_auth(true)?;
1411        self.configure_protocols(&mut ctx)?;
1412        self.configure_ciphers(&mut ctx)?;
1413
1414        ctx.into_stream(stream)
1415    }
1416
1417    fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1418        if let Some(min) = self.protocol_min {
1419            ctx.set_protocol_version_min(min)?;
1420        }
1421        if let Some(max) = self.protocol_max {
1422            ctx.set_protocol_version_max(max)?;
1423        }
1424        Ok(())
1425    }
1426
1427    fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1428        let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1429            ctx.enabled_ciphers()?
1430        } else {
1431            self.whitelisted_ciphers.clone()
1432        };
1433
1434        if !self.blacklisted_ciphers.is_empty() {
1435            ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1436        }
1437
1438        ctx.set_enabled_ciphers(&ciphers)?;
1439        Ok(())
1440    }
1441}
1442
1443/// A builder type to simplify the creation of server-side `SslStream`s.
1444#[derive(Debug)]
1445pub struct ServerBuilder {
1446    identity: SecIdentity,
1447    certs: Vec<SecCertificate>,
1448}
1449
1450impl ServerBuilder {
1451    /// Creates a new `ServerBuilder` which will use the specified identity
1452    /// and certificate chain for handshakes.
1453    #[must_use]
1454    pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1455        Self {
1456            identity: identity.clone(),
1457            certs: certs.to_owned(),
1458        }
1459    }
1460
1461    /// Creates a new `ServerBuilder` which will use the identity
1462    /// from the given PKCS #12 data.
1463    ///
1464    /// This operation fails if PKCS #12 file contains zero or more than one identity.
1465    ///
1466    /// This is a shortcut for the most common operation.
1467    pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1468        let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1469            .passphrase(passphrase)
1470            .import(pkcs12_der)?
1471            .into_iter()
1472            .filter_map(|identity| {
1473                let certs = identity.cert_chain.unwrap_or_default();
1474                identity.identity.map(|identity| (identity, certs))
1475            })
1476            .collect();
1477        if identities.len() == 1 {
1478            let (identity, certs) = identities.pop().unwrap();
1479            Ok(ServerBuilder::new(&identity, &certs))
1480        } else {
1481            // This error code is not really helpful
1482            Err(Error::from_code(errSecParam))
1483        }
1484    }
1485
1486    /// Create a SSL context for lower-level stream initialization.
1487    pub fn new_ssl_context(&self) -> Result<SslContext> {
1488        let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1489        ctx.set_certificate(&self.identity, &self.certs)?;
1490        Ok(ctx)
1491    }
1492
1493    /// Initiates a new SSL/TLS session over a stream.
1494    pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1495    where
1496        S: Read + Write,
1497    {
1498        match self.new_ssl_context()?.handshake(stream) {
1499            Ok(stream) => Ok(stream),
1500            Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1501            Err(HandshakeError::Failure(err)) => Err(err),
1502        }
1503    }
1504}
1505
1506#[cfg(test)]
1507mod test {
1508    use std::{io, io::prelude::*, net::TcpStream};
1509
1510    use super::*;
1511
1512    #[test]
1513    fn server_builder_from_pkcs12() {
1514        let pkcs12_der = include_bytes!("../test/server.p12");
1515        ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1516    }
1517
1518    #[test]
1519    fn connect() {
1520        let mut ctx = p!(SslContext::new(
1521            SslProtocolSide::CLIENT,
1522            SslConnectionType::STREAM
1523        ));
1524        p!(ctx.set_peer_domain_name("google.com"));
1525        let stream = p!(TcpStream::connect("google.com:443"));
1526        p!(ctx.handshake(stream));
1527    }
1528
1529    #[test]
1530    fn connect_bad_domain() {
1531        let mut ctx = p!(SslContext::new(
1532            SslProtocolSide::CLIENT,
1533            SslConnectionType::STREAM
1534        ));
1535        p!(ctx.set_peer_domain_name("foobar.com"));
1536        let stream = p!(TcpStream::connect("google.com:443"));
1537        if ctx.handshake(stream).is_ok() {
1538            panic!("expected failure")
1539        }
1540    }
1541
1542    #[test]
1543    fn load_page() {
1544        let mut ctx = p!(SslContext::new(
1545            SslProtocolSide::CLIENT,
1546            SslConnectionType::STREAM
1547        ));
1548        p!(ctx.set_peer_domain_name("google.com"));
1549        let stream = p!(TcpStream::connect("google.com:443"));
1550        let mut stream = p!(ctx.handshake(stream));
1551        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1552        p!(stream.flush());
1553        let mut buf = vec![];
1554        p!(stream.read_to_end(&mut buf));
1555        println!("{}", String::from_utf8_lossy(&buf));
1556    }
1557
1558    #[test]
1559    fn client_no_session_ticket_resumption() {
1560        for _ in 0..2 {
1561            let stream = p!(TcpStream::connect("google.com:443"));
1562
1563            // Manually handshake here.
1564            let stream = MidHandshakeSslStream {
1565                stream: ClientBuilder::new()
1566                    .ctx_into_stream("google.com", stream)
1567                    .unwrap(),
1568                error: Error::from(errSecSuccess),
1569            };
1570
1571            let mut result = stream.handshake();
1572
1573            if let Err(HandshakeError::Interrupted(stream)) = result {
1574                assert!(stream.server_auth_completed());
1575                result = stream.handshake();
1576            } else {
1577                panic!("Unexpectedly skipped server auth");
1578            }
1579
1580            assert!(result.is_ok());
1581        }
1582    }
1583
1584    #[test]
1585    #[cfg(feature = "session-tickets")]
1586    fn client_session_ticket_resumption() {
1587        // The first time through this loop, we should do a full handshake. The second time, we
1588        // should immediately finish the handshake without breaking on server auth.
1589        for i in 0..2 {
1590            let stream = p!(TcpStream::connect("google.com:443"));
1591            let mut builder = ClientBuilder::new();
1592            builder.enable_session_tickets(true);
1593
1594            // Manually handshake here.
1595            let stream = MidHandshakeSslStream {
1596                stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1597                error: Error::from(errSecSuccess),
1598            };
1599
1600            let mut result = stream.handshake();
1601
1602            if let Err(HandshakeError::Interrupted(stream)) = result {
1603                assert!(stream.server_auth_completed());
1604                assert_eq!(
1605                    i, 0,
1606                    "Session ticket resumption did not work, server auth was not skipped"
1607                );
1608                result = stream.handshake();
1609            } else {
1610                assert_eq!(i, 1, "Unexpectedly skipped server auth");
1611            }
1612
1613            assert!(result.is_ok());
1614        }
1615    }
1616
1617    #[test]
1618    #[cfg(feature = "alpn")]
1619    fn client_alpn_accept() {
1620        let mut ctx = p!(SslContext::new(
1621            SslProtocolSide::CLIENT,
1622            SslConnectionType::STREAM
1623        ));
1624        p!(ctx.set_peer_domain_name("google.com"));
1625        p!(ctx.set_alpn_protocols(&["h2"]));
1626        let stream = p!(TcpStream::connect("google.com:443"));
1627        let stream = ctx.handshake(stream).unwrap();
1628        assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1629    }
1630
1631    #[test]
1632    #[cfg(feature = "alpn")]
1633    fn client_alpn_reject() {
1634        let mut ctx = p!(SslContext::new(
1635            SslProtocolSide::CLIENT,
1636            SslConnectionType::STREAM
1637        ));
1638        p!(ctx.set_peer_domain_name("google.com"));
1639        p!(ctx.set_alpn_protocols(&["h2c"]));
1640        let stream = p!(TcpStream::connect("google.com:443"));
1641        let stream = ctx.handshake(stream).unwrap();
1642        assert!(stream.context().alpn_protocols().is_err());
1643    }
1644
1645    #[test]
1646    fn client_no_anchor_certs() {
1647        let stream = p!(TcpStream::connect("google.com:443"));
1648        assert!(ClientBuilder::new()
1649            .trust_anchor_certificates_only(true)
1650            .handshake("google.com", stream)
1651            .is_err());
1652    }
1653
1654    #[test]
1655    fn client_bad_domain() {
1656        let stream = p!(TcpStream::connect("google.com:443"));
1657        assert!(ClientBuilder::new()
1658            .handshake("foobar.com", stream)
1659            .is_err());
1660    }
1661
1662    #[test]
1663    fn client_bad_domain_ignored() {
1664        let stream = p!(TcpStream::connect("google.com:443"));
1665        ClientBuilder::new()
1666            .danger_accept_invalid_hostnames(true)
1667            .handshake("foobar.com", stream)
1668            .unwrap();
1669    }
1670
1671    #[test]
1672    fn connect_no_verify_ssl() {
1673        let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1674        let mut builder = ClientBuilder::new();
1675        builder.danger_accept_invalid_certs(true);
1676        builder.handshake("expired.badssl.com", stream).unwrap();
1677    }
1678
1679    #[test]
1680    fn load_page_client() {
1681        let stream = p!(TcpStream::connect("google.com:443"));
1682        let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1683        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1684        p!(stream.flush());
1685        let mut buf = vec![];
1686        p!(stream.read_to_end(&mut buf));
1687        println!("{}", String::from_utf8_lossy(&buf));
1688    }
1689
1690    #[test]
1691    #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
1692    fn cipher_configuration() {
1693        let mut ctx = p!(SslContext::new(
1694            SslProtocolSide::SERVER,
1695            SslConnectionType::STREAM
1696        ));
1697        let ciphers = p!(ctx.enabled_ciphers());
1698        let ciphers = ciphers
1699            .iter()
1700            .enumerate()
1701            .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1702            .collect::<Vec<_>>();
1703        p!(ctx.set_enabled_ciphers(&ciphers));
1704        assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1705    }
1706
1707    #[test]
1708    fn test_builder_whitelist_ciphers() {
1709        let stream = p!(TcpStream::connect("google.com:443"));
1710
1711        let ctx = p!(SslContext::new(
1712            SslProtocolSide::CLIENT,
1713            SslConnectionType::STREAM
1714        ));
1715        assert!(p!(ctx.enabled_ciphers()).len() > 1);
1716
1717        let ciphers = p!(ctx.enabled_ciphers());
1718        let cipher = ciphers.first().unwrap();
1719        let stream = p!(ClientBuilder::new()
1720            .whitelist_ciphers(&[*cipher])
1721            .ctx_into_stream("google.com", stream));
1722
1723        assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1724    }
1725
1726    #[test]
1727    #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
1728    fn test_builder_blacklist_ciphers() {
1729        let stream = p!(TcpStream::connect("google.com:443"));
1730
1731        let ctx = p!(SslContext::new(
1732            SslProtocolSide::CLIENT,
1733            SslConnectionType::STREAM
1734        ));
1735        let num = p!(ctx.enabled_ciphers()).len();
1736        assert!(num > 1);
1737
1738        let ciphers = p!(ctx.enabled_ciphers());
1739        let cipher = ciphers.first().unwrap();
1740        let stream = p!(ClientBuilder::new()
1741            .blacklist_ciphers(&[*cipher])
1742            .ctx_into_stream("google.com", stream));
1743
1744        assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1745    }
1746
1747    #[test]
1748    fn idle_context_peer_trust() {
1749        let ctx = p!(SslContext::new(
1750            SslProtocolSide::SERVER,
1751            SslConnectionType::STREAM
1752        ));
1753        assert!(ctx.peer_trust2().is_err());
1754    }
1755
1756    #[test]
1757    fn peer_id() {
1758        let mut ctx = p!(SslContext::new(
1759            SslProtocolSide::SERVER,
1760            SslConnectionType::STREAM
1761        ));
1762        assert!(p!(ctx.peer_id()).is_none());
1763        p!(ctx.set_peer_id(b"foobar"));
1764        assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1765    }
1766
1767    #[test]
1768    fn peer_domain_name() {
1769        let mut ctx = p!(SslContext::new(
1770            SslProtocolSide::CLIENT,
1771            SslConnectionType::STREAM
1772        ));
1773        assert_eq!("", p!(ctx.peer_domain_name()));
1774        p!(ctx.set_peer_domain_name("foobar.com"));
1775        assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1776    }
1777
1778    #[test]
1779    #[should_panic(expected = "blammo")]
1780    fn write_panic() {
1781        struct ExplodingStream(TcpStream);
1782
1783        impl Read for ExplodingStream {
1784            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1785                self.0.read(buf)
1786            }
1787        }
1788
1789        impl Write for ExplodingStream {
1790            fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1791                panic!("blammo");
1792            }
1793
1794            fn flush(&mut self) -> io::Result<()> {
1795                self.0.flush()
1796            }
1797        }
1798
1799        let mut ctx = p!(SslContext::new(
1800            SslProtocolSide::CLIENT,
1801            SslConnectionType::STREAM
1802        ));
1803        p!(ctx.set_peer_domain_name("google.com"));
1804        let stream = p!(TcpStream::connect("google.com:443"));
1805        let _ = ctx.handshake(ExplodingStream(stream));
1806    }
1807
1808    #[test]
1809    #[should_panic(expected = "blammo")]
1810    fn read_panic() {
1811        struct ExplodingStream(TcpStream);
1812
1813        impl Read for ExplodingStream {
1814            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1815                panic!("blammo");
1816            }
1817        }
1818
1819        impl Write for ExplodingStream {
1820            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1821                self.0.write(buf)
1822            }
1823
1824            fn flush(&mut self) -> io::Result<()> {
1825                self.0.flush()
1826            }
1827        }
1828
1829        let mut ctx = p!(SslContext::new(
1830            SslProtocolSide::CLIENT,
1831            SslConnectionType::STREAM
1832        ));
1833        p!(ctx.set_peer_domain_name("google.com"));
1834        let stream = p!(TcpStream::connect("google.com:443"));
1835        let _ = ctx.handshake(ExplodingStream(stream));
1836    }
1837
1838    #[test]
1839    fn zero_length_buffers() {
1840        let mut ctx = p!(SslContext::new(
1841            SslProtocolSide::CLIENT,
1842            SslConnectionType::STREAM
1843        ));
1844        p!(ctx.set_peer_domain_name("google.com"));
1845        let stream = p!(TcpStream::connect("google.com:443"));
1846        let mut stream = ctx.handshake(stream).unwrap();
1847        assert_eq!(stream.write(b"").unwrap(), 0);
1848        assert_eq!(stream.read(&mut []).unwrap(), 0);
1849    }
1850}