1use std::{os::raw::c_void, ptr};
4
5use bitflags::bitflags;
6use windows_sys::Win32::Security::Cryptography::*;
7
8use crate::{cert::CertContext, error::CngError, Result};
9
10const MY_ENCODING_TYPE: CERT_QUERY_ENCODING_TYPE = PKCS_7_ASN_ENCODING | X509_ASN_ENCODING;
11
12macro_rules! utf16z {
13 ($str: expr) => {
14 $str.encode_utf16().chain([0]).collect::<Vec<_>>()
15 };
16}
17
18#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd)]
20pub enum CertStoreType {
21 LocalMachine,
22 CurrentUser,
23 CurrentService,
24}
25
26bitflags! {
27 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
29 pub struct Pkcs12Flags: u32 {
30 const INCLUDE_EXTENDED_PROPERTIES = 0x0010;
31 const PREFER_CNG_KSP = 0x0000_0100;
32 const ALWAYS_CNG_KSP = 0x0000_0200;
33 const ALLOW_OVERWRITE_KEY = 0x0000_4000;
34 const NO_PERSIST_KEY =0x0000_8000;
35 }
36}
37
38impl Default for Pkcs12Flags {
39 fn default() -> Self {
40 Pkcs12Flags::INCLUDE_EXTENDED_PROPERTIES | Pkcs12Flags::PREFER_CNG_KSP
41 }
42}
43
44impl CertStoreType {
45 fn as_flags(&self) -> u32 {
46 match self {
47 CertStoreType::LocalMachine => {
48 CERT_SYSTEM_STORE_LOCAL_MACHINE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
49 }
50 CertStoreType::CurrentUser => {
51 CERT_SYSTEM_STORE_CURRENT_USER_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
52 }
53 CertStoreType::CurrentService => {
54 CERT_SYSTEM_STORE_CURRENT_SERVICE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
55 }
56 }
57 }
58}
59
60#[derive(Debug)]
62pub struct CertStore(HCERTSTORE);
63
64unsafe impl Send for CertStore {}
65unsafe impl Sync for CertStore {}
66
67impl CertStore {
68 pub fn inner(&self) -> HCERTSTORE {
70 self.0
71 }
72
73 pub fn open(store_type: CertStoreType, store_name: &str) -> Result<CertStore> {
75 unsafe {
76 let store_name = utf16z!(store_name);
77 let handle = CertOpenStore(
78 CERT_STORE_PROV_SYSTEM_W,
79 CERT_QUERY_ENCODING_TYPE::default(),
80 HCRYPTPROV_LEGACY::default(),
81 store_type.as_flags() | CERT_STORE_OPEN_EXISTING_FLAG,
82 store_name.as_ptr() as _,
83 );
84 if handle.is_null() {
85 Err(CngError::from_win32_error())
86 } else {
87 Ok(CertStore(handle))
88 }
89 }
90 }
91
92 pub fn from_pkcs12(data: &[u8], password: &str, flags: Pkcs12Flags) -> Result<CertStore> {
94 unsafe {
95 let blob = CRYPT_INTEGER_BLOB {
96 cbData: data.len() as u32,
97 pbData: data.as_ptr() as _,
98 };
99
100 let password = utf16z!(password);
101 let store =
102 PFXImportCertStore(&blob, password.as_ptr(), CRYPT_EXPORTABLE | flags.bits());
103 if store.is_null() {
104 Err(CngError::from_win32_error())
105 } else {
106 Ok(CertStore(store))
107 }
108 }
109 }
110
111 pub fn find_by_subject_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
113 where
114 S: AsRef<str>,
115 {
116 self.find_by_str(subject.as_ref(), CERT_FIND_SUBJECT_STR)
117 }
118
119 pub fn find_by_subject_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
121 where
122 S: AsRef<str>,
123 {
124 self.find_by_name(subject.as_ref(), CERT_FIND_SUBJECT_NAME)
125 }
126
127 pub fn find_by_issuer_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
129 where
130 S: AsRef<str>,
131 {
132 self.find_by_str(subject.as_ref(), CERT_FIND_ISSUER_STR)
133 }
134
135 pub fn find_by_issuer_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
137 where
138 S: AsRef<str>,
139 {
140 self.find_by_name(subject.as_ref(), CERT_FIND_ISSUER_NAME)
141 }
142
143 pub fn find_by_sha1<D>(&self, hash: D) -> Result<Vec<CertContext>>
145 where
146 D: AsRef<[u8]>,
147 {
148 let hash_blob = CRYPT_INTEGER_BLOB {
149 cbData: hash.as_ref().len() as u32,
150 pbData: hash.as_ref().as_ptr() as _,
151 };
152 unsafe { self.do_find(CERT_FIND_HASH, &hash_blob as *const _ as _) }
153 }
154
155 pub fn find_by_sha256<D>(&self, hash: D) -> Result<Vec<CertContext>>
164 where
165 D: AsRef<[u8]>,
166 {
167 let hash_blob = CRYPT_INTEGER_BLOB {
168 cbData: hash.as_ref().len() as u32,
169 pbData: hash.as_ref().as_ptr() as _,
170 };
171 unsafe { self.do_find_by_sha256_property(&hash_blob as *const _ as _) }
172 }
173
174 pub fn find_by_key_id<D>(&self, key_id: D) -> Result<Vec<CertContext>>
176 where
177 D: AsRef<[u8]>,
178 {
179 let cert_id = CERT_ID {
180 dwIdChoice: CERT_ID_KEY_IDENTIFIER,
181 Anonymous: CERT_ID_0 {
182 KeyId: CRYPT_INTEGER_BLOB {
183 cbData: key_id.as_ref().len() as u32,
184 pbData: key_id.as_ref().as_ptr() as _,
185 },
186 },
187 };
188 unsafe { self.do_find(CERT_FIND_CERT_ID, &cert_id as *const _ as _) }
189 }
190
191 pub fn find_all(&self) -> Result<Vec<CertContext>> {
193 unsafe { self.do_find(CERT_FIND_ANY, ptr::null()) }
194 }
195
196 unsafe fn do_find(
197 &self,
198 flags: CERT_FIND_FLAGS,
199 find_param: *const c_void,
200 ) -> Result<Vec<CertContext>> {
201 let mut certs = Vec::new();
202
203 let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
204
205 loop {
206 cert = CertFindCertificateInStore(self.0, MY_ENCODING_TYPE, 0, flags, find_param, cert);
207 if cert.is_null() {
208 break;
209 } else {
210 let cert = CertDuplicateCertificateContext(cert);
212 certs.push(CertContext::new_owned(cert))
213 }
214 }
215 Ok(certs)
216 }
217
218 unsafe fn do_find_by_sha256_property(
219 &self,
220 find_param: *const c_void,
221 ) -> Result<Vec<CertContext>> {
222 let mut certs = Vec::new();
223 let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
224 let hash_blob = &*(find_param as *const CRYPT_INTEGER_BLOB);
225 let sha256_hash = std::slice::from_raw_parts(hash_blob.pbData, hash_blob.cbData as usize);
226 loop {
227 cert = CertFindCertificateInStore(
228 self.0,
229 MY_ENCODING_TYPE,
230 0,
231 CERT_FIND_ANY,
232 find_param,
233 cert,
234 );
235 if cert.is_null() {
236 break;
237 } else {
238 let mut prop_data = [0u8; 32];
239 let mut prop_data_len = prop_data.len() as u32;
240
241 if CertGetCertificateContextProperty(
242 cert,
243 CERT_SHA256_HASH_PROP_ID,
244 prop_data.as_mut_ptr() as *mut c_void,
245 &mut prop_data_len,
246 ) != 0
247 && prop_data[..prop_data_len as usize] == sha256_hash[..]
248 {
249 let cert = CertDuplicateCertificateContext(cert);
250 certs.push(CertContext::new_owned(cert))
251 }
252 }
253 }
254 Ok(certs)
255 }
256
257 fn find_by_str(&self, pattern: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
258 let u16pattern = utf16z!(pattern);
259 unsafe { self.do_find(flags, u16pattern.as_ptr() as _) }
260 }
261
262 fn find_by_name(&self, field: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
263 let mut name_size = 0;
264
265 unsafe {
266 let field_name = utf16z!(field);
267 if CertStrToNameW(
268 MY_ENCODING_TYPE,
269 field_name.as_ptr(),
270 CERT_X500_NAME_STR,
271 ptr::null(),
272 ptr::null_mut(),
273 &mut name_size,
274 ptr::null_mut(),
275 ) == 0
276 {
277 return Err(CngError::from_win32_error());
278 }
279
280 let mut x509name = vec![0u8; name_size as usize];
281 if CertStrToNameW(
282 MY_ENCODING_TYPE,
283 field_name.as_ptr(),
284 CERT_X500_NAME_STR,
285 ptr::null(),
286 x509name.as_mut_ptr(),
287 &mut name_size,
288 ptr::null_mut(),
289 ) == 0
290 {
291 return Err(CngError::from_win32_error());
292 }
293
294 let name_blob = CRYPT_INTEGER_BLOB {
295 cbData: x509name.len() as _,
296 pbData: x509name.as_mut_ptr(),
297 };
298
299 self.do_find(flags, &name_blob as *const _ as _)
300 }
301 }
302}
303
304impl Drop for CertStore {
305 fn drop(&mut self) {
306 unsafe { CertCloseStore(self.0, 0) };
307 }
308}