rustls_cng/
store.rs

1//! Windows certificate store wrapper
2
3use std::{os::raw::c_void, ptr};
4
5use bitflags::bitflags;
6use windows_sys::Win32::Security::Cryptography::*;
7
8use crate::{cert::CertContext, error::CngError, Result};
9
10const MY_ENCODING_TYPE: CERT_QUERY_ENCODING_TYPE = PKCS_7_ASN_ENCODING | X509_ASN_ENCODING;
11
12macro_rules! utf16z {
13    ($str: expr) => {
14        $str.encode_utf16().chain([0]).collect::<Vec<_>>()
15    };
16}
17
18/// Certificate store type
19#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd)]
20pub enum CertStoreType {
21    LocalMachine,
22    CurrentUser,
23    CurrentService,
24}
25
26bitflags! {
27    /// Set of flags to pass to the ` CertStore::from_pkcs12 ` method.
28    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
29    pub struct Pkcs12Flags: u32 {
30        const INCLUDE_EXTENDED_PROPERTIES = 0x0010;
31        const PREFER_CNG_KSP = 0x0000_0100;
32        const ALWAYS_CNG_KSP = 0x0000_0200;
33        const ALLOW_OVERWRITE_KEY = 0x0000_4000;
34        const NO_PERSIST_KEY =0x0000_8000;
35    }
36}
37
38impl Default for Pkcs12Flags {
39    fn default() -> Self {
40        Pkcs12Flags::INCLUDE_EXTENDED_PROPERTIES | Pkcs12Flags::PREFER_CNG_KSP
41    }
42}
43
44impl CertStoreType {
45    fn as_flags(&self) -> u32 {
46        match self {
47            CertStoreType::LocalMachine => {
48                CERT_SYSTEM_STORE_LOCAL_MACHINE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
49            }
50            CertStoreType::CurrentUser => {
51                CERT_SYSTEM_STORE_CURRENT_USER_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
52            }
53            CertStoreType::CurrentService => {
54                CERT_SYSTEM_STORE_CURRENT_SERVICE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
55            }
56        }
57    }
58}
59
60/// Windows certificate store wrapper
61#[derive(Debug)]
62pub struct CertStore(HCERTSTORE);
63
64unsafe impl Send for CertStore {}
65unsafe impl Sync for CertStore {}
66
67impl CertStore {
68    /// Return an inner handle to the store
69    pub fn inner(&self) -> HCERTSTORE {
70        self.0
71    }
72
73    /// Open certificate store of the given type and name
74    pub fn open(store_type: CertStoreType, store_name: &str) -> Result<CertStore> {
75        unsafe {
76            let store_name = utf16z!(store_name);
77            let handle = CertOpenStore(
78                CERT_STORE_PROV_SYSTEM_W,
79                CERT_QUERY_ENCODING_TYPE::default(),
80                HCRYPTPROV_LEGACY::default(),
81                store_type.as_flags() | CERT_STORE_OPEN_EXISTING_FLAG,
82                store_name.as_ptr() as _,
83            );
84            if handle.is_null() {
85                Err(CngError::from_win32_error())
86            } else {
87                Ok(CertStore(handle))
88            }
89        }
90    }
91
92    /// Import certificate store from PKCS12 file
93    pub fn from_pkcs12(data: &[u8], password: &str, flags: Pkcs12Flags) -> Result<CertStore> {
94        unsafe {
95            let blob = CRYPT_INTEGER_BLOB {
96                cbData: data.len() as u32,
97                pbData: data.as_ptr() as _,
98            };
99
100            let password = utf16z!(password);
101            let store =
102                PFXImportCertStore(&blob, password.as_ptr(), CRYPT_EXPORTABLE | flags.bits());
103            if store.is_null() {
104                Err(CngError::from_win32_error())
105            } else {
106                Ok(CertStore(store))
107            }
108        }
109    }
110
111    /// Find list of certificates matching the subject substring
112    pub fn find_by_subject_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
113    where
114        S: AsRef<str>,
115    {
116        self.find_by_str(subject.as_ref(), CERT_FIND_SUBJECT_STR)
117    }
118
119    /// Find list of certificates matching the exact subject name
120    pub fn find_by_subject_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
121    where
122        S: AsRef<str>,
123    {
124        self.find_by_name(subject.as_ref(), CERT_FIND_SUBJECT_NAME)
125    }
126
127    /// Find list of certificates matching the issuer substring
128    pub fn find_by_issuer_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
129    where
130        S: AsRef<str>,
131    {
132        self.find_by_str(subject.as_ref(), CERT_FIND_ISSUER_STR)
133    }
134
135    /// Find list of certificates matching the exact issuer name
136    pub fn find_by_issuer_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
137    where
138        S: AsRef<str>,
139    {
140        self.find_by_name(subject.as_ref(), CERT_FIND_ISSUER_NAME)
141    }
142
143    /// Find list of certificates matching the SHA1 hash
144    pub fn find_by_sha1<D>(&self, hash: D) -> Result<Vec<CertContext>>
145    where
146        D: AsRef<[u8]>,
147    {
148        let hash_blob = CRYPT_INTEGER_BLOB {
149            cbData: hash.as_ref().len() as u32,
150            pbData: hash.as_ref().as_ptr() as _,
151        };
152        unsafe { self.do_find(CERT_FIND_HASH, &hash_blob as *const _ as _) }
153    }
154
155    // On later OS releases, we added CERT_FIND_SHA256_HASH.
156    // However, rustls-cng could be installed on earlier OS release where this FIND_SHA256 isn't present.
157    // But the CERT_SHA256_HASH_PROP_ID is present.
158    // So will need to add a new internal find function that gets and compares the SHA256 property.
159    // Also, since SHA1 is being deprecated, Windows components should not use.
160    // Therefore, the need to find via SHA256 instead of SHA1.
161
162    /// Find list of certificates matching the SHA256 hash
163    pub fn find_by_sha256<D>(&self, hash: D) -> Result<Vec<CertContext>>
164    where
165        D: AsRef<[u8]>,
166    {
167        let hash_blob = CRYPT_INTEGER_BLOB {
168            cbData: hash.as_ref().len() as u32,
169            pbData: hash.as_ref().as_ptr() as _,
170        };
171        unsafe { self.do_find_by_sha256_property(&hash_blob as *const _ as _) }
172    }
173
174    /// Find list of certificates matching the key identifier
175    pub fn find_by_key_id<D>(&self, key_id: D) -> Result<Vec<CertContext>>
176    where
177        D: AsRef<[u8]>,
178    {
179        let cert_id = CERT_ID {
180            dwIdChoice: CERT_ID_KEY_IDENTIFIER,
181            Anonymous: CERT_ID_0 {
182                KeyId: CRYPT_INTEGER_BLOB {
183                    cbData: key_id.as_ref().len() as u32,
184                    pbData: key_id.as_ref().as_ptr() as _,
185                },
186            },
187        };
188        unsafe { self.do_find(CERT_FIND_CERT_ID, &cert_id as *const _ as _) }
189    }
190
191    /// Get all certificates
192    pub fn find_all(&self) -> Result<Vec<CertContext>> {
193        unsafe { self.do_find(CERT_FIND_ANY, ptr::null()) }
194    }
195
196    unsafe fn do_find(
197        &self,
198        flags: CERT_FIND_FLAGS,
199        find_param: *const c_void,
200    ) -> Result<Vec<CertContext>> {
201        let mut certs = Vec::new();
202
203        let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
204
205        loop {
206            cert = CertFindCertificateInStore(self.0, MY_ENCODING_TYPE, 0, flags, find_param, cert);
207            if cert.is_null() {
208                break;
209            } else {
210                // increase refcount because it will be released by next call to CertFindCertificateInStore
211                let cert = CertDuplicateCertificateContext(cert);
212                certs.push(CertContext::new_owned(cert))
213            }
214        }
215        Ok(certs)
216    }
217
218    unsafe fn do_find_by_sha256_property(
219        &self,
220        find_param: *const c_void,
221    ) -> Result<Vec<CertContext>> {
222        let mut certs = Vec::new();
223        let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
224        let hash_blob = &*(find_param as *const CRYPT_INTEGER_BLOB);
225        let sha256_hash = std::slice::from_raw_parts(hash_blob.pbData, hash_blob.cbData as usize);
226        loop {
227            cert = CertFindCertificateInStore(
228                self.0,
229                MY_ENCODING_TYPE,
230                0,
231                CERT_FIND_ANY,
232                find_param,
233                cert,
234            );
235            if cert.is_null() {
236                break;
237            } else {
238                let mut prop_data = [0u8; 32];
239                let mut prop_data_len = prop_data.len() as u32;
240
241                if CertGetCertificateContextProperty(
242                    cert,
243                    CERT_SHA256_HASH_PROP_ID,
244                    prop_data.as_mut_ptr() as *mut c_void,
245                    &mut prop_data_len,
246                ) != 0
247                    && prop_data[..prop_data_len as usize] == sha256_hash[..]
248                {
249                    let cert = CertDuplicateCertificateContext(cert);
250                    certs.push(CertContext::new_owned(cert))
251                }
252            }
253        }
254        Ok(certs)
255    }
256
257    fn find_by_str(&self, pattern: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
258        let u16pattern = utf16z!(pattern);
259        unsafe { self.do_find(flags, u16pattern.as_ptr() as _) }
260    }
261
262    fn find_by_name(&self, field: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
263        let mut name_size = 0;
264
265        unsafe {
266            let field_name = utf16z!(field);
267            if CertStrToNameW(
268                MY_ENCODING_TYPE,
269                field_name.as_ptr(),
270                CERT_X500_NAME_STR,
271                ptr::null(),
272                ptr::null_mut(),
273                &mut name_size,
274                ptr::null_mut(),
275            ) == 0
276            {
277                return Err(CngError::from_win32_error());
278            }
279
280            let mut x509name = vec![0u8; name_size as usize];
281            if CertStrToNameW(
282                MY_ENCODING_TYPE,
283                field_name.as_ptr(),
284                CERT_X500_NAME_STR,
285                ptr::null(),
286                x509name.as_mut_ptr(),
287                &mut name_size,
288                ptr::null_mut(),
289            ) == 0
290            {
291                return Err(CngError::from_win32_error());
292            }
293
294            let name_blob = CRYPT_INTEGER_BLOB {
295                cbData: x509name.len() as _,
296                pbData: x509name.as_mut_ptr(),
297            };
298
299            self.do_find(flags, &name_blob as *const _ as _)
300        }
301    }
302}
303
304impl Drop for CertStore {
305    fn drop(&mut self) {
306        unsafe { CertCloseStore(self.0, 0) };
307    }
308}