1#![cfg(windows)]
3#![warn(missing_docs)]
4#![allow(non_upper_case_globals)]
5
6#[macro_use]
7extern crate lazy_static;
8
9use windows_sys::Win32::Foundation;
10use windows_sys::Win32::Security::Credentials;
11use windows_sys::Win32::Security::Authentication::Identity;
12use windows_sys::Win32::Security::Cryptography;
13
14use std::{fmt, io, ptr, mem};
15use std::ffi::OsStr;
16use std::os::windows::ffi::OsStrExt;
17use std::ops::Deref;
18use std::slice;
19use std::ffi::c_void;
20use std::any::Any;
21use std::cmp;
22use std::error::Error;
23use std::io::{BufRead, Cursor, Read, Write};
24use std::sync::Arc;
25
26macro_rules! inner {
27 ($t:path, $raw:ty) => {
28 impl crate::Inner<$raw> for $t {
29 unsafe fn from_inner(t: $raw) -> Self {
30 $t(t)
31 }
32
33 fn as_inner(&self) -> $raw {
34 self.0
35 }
36
37 fn get_mut(&mut self) -> &mut $raw {
38 &mut self.0
39 }
40 }
41
42 impl crate::RawPointer for $t {
43 unsafe fn from_ptr(t: *mut ::std::os::raw::c_void) -> $t {
44 $t(t as _)
45 }
46
47 unsafe fn as_ptr(&self) -> *mut ::std::os::raw::c_void {
48 self.0 as *mut _
49 }
50 }
51 };
52}
53
54const ACCEPT_REQUESTS: u32 = Identity::ASC_REQ_ALLOCATE_MEMORY
55 | Identity::ASC_REQ_CONFIDENTIALITY
56 | Identity::ASC_REQ_SEQUENCE_DETECT
57 | Identity::ASC_REQ_STREAM
58 | Identity::ASC_REQ_REPLAY_DETECT;
59
60
61#[derive(Debug)]
63pub struct CertContext(*const Cryptography::CERT_CONTEXT);
64
65unsafe impl Sync for CertContext {}
66
67unsafe impl Send for CertContext {}
68
69impl Drop for CertContext {
70 fn drop(&mut self) {
71 unsafe {
72 Cryptography::CertFreeCertificateContext(self.0);
73 }
74 }
75}
76
77impl Clone for CertContext {
78 fn clone(&self) -> CertContext {
79 unsafe { CertContext(Cryptography::CertDuplicateCertificateContext(self.0)) }
80 }
81}
82
83inner!(CertContext, *const Cryptography::CERT_CONTEXT);
84
85pub struct CertStore(Cryptography::HCERTSTORE);
87
88
89inner!(CertStore, Cryptography::HCERTSTORE);
90
91unsafe impl Sync for CertStore {}
92
93unsafe impl Send for CertStore {}
94
95impl fmt::Debug for CertStore {
96 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
97 fmt.debug_struct("CertStore").finish()
98 }
99}
100
101impl Drop for CertStore {
102 fn drop(&mut self) {
103 unsafe {
104 Cryptography::CertCloseStore(self.0, 0);
105 }
106 }
107}
108
109impl Clone for CertStore {
110 fn clone(&self) -> CertStore {
111 unsafe { CertStore(Cryptography::CertDuplicateStore(self.0)) }
112 }
113}
114
115impl CertStore {
116 pub fn import_pkcs12(data: &[u8], password: Option<&str>) -> io::Result<CertStore> {
117 unsafe {
118 let blob = Cryptography::CRYPTOAPI_BLOB {
119 cbData: data.len() as u32,
120 pbData: data.as_ptr() as *mut u8,
121 };
122 let password = password.map(|s| {
123 OsStr::new(s)
124 .encode_wide()
125 .chain(Some(0))
126 .collect::<Vec<_>>()
127 });
128 let password = password.as_ref().map(|s| s.as_ptr());
129 let password = password.unwrap_or(ptr::null());
130 let res = Cryptography::PFXImportCertStore(
131 &blob,
132 password,
133 Cryptography::CRYPT_KEY_FLAGS::default(),
134 );
135 if !res.is_null() {
136 Ok(CertStore(res))
137 } else {
138 Err(io::Error::last_os_error())
139 }
140 }
141 }
142
143 pub fn first(&self) -> Option<CertContext> {
144 unsafe {
145 let value = Cryptography::CertEnumCertificatesInStore(self.0, ptr::null_mut());
146 if value.is_null() {
147 None
148 } else {
149 let next = CertContext::from_inner(value);
150 Some(next)
151 }
152 }
153 }
154}
155
156
157pub struct ContextBuffer(pub Identity::SecBuffer);
158
159impl Drop for ContextBuffer {
160 fn drop(&mut self) {
161 unsafe {
162 Identity::FreeContextBuffer(self.0.pvBuffer);
163 }
164 }
165}
166
167impl Deref for ContextBuffer {
168 type Target = [u8];
169
170 fn deref(&self) -> &[u8] {
171 unsafe { slice::from_raw_parts(self.0.pvBuffer as *const _, self.0.cbBuffer as usize) }
172 }
173}
174
175
176pub struct SecurityContext(Credentials::SecHandle);
177
178impl Drop for SecurityContext {
179 fn drop(&mut self) {
180 unsafe {
181 Identity::DeleteSecurityContext(&self.0);
182 }
183 }
184}
185
186impl Inner<Credentials::SecHandle> for SecurityContext {
187 unsafe fn from_inner(inner: Credentials::SecHandle) -> SecurityContext {
188 SecurityContext(inner)
189 }
190
191 fn as_inner(&self) -> Credentials::SecHandle {
192 self.0
193 }
194
195 fn get_mut(&mut self) -> &mut Credentials::SecHandle {
196 &mut self.0
197 }
198}
199
200impl SecurityContext {
201
202 pub fn new() -> SecurityContext{
203 unsafe { return SecurityContext(mem::zeroed()); }
204 }
205
206 unsafe fn attribute<T>(&self, attr: Identity::SECPKG_ATTR) -> io::Result<T> {
207 let mut value = mem::zeroed();
208 let status =
209 Identity::QueryContextAttributesW(&self.0, attr, &mut value as *mut _ as *mut _);
210 match status {
211 Foundation::SEC_E_OK => Ok(value),
212 err => Err(io::Error::from_raw_os_error(err)),
213 }
214 }
215
216 pub fn stream_sizes(&self) -> io::Result<Identity::SecPkgContext_StreamSizes> {
217 unsafe { self.attribute(Identity::SECPKG_ATTR_STREAM_SIZES) }
218 }
219}
220
221
222lazy_static! {
223 static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> = Cryptography::szOID_PKIX_KP_SERVER_AUTH
224 .bytes()
225 .chain(Some(0))
226 .collect();
227 static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> = Cryptography::szOID_SERVER_GATED_CRYPTO
228 .bytes()
229 .chain(Some(0))
230 .collect();
231 static ref szOID_SGC_NETSCAPE: Vec<u8> = Cryptography::szOID_SGC_NETSCAPE
232 .bytes()
233 .chain(Some(0))
234 .collect();
235}
236
237enum State {
238 Initializing {
239 needs_flush: bool,
240 more_calls: bool,
241 shutting_down: bool,
242 },
243 Streaming {
244 sizes: Identity::SecPkgContext_StreamSizes,
245 },
246 Shutdown,
247}
248
249pub struct TlsStream<S> {
250 cred: SchannelCred,
251 context: SecurityContext,
252 stream: S,
253 state: State,
254 accept_first: bool,
255 needs_read: usize,
256 dec_in: Cursor<Vec<u8>>,
257 enc_in: Cursor<Vec<u8>>,
258 out_buf: Cursor<Vec<u8>>,
259 last_write_len: usize,
260}
261
262impl<S> TlsStream<S> {
263 pub fn new(
264 cred: SchannelCred,
265 stream: S,
266 ) -> Result<TlsStream<S>, HandshakeError<S>>
267 where
268 S: Read + Write,
269 {
270 let context = SecurityContext::new();
271 let stream = TlsStream {
272 cred,
273 context,
274 stream,
275 accept_first: true,
276 state: State::Initializing {
277 needs_flush: false,
278 more_calls: true,
279 shutting_down: false,
280 },
281 needs_read: 1,
282 dec_in: Cursor::new(Vec::new()),
283 enc_in: Cursor::new(Vec::new()),
284 out_buf: Cursor::new(Vec::new()),
285 last_write_len: 0,
286 };
287
288 MidHandshakeTlsStream { inner: stream }.handshake()
289 }
290}
291
292
293fn _is_sync() {
295 fn sync<T: Sync + Send>() {}
296 sync::<TlsStream<()>>();
297}
298
299#[derive(Debug)]
300pub enum HandshakeError<S> {
301 IO(io::Error),
302 Interrupted(MidHandshakeTlsStream<S>),
303}
304
305impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
306 fn source(&self) -> Option<&(dyn Error + 'static)> {
307 match *self {
308 HandshakeError::IO(ref e) => Some(e),
309 HandshakeError::Interrupted(_) => None,
310 }
311 }
312}
313
314impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
315 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
316 let desc = match *self {
317 HandshakeError::IO(_) => "failed to perform handshake",
318 HandshakeError::Interrupted(_) => "interrupted performing handshake",
319 };
320 write!(f, "{}", desc)?;
321 if let Some(e) = self.source() {
322 write!(f, ": {}", e)?;
323 }
324 Ok(())
325 }
326}
327
328#[derive(Debug)]
330pub struct MidHandshakeTlsStream<S> {
331 inner: TlsStream<S>,
332}
333
334impl<S> fmt::Debug for TlsStream<S>
335 where
336 S: fmt::Debug,
337{
338 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
339 fmt.debug_struct("TlsStream")
340 .field("stream", &self.stream)
341 .finish()
342 }
343}
344
345impl<S> TlsStream<S> {
346 pub fn get_ref(&self) -> &S {
348 &self.stream
349 }
350
351 pub fn get_mut(&mut self) -> &mut S {
353 &mut self.stream
354 }
355}
356
357impl<S> TlsStream<S>
358 where
359 S: Read + Write,
360{
361
362 pub fn get_buf(&self) -> &[u8] {
368 &self.dec_in.get_ref()[self.dec_in.position() as usize..]
369 }
370
371 pub fn shutdown(&mut self) -> io::Result<()> {
373 match self.state {
374 State::Shutdown => return Ok(()),
375 State::Initializing {
376 shutting_down: true,
377 ..
378 } => {}
379 _ => {
380 unsafe {
381 let mut token = Identity::SCHANNEL_SHUTDOWN;
382 let ptr = &mut token as *mut _ as *mut u8;
383 let size = mem::size_of_val(&token);
384 let token = slice::from_raw_parts_mut(ptr, size);
385 let mut buf = [secbuf(Identity::SECBUFFER_TOKEN, Some(token))];
386 let desc = secbuf_desc(&mut buf);
387
388 match Identity::ApplyControlToken(self.context.get_mut(), &desc) {
389 Foundation::SEC_E_OK => {}
390 err => return Err(io::Error::from_raw_os_error(err)),
391 }
392 }
393
394 self.state = State::Initializing {
395 needs_flush: false,
396 more_calls: true,
397 shutting_down: true,
398 };
399 self.needs_read = 0;
400 }
401 }
402
403 self.initialize().map(|_| ())
404 }
405
406 fn step_initialize(&mut self) -> io::Result<()> {
407 unsafe {
408 let pos = self.enc_in.position() as usize;
409 let mut inbufs = vec![
410 secbuf(
411 Identity::SECBUFFER_TOKEN,
412 Some(&mut self.enc_in.get_mut()[..pos]),
413 ),
414 secbuf(Identity::SECBUFFER_EMPTY, None),
415 ];
416 let inbuf_desc = secbuf_desc(&mut inbufs[..]);
417
418 let mut outbufs = [
419 secbuf(Identity::SECBUFFER_TOKEN, None),
420 secbuf(Identity::SECBUFFER_ALERT, None),
421 secbuf(Identity::SECBUFFER_EMPTY, None),
422 ];
423 let mut outbuf_desc = secbuf_desc(&mut outbufs);
424
425 let mut attributes = 0;
426
427 let status = {
428 let ptr = if self.accept_first {
429 ptr::null_mut()
430 } else {
431 self.context.get_mut()
432 };
433 Identity::AcceptSecurityContext(
434 &self.cred.as_inner(),
435 ptr,
436 &inbuf_desc,
437 ACCEPT_REQUESTS,
438 0,
439 self.context.get_mut(),
440 &mut outbuf_desc,
441 &mut attributes,
442 ptr::null_mut(),
443 )
444 };
445
446 for buf in &outbufs[1..] {
447 if !buf.pvBuffer.is_null() {
448 Identity::FreeContextBuffer(buf.pvBuffer);
449 }
450 }
451
452 match status {
453 Foundation::SEC_E_OK => {
454 let nread = if inbufs[1].BufferType == Identity::SECBUFFER_EXTRA {
455 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
456 } else {
457 self.enc_in.position() as usize
458 };
459 let to_write = if outbufs[0].pvBuffer.is_null() {
460 None
461 } else {
462 Some(ContextBuffer(outbufs[0]))
463 };
464
465 self.consume_enc_in(nread);
466 self.needs_read = (self.enc_in.position() == 0) as usize;
467 if let Some(to_write) = to_write {
468 self.out_buf.get_mut().extend_from_slice(&to_write);
469 }
470 if self.enc_in.position() != 0 {
471 self.decrypt()?;
472 }
473 if let State::Initializing {
474 ref mut more_calls, ..
475 } = self.state
476 {
477 *more_calls = false;
478 }
479 }
480 Foundation::SEC_I_CONTINUE_NEEDED => {
481 self.accept_first = false;
495 let nread = if inbufs[1].BufferType == Identity::SECBUFFER_EXTRA {
496 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
497 } else {
498 self.enc_in.position() as usize
499 };
500 let to_write = ContextBuffer(outbufs[0]);
501
502 self.consume_enc_in(nread);
503 self.needs_read = (self.enc_in.position() == 0) as usize;
504 self.out_buf.get_mut().extend_from_slice(&to_write);
505 }
506 Foundation::SEC_E_INCOMPLETE_MESSAGE => {
507 self.needs_read = if inbufs[1].BufferType == Identity::SECBUFFER_MISSING {
508 inbufs[1].cbBuffer as usize
509 } else {
510 1
511 };
512 }
513 err => return Err(io::Error::from_raw_os_error(err)),
514 }
515 Ok(())
516 }
517 }
518
519 fn initialize(&mut self) -> io::Result<Option<Identity::SecPkgContext_StreamSizes>> {
520 loop {
521 match self.state {
522 State::Initializing {
523 mut needs_flush,
524 more_calls,
525 shutting_down,
526 } => {
527 if self.write_out()? > 0 {
528 needs_flush = true;
529 if let State::Initializing {
530 ref mut needs_flush,
531 ..
532 } = self.state
533 {
534 *needs_flush = true;
535 }
536 }
537
538 if needs_flush {
539 self.stream.flush()?;
540 if let State::Initializing {
541 ref mut needs_flush,
542 ..
543 } = self.state
544 {
545 *needs_flush = false;
546 }
547 }
548
549 if !more_calls {
550 self.state = if shutting_down {
551 State::Shutdown
552 } else {
553 State::Streaming {
554 sizes: self.context.stream_sizes()?,
555 }
556 };
557 continue;
558 }
559
560 if self.needs_read > 0 && self.read_in()? == 0 {
561 return Err(io::Error::new(
562 io::ErrorKind::UnexpectedEof,
563 "unexpected EOF during handshake",
564 ));
565 }
566
567 self.step_initialize()?;
568 }
569 State::Streaming { sizes } => return Ok(Some(sizes)),
570 State::Shutdown => return Ok(None),
571 }
572 }
573 }
574
575 fn write_out(&mut self) -> io::Result<usize> {
576 let mut out = 0;
577 while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
578 let position = self.out_buf.position() as usize;
579 let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
580 out += nwritten;
581 self.out_buf.set_position((position + nwritten) as u64);
582 }
583
584 Ok(out)
585 }
586
587 fn read_in(&mut self) -> io::Result<usize> {
588 let mut sum_nread = 0;
589
590 while self.needs_read > 0 {
591 let existing_len = self.enc_in.position() as usize;
592 let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
593 if self.enc_in.get_ref().len() < min_len {
594 self.enc_in.get_mut().resize(min_len, 0);
595 }
596 let nread = {
597 let buf = &mut self.enc_in.get_mut()[existing_len..];
598 self.stream.read(buf)?
599 };
600 self.enc_in.set_position((existing_len + nread) as u64);
601 self.needs_read = self.needs_read.saturating_sub(nread);
602 if nread == 0 {
603 break;
604 }
605 sum_nread += nread;
606 }
607
608 Ok(sum_nread)
609 }
610
611 fn consume_enc_in(&mut self, nread: usize) {
612 let size = self.enc_in.position() as usize;
613 assert!(size >= nread);
614 let count = size - nread;
615
616 if count > 0 {
617 self.enc_in.get_mut().drain(..nread);
618 }
619
620 self.enc_in.set_position(count as u64);
621 }
622
623 fn decrypt(&mut self) -> io::Result<bool> {
624 unsafe {
625 let position = self.enc_in.position() as usize;
626 let mut bufs = [
627 secbuf(
628 Identity::SECBUFFER_DATA,
629 Some(&mut self.enc_in.get_mut()[..position]),
630 ),
631 secbuf(Identity::SECBUFFER_EMPTY, None),
632 secbuf(Identity::SECBUFFER_EMPTY, None),
633 secbuf(Identity::SECBUFFER_EMPTY, None),
634 ];
635 let bufdesc = secbuf_desc(&mut bufs);
636
637 match Identity::DecryptMessage(self.context.get_mut(), &bufdesc, 0, ptr::null_mut()) {
638 Foundation::SEC_E_OK => {
639 let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
640 let end = start + bufs[1].cbBuffer as usize;
641 self.dec_in.get_mut().clear();
642 self.dec_in
643 .get_mut()
644 .extend_from_slice(&self.enc_in.get_ref()[start..end]);
645 self.dec_in.set_position(0);
646
647 let nread = if bufs[3].BufferType == Identity::SECBUFFER_EXTRA {
648 self.enc_in.position() as usize - bufs[3].cbBuffer as usize
649 } else {
650 self.enc_in.position() as usize
651 };
652 self.consume_enc_in(nread);
653 self.needs_read = (self.enc_in.position() == 0) as usize;
654 Ok(false)
655 }
656 Foundation::SEC_E_INCOMPLETE_MESSAGE => {
657 self.needs_read = if bufs[1].BufferType == Identity::SECBUFFER_MISSING {
658 bufs[1].cbBuffer as usize
659 } else {
660 1
661 };
662 Ok(false)
663 }
664 Foundation::SEC_I_CONTEXT_EXPIRED => Ok(true),
665 Foundation::SEC_I_RENEGOTIATE => {
666 self.state = State::Initializing {
667 needs_flush: false,
668 more_calls: true,
669 shutting_down: false,
670 };
671
672 let nread = if bufs[3].BufferType == Identity::SECBUFFER_EXTRA {
673 self.enc_in.position() as usize - bufs[3].cbBuffer as usize
674 } else {
675 self.enc_in.position() as usize
676 };
677 self.consume_enc_in(nread);
678 self.needs_read = 0;
679 Ok(false)
680 }
681 err => Err(io::Error::from_raw_os_error(err)),
682 }
683 }
684 }
685
686 fn encrypt(&mut self, buf: &[u8], sizes: &Identity::SecPkgContext_StreamSizes, ) -> io::Result<()> {
687 assert!(buf.len() <= sizes.cbMaximumMessage as usize);
688
689 unsafe {
690 let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
691
692 if self.out_buf.get_ref().len() < len {
693 self.out_buf.get_mut().resize(len, 0);
694 }
695
696 let message_start = sizes.cbHeader as usize;
697 self.out_buf.get_mut()[message_start..message_start + buf.len()].clone_from_slice(buf);
698
699 let mut bufs = {
700 let out_buf = self.out_buf.get_mut();
701 let size = sizes.cbHeader as usize;
702
703 let header = secbuf(
704 Identity::SECBUFFER_STREAM_HEADER,
705 Some(&mut out_buf[..size]),
706 );
707 let data = secbuf(
708 Identity::SECBUFFER_DATA,
709 Some(&mut out_buf[size..size + buf.len()]),
710 );
711 let trailer = secbuf(
712 Identity::SECBUFFER_STREAM_TRAILER,
713 Some(&mut out_buf[size + buf.len()..]),
714 );
715 let empty = secbuf(Identity::SECBUFFER_EMPTY, None);
716 [header, data, trailer, empty]
717 };
718 let bufdesc = secbuf_desc(&mut bufs);
719
720 match Identity::EncryptMessage(self.context.get_mut(), 0, &bufdesc, 0) {
721 Foundation::SEC_E_OK => {
722 let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
723 self.out_buf.get_mut().truncate(len as usize);
724 self.out_buf.set_position(0);
725 Ok(())
726 }
727 err => Err(io::Error::from_raw_os_error(err)),
728 }
729 }
730 }
731}
732
733impl<S> MidHandshakeTlsStream<S> {
734 pub fn get_ref(&self) -> &S {
736 self.inner.get_ref()
737 }
738
739 pub fn get_mut(&mut self) -> &mut S {
741 self.inner.get_mut()
742 }
743}
744
745impl<S> MidHandshakeTlsStream<S>
746 where
747 S: Read + Write,
748{
749 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
751 match self.inner.initialize() {
752 Ok(_) => Ok(self.inner),
753 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
754 Err(HandshakeError::Interrupted(self))
755 }
756 Err(e) => Err(HandshakeError::IO(e)),
757 }
758 }
759}
760
761impl<S> Write for TlsStream<S>
762 where
763 S: Read + Write,
764{
765 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
769 let sizes = match self.initialize()? {
770 Some(sizes) => sizes,
771 None => {
772 return Err(io::Error::from_raw_os_error(
773 Foundation::SEC_E_CONTEXT_EXPIRED as i32,
774 ));
775 }
776 };
777
778 if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
781 let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
782 self.encrypt(&buf[..len], &sizes)?;
783 self.last_write_len = len;
784 }
785 self.write_out()?;
786
787 Ok(self.last_write_len)
788 }
789
790 fn flush(&mut self) -> io::Result<()> {
791 self.write_out()?;
793 self.stream.flush()
794 }
795}
796
797impl<S> Read for TlsStream<S>
798 where
799 S: Read + Write,
800{
801 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
802 let nread = {
803 let read_buf = self.fill_buf()?;
804 let nread = cmp::min(buf.len(), read_buf.len());
805 buf[..nread].copy_from_slice(&read_buf[..nread]);
806 nread
807 };
808 self.consume(nread);
809 Ok(nread)
810 }
811}
812
813
814impl<S> BufRead for TlsStream<S>
815 where
816 S: Read + Write,
817{
818 fn fill_buf(&mut self) -> io::Result<&[u8]> {
819 while self.get_buf().is_empty() {
820 if self.initialize()?.is_none() {
821 break;
822 }
823
824 if self.needs_read > 0 {
825 if self.read_in()? == 0 {
826 break;
827 }
828 self.needs_read = 0;
829 }
830
831 let eof = self.decrypt()?;
832 if eof {
833 break;
834 }
835 }
836
837 Ok(self.get_buf())
838 }
839
840 fn consume(&mut self, amt: usize) {
841 let pos = self.dec_in.position() + amt as u64;
842 assert!(pos <= self.dec_in.get_ref().len() as u64);
843 self.dec_in.set_position(pos);
844 }
845}
846
847
848lazy_static! {
849 static ref UNISP_NAME: Vec<u8> = Identity::UNISP_NAME.bytes().chain(Some(0)).collect();
850}
851
852#[derive(Clone)]
853pub struct SchannelCred(Arc<RawCredHandle>);
854
855struct RawCredHandle(Credentials::SecHandle);
856
857impl Drop for RawCredHandle {
858 fn drop(&mut self) {
859 unsafe {
860 Identity::FreeCredentialsHandle(&self.0);
861 }
862 }
863}
864
865impl SchannelCred {
866 pub fn new(cert: CertContext) -> io::Result<SchannelCred> {
867 unsafe {
868 let mut handle: Credentials::SecHandle = mem::zeroed();
869 let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed();
870 cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION;
871 cred_data.dwFlags = Identity::SCH_CRED_NO_DEFAULT_CREDS;
872
873 cred_data.grbitEnabledProtocols = Identity::SP_PROT_SSL3_SERVER |
875 Identity::SP_PROT_TLS1_0_SERVER |
876 Identity::SP_PROT_TLS1_1_SERVER |
877 Identity::SP_PROT_TLS1_2_SERVER |
878 Identity::SP_PROT_TLS1_3_SERVER;
879
880 let mut certs = vec![cert];
881 cred_data.cCreds = certs.len() as u32;
882 cred_data.paCred = certs.as_mut_ptr() as _;
883
884 match Identity::AcquireCredentialsHandleA(
885 ptr::null(),
886 UNISP_NAME.as_ptr(),
887 Identity::SECPKG_CRED_INBOUND,
888 ptr::null_mut(),
889 &mut cred_data as *const _ as *const _,
890 None,
891 ptr::null_mut(),
892 &mut handle,
893 ptr::null_mut(),
894 ) {
895 Foundation::SEC_E_OK => Ok(SchannelCred::from_inner(handle)),
896 err => Err(io::Error::from_raw_os_error(err)),
897 }
898 }
899 }
900
901 unsafe fn from_inner(inner: Credentials::SecHandle) -> SchannelCred {
902 SchannelCred(Arc::new(RawCredHandle(inner)))
903 }
904
905 pub(crate) fn as_inner(&self) -> Credentials::SecHandle {
906 self.0.as_ref().0
907 }
908}
909
910pub trait RawPointer {
915 unsafe fn from_ptr(t: *mut ::std::os::raw::c_void) -> Self;
919
920 unsafe fn as_ptr(&self) -> *mut ::std::os::raw::c_void;
924}
925
926
927trait Inner<T> {
928 unsafe fn from_inner(t: T) -> Self;
929
930 fn as_inner(&self) -> T;
931
932 fn get_mut(&mut self) -> &mut T;
933}
934
935unsafe fn secbuf(buftype: u32, bytes: Option<&mut [u8]>) -> Identity::SecBuffer {
936 let (ptr, len) = match bytes {
937 Some(bytes) => (bytes.as_mut_ptr(), bytes.len() as u32),
938 None => (ptr::null_mut(), 0),
939 };
940 Identity::SecBuffer {
941 BufferType: buftype,
942 cbBuffer: len,
943 pvBuffer: ptr as *mut c_void,
944 }
945}
946
947unsafe fn secbuf_desc(bufs: &mut [Identity::SecBuffer]) -> Identity::SecBufferDesc {
948 Identity::SecBufferDesc {
949 ulVersion: Identity::SECBUFFER_VERSION,
950 cBuffers: bufs.len() as u32,
951 pBuffers: bufs.as_mut_ptr(),
952 }
953}