use std::{
borrow::Borrow,
ffi::CStr,
io::Write,
panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
};
use crate::{
convert::{binary_bytes, binary_bytes_mut},
error::{self, HasStatus, Result, Status},
CryptBuilder,
};
use mongocrypt_sys as sys;
impl CryptBuilder {
pub fn log_handler<F>(mut self, handler: F) -> Result<Self>
where
F: Fn(LogLevel, &str) + 'static + UnwindSafe,
{
type LogCb = dyn Fn(LogLevel, &str) + UnwindSafe;
extern "C" fn log_shim(
c_level: sys::mongocrypt_log_level_t,
c_message: *const ::std::os::raw::c_char,
_message_len: u32,
ctx: *mut ::std::os::raw::c_void,
) {
let level = LogLevel::from_native(c_level);
let cs_message = unsafe { CStr::from_ptr(c_message) };
let message = cs_message.to_string_lossy();
let handler = unsafe { &*(ctx as *const Box<LogCb>) };
let _ = run_hook(AssertUnwindSafe(|| {
handler(level, &message);
Ok(())
}));
}
let handler: Box<Box<LogCb>> = Box::new(Box::new(handler));
let handler_ptr = &*handler as *const Box<LogCb> as *mut std::ffi::c_void;
unsafe {
if !sys::mongocrypt_setopt_log_handler(
*self.inner.borrow(),
Some(log_shim),
handler_ptr,
) {
return Err(self.status().as_error());
}
}
self.cleanup.push(handler);
Ok(self)
}
pub fn crypto_hooks(
mut self,
aes_256_cbc_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
+ UnwindSafe
+ 'static,
aes_256_cbc_decrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
+ UnwindSafe
+ 'static,
random: impl Fn(&mut dyn Write, u32) -> Result<()> + UnwindSafe + 'static,
hmac_sha_512: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
hmac_sha_256: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
sha_256: impl Fn(&[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
) -> Result<Self> {
let hooks = Box::new(CryptoHooks {
aes_256_cbc_encrypt: Box::new(aes_256_cbc_encrypt),
aes_256_cbc_decrypt: Box::new(aes_256_cbc_decrypt),
random: Box::new(random),
hmac_sha_512: Box::new(hmac_sha_512),
hmac_sha_256: Box::new(hmac_sha_256),
sha_256: Box::new(sha_256),
});
unsafe {
if !sys::mongocrypt_setopt_crypto_hooks(
*self.inner.borrow(),
Some(aes_256_cbc_encrypt_shim),
Some(aes_256_cbc_decrypt_shim),
Some(random_shim),
Some(hmac_sha_512_shim),
Some(hmac_sha_256_shim),
Some(sha_256_shim),
&*hooks as *const CryptoHooks as *mut std::ffi::c_void,
) {
return Err(self.status().as_error());
}
}
self.cleanup.push(hooks);
Ok(self)
}
pub fn aes_256_ctr(
mut self,
aes_256_ctr_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
+ UnwindSafe
+ 'static,
aes_256_ctr_decrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
+ UnwindSafe
+ 'static,
) -> Result<Self> {
struct Hooks {
aes_256_ctr_encrypt: CryptoFn,
aes_256_ctr_decrypt: CryptoFn,
}
let hooks = Box::new(Hooks {
aes_256_ctr_encrypt: Box::new(aes_256_ctr_encrypt),
aes_256_ctr_decrypt: Box::new(aes_256_ctr_decrypt),
});
extern "C" fn aes_256_ctr_encrypt_shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
iv: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
bytes_written: *mut u32,
status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const Hooks) };
crypto_fn_shim(
&hooks.aes_256_ctr_encrypt,
key,
iv,
in_,
out,
bytes_written,
status,
)
}
extern "C" fn aes_256_ctr_decrypt_shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
iv: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
bytes_written: *mut u32,
status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const Hooks) };
crypto_fn_shim(
&hooks.aes_256_ctr_decrypt,
key,
iv,
in_,
out,
bytes_written,
status,
)
}
unsafe {
if !sys::mongocrypt_setopt_aes_256_ctr(
*self.inner.borrow(),
Some(aes_256_ctr_encrypt_shim),
Some(aes_256_ctr_decrypt_shim),
&*hooks as *const Hooks as *mut std::ffi::c_void,
) {
return Err(self.status().as_error());
}
}
self.cleanup.push(hooks);
Ok(self)
}
pub fn aes_256_ecb(
mut self,
aes_256_ecb_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
+ UnwindSafe
+ 'static,
) -> Result<Self> {
let hook: Box<CryptoFn> = Box::new(Box::new(aes_256_ecb_encrypt));
extern "C" fn shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
iv: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
bytes_written: *mut u32,
status: *mut sys::mongocrypt_status_t,
) -> bool {
let hook = unsafe { &*(ctx as *const CryptoFn) };
crypto_fn_shim(hook, key, iv, in_, out, bytes_written, status)
}
unsafe {
if !sys::mongocrypt_setopt_aes_256_ecb(
*self.inner.borrow(),
Some(shim),
&*hook as *const CryptoFn as *mut std::ffi::c_void,
) {
return Err(self.status().as_error());
}
}
self.cleanup.push(hook);
Ok(self)
}
pub fn crypto_hook_sign_rsassa_pkcs1_v1_5(
mut self,
sign_rsaes_pkcs1_v1_5: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()>
+ UnwindSafe
+ 'static,
) -> Result<Self> {
let hook: Box<HmacFn> = Box::new(Box::new(sign_rsaes_pkcs1_v1_5));
extern "C" fn shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
status: *mut sys::mongocrypt_status_t,
) -> bool {
let hook = unsafe { &*(ctx as *const HmacFn) };
hmac_fn_shim(hook, key, in_, out, status)
}
unsafe {
if !sys::mongocrypt_setopt_crypto_hook_sign_rsaes_pkcs1_v1_5(
*self.inner.borrow(),
Some(shim),
&*hook as *const HmacFn as *mut std::ffi::c_void,
) {
return Err(self.status().as_error());
}
}
self.cleanup.push(hook);
Ok(self)
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
#[non_exhaustive]
pub enum LogLevel {
Fatal,
Error,
Warning,
Info,
Trace,
Other(sys::mongocrypt_log_level_t),
}
impl LogLevel {
fn from_native(level: sys::mongocrypt_log_level_t) -> Self {
match level {
sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_FATAL => Self::Fatal,
sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_ERROR => Self::Error,
sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_WARNING => Self::Warning,
sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_INFO => Self::Info,
sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_TRACE => Self::Trace,
_ => LogLevel::Other(level),
}
}
}
fn run_hook(hook: impl FnOnce() -> Result<()> + UnwindSafe) -> Result<()> {
catch_unwind(hook)
.map_err(|_| error::internal!("panic in rust hook"))?
.map_err(Into::into)
}
type CryptoFn = Box<dyn Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
type RandomFn = Box<dyn Fn(&mut dyn Write, u32) -> Result<()> + UnwindSafe>;
type HmacFn = Box<dyn Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
type HashFn = Box<dyn Fn(&[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
struct CryptoHooks {
aes_256_cbc_encrypt: CryptoFn,
random: RandomFn,
hmac_sha_512: HmacFn,
aes_256_cbc_decrypt: CryptoFn,
hmac_sha_256: HmacFn,
sha_256: HashFn,
}
fn crypto_fn_shim(
hook_fn: &CryptoFn,
key: *mut sys::mongocrypt_binary_t,
iv: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
bytes_written: *mut u32,
c_status: *mut sys::mongocrypt_status_t,
) -> bool {
let result = || -> Result<()> {
let key_bytes = unsafe { binary_bytes(key)? };
let iv_bytes = unsafe { binary_bytes(iv)? };
let in_bytes = unsafe { binary_bytes(in_)? };
let mut out_bytes = unsafe { binary_bytes_mut(out)? };
let buffer_len = out_bytes.len();
let out_bytes_writer: &mut dyn Write = &mut out_bytes;
let result = run_hook(AssertUnwindSafe(|| {
hook_fn(key_bytes, iv_bytes, in_bytes, out_bytes_writer)
}));
let written = buffer_len - out_bytes.len();
unsafe {
*bytes_written = written.try_into()?;
}
result
}();
write_status(result, c_status)
}
fn write_status(result: Result<()>, c_status: *mut sys::mongocrypt_status_t) -> bool {
let err = match result {
Ok(()) => return true,
Err(e) => e,
};
let mut status = Status::from_native(c_status);
if let Err(status_err) = status.set(&err) {
eprintln!(
"Failed to record error:\noriginal error = {:?}\nstatus error = {:?}",
err, status_err
);
unsafe {
sys::mongocrypt_status_set(
c_status,
sys::mongocrypt_status_type_t_MONGOCRYPT_STATUS_ERROR_CLIENT,
0,
b"Failed to record error, see logs for details\0".as_ptr()
as *const std::ffi::c_char,
-1,
);
}
}
std::mem::forget(status);
false
}
extern "C" fn aes_256_cbc_encrypt_shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
iv: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
bytes_written: *mut u32,
c_status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const CryptoHooks) };
crypto_fn_shim(
&hooks.aes_256_cbc_encrypt,
key,
iv,
in_,
out,
bytes_written,
c_status,
)
}
extern "C" fn aes_256_cbc_decrypt_shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
iv: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
bytes_written: *mut u32,
c_status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const CryptoHooks) };
crypto_fn_shim(
&hooks.aes_256_cbc_decrypt,
key,
iv,
in_,
out,
bytes_written,
c_status,
)
}
extern "C" fn random_shim(
ctx: *mut ::std::os::raw::c_void,
out: *mut sys::mongocrypt_binary_t,
count: u32,
status: *mut sys::mongocrypt_status_t,
) -> bool {
let result = || -> Result<()> {
let hooks = unsafe { &*(ctx as *const CryptoHooks) };
let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
run_hook(AssertUnwindSafe(|| (hooks.random)(out_writer, count)))
}();
write_status(result, status)
}
fn hmac_fn_shim(
hook_fn: &HmacFn,
key: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
c_status: *mut sys::mongocrypt_status_t,
) -> bool {
let result = || -> Result<()> {
let key_bytes = unsafe { binary_bytes(key)? };
let in_bytes = unsafe { binary_bytes(in_)? };
let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
run_hook(AssertUnwindSafe(|| {
hook_fn(key_bytes, in_bytes, out_writer)
}))
}();
write_status(result, c_status)
}
extern "C" fn hmac_sha_512_shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
c_status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const CryptoHooks) };
hmac_fn_shim(&hooks.hmac_sha_512, key, in_, out, c_status)
}
extern "C" fn hmac_sha_256_shim(
ctx: *mut ::std::os::raw::c_void,
key: *mut sys::mongocrypt_binary_t,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
c_status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const CryptoHooks) };
hmac_fn_shim(&hooks.hmac_sha_256, key, in_, out, c_status)
}
extern "C" fn sha_256_shim(
ctx: *mut ::std::os::raw::c_void,
in_: *mut sys::mongocrypt_binary_t,
out: *mut sys::mongocrypt_binary_t,
status: *mut sys::mongocrypt_status_t,
) -> bool {
let hooks = unsafe { &*(ctx as *const CryptoHooks) };
let result = || -> Result<()> {
let in_bytes = unsafe { binary_bytes(in_)? };
let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
run_hook(AssertUnwindSafe(|| (hooks.sha_256)(in_bytes, out_writer)))
}();
write_status(result, status)
}