blaze_schannel/
lib.rs

1//! Bindings to the Windows SChannel APIs.
2#![cfg(windows)]
3#![warn(missing_docs)]
4#![allow(non_upper_case_globals)]
5
6#[macro_use]
7extern crate lazy_static;
8
9use windows_sys::Win32::Foundation;
10use windows_sys::Win32::Security::Credentials;
11use windows_sys::Win32::Security::Authentication::Identity;
12use windows_sys::Win32::Security::Cryptography;
13
14use std::{fmt, io, ptr, mem};
15use std::ffi::OsStr;
16use std::os::windows::ffi::OsStrExt;
17use std::ops::Deref;
18use std::slice;
19use std::ffi::c_void;
20use std::any::Any;
21use std::cmp;
22use std::error::Error;
23use std::io::{BufRead, Cursor, Read, Write};
24use std::sync::Arc;
25
26macro_rules! inner {
27    ($t:path, $raw:ty) => {
28        impl crate::Inner<$raw> for $t {
29            unsafe fn from_inner(t: $raw) -> Self {
30                $t(t)
31            }
32
33            fn as_inner(&self) -> $raw {
34                self.0
35            }
36
37            fn get_mut(&mut self) -> &mut $raw {
38                &mut self.0
39            }
40        }
41
42        impl crate::RawPointer for $t {
43            unsafe fn from_ptr(t: *mut ::std::os::raw::c_void) -> $t {
44                $t(t as _)
45            }
46
47            unsafe fn as_ptr(&self) -> *mut ::std::os::raw::c_void {
48                self.0 as *mut _
49            }
50        }
51    };
52}
53
54const ACCEPT_REQUESTS: u32 = Identity::ASC_REQ_ALLOCATE_MEMORY
55    | Identity::ASC_REQ_CONFIDENTIALITY
56    | Identity::ASC_REQ_SEQUENCE_DETECT
57    | Identity::ASC_REQ_STREAM
58    | Identity::ASC_REQ_REPLAY_DETECT;
59
60
61/// Wrapper of a winapi certificate, or a `PCCERT_CONTEXT`.
62#[derive(Debug)]
63pub struct CertContext(*const Cryptography::CERT_CONTEXT);
64
65unsafe impl Sync for CertContext {}
66
67unsafe impl Send for CertContext {}
68
69impl Drop for CertContext {
70    fn drop(&mut self) {
71        unsafe {
72            Cryptography::CertFreeCertificateContext(self.0);
73        }
74    }
75}
76
77impl Clone for CertContext {
78    fn clone(&self) -> CertContext {
79        unsafe { CertContext(Cryptography::CertDuplicateCertificateContext(self.0)) }
80    }
81}
82
83inner!(CertContext, *const Cryptography::CERT_CONTEXT);
84
85/// Representation of certificate store on Windows, wrapping a `HCERTSTORE`.
86pub struct CertStore(Cryptography::HCERTSTORE);
87
88
89inner!(CertStore, Cryptography::HCERTSTORE);
90
91unsafe impl Sync for CertStore {}
92
93unsafe impl Send for CertStore {}
94
95impl fmt::Debug for CertStore {
96    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
97        fmt.debug_struct("CertStore").finish()
98    }
99}
100
101impl Drop for CertStore {
102    fn drop(&mut self) {
103        unsafe {
104            Cryptography::CertCloseStore(self.0, 0);
105        }
106    }
107}
108
109impl Clone for CertStore {
110    fn clone(&self) -> CertStore {
111        unsafe { CertStore(Cryptography::CertDuplicateStore(self.0)) }
112    }
113}
114
115impl CertStore {
116    pub fn import_pkcs12(data: &[u8], password: Option<&str>) -> io::Result<CertStore> {
117        unsafe {
118            let blob = Cryptography::CRYPTOAPI_BLOB {
119                cbData: data.len() as u32,
120                pbData: data.as_ptr() as *mut u8,
121            };
122            let password = password.map(|s| {
123                OsStr::new(s)
124                    .encode_wide()
125                    .chain(Some(0))
126                    .collect::<Vec<_>>()
127            });
128            let password = password.as_ref().map(|s| s.as_ptr());
129            let password = password.unwrap_or(ptr::null());
130            let res = Cryptography::PFXImportCertStore(
131                &blob,
132                password,
133                Cryptography::CRYPT_KEY_FLAGS::default(),
134            );
135            if !res.is_null() {
136                Ok(CertStore(res))
137            } else {
138                Err(io::Error::last_os_error())
139            }
140        }
141    }
142
143    pub fn first(&self) -> Option<CertContext> {
144        unsafe {
145            let value = Cryptography::CertEnumCertificatesInStore(self.0, ptr::null_mut());
146            if value.is_null() {
147                None
148            } else {
149                let next = CertContext::from_inner(value);
150                Some(next)
151            }
152        }
153    }
154}
155
156
157pub struct ContextBuffer(pub Identity::SecBuffer);
158
159impl Drop for ContextBuffer {
160    fn drop(&mut self) {
161        unsafe {
162            Identity::FreeContextBuffer(self.0.pvBuffer);
163        }
164    }
165}
166
167impl Deref for ContextBuffer {
168    type Target = [u8];
169
170    fn deref(&self) -> &[u8] {
171        unsafe { slice::from_raw_parts(self.0.pvBuffer as *const _, self.0.cbBuffer as usize) }
172    }
173}
174
175
176pub struct SecurityContext(Credentials::SecHandle);
177
178impl Drop for SecurityContext {
179    fn drop(&mut self) {
180        unsafe {
181            Identity::DeleteSecurityContext(&self.0);
182        }
183    }
184}
185
186impl Inner<Credentials::SecHandle> for SecurityContext {
187    unsafe fn from_inner(inner: Credentials::SecHandle) -> SecurityContext {
188        SecurityContext(inner)
189    }
190
191    fn as_inner(&self) -> Credentials::SecHandle {
192        self.0
193    }
194
195    fn get_mut(&mut self) -> &mut Credentials::SecHandle {
196        &mut self.0
197    }
198}
199
200impl SecurityContext {
201
202    pub fn new() -> SecurityContext{
203        unsafe { return SecurityContext(mem::zeroed()); }
204    }
205
206    unsafe fn attribute<T>(&self, attr: Identity::SECPKG_ATTR) -> io::Result<T> {
207        let mut value = mem::zeroed();
208        let status =
209            Identity::QueryContextAttributesW(&self.0, attr, &mut value as *mut _ as *mut _);
210        match status {
211            Foundation::SEC_E_OK => Ok(value),
212            err => Err(io::Error::from_raw_os_error(err)),
213        }
214    }
215
216    pub fn stream_sizes(&self) -> io::Result<Identity::SecPkgContext_StreamSizes> {
217        unsafe { self.attribute(Identity::SECPKG_ATTR_STREAM_SIZES) }
218    }
219}
220
221
222lazy_static! {
223    static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> = Cryptography::szOID_PKIX_KP_SERVER_AUTH
224        .bytes()
225        .chain(Some(0))
226        .collect();
227    static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> = Cryptography::szOID_SERVER_GATED_CRYPTO
228        .bytes()
229        .chain(Some(0))
230        .collect();
231    static ref szOID_SGC_NETSCAPE: Vec<u8> = Cryptography::szOID_SGC_NETSCAPE
232        .bytes()
233        .chain(Some(0))
234        .collect();
235}
236
237enum State {
238    Initializing {
239        needs_flush: bool,
240        more_calls: bool,
241        shutting_down: bool,
242    },
243    Streaming {
244        sizes: Identity::SecPkgContext_StreamSizes,
245    },
246    Shutdown,
247}
248
249pub struct TlsStream<S> {
250    cred: SchannelCred,
251    context: SecurityContext,
252    stream: S,
253    state: State,
254    accept_first: bool,
255    needs_read: usize,
256    dec_in: Cursor<Vec<u8>>,
257    enc_in: Cursor<Vec<u8>>,
258    out_buf: Cursor<Vec<u8>>,
259    last_write_len: usize,
260}
261
262impl<S> TlsStream<S> {
263    pub fn new(
264        cred: SchannelCred,
265        stream: S,
266    ) -> Result<TlsStream<S>, HandshakeError<S>>
267        where
268            S: Read + Write,
269    {
270        let context = SecurityContext::new();
271        let stream = TlsStream {
272            cred,
273            context,
274            stream,
275            accept_first: true,
276            state: State::Initializing {
277                needs_flush: false,
278                more_calls: true,
279                shutting_down: false,
280            },
281            needs_read: 1,
282            dec_in: Cursor::new(Vec::new()),
283            enc_in: Cursor::new(Vec::new()),
284            out_buf: Cursor::new(Vec::new()),
285            last_write_len: 0,
286        };
287
288        MidHandshakeTlsStream { inner: stream }.handshake()
289    }
290}
291
292
293/// ensures that a TlsStream is always Sync/Send
294fn _is_sync() {
295    fn sync<T: Sync + Send>() {}
296    sync::<TlsStream<()>>();
297}
298
299#[derive(Debug)]
300pub enum HandshakeError<S> {
301    IO(io::Error),
302    Interrupted(MidHandshakeTlsStream<S>),
303}
304
305impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
306    fn source(&self) -> Option<&(dyn Error + 'static)> {
307        match *self {
308            HandshakeError::IO(ref e) => Some(e),
309            HandshakeError::Interrupted(_) => None,
310        }
311    }
312}
313
314impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
315    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
316        let desc = match *self {
317            HandshakeError::IO(_) => "failed to perform handshake",
318            HandshakeError::Interrupted(_) => "interrupted performing handshake",
319        };
320        write!(f, "{}", desc)?;
321        if let Some(e) = self.source() {
322            write!(f, ": {}", e)?;
323        }
324        Ok(())
325    }
326}
327
328/// A stream which has not yet completed its handshake.
329#[derive(Debug)]
330pub struct MidHandshakeTlsStream<S> {
331    inner: TlsStream<S>,
332}
333
334impl<S> fmt::Debug for TlsStream<S>
335    where
336        S: fmt::Debug,
337{
338    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
339        fmt.debug_struct("TlsStream")
340            .field("stream", &self.stream)
341            .finish()
342    }
343}
344
345impl<S> TlsStream<S> {
346    /// Returns a reference to the wrapped stream.
347    pub fn get_ref(&self) -> &S {
348        &self.stream
349    }
350
351    /// Returns a mutable reference to the wrapped stream.
352    pub fn get_mut(&mut self) -> &mut S {
353        &mut self.stream
354    }
355}
356
357impl<S> TlsStream<S>
358    where
359        S: Read + Write,
360{
361
362    /// Returns a reference to the buffer of pending data.
363    ///
364    /// Like `BufRead::fill_buf` except that it will return an empty slice
365    /// rather than reading from the wrapped stream if there is no buffered
366    /// data.
367    pub fn get_buf(&self) -> &[u8] {
368        &self.dec_in.get_ref()[self.dec_in.position() as usize..]
369    }
370
371    /// Shuts the TLS session down.
372    pub fn shutdown(&mut self) -> io::Result<()> {
373        match self.state {
374            State::Shutdown => return Ok(()),
375            State::Initializing {
376                shutting_down: true,
377                ..
378            } => {}
379            _ => {
380                unsafe {
381                    let mut token = Identity::SCHANNEL_SHUTDOWN;
382                    let ptr = &mut token as *mut _ as *mut u8;
383                    let size = mem::size_of_val(&token);
384                    let token = slice::from_raw_parts_mut(ptr, size);
385                    let mut buf = [secbuf(Identity::SECBUFFER_TOKEN, Some(token))];
386                    let desc = secbuf_desc(&mut buf);
387
388                    match Identity::ApplyControlToken(self.context.get_mut(), &desc) {
389                        Foundation::SEC_E_OK => {}
390                        err => return Err(io::Error::from_raw_os_error(err)),
391                    }
392                }
393
394                self.state = State::Initializing {
395                    needs_flush: false,
396                    more_calls: true,
397                    shutting_down: true,
398                };
399                self.needs_read = 0;
400            }
401        }
402
403        self.initialize().map(|_| ())
404    }
405
406    fn step_initialize(&mut self) -> io::Result<()> {
407        unsafe {
408            let pos = self.enc_in.position() as usize;
409            let mut inbufs = vec![
410                secbuf(
411                    Identity::SECBUFFER_TOKEN,
412                    Some(&mut self.enc_in.get_mut()[..pos]),
413                ),
414                secbuf(Identity::SECBUFFER_EMPTY, None),
415            ];
416            let inbuf_desc = secbuf_desc(&mut inbufs[..]);
417
418            let mut outbufs = [
419                secbuf(Identity::SECBUFFER_TOKEN, None),
420                secbuf(Identity::SECBUFFER_ALERT, None),
421                secbuf(Identity::SECBUFFER_EMPTY, None),
422            ];
423            let mut outbuf_desc = secbuf_desc(&mut outbufs);
424
425            let mut attributes = 0;
426
427            let status = {
428                let ptr = if self.accept_first {
429                    ptr::null_mut()
430                } else {
431                    self.context.get_mut()
432                };
433                Identity::AcceptSecurityContext(
434                    &self.cred.as_inner(),
435                    ptr,
436                    &inbuf_desc,
437                    ACCEPT_REQUESTS,
438                    0,
439                    self.context.get_mut(),
440                    &mut outbuf_desc,
441                    &mut attributes,
442                    ptr::null_mut(),
443                )
444            };
445
446            for buf in &outbufs[1..] {
447                if !buf.pvBuffer.is_null() {
448                    Identity::FreeContextBuffer(buf.pvBuffer);
449                }
450            }
451
452            match status {
453                Foundation::SEC_E_OK => {
454                    let nread = if inbufs[1].BufferType == Identity::SECBUFFER_EXTRA {
455                        self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
456                    } else {
457                        self.enc_in.position() as usize
458                    };
459                    let to_write = if outbufs[0].pvBuffer.is_null() {
460                        None
461                    } else {
462                        Some(ContextBuffer(outbufs[0]))
463                    };
464
465                    self.consume_enc_in(nread);
466                    self.needs_read = (self.enc_in.position() == 0) as usize;
467                    if let Some(to_write) = to_write {
468                        self.out_buf.get_mut().extend_from_slice(&to_write);
469                    }
470                    if self.enc_in.position() != 0 {
471                        self.decrypt()?;
472                    }
473                    if let State::Initializing {
474                        ref mut more_calls, ..
475                    } = self.state
476                    {
477                        *more_calls = false;
478                    }
479                }
480                Foundation::SEC_I_CONTINUE_NEEDED => {
481                    // Windows apparently doesn't like AcceptSecurityContext
482                    // being called as if it were the second time unless the
483                    // first call to AcceptSecurityContext succeeded with
484                    // CONTINUE_NEEDED.
485                    //
486                    // In other words, if we were to set `accept_first` to
487                    // `false` after the literal first call to
488                    // `AcceptSecurityContext` while the call returned
489                    // INCOMPLETE_MESSAGE, the next call would return an error.
490                    //
491                    // For that reason we only set `accept_first` to false here
492                    // once we've actually successfully received the full
493                    // "token" from the client.
494                    self.accept_first = false;
495                    let nread = if inbufs[1].BufferType == Identity::SECBUFFER_EXTRA {
496                        self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
497                    } else {
498                        self.enc_in.position() as usize
499                    };
500                    let to_write = ContextBuffer(outbufs[0]);
501
502                    self.consume_enc_in(nread);
503                    self.needs_read = (self.enc_in.position() == 0) as usize;
504                    self.out_buf.get_mut().extend_from_slice(&to_write);
505                }
506                Foundation::SEC_E_INCOMPLETE_MESSAGE => {
507                    self.needs_read = if inbufs[1].BufferType == Identity::SECBUFFER_MISSING {
508                        inbufs[1].cbBuffer as usize
509                    } else {
510                        1
511                    };
512                }
513                err => return Err(io::Error::from_raw_os_error(err)),
514            }
515            Ok(())
516        }
517    }
518
519    fn initialize(&mut self) -> io::Result<Option<Identity::SecPkgContext_StreamSizes>> {
520        loop {
521            match self.state {
522                State::Initializing {
523                    mut needs_flush,
524                    more_calls,
525                    shutting_down,
526                } => {
527                    if self.write_out()? > 0 {
528                        needs_flush = true;
529                        if let State::Initializing {
530                            ref mut needs_flush,
531                            ..
532                        } = self.state
533                        {
534                            *needs_flush = true;
535                        }
536                    }
537
538                    if needs_flush {
539                        self.stream.flush()?;
540                        if let State::Initializing {
541                            ref mut needs_flush,
542                            ..
543                        } = self.state
544                        {
545                            *needs_flush = false;
546                        }
547                    }
548
549                    if !more_calls {
550                        self.state = if shutting_down {
551                            State::Shutdown
552                        } else {
553                            State::Streaming {
554                                sizes: self.context.stream_sizes()?,
555                            }
556                        };
557                        continue;
558                    }
559
560                    if self.needs_read > 0 && self.read_in()? == 0 {
561                        return Err(io::Error::new(
562                            io::ErrorKind::UnexpectedEof,
563                            "unexpected EOF during handshake",
564                        ));
565                    }
566
567                    self.step_initialize()?;
568                }
569                State::Streaming { sizes } => return Ok(Some(sizes)),
570                State::Shutdown => return Ok(None),
571            }
572        }
573    }
574
575    fn write_out(&mut self) -> io::Result<usize> {
576        let mut out = 0;
577        while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
578            let position = self.out_buf.position() as usize;
579            let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
580            out += nwritten;
581            self.out_buf.set_position((position + nwritten) as u64);
582        }
583
584        Ok(out)
585    }
586
587    fn read_in(&mut self) -> io::Result<usize> {
588        let mut sum_nread = 0;
589
590        while self.needs_read > 0 {
591            let existing_len = self.enc_in.position() as usize;
592            let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
593            if self.enc_in.get_ref().len() < min_len {
594                self.enc_in.get_mut().resize(min_len, 0);
595            }
596            let nread = {
597                let buf = &mut self.enc_in.get_mut()[existing_len..];
598                self.stream.read(buf)?
599            };
600            self.enc_in.set_position((existing_len + nread) as u64);
601            self.needs_read = self.needs_read.saturating_sub(nread);
602            if nread == 0 {
603                break;
604            }
605            sum_nread += nread;
606        }
607
608        Ok(sum_nread)
609    }
610
611    fn consume_enc_in(&mut self, nread: usize) {
612        let size = self.enc_in.position() as usize;
613        assert!(size >= nread);
614        let count = size - nread;
615
616        if count > 0 {
617            self.enc_in.get_mut().drain(..nread);
618        }
619
620        self.enc_in.set_position(count as u64);
621    }
622
623    fn decrypt(&mut self) -> io::Result<bool> {
624        unsafe {
625            let position = self.enc_in.position() as usize;
626            let mut bufs = [
627                secbuf(
628                    Identity::SECBUFFER_DATA,
629                    Some(&mut self.enc_in.get_mut()[..position]),
630                ),
631                secbuf(Identity::SECBUFFER_EMPTY, None),
632                secbuf(Identity::SECBUFFER_EMPTY, None),
633                secbuf(Identity::SECBUFFER_EMPTY, None),
634            ];
635            let bufdesc = secbuf_desc(&mut bufs);
636
637            match Identity::DecryptMessage(self.context.get_mut(), &bufdesc, 0, ptr::null_mut()) {
638                Foundation::SEC_E_OK => {
639                    let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
640                    let end = start + bufs[1].cbBuffer as usize;
641                    self.dec_in.get_mut().clear();
642                    self.dec_in
643                        .get_mut()
644                        .extend_from_slice(&self.enc_in.get_ref()[start..end]);
645                    self.dec_in.set_position(0);
646
647                    let nread = if bufs[3].BufferType == Identity::SECBUFFER_EXTRA {
648                        self.enc_in.position() as usize - bufs[3].cbBuffer as usize
649                    } else {
650                        self.enc_in.position() as usize
651                    };
652                    self.consume_enc_in(nread);
653                    self.needs_read = (self.enc_in.position() == 0) as usize;
654                    Ok(false)
655                }
656                Foundation::SEC_E_INCOMPLETE_MESSAGE => {
657                    self.needs_read = if bufs[1].BufferType == Identity::SECBUFFER_MISSING {
658                        bufs[1].cbBuffer as usize
659                    } else {
660                        1
661                    };
662                    Ok(false)
663                }
664                Foundation::SEC_I_CONTEXT_EXPIRED => Ok(true),
665                Foundation::SEC_I_RENEGOTIATE => {
666                    self.state = State::Initializing {
667                        needs_flush: false,
668                        more_calls: true,
669                        shutting_down: false,
670                    };
671
672                    let nread = if bufs[3].BufferType == Identity::SECBUFFER_EXTRA {
673                        self.enc_in.position() as usize - bufs[3].cbBuffer as usize
674                    } else {
675                        self.enc_in.position() as usize
676                    };
677                    self.consume_enc_in(nread);
678                    self.needs_read = 0;
679                    Ok(false)
680                }
681                err => Err(io::Error::from_raw_os_error(err)),
682            }
683        }
684    }
685
686    fn encrypt(&mut self, buf: &[u8], sizes: &Identity::SecPkgContext_StreamSizes, ) -> io::Result<()> {
687        assert!(buf.len() <= sizes.cbMaximumMessage as usize);
688
689        unsafe {
690            let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
691
692            if self.out_buf.get_ref().len() < len {
693                self.out_buf.get_mut().resize(len, 0);
694            }
695
696            let message_start = sizes.cbHeader as usize;
697            self.out_buf.get_mut()[message_start..message_start + buf.len()].clone_from_slice(buf);
698
699            let mut bufs = {
700                let out_buf = self.out_buf.get_mut();
701                let size = sizes.cbHeader as usize;
702
703                let header = secbuf(
704                    Identity::SECBUFFER_STREAM_HEADER,
705                    Some(&mut out_buf[..size]),
706                );
707                let data = secbuf(
708                    Identity::SECBUFFER_DATA,
709                    Some(&mut out_buf[size..size + buf.len()]),
710                );
711                let trailer = secbuf(
712                    Identity::SECBUFFER_STREAM_TRAILER,
713                    Some(&mut out_buf[size + buf.len()..]),
714                );
715                let empty = secbuf(Identity::SECBUFFER_EMPTY, None);
716                [header, data, trailer, empty]
717            };
718            let bufdesc = secbuf_desc(&mut bufs);
719
720            match Identity::EncryptMessage(self.context.get_mut(), 0, &bufdesc, 0) {
721                Foundation::SEC_E_OK => {
722                    let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
723                    self.out_buf.get_mut().truncate(len as usize);
724                    self.out_buf.set_position(0);
725                    Ok(())
726                }
727                err => Err(io::Error::from_raw_os_error(err)),
728            }
729        }
730    }
731}
732
733impl<S> MidHandshakeTlsStream<S> {
734    /// Returns a shared reference to the inner stream.
735    pub fn get_ref(&self) -> &S {
736        self.inner.get_ref()
737    }
738
739    /// Returns a mutable reference to the inner stream.
740    pub fn get_mut(&mut self) -> &mut S {
741        self.inner.get_mut()
742    }
743}
744
745impl<S> MidHandshakeTlsStream<S>
746    where
747        S: Read + Write,
748{
749    /// Restarts the handshake process.
750    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
751        match self.inner.initialize() {
752            Ok(_) => Ok(self.inner),
753            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
754                Err(HandshakeError::Interrupted(self))
755            }
756            Err(e) => Err(HandshakeError::IO(e)),
757        }
758    }
759}
760
761impl<S> Write for TlsStream<S>
762    where
763        S: Read + Write,
764{
765    /// In the case of a WouldBlock error, we expect another call
766    /// starting with the same input data
767    /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl
768    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
769        let sizes = match self.initialize()? {
770            Some(sizes) => sizes,
771            None => {
772                return Err(io::Error::from_raw_os_error(
773                    Foundation::SEC_E_CONTEXT_EXPIRED as i32,
774                ));
775            }
776        };
777
778        // if we have pending output data, it must have been because a previous
779        // attempt to send this part of the data ran into an error.
780        if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
781            let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
782            self.encrypt(&buf[..len], &sizes)?;
783            self.last_write_len = len;
784        }
785        self.write_out()?;
786
787        Ok(self.last_write_len)
788    }
789
790    fn flush(&mut self) -> io::Result<()> {
791        // Make sure the write buffer is emptied
792        self.write_out()?;
793        self.stream.flush()
794    }
795}
796
797impl<S> Read for TlsStream<S>
798    where
799        S: Read + Write,
800{
801    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
802        let nread = {
803            let read_buf = self.fill_buf()?;
804            let nread = cmp::min(buf.len(), read_buf.len());
805            buf[..nread].copy_from_slice(&read_buf[..nread]);
806            nread
807        };
808        self.consume(nread);
809        Ok(nread)
810    }
811}
812
813
814impl<S> BufRead for TlsStream<S>
815    where
816        S: Read + Write,
817{
818    fn fill_buf(&mut self) -> io::Result<&[u8]> {
819        while self.get_buf().is_empty() {
820            if self.initialize()?.is_none() {
821                break;
822            }
823
824            if self.needs_read > 0 {
825                if self.read_in()? == 0 {
826                    break;
827                }
828                self.needs_read = 0;
829            }
830
831            let eof = self.decrypt()?;
832            if eof {
833                break;
834            }
835        }
836
837        Ok(self.get_buf())
838    }
839
840    fn consume(&mut self, amt: usize) {
841        let pos = self.dec_in.position() + amt as u64;
842        assert!(pos <= self.dec_in.get_ref().len() as u64);
843        self.dec_in.set_position(pos);
844    }
845}
846
847
848lazy_static! {
849    static ref UNISP_NAME: Vec<u8> = Identity::UNISP_NAME.bytes().chain(Some(0)).collect();
850}
851
852#[derive(Clone)]
853pub struct SchannelCred(Arc<RawCredHandle>);
854
855struct RawCredHandle(Credentials::SecHandle);
856
857impl Drop for RawCredHandle {
858    fn drop(&mut self) {
859        unsafe {
860            Identity::FreeCredentialsHandle(&self.0);
861        }
862    }
863}
864
865impl SchannelCred {
866    pub fn new(cert: CertContext) -> io::Result<SchannelCred> {
867        unsafe {
868            let mut handle: Credentials::SecHandle = mem::zeroed();
869            let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed();
870            cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION;
871            cred_data.dwFlags = Identity::SCH_CRED_NO_DEFAULT_CREDS;
872
873            // Enable all protocols
874            cred_data.grbitEnabledProtocols = Identity::SP_PROT_SSL3_SERVER |
875                Identity::SP_PROT_TLS1_0_SERVER |
876                Identity::SP_PROT_TLS1_1_SERVER |
877                Identity::SP_PROT_TLS1_2_SERVER |
878                Identity::SP_PROT_TLS1_3_SERVER;
879
880            let mut certs = vec![cert];
881            cred_data.cCreds = certs.len() as u32;
882            cred_data.paCred = certs.as_mut_ptr() as _;
883
884            match Identity::AcquireCredentialsHandleA(
885                ptr::null(),
886                UNISP_NAME.as_ptr(),
887                Identity::SECPKG_CRED_INBOUND,
888                ptr::null_mut(),
889                &mut cred_data as *const _ as *const _,
890                None,
891                ptr::null_mut(),
892                &mut handle,
893                ptr::null_mut(),
894            ) {
895                Foundation::SEC_E_OK => Ok(SchannelCred::from_inner(handle)),
896                err => Err(io::Error::from_raw_os_error(err)),
897            }
898        }
899    }
900
901    unsafe fn from_inner(inner: Credentials::SecHandle) -> SchannelCred {
902        SchannelCred(Arc::new(RawCredHandle(inner)))
903    }
904
905    pub(crate) fn as_inner(&self) -> Credentials::SecHandle {
906        self.0.as_ref().0
907    }
908}
909
910/// Allows access to the underlying schannel API representation of a wrapped data type
911///
912/// Performing actions with internal handles might lead to the violation of internal assumptions
913/// and therefore is inherently unsafe.
914pub trait RawPointer {
915    /// Constructs an instance of this type from its handle / pointer.
916    /// # Safety
917    /// This function is unsafe
918    unsafe fn from_ptr(t: *mut ::std::os::raw::c_void) -> Self;
919
920    /// Get a raw pointer from the underlying handle / pointer.
921    /// # Safety
922    /// This function is unsafe
923    unsafe fn as_ptr(&self) -> *mut ::std::os::raw::c_void;
924}
925
926
927trait Inner<T> {
928    unsafe fn from_inner(t: T) -> Self;
929
930    fn as_inner(&self) -> T;
931
932    fn get_mut(&mut self) -> &mut T;
933}
934
935unsafe fn secbuf(buftype: u32, bytes: Option<&mut [u8]>) -> Identity::SecBuffer {
936    let (ptr, len) = match bytes {
937        Some(bytes) => (bytes.as_mut_ptr(), bytes.len() as u32),
938        None => (ptr::null_mut(), 0),
939    };
940    Identity::SecBuffer {
941        BufferType: buftype,
942        cbBuffer: len,
943        pvBuffer: ptr as *mut c_void,
944    }
945}
946
947unsafe fn secbuf_desc(bufs: &mut [Identity::SecBuffer]) -> Identity::SecBufferDesc {
948    Identity::SecBufferDesc {
949        ulVersion: Identity::SECBUFFER_VERSION,
950        cBuffers: bufs.len() as u32,
951        pBuffers: bufs.as_mut_ptr(),
952    }
953}