mongocrypt/
hooks.rs

1use std::{
2    borrow::Borrow,
3    ffi::CStr,
4    io::Write,
5    panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
6};
7
8use crate::{
9    convert::{binary_bytes, binary_bytes_mut},
10    error::{self, HasStatus, Result, Status},
11    CryptBuilder,
12};
13
14use mongocrypt_sys as sys;
15
16impl CryptBuilder {
17    /// Set a handler to get called on every log message.
18    pub fn log_handler<F>(mut self, handler: F) -> Result<Self>
19    where
20        F: Fn(LogLevel, &str) + 'static + UnwindSafe,
21    {
22        type LogCb = dyn Fn(LogLevel, &str) + UnwindSafe;
23
24        extern "C" fn log_shim(
25            c_level: sys::mongocrypt_log_level_t,
26            c_message: *const ::std::os::raw::c_char,
27            _message_len: u32,
28            ctx: *mut ::std::os::raw::c_void,
29        ) {
30            let level = LogLevel::from_native(c_level);
31            let cs_message = unsafe { CStr::from_ptr(c_message) };
32            let message = cs_message.to_string_lossy();
33            // Safety: this pointer originates below with the same type and with a lifetime of that of the containing `MongoCrypt`.
34            let handler = unsafe { &*(ctx as *const Box<LogCb>) };
35            let _ = run_hook(AssertUnwindSafe(|| {
36                handler(level, &message);
37                Ok(())
38            }));
39        }
40
41        // Double-boxing is required because the inner `Box<dyn ..>` is represented as a fat pointer; the outer one is a thin pointer convertible to *c_void.
42        let handler: Box<Box<LogCb>> = Box::new(Box::new(handler));
43        let handler_ptr = &*handler as *const Box<LogCb> as *mut std::ffi::c_void;
44        unsafe {
45            if !sys::mongocrypt_setopt_log_handler(
46                *self.inner.borrow(),
47                Some(log_shim),
48                handler_ptr,
49            ) {
50                return Err(self.status().as_error());
51            }
52        }
53
54        // Now that the handler's successfully set, store it so it gets cleaned up on drop.
55        self.cleanup.push(handler);
56        Ok(self)
57    }
58
59    /// Set crypto hooks.
60    ///
61    /// * `aes_256_cbc_encrypt` - A `crypto fn`.
62    /// * `aes_256_cbc_decrypt` - A `crypto fn`.
63    /// * `random` - A `random fn`.
64    /// * `hmac_sha_512` - A `hmac fn`.
65    /// * `hmac_sha_256` - A `hmac fn`.
66    /// * `sha_256` - A `hash fn`.
67    ///
68    /// The `Fn` bounds used here fall into four distinct kinds, some of which are reused elswhere:
69    /// * `crypto fn` - A crypto AES-256-CBC encrypt or decrypt function.
70    ///   - `key` - An encryption key (32 bytes for AES_256).
71    ///   - `iv` - An initialization vector (16 bytes for AES_256).
72    ///   - `in` - The input.  Note, this is already padded.  Encrypt with padding disabled.
73    ///   - `out` - The output.
74    /// * `hmac fn` - A crypto signature or HMAC function.
75    ///   - `key` - An encryption key (32 bytes for HMAC_SHA512).
76    ///   - `in` - The input.
77    ///   - `out` - The output.
78    /// * `hash fn` - A crypto hash (SHA-256) function.
79    ///   - `in` - The input.
80    ///   - `out` - The output.
81    /// * `random fn` - A crypto secure random function.
82    ///   - `out` - The output.
83    ///   - `count` - The number of random bytes requested.
84    pub fn crypto_hooks(
85        mut self,
86        aes_256_cbc_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
87            + UnwindSafe
88            + 'static,
89        aes_256_cbc_decrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
90            + UnwindSafe
91            + 'static,
92        random: impl Fn(&mut dyn Write, u32) -> Result<()> + UnwindSafe + 'static,
93        hmac_sha_512: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
94        hmac_sha_256: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
95        sha_256: impl Fn(&[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
96    ) -> Result<Self> {
97        let hooks = Box::new(CryptoHooks {
98            aes_256_cbc_encrypt: Box::new(aes_256_cbc_encrypt),
99            aes_256_cbc_decrypt: Box::new(aes_256_cbc_decrypt),
100            random: Box::new(random),
101            hmac_sha_512: Box::new(hmac_sha_512),
102            hmac_sha_256: Box::new(hmac_sha_256),
103            sha_256: Box::new(sha_256),
104        });
105        unsafe {
106            if !sys::mongocrypt_setopt_crypto_hooks(
107                *self.inner.borrow(),
108                Some(aes_256_cbc_encrypt_shim),
109                Some(aes_256_cbc_decrypt_shim),
110                Some(random_shim),
111                Some(hmac_sha_512_shim),
112                Some(hmac_sha_256_shim),
113                Some(sha_256_shim),
114                &*hooks as *const CryptoHooks as *mut std::ffi::c_void,
115            ) {
116                return Err(self.status().as_error());
117            }
118        }
119        self.cleanup.push(hooks);
120        Ok(self)
121    }
122
123    /// Set a crypto hook for the AES256-CTR operations.
124    ///
125    /// * `aes_256_ctr_encrypt` - A `crypto fn`.  The crypto callback function for encrypt
126    /// operation.
127    /// * `aes_256_ctr_decrypt` - A `crypto fn`.  The crypto callback function for decrypt
128    /// operation.
129    pub fn aes_256_ctr(
130        mut self,
131        aes_256_ctr_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
132            + UnwindSafe
133            + 'static,
134        aes_256_ctr_decrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
135            + UnwindSafe
136            + 'static,
137    ) -> Result<Self> {
138        struct Hooks {
139            aes_256_ctr_encrypt: CryptoFn,
140            aes_256_ctr_decrypt: CryptoFn,
141        }
142        let hooks = Box::new(Hooks {
143            aes_256_ctr_encrypt: Box::new(aes_256_ctr_encrypt),
144            aes_256_ctr_decrypt: Box::new(aes_256_ctr_decrypt),
145        });
146        extern "C" fn aes_256_ctr_encrypt_shim(
147            ctx: *mut ::std::os::raw::c_void,
148            key: *mut sys::mongocrypt_binary_t,
149            iv: *mut sys::mongocrypt_binary_t,
150            in_: *mut sys::mongocrypt_binary_t,
151            out: *mut sys::mongocrypt_binary_t,
152            bytes_written: *mut u32,
153            status: *mut sys::mongocrypt_status_t,
154        ) -> bool {
155            let hooks = unsafe { &*(ctx as *const Hooks) };
156            crypto_fn_shim(
157                &hooks.aes_256_ctr_encrypt,
158                key,
159                iv,
160                in_,
161                out,
162                bytes_written,
163                status,
164            )
165        }
166        extern "C" fn aes_256_ctr_decrypt_shim(
167            ctx: *mut ::std::os::raw::c_void,
168            key: *mut sys::mongocrypt_binary_t,
169            iv: *mut sys::mongocrypt_binary_t,
170            in_: *mut sys::mongocrypt_binary_t,
171            out: *mut sys::mongocrypt_binary_t,
172            bytes_written: *mut u32,
173            status: *mut sys::mongocrypt_status_t,
174        ) -> bool {
175            let hooks = unsafe { &*(ctx as *const Hooks) };
176            crypto_fn_shim(
177                &hooks.aes_256_ctr_decrypt,
178                key,
179                iv,
180                in_,
181                out,
182                bytes_written,
183                status,
184            )
185        }
186        unsafe {
187            if !sys::mongocrypt_setopt_aes_256_ctr(
188                *self.inner.borrow(),
189                Some(aes_256_ctr_encrypt_shim),
190                Some(aes_256_ctr_decrypt_shim),
191                &*hooks as *const Hooks as *mut std::ffi::c_void,
192            ) {
193                return Err(self.status().as_error());
194            }
195        }
196        self.cleanup.push(hooks);
197        Ok(self)
198    }
199
200    /// Set an AES256-ECB crypto hook for the AES256-CTR operations. If CTR hook was
201    /// configured using `aes_256_ctr`, ECB hook will be ignored.
202    ///
203    /// * `aes_256_ecb_encrypt` - A `crypto fn`.  The crypto callback function for encrypt
204    /// operation.
205    pub fn aes_256_ecb(
206        mut self,
207        aes_256_ecb_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
208            + UnwindSafe
209            + 'static,
210    ) -> Result<Self> {
211        let hook: Box<CryptoFn> = Box::new(Box::new(aes_256_ecb_encrypt));
212        extern "C" fn shim(
213            ctx: *mut ::std::os::raw::c_void,
214            key: *mut sys::mongocrypt_binary_t,
215            iv: *mut sys::mongocrypt_binary_t,
216            in_: *mut sys::mongocrypt_binary_t,
217            out: *mut sys::mongocrypt_binary_t,
218            bytes_written: *mut u32,
219            status: *mut sys::mongocrypt_status_t,
220        ) -> bool {
221            let hook = unsafe { &*(ctx as *const CryptoFn) };
222            crypto_fn_shim(hook, key, iv, in_, out, bytes_written, status)
223        }
224        unsafe {
225            if !sys::mongocrypt_setopt_aes_256_ecb(
226                *self.inner.borrow(),
227                Some(shim),
228                &*hook as *const CryptoFn as *mut std::ffi::c_void,
229            ) {
230                return Err(self.status().as_error());
231            }
232        }
233        self.cleanup.push(hook);
234        Ok(self)
235    }
236
237    /// Set a crypto hook for the RSASSA-PKCS1-v1_5 algorithm with a SHA-256 hash.
238    ///
239    /// See: https://tools.ietf.org/html/rfc3447#section-8.2
240    ///
241    /// * `sign_rsaes_pkcs1_v1_5` - A `hmac fn`.  The crypto callback function.
242    pub fn crypto_hook_sign_rsassa_pkcs1_v1_5(
243        mut self,
244        sign_rsaes_pkcs1_v1_5: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()>
245            + UnwindSafe
246            + 'static,
247    ) -> Result<Self> {
248        let hook: Box<HmacFn> = Box::new(Box::new(sign_rsaes_pkcs1_v1_5));
249        extern "C" fn shim(
250            ctx: *mut ::std::os::raw::c_void,
251            key: *mut sys::mongocrypt_binary_t,
252            in_: *mut sys::mongocrypt_binary_t,
253            out: *mut sys::mongocrypt_binary_t,
254            status: *mut sys::mongocrypt_status_t,
255        ) -> bool {
256            let hook = unsafe { &*(ctx as *const HmacFn) };
257            hmac_fn_shim(hook, key, in_, out, status)
258        }
259        unsafe {
260            if !sys::mongocrypt_setopt_crypto_hook_sign_rsaes_pkcs1_v1_5(
261                *self.inner.borrow(),
262                Some(shim),
263                &*hook as *const HmacFn as *mut std::ffi::c_void,
264            ) {
265                return Err(self.status().as_error());
266            }
267        }
268        self.cleanup.push(hook);
269        Ok(self)
270    }
271}
272
273#[derive(PartialEq, Eq, Debug, Clone, Copy)]
274#[non_exhaustive]
275pub enum LogLevel {
276    Fatal,
277    Error,
278    Warning,
279    Info,
280    Trace,
281    Other(sys::mongocrypt_log_level_t),
282}
283
284impl LogLevel {
285    fn from_native(level: sys::mongocrypt_log_level_t) -> Self {
286        match level {
287            sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_FATAL => Self::Fatal,
288            sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_ERROR => Self::Error,
289            sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_WARNING => Self::Warning,
290            sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_INFO => Self::Info,
291            sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_TRACE => Self::Trace,
292            _ => LogLevel::Other(level),
293        }
294    }
295}
296
297fn run_hook(hook: impl FnOnce() -> Result<()> + UnwindSafe) -> Result<()> {
298    catch_unwind(hook)
299        .map_err(|_| error::internal!("panic in rust hook"))?
300        .map_err(Into::into)
301}
302
303type CryptoFn = Box<dyn Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
304type RandomFn = Box<dyn Fn(&mut dyn Write, u32) -> Result<()> + UnwindSafe>;
305type HmacFn = Box<dyn Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
306type HashFn = Box<dyn Fn(&[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
307
308struct CryptoHooks {
309    aes_256_cbc_encrypt: CryptoFn,
310    random: RandomFn,
311    hmac_sha_512: HmacFn,
312    aes_256_cbc_decrypt: CryptoFn,
313    hmac_sha_256: HmacFn,
314    sha_256: HashFn,
315}
316
317fn crypto_fn_shim(
318    hook_fn: &CryptoFn,
319    key: *mut sys::mongocrypt_binary_t,
320    iv: *mut sys::mongocrypt_binary_t,
321    in_: *mut sys::mongocrypt_binary_t,
322    out: *mut sys::mongocrypt_binary_t,
323    bytes_written: *mut u32,
324    c_status: *mut sys::mongocrypt_status_t,
325) -> bool {
326    // Convenience scope for intermediate error propagation via `?`.
327    let result = || -> Result<()> {
328        let key_bytes = unsafe { binary_bytes(key)? };
329        let iv_bytes = unsafe { binary_bytes(iv)? };
330        let in_bytes = unsafe { binary_bytes(in_)? };
331        let mut out_bytes = unsafe { binary_bytes_mut(out)? };
332        let buffer_len = out_bytes.len();
333        let out_bytes_writer: &mut dyn Write = &mut out_bytes;
334        let result = run_hook(AssertUnwindSafe(|| {
335            hook_fn(key_bytes, iv_bytes, in_bytes, out_bytes_writer)
336        }));
337        let written = buffer_len - out_bytes.len();
338        unsafe {
339            *bytes_written = written.try_into()?;
340        }
341        result
342    }();
343    write_status(result, c_status)
344}
345
346fn write_status(result: Result<()>, c_status: *mut sys::mongocrypt_status_t) -> bool {
347    let err = match result {
348        Ok(()) => return true,
349        Err(e) => e,
350    };
351    let mut status = Status::from_native(c_status);
352    if let Err(status_err) = status.set(&err) {
353        eprintln!(
354            "Failed to record error:\noriginal error = {:?}\nstatus error = {:?}",
355            err, status_err
356        );
357        unsafe {
358            // Set a hardcoded status that can't fail.
359            sys::mongocrypt_status_set(
360                c_status,
361                sys::mongocrypt_status_type_t_MONGOCRYPT_STATUS_ERROR_CLIENT,
362                0,
363                b"Failed to record error, see logs for details\0".as_ptr()
364                    as *const std::ffi::c_char,
365                -1,
366            );
367        }
368    }
369    // The status is owned by the caller, so don't run cleanup.
370    std::mem::forget(status);
371    false
372}
373
374extern "C" fn aes_256_cbc_encrypt_shim(
375    ctx: *mut ::std::os::raw::c_void,
376    key: *mut sys::mongocrypt_binary_t,
377    iv: *mut sys::mongocrypt_binary_t,
378    in_: *mut sys::mongocrypt_binary_t,
379    out: *mut sys::mongocrypt_binary_t,
380    bytes_written: *mut u32,
381    c_status: *mut sys::mongocrypt_status_t,
382) -> bool {
383    let hooks = unsafe { &*(ctx as *const CryptoHooks) };
384    crypto_fn_shim(
385        &hooks.aes_256_cbc_encrypt,
386        key,
387        iv,
388        in_,
389        out,
390        bytes_written,
391        c_status,
392    )
393}
394
395extern "C" fn aes_256_cbc_decrypt_shim(
396    ctx: *mut ::std::os::raw::c_void,
397    key: *mut sys::mongocrypt_binary_t,
398    iv: *mut sys::mongocrypt_binary_t,
399    in_: *mut sys::mongocrypt_binary_t,
400    out: *mut sys::mongocrypt_binary_t,
401    bytes_written: *mut u32,
402    c_status: *mut sys::mongocrypt_status_t,
403) -> bool {
404    let hooks = unsafe { &*(ctx as *const CryptoHooks) };
405    crypto_fn_shim(
406        &hooks.aes_256_cbc_decrypt,
407        key,
408        iv,
409        in_,
410        out,
411        bytes_written,
412        c_status,
413    )
414}
415
416extern "C" fn random_shim(
417    ctx: *mut ::std::os::raw::c_void,
418    out: *mut sys::mongocrypt_binary_t,
419    count: u32,
420    status: *mut sys::mongocrypt_status_t,
421) -> bool {
422    let result = || -> Result<()> {
423        let hooks = unsafe { &*(ctx as *const CryptoHooks) };
424        let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
425        run_hook(AssertUnwindSafe(|| (hooks.random)(out_writer, count)))
426    }();
427    write_status(result, status)
428}
429
430fn hmac_fn_shim(
431    hook_fn: &HmacFn,
432    key: *mut sys::mongocrypt_binary_t,
433    in_: *mut sys::mongocrypt_binary_t,
434    out: *mut sys::mongocrypt_binary_t,
435    c_status: *mut sys::mongocrypt_status_t,
436) -> bool {
437    let result = || -> Result<()> {
438        let key_bytes = unsafe { binary_bytes(key)? };
439        let in_bytes = unsafe { binary_bytes(in_)? };
440        let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
441        run_hook(AssertUnwindSafe(|| {
442            hook_fn(key_bytes, in_bytes, out_writer)
443        }))
444    }();
445    write_status(result, c_status)
446}
447
448extern "C" fn hmac_sha_512_shim(
449    ctx: *mut ::std::os::raw::c_void,
450    key: *mut sys::mongocrypt_binary_t,
451    in_: *mut sys::mongocrypt_binary_t,
452    out: *mut sys::mongocrypt_binary_t,
453    c_status: *mut sys::mongocrypt_status_t,
454) -> bool {
455    let hooks = unsafe { &*(ctx as *const CryptoHooks) };
456    hmac_fn_shim(&hooks.hmac_sha_512, key, in_, out, c_status)
457}
458
459extern "C" fn hmac_sha_256_shim(
460    ctx: *mut ::std::os::raw::c_void,
461    key: *mut sys::mongocrypt_binary_t,
462    in_: *mut sys::mongocrypt_binary_t,
463    out: *mut sys::mongocrypt_binary_t,
464    c_status: *mut sys::mongocrypt_status_t,
465) -> bool {
466    let hooks = unsafe { &*(ctx as *const CryptoHooks) };
467    hmac_fn_shim(&hooks.hmac_sha_256, key, in_, out, c_status)
468}
469
470extern "C" fn sha_256_shim(
471    ctx: *mut ::std::os::raw::c_void,
472    in_: *mut sys::mongocrypt_binary_t,
473    out: *mut sys::mongocrypt_binary_t,
474    status: *mut sys::mongocrypt_status_t,
475) -> bool {
476    let hooks = unsafe { &*(ctx as *const CryptoHooks) };
477    let result = || -> Result<()> {
478        let in_bytes = unsafe { binary_bytes(in_)? };
479        let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
480        run_hook(AssertUnwindSafe(|| (hooks.sha_256)(in_bytes, out_writer)))
481    }();
482    write_status(result, status)
483}