1use std::{
2 borrow::Borrow,
3 ffi::CStr,
4 io::Write,
5 panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
6};
7
8use crate::{
9 convert::{binary_bytes, binary_bytes_mut},
10 error::{self, HasStatus, Result, Status},
11 CryptBuilder,
12};
13
14use mongocrypt_sys as sys;
15
16impl CryptBuilder {
17 pub fn log_handler<F>(mut self, handler: F) -> Result<Self>
19 where
20 F: Fn(LogLevel, &str) + 'static + UnwindSafe,
21 {
22 type LogCb = dyn Fn(LogLevel, &str) + UnwindSafe;
23
24 extern "C" fn log_shim(
25 c_level: sys::mongocrypt_log_level_t,
26 c_message: *const ::std::os::raw::c_char,
27 _message_len: u32,
28 ctx: *mut ::std::os::raw::c_void,
29 ) {
30 let level = LogLevel::from_native(c_level);
31 let cs_message = unsafe { CStr::from_ptr(c_message) };
32 let message = cs_message.to_string_lossy();
33 let handler = unsafe { &*(ctx as *const Box<LogCb>) };
35 let _ = run_hook(AssertUnwindSafe(|| {
36 handler(level, &message);
37 Ok(())
38 }));
39 }
40
41 let handler: Box<Box<LogCb>> = Box::new(Box::new(handler));
43 let handler_ptr = &*handler as *const Box<LogCb> as *mut std::ffi::c_void;
44 unsafe {
45 if !sys::mongocrypt_setopt_log_handler(
46 *self.inner.borrow(),
47 Some(log_shim),
48 handler_ptr,
49 ) {
50 return Err(self.status().as_error());
51 }
52 }
53
54 self.cleanup.push(handler);
56 Ok(self)
57 }
58
59 pub fn crypto_hooks(
85 mut self,
86 aes_256_cbc_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
87 + UnwindSafe
88 + 'static,
89 aes_256_cbc_decrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
90 + UnwindSafe
91 + 'static,
92 random: impl Fn(&mut dyn Write, u32) -> Result<()> + UnwindSafe + 'static,
93 hmac_sha_512: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
94 hmac_sha_256: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
95 sha_256: impl Fn(&[u8], &mut dyn Write) -> Result<()> + UnwindSafe + 'static,
96 ) -> Result<Self> {
97 let hooks = Box::new(CryptoHooks {
98 aes_256_cbc_encrypt: Box::new(aes_256_cbc_encrypt),
99 aes_256_cbc_decrypt: Box::new(aes_256_cbc_decrypt),
100 random: Box::new(random),
101 hmac_sha_512: Box::new(hmac_sha_512),
102 hmac_sha_256: Box::new(hmac_sha_256),
103 sha_256: Box::new(sha_256),
104 });
105 unsafe {
106 if !sys::mongocrypt_setopt_crypto_hooks(
107 *self.inner.borrow(),
108 Some(aes_256_cbc_encrypt_shim),
109 Some(aes_256_cbc_decrypt_shim),
110 Some(random_shim),
111 Some(hmac_sha_512_shim),
112 Some(hmac_sha_256_shim),
113 Some(sha_256_shim),
114 &*hooks as *const CryptoHooks as *mut std::ffi::c_void,
115 ) {
116 return Err(self.status().as_error());
117 }
118 }
119 self.cleanup.push(hooks);
120 Ok(self)
121 }
122
123 pub fn aes_256_ctr(
130 mut self,
131 aes_256_ctr_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
132 + UnwindSafe
133 + 'static,
134 aes_256_ctr_decrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
135 + UnwindSafe
136 + 'static,
137 ) -> Result<Self> {
138 struct Hooks {
139 aes_256_ctr_encrypt: CryptoFn,
140 aes_256_ctr_decrypt: CryptoFn,
141 }
142 let hooks = Box::new(Hooks {
143 aes_256_ctr_encrypt: Box::new(aes_256_ctr_encrypt),
144 aes_256_ctr_decrypt: Box::new(aes_256_ctr_decrypt),
145 });
146 extern "C" fn aes_256_ctr_encrypt_shim(
147 ctx: *mut ::std::os::raw::c_void,
148 key: *mut sys::mongocrypt_binary_t,
149 iv: *mut sys::mongocrypt_binary_t,
150 in_: *mut sys::mongocrypt_binary_t,
151 out: *mut sys::mongocrypt_binary_t,
152 bytes_written: *mut u32,
153 status: *mut sys::mongocrypt_status_t,
154 ) -> bool {
155 let hooks = unsafe { &*(ctx as *const Hooks) };
156 crypto_fn_shim(
157 &hooks.aes_256_ctr_encrypt,
158 key,
159 iv,
160 in_,
161 out,
162 bytes_written,
163 status,
164 )
165 }
166 extern "C" fn aes_256_ctr_decrypt_shim(
167 ctx: *mut ::std::os::raw::c_void,
168 key: *mut sys::mongocrypt_binary_t,
169 iv: *mut sys::mongocrypt_binary_t,
170 in_: *mut sys::mongocrypt_binary_t,
171 out: *mut sys::mongocrypt_binary_t,
172 bytes_written: *mut u32,
173 status: *mut sys::mongocrypt_status_t,
174 ) -> bool {
175 let hooks = unsafe { &*(ctx as *const Hooks) };
176 crypto_fn_shim(
177 &hooks.aes_256_ctr_decrypt,
178 key,
179 iv,
180 in_,
181 out,
182 bytes_written,
183 status,
184 )
185 }
186 unsafe {
187 if !sys::mongocrypt_setopt_aes_256_ctr(
188 *self.inner.borrow(),
189 Some(aes_256_ctr_encrypt_shim),
190 Some(aes_256_ctr_decrypt_shim),
191 &*hooks as *const Hooks as *mut std::ffi::c_void,
192 ) {
193 return Err(self.status().as_error());
194 }
195 }
196 self.cleanup.push(hooks);
197 Ok(self)
198 }
199
200 pub fn aes_256_ecb(
206 mut self,
207 aes_256_ecb_encrypt: impl Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()>
208 + UnwindSafe
209 + 'static,
210 ) -> Result<Self> {
211 let hook: Box<CryptoFn> = Box::new(Box::new(aes_256_ecb_encrypt));
212 extern "C" fn shim(
213 ctx: *mut ::std::os::raw::c_void,
214 key: *mut sys::mongocrypt_binary_t,
215 iv: *mut sys::mongocrypt_binary_t,
216 in_: *mut sys::mongocrypt_binary_t,
217 out: *mut sys::mongocrypt_binary_t,
218 bytes_written: *mut u32,
219 status: *mut sys::mongocrypt_status_t,
220 ) -> bool {
221 let hook = unsafe { &*(ctx as *const CryptoFn) };
222 crypto_fn_shim(hook, key, iv, in_, out, bytes_written, status)
223 }
224 unsafe {
225 if !sys::mongocrypt_setopt_aes_256_ecb(
226 *self.inner.borrow(),
227 Some(shim),
228 &*hook as *const CryptoFn as *mut std::ffi::c_void,
229 ) {
230 return Err(self.status().as_error());
231 }
232 }
233 self.cleanup.push(hook);
234 Ok(self)
235 }
236
237 pub fn crypto_hook_sign_rsassa_pkcs1_v1_5(
243 mut self,
244 sign_rsaes_pkcs1_v1_5: impl Fn(&[u8], &[u8], &mut dyn Write) -> Result<()>
245 + UnwindSafe
246 + 'static,
247 ) -> Result<Self> {
248 let hook: Box<HmacFn> = Box::new(Box::new(sign_rsaes_pkcs1_v1_5));
249 extern "C" fn shim(
250 ctx: *mut ::std::os::raw::c_void,
251 key: *mut sys::mongocrypt_binary_t,
252 in_: *mut sys::mongocrypt_binary_t,
253 out: *mut sys::mongocrypt_binary_t,
254 status: *mut sys::mongocrypt_status_t,
255 ) -> bool {
256 let hook = unsafe { &*(ctx as *const HmacFn) };
257 hmac_fn_shim(hook, key, in_, out, status)
258 }
259 unsafe {
260 if !sys::mongocrypt_setopt_crypto_hook_sign_rsaes_pkcs1_v1_5(
261 *self.inner.borrow(),
262 Some(shim),
263 &*hook as *const HmacFn as *mut std::ffi::c_void,
264 ) {
265 return Err(self.status().as_error());
266 }
267 }
268 self.cleanup.push(hook);
269 Ok(self)
270 }
271}
272
273#[derive(PartialEq, Eq, Debug, Clone, Copy)]
274#[non_exhaustive]
275pub enum LogLevel {
276 Fatal,
277 Error,
278 Warning,
279 Info,
280 Trace,
281 Other(sys::mongocrypt_log_level_t),
282}
283
284impl LogLevel {
285 fn from_native(level: sys::mongocrypt_log_level_t) -> Self {
286 match level {
287 sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_FATAL => Self::Fatal,
288 sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_ERROR => Self::Error,
289 sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_WARNING => Self::Warning,
290 sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_INFO => Self::Info,
291 sys::mongocrypt_log_level_t_MONGOCRYPT_LOG_LEVEL_TRACE => Self::Trace,
292 _ => LogLevel::Other(level),
293 }
294 }
295}
296
297fn run_hook(hook: impl FnOnce() -> Result<()> + UnwindSafe) -> Result<()> {
298 catch_unwind(hook)
299 .map_err(|_| error::internal!("panic in rust hook"))?
300 .map_err(Into::into)
301}
302
303type CryptoFn = Box<dyn Fn(&[u8], &[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
304type RandomFn = Box<dyn Fn(&mut dyn Write, u32) -> Result<()> + UnwindSafe>;
305type HmacFn = Box<dyn Fn(&[u8], &[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
306type HashFn = Box<dyn Fn(&[u8], &mut dyn Write) -> Result<()> + UnwindSafe>;
307
308struct CryptoHooks {
309 aes_256_cbc_encrypt: CryptoFn,
310 random: RandomFn,
311 hmac_sha_512: HmacFn,
312 aes_256_cbc_decrypt: CryptoFn,
313 hmac_sha_256: HmacFn,
314 sha_256: HashFn,
315}
316
317fn crypto_fn_shim(
318 hook_fn: &CryptoFn,
319 key: *mut sys::mongocrypt_binary_t,
320 iv: *mut sys::mongocrypt_binary_t,
321 in_: *mut sys::mongocrypt_binary_t,
322 out: *mut sys::mongocrypt_binary_t,
323 bytes_written: *mut u32,
324 c_status: *mut sys::mongocrypt_status_t,
325) -> bool {
326 let result = || -> Result<()> {
328 let key_bytes = unsafe { binary_bytes(key)? };
329 let iv_bytes = unsafe { binary_bytes(iv)? };
330 let in_bytes = unsafe { binary_bytes(in_)? };
331 let mut out_bytes = unsafe { binary_bytes_mut(out)? };
332 let buffer_len = out_bytes.len();
333 let out_bytes_writer: &mut dyn Write = &mut out_bytes;
334 let result = run_hook(AssertUnwindSafe(|| {
335 hook_fn(key_bytes, iv_bytes, in_bytes, out_bytes_writer)
336 }));
337 let written = buffer_len - out_bytes.len();
338 unsafe {
339 *bytes_written = written.try_into()?;
340 }
341 result
342 }();
343 write_status(result, c_status)
344}
345
346fn write_status(result: Result<()>, c_status: *mut sys::mongocrypt_status_t) -> bool {
347 let err = match result {
348 Ok(()) => return true,
349 Err(e) => e,
350 };
351 let mut status = Status::from_native(c_status);
352 if let Err(status_err) = status.set(&err) {
353 eprintln!(
354 "Failed to record error:\noriginal error = {:?}\nstatus error = {:?}",
355 err, status_err
356 );
357 unsafe {
358 sys::mongocrypt_status_set(
360 c_status,
361 sys::mongocrypt_status_type_t_MONGOCRYPT_STATUS_ERROR_CLIENT,
362 0,
363 b"Failed to record error, see logs for details\0".as_ptr()
364 as *const std::ffi::c_char,
365 -1,
366 );
367 }
368 }
369 std::mem::forget(status);
371 false
372}
373
374extern "C" fn aes_256_cbc_encrypt_shim(
375 ctx: *mut ::std::os::raw::c_void,
376 key: *mut sys::mongocrypt_binary_t,
377 iv: *mut sys::mongocrypt_binary_t,
378 in_: *mut sys::mongocrypt_binary_t,
379 out: *mut sys::mongocrypt_binary_t,
380 bytes_written: *mut u32,
381 c_status: *mut sys::mongocrypt_status_t,
382) -> bool {
383 let hooks = unsafe { &*(ctx as *const CryptoHooks) };
384 crypto_fn_shim(
385 &hooks.aes_256_cbc_encrypt,
386 key,
387 iv,
388 in_,
389 out,
390 bytes_written,
391 c_status,
392 )
393}
394
395extern "C" fn aes_256_cbc_decrypt_shim(
396 ctx: *mut ::std::os::raw::c_void,
397 key: *mut sys::mongocrypt_binary_t,
398 iv: *mut sys::mongocrypt_binary_t,
399 in_: *mut sys::mongocrypt_binary_t,
400 out: *mut sys::mongocrypt_binary_t,
401 bytes_written: *mut u32,
402 c_status: *mut sys::mongocrypt_status_t,
403) -> bool {
404 let hooks = unsafe { &*(ctx as *const CryptoHooks) };
405 crypto_fn_shim(
406 &hooks.aes_256_cbc_decrypt,
407 key,
408 iv,
409 in_,
410 out,
411 bytes_written,
412 c_status,
413 )
414}
415
416extern "C" fn random_shim(
417 ctx: *mut ::std::os::raw::c_void,
418 out: *mut sys::mongocrypt_binary_t,
419 count: u32,
420 status: *mut sys::mongocrypt_status_t,
421) -> bool {
422 let result = || -> Result<()> {
423 let hooks = unsafe { &*(ctx as *const CryptoHooks) };
424 let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
425 run_hook(AssertUnwindSafe(|| (hooks.random)(out_writer, count)))
426 }();
427 write_status(result, status)
428}
429
430fn hmac_fn_shim(
431 hook_fn: &HmacFn,
432 key: *mut sys::mongocrypt_binary_t,
433 in_: *mut sys::mongocrypt_binary_t,
434 out: *mut sys::mongocrypt_binary_t,
435 c_status: *mut sys::mongocrypt_status_t,
436) -> bool {
437 let result = || -> Result<()> {
438 let key_bytes = unsafe { binary_bytes(key)? };
439 let in_bytes = unsafe { binary_bytes(in_)? };
440 let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
441 run_hook(AssertUnwindSafe(|| {
442 hook_fn(key_bytes, in_bytes, out_writer)
443 }))
444 }();
445 write_status(result, c_status)
446}
447
448extern "C" fn hmac_sha_512_shim(
449 ctx: *mut ::std::os::raw::c_void,
450 key: *mut sys::mongocrypt_binary_t,
451 in_: *mut sys::mongocrypt_binary_t,
452 out: *mut sys::mongocrypt_binary_t,
453 c_status: *mut sys::mongocrypt_status_t,
454) -> bool {
455 let hooks = unsafe { &*(ctx as *const CryptoHooks) };
456 hmac_fn_shim(&hooks.hmac_sha_512, key, in_, out, c_status)
457}
458
459extern "C" fn hmac_sha_256_shim(
460 ctx: *mut ::std::os::raw::c_void,
461 key: *mut sys::mongocrypt_binary_t,
462 in_: *mut sys::mongocrypt_binary_t,
463 out: *mut sys::mongocrypt_binary_t,
464 c_status: *mut sys::mongocrypt_status_t,
465) -> bool {
466 let hooks = unsafe { &*(ctx as *const CryptoHooks) };
467 hmac_fn_shim(&hooks.hmac_sha_256, key, in_, out, c_status)
468}
469
470extern "C" fn sha_256_shim(
471 ctx: *mut ::std::os::raw::c_void,
472 in_: *mut sys::mongocrypt_binary_t,
473 out: *mut sys::mongocrypt_binary_t,
474 status: *mut sys::mongocrypt_status_t,
475) -> bool {
476 let hooks = unsafe { &*(ctx as *const CryptoHooks) };
477 let result = || -> Result<()> {
478 let in_bytes = unsafe { binary_bytes(in_)? };
479 let out_writer: &mut dyn Write = &mut unsafe { binary_bytes_mut(out)? };
480 run_hook(AssertUnwindSafe(|| (hooks.sha_256)(in_bytes, out_writer)))
481 }();
482 write_status(result, status)
483}