use std::ffi::c_void;
use tracing::{debug, instrument};
use windows::Win32::Security::Cryptography::CryptAcquireCertificatePrivateKey;
use windows::Win32::Security::Cryptography::{
BCRYPT_OAEP_PADDING_INFO, BCRYPT_PAD_OAEP, BCRYPT_PAD_PKCS1, BCRYPT_PKCS1_PADDING_INFO,
CERT_CLOSE_STORE_CHECK_FLAG, CERT_FIND_HASH, CERT_KEY_SPEC, CERT_OPEN_STORE_FLAGS,
CERT_QUERY_ENCODING_TYPE, CERT_STORE_PROV_SYSTEM_W, CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG,
CRYPT_INTEGER_BLOB, CertCloseStore, CertFindCertificateInStore, CertFreeCertificateContext,
CertOpenStore, HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, NCRYPT_FLAGS, NCRYPT_HANDLE, NCRYPT_KEY_HANDLE,
NCRYPT_SILENT_FLAG, NCryptDecrypt, NCryptFreeObject, NCryptSignHash, NCryptVerifySignature,
X509_ASN_ENCODING,
};
use windows::core::{BOOL, PCWSTR};
use crate::encryption::{EncryptionError, KeyStoreProvider};
const PROVIDER_NAME: &str = "MSSQL_CERTIFICATE_STORE";
#[derive(Debug, Clone, Default)]
pub struct WindowsCertStoreProvider {
_private: (),
}
impl WindowsCertStoreProvider {
#[must_use]
pub fn new() -> Self {
Self { _private: () }
}
fn parse_cmk_path(cmk_path: &str) -> Result<(StoreLocation, String, Vec<u8>), EncryptionError> {
let parts: Vec<&str> = cmk_path.split('/').collect();
if parts.len() < 3 {
return Err(EncryptionError::CmkError(format!(
"Invalid CMK path format: expected '<StoreLocation>/<StoreName>/<Thumbprint>', got '{cmk_path}'"
)));
}
let store_location = match parts[0].to_uppercase().as_str() {
"CURRENTUSER" | "CURRENT_USER" => StoreLocation::CurrentUser,
"LOCALMACHINE" | "LOCAL_MACHINE" => StoreLocation::LocalMachine,
_ => {
return Err(EncryptionError::CmkError(format!(
"Unknown store location: '{}'. Expected 'CurrentUser' or 'LocalMachine'",
parts[0]
)));
}
};
let store_name = parts[1].to_string();
let thumbprint_hex = parts[2..].join("");
let thumbprint = hex_to_bytes(&thumbprint_hex)
.map_err(|e| EncryptionError::CmkError(format!("Invalid thumbprint hex: {e}")))?;
Ok((store_location, store_name, thumbprint))
}
fn get_private_key(
store_location: StoreLocation,
store_name: &str,
thumbprint: &[u8],
) -> Result<CngKeyHandle, EncryptionError> {
let store_name_wide: Vec<u16> = store_name
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let store = unsafe {
CertOpenStore(
CERT_STORE_PROV_SYSTEM_W,
CERT_QUERY_ENCODING_TYPE(0),
None,
CERT_OPEN_STORE_FLAGS(store_location.to_flags()),
Some(store_name_wide.as_ptr() as *const c_void),
)
}
.map_err(|e| {
EncryptionError::CmkError(format!(
"Failed to open certificate store '{store_name}': {e}"
))
})?;
let store_guard = CertStoreGuard(store);
let hash_blob = CRYPT_INTEGER_BLOB {
cbData: thumbprint.len() as u32,
pbData: thumbprint.as_ptr() as *mut u8,
};
let cert_context = unsafe {
CertFindCertificateInStore(
store_guard.0,
X509_ASN_ENCODING,
0,
CERT_FIND_HASH,
Some(&hash_blob as *const _ as *const c_void),
None,
)
};
if cert_context.is_null() {
return Err(EncryptionError::CmkError(format!(
"Certificate not found with thumbprint: {}",
bytes_to_hex(thumbprint)
)));
}
let cert_guard = CertContextGuard(cert_context);
let mut key_handle = HCRYPTPROV_OR_NCRYPT_KEY_HANDLE::default();
let mut key_spec = CERT_KEY_SPEC::default();
let mut caller_free = BOOL::default();
let result = unsafe {
CryptAcquireCertificatePrivateKey(
cert_guard.0,
CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG,
None,
&raw mut key_handle,
Some(&raw mut key_spec),
Some(&raw mut caller_free),
)
};
if result.is_err() {
return Err(EncryptionError::CmkError(format!(
"Failed to acquire private key for certificate: {:?}",
result.err()
)));
}
Ok(CngKeyHandle {
handle: NCRYPT_KEY_HANDLE(key_handle.0),
should_free: caller_free.as_bool(),
})
}
}
#[async_trait::async_trait]
impl KeyStoreProvider for WindowsCertStoreProvider {
fn provider_name(&self) -> &str {
PROVIDER_NAME
}
#[instrument(skip(self, encrypted_cek), fields(cmk_path = %cmk_path, algorithm = %algorithm))]
async fn decrypt_cek(
&self,
cmk_path: &str,
algorithm: &str,
encrypted_cek: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
debug!("Decrypting CEK using Windows Certificate Store");
let (store_location, store_name, thumbprint) = Self::parse_cmk_path(cmk_path)?;
let key_handle = Self::get_private_key(store_location, &store_name, &thumbprint)?;
let ciphertext = parse_sql_server_encrypted_cek(encrypted_cek)?;
let (padding_info, flags) = get_padding_info(algorithm)?;
let mut result_size = 0u32;
let decrypt_result = unsafe {
NCryptDecrypt(
key_handle.handle,
Some(ciphertext),
Some(padding_info.as_ptr()),
None,
&mut result_size,
flags,
)
};
if decrypt_result.is_err() {
return Err(EncryptionError::CekDecryptionFailed(format!(
"NCryptDecrypt (size query) failed: {:?}",
decrypt_result.err()
)));
}
let mut output = vec![0u8; result_size as usize];
let decrypt_result = unsafe {
NCryptDecrypt(
key_handle.handle,
Some(ciphertext),
Some(padding_info.as_ptr()),
Some(&mut output),
&mut result_size,
flags,
)
};
if decrypt_result.is_err() {
return Err(EncryptionError::CekDecryptionFailed(format!(
"NCryptDecrypt failed: {:?}",
decrypt_result.err()
)));
}
output.truncate(result_size as usize);
debug!("Successfully decrypted CEK using Windows Certificate Store");
Ok(output)
}
#[instrument(skip(self, data), fields(cmk_path = %cmk_path))]
async fn sign_data(&self, cmk_path: &str, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
debug!("Signing data using Windows Certificate Store");
let (store_location, store_name, thumbprint) = Self::parse_cmk_path(cmk_path)?;
let key_handle = Self::get_private_key(store_location, &store_name, &thumbprint)?;
let hash_algorithm: Vec<u16> = "SHA256\0".encode_utf16().collect();
let padding_info = BCRYPT_PKCS1_PADDING_INFO {
pszAlgId: PCWSTR(hash_algorithm.as_ptr()),
};
let mut sig_size = 0u32;
let sign_result = unsafe {
NCryptSignHash(
key_handle.handle,
Some(&padding_info as *const _ as *const c_void),
data,
None,
&mut sig_size,
NCRYPT_FLAGS(BCRYPT_PAD_PKCS1.0),
)
};
if sign_result.is_err() {
return Err(EncryptionError::CmkError(format!(
"NCryptSignHash (size query) failed: {:?}",
sign_result.err()
)));
}
let mut signature = vec![0u8; sig_size as usize];
let sign_result = unsafe {
NCryptSignHash(
key_handle.handle,
Some(&padding_info as *const _ as *const c_void),
data,
Some(&mut signature),
&mut sig_size,
NCRYPT_FLAGS(BCRYPT_PAD_PKCS1.0),
)
};
if sign_result.is_err() {
return Err(EncryptionError::CmkError(format!(
"NCryptSignHash failed: {:?}",
sign_result.err()
)));
}
signature.truncate(sig_size as usize);
debug!("Successfully signed data using Windows Certificate Store");
Ok(signature)
}
#[instrument(skip(self, data, signature), fields(cmk_path = %cmk_path))]
async fn verify_signature(
&self,
cmk_path: &str,
data: &[u8],
signature: &[u8],
) -> Result<bool, EncryptionError> {
debug!("Verifying signature using Windows Certificate Store");
let (store_location, store_name, thumbprint) = Self::parse_cmk_path(cmk_path)?;
let key_handle = Self::get_private_key(store_location, &store_name, &thumbprint)?;
let hash_algorithm: Vec<u16> = "SHA256\0".encode_utf16().collect();
let padding_info = BCRYPT_PKCS1_PADDING_INFO {
pszAlgId: PCWSTR(hash_algorithm.as_ptr()),
};
let verify_result = unsafe {
NCryptVerifySignature(
key_handle.handle,
Some(&padding_info as *const _ as *const c_void),
data,
signature,
NCRYPT_FLAGS(BCRYPT_PAD_PKCS1.0),
)
};
let is_valid = verify_result.is_ok();
debug!("Signature verification result: {}", is_valid);
Ok(is_valid)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StoreLocation {
CurrentUser,
LocalMachine,
}
impl StoreLocation {
fn to_flags(self) -> u32 {
match self {
StoreLocation::CurrentUser => 0x00010000, StoreLocation::LocalMachine => 0x00020000, }
}
}
struct CertStoreGuard(windows::Win32::Security::Cryptography::HCERTSTORE);
impl Drop for CertStoreGuard {
fn drop(&mut self) {
let _ = unsafe { CertCloseStore(Some(self.0), CERT_CLOSE_STORE_CHECK_FLAG) };
}
}
struct CertContextGuard(*const windows::Win32::Security::Cryptography::CERT_CONTEXT);
impl Drop for CertContextGuard {
fn drop(&mut self) {
if !self.0.is_null() {
let _ = unsafe { CertFreeCertificateContext(Some(self.0)) };
}
}
}
struct CngKeyHandle {
handle: NCRYPT_KEY_HANDLE,
should_free: bool,
}
impl Drop for CngKeyHandle {
fn drop(&mut self) {
if self.should_free && !self.handle.is_invalid() {
let _ = unsafe { NCryptFreeObject(NCRYPT_HANDLE(self.handle.0)) };
}
}
}
enum PaddingInfo {
Oaep(BCRYPT_OAEP_PADDING_INFO),
#[allow(dead_code)]
Pkcs1(BCRYPT_PKCS1_PADDING_INFO),
}
impl PaddingInfo {
fn as_ptr(&self) -> *const c_void {
match self {
PaddingInfo::Oaep(info) => info as *const _ as *const c_void,
PaddingInfo::Pkcs1(info) => info as *const _ as *const c_void,
}
}
}
fn get_padding_info(algorithm: &str) -> Result<(PaddingInfo, NCRYPT_FLAGS), EncryptionError> {
static SHA256_ALG: &str = "SHA256\0";
match algorithm.to_uppercase().as_str() {
"RSA_OAEP" | "RSA-OAEP" | "RSA_OAEP_256" | "RSA-OAEP-256" => {
let hash_alg: Vec<u16> = SHA256_ALG.encode_utf16().collect();
let hash_alg_ptr = Box::leak(hash_alg.into_boxed_slice());
let info = BCRYPT_OAEP_PADDING_INFO {
pszAlgId: PCWSTR(hash_alg_ptr.as_ptr()),
pbLabel: std::ptr::null_mut(),
cbLabel: 0,
};
Ok((
PaddingInfo::Oaep(info),
NCRYPT_FLAGS(BCRYPT_PAD_OAEP.0 | NCRYPT_SILENT_FLAG.0),
))
}
"RSA1_5" | "RSA-1_5" | "RSA_PKCS1" | "RSA-PKCS1" => {
let hash_alg: Vec<u16> = SHA256_ALG.encode_utf16().collect();
let hash_alg_ptr = Box::leak(hash_alg.into_boxed_slice());
let info = BCRYPT_PKCS1_PADDING_INFO {
pszAlgId: PCWSTR(hash_alg_ptr.as_ptr()),
};
Ok((
PaddingInfo::Pkcs1(info),
NCRYPT_FLAGS(BCRYPT_PAD_PKCS1.0 | NCRYPT_SILENT_FLAG.0),
))
}
_ => Err(EncryptionError::ConfigurationError(format!(
"Unsupported key encryption algorithm: {algorithm}. Expected RSA_OAEP, RSA_OAEP_256, or RSA1_5"
))),
}
}
fn parse_sql_server_encrypted_cek(data: &[u8]) -> Result<&[u8], EncryptionError> {
if data.len() < 5 {
return Err(EncryptionError::CekDecryptionFailed(
"Encrypted CEK too short".into(),
));
}
if data[0] != 0x01 {
return Err(EncryptionError::CekDecryptionFailed(format!(
"Invalid CEK version: expected 0x01, got {:#04x}",
data[0]
)));
}
let key_path_len = u16::from_le_bytes([data[1], data[2]]) as usize;
let ciphertext_len_offset = 3 + key_path_len;
if data.len() < ciphertext_len_offset + 2 {
return Err(EncryptionError::CekDecryptionFailed(
"Encrypted CEK truncated: missing ciphertext length".into(),
));
}
let ciphertext_len =
u16::from_le_bytes([data[ciphertext_len_offset], data[ciphertext_len_offset + 1]]) as usize;
let ciphertext_offset = ciphertext_len_offset + 2;
if data.len() < ciphertext_offset + ciphertext_len {
return Err(EncryptionError::CekDecryptionFailed(format!(
"Encrypted CEK truncated: expected {} bytes of ciphertext, got {}",
ciphertext_len,
data.len() - ciphertext_offset
)));
}
Ok(&data[ciphertext_offset..ciphertext_offset + ciphertext_len])
}
fn hex_to_bytes(hex: &str) -> Result<Vec<u8>, &'static str> {
let hex = hex.trim();
if hex.len() % 2 != 0 {
return Err("Hex string has odd length");
}
hex.as_bytes()
.chunks(2)
.map(|chunk| {
let high = char::from(chunk[0])
.to_digit(16)
.ok_or("Invalid hex digit")?;
let low = char::from(chunk[1])
.to_digit(16)
.ok_or("Invalid hex digit")?;
Ok((high * 16 + low) as u8)
})
.collect()
}
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02X}")).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cmk_path() {
let (location, name, thumb) =
WindowsCertStoreProvider::parse_cmk_path("CurrentUser/My/AABBCCDD").unwrap();
assert_eq!(location, StoreLocation::CurrentUser);
assert_eq!(name, "My");
assert_eq!(thumb, vec![0xAA, 0xBB, 0xCC, 0xDD]);
let (location, name, _) =
WindowsCertStoreProvider::parse_cmk_path("localmachine/My/1234").unwrap();
assert_eq!(location, StoreLocation::LocalMachine);
assert_eq!(name, "My");
let (location, _, _) =
WindowsCertStoreProvider::parse_cmk_path("Current_User/My/1234").unwrap();
assert_eq!(location, StoreLocation::CurrentUser);
}
#[test]
fn test_parse_cmk_path_invalid() {
assert!(WindowsCertStoreProvider::parse_cmk_path("CurrentUser/My").is_err());
assert!(WindowsCertStoreProvider::parse_cmk_path("Invalid/My/1234").is_err());
assert!(WindowsCertStoreProvider::parse_cmk_path("CurrentUser/My/GGGG").is_err());
}
#[test]
fn test_hex_conversion() {
assert_eq!(
hex_to_bytes("AABBCCDD").unwrap(),
vec![0xAA, 0xBB, 0xCC, 0xDD]
);
assert_eq!(
hex_to_bytes("aabbccdd").unwrap(),
vec![0xAA, 0xBB, 0xCC, 0xDD]
);
assert_eq!(hex_to_bytes("").unwrap(), vec![]);
assert!(hex_to_bytes("ABC").is_err()); assert!(hex_to_bytes("GGGG").is_err()); }
#[test]
fn test_bytes_to_hex() {
assert_eq!(bytes_to_hex(&[0xAA, 0xBB, 0xCC, 0xDD]), "AABBCCDD");
assert_eq!(bytes_to_hex(&[0x01, 0x02, 0x0F]), "01020F");
assert_eq!(bytes_to_hex(&[]), "");
}
#[test]
fn test_parse_sql_server_encrypted_cek() {
let key_path = "test";
let key_path_utf16: Vec<u8> = key_path
.encode_utf16()
.flat_map(|c| c.to_le_bytes())
.collect();
let ciphertext = vec![0xAB, 0xCD, 0xEF];
let mut data = Vec::new();
data.push(0x01); data.extend_from_slice(&(key_path_utf16.len() as u16).to_le_bytes());
data.extend_from_slice(&key_path_utf16);
data.extend_from_slice(&(ciphertext.len() as u16).to_le_bytes());
data.extend_from_slice(&ciphertext);
let parsed = parse_sql_server_encrypted_cek(&data).unwrap();
assert_eq!(parsed, &ciphertext[..]);
}
#[test]
fn test_parse_sql_server_encrypted_cek_invalid() {
assert!(parse_sql_server_encrypted_cek(&[0x01, 0x00]).is_err());
assert!(parse_sql_server_encrypted_cek(&[0x02, 0x00, 0x00, 0x00, 0x00]).is_err());
}
#[test]
fn test_store_location_flags() {
assert_eq!(StoreLocation::CurrentUser.to_flags(), 0x00010000);
assert_eq!(StoreLocation::LocalMachine.to_flags(), 0x00020000);
}
}