1use std::{
4 convert::TryFrom,
5 os::raw::{c_int, c_void},
6 panic::RefUnwindSafe,
7 pin::Pin,
8 ptr, slice,
9 sync::Mutex,
10};
11
12use sys::{signal_buffer, signal_crypto_provider};
13
14use crate::{
15 buffer::Buffer,
16 errors::{InternalError, IntoInternalErrorCode},
17};
18
19#[cfg(feature = "crypto-native")]
20pub use self::native::DefaultCrypto;
21#[cfg(feature = "crypto-openssl")]
22pub use self::openssl::OpenSSLCrypto;
23
24#[cfg(feature = "crypto-native")]
25mod native;
26#[cfg(feature = "crypto-openssl")]
27mod openssl;
28
29#[derive(Debug, Copy, Clone)]
31pub struct SignalCipherTypeError(i32);
32
33#[derive(Debug, Copy, Clone)]
34enum CipherMode {
35 Encrypt,
36 Decrypt,
37}
38
39#[derive(Debug, Copy, Clone)]
41#[allow(missing_docs)]
42pub enum SignalCipherType {
43 AesCtrNoPadding,
44 AesCbcPkcs5,
45}
46
47impl TryFrom<i32> for SignalCipherType {
48 type Error = SignalCipherTypeError;
49
50 fn try_from(v: i32) -> Result<Self, Self::Error> {
51 match v as u32 {
52 sys::SG_CIPHER_AES_CTR_NOPADDING => {
53 Ok(SignalCipherType::AesCtrNoPadding)
54 },
55 sys::SG_CIPHER_AES_CBC_PKCS5 => Ok(SignalCipherType::AesCbcPkcs5),
56 _ => Err(SignalCipherTypeError(v)),
57 }
58 }
59}
60
61pub trait Sha256Hmac {
63 fn update(&mut self, data: &[u8]) -> Result<(), InternalError>;
65 fn finalize(&mut self) -> Result<Vec<u8>, InternalError>;
71}
72
73pub trait Sha512Digest {
75 fn update(&mut self, data: &[u8]) -> Result<(), InternalError>;
77 fn finalize(&mut self) -> Result<Vec<u8>, InternalError>;
83}
84
85pub trait Crypto: RefUnwindSafe {
87 fn fill_random(&self, buffer: &mut [u8]) -> Result<(), InternalError>;
89
90 fn hmac_sha256(
92 &self,
93 key: &[u8],
94 ) -> Result<Box<dyn Sha256Hmac>, InternalError>;
95
96 fn sha512_digest(&self) -> Result<Box<dyn Sha512Digest>, InternalError>;
98
99 fn encrypt(
101 &self,
102 cipher: SignalCipherType,
103 key: &[u8],
104 iv: &[u8],
105 data: &[u8],
106 ) -> Result<Vec<u8>, InternalError>;
107
108 fn decrypt(
110 &self,
111 cipher: SignalCipherType,
112 key: &[u8],
113 iv: &[u8],
114 data: &[u8],
115 ) -> Result<Vec<u8>, InternalError>;
116}
117
118pub(crate) struct CryptoProvider {
121 pub(crate) vtable: signal_crypto_provider,
122 state: Pin<Box<State>>,
123}
124
125impl CryptoProvider {
126 pub fn new<C: Crypto + 'static>(crypto: C) -> CryptoProvider {
127 let mut state: Pin<Box<State>> = Box::pin(State(Box::new(crypto)));
129
130 let vtable = signal_crypto_provider {
131 user_data: state.as_mut().get_mut() as *mut State as *mut c_void,
132 random_func: Some(random_func),
133 hmac_sha256_init_func: Some(hmac_sha256_init_func),
134 hmac_sha256_update_func: Some(hmac_sha256_update_func),
135 hmac_sha256_final_func: Some(hmac_sha256_final_func),
136 hmac_sha256_cleanup_func: Some(hmac_sha256_cleanup_func),
137 sha512_digest_init_func: Some(sha512_digest_init_func),
138 sha512_digest_update_func: Some(sha512_digest_update_func),
139 sha512_digest_final_func: Some(sha512_digest_final_func),
140 sha512_digest_cleanup_func: Some(sha512_digest_cleanup_func),
141 encrypt_func: Some(encrypt_func),
142 decrypt_func: Some(decrypt_func),
143 };
144
145 CryptoProvider { vtable, state }
146 }
147
148 pub fn state(&self) -> &dyn Crypto { &*self.state.0 }
149}
150
151struct State(Box<dyn Crypto>);
152
153struct HmacContext(Mutex<Box<dyn Sha256Hmac>>);
154
155struct DigestContext(Mutex<Box<dyn Sha512Digest>>);
156
157unsafe extern "C" fn random_func(
158 data: *mut u8,
159 len: usize,
160 user_data: *mut c_void,
161) -> c_int {
162 signal_assert!(!data.is_null());
163 signal_assert!(!user_data.is_null());
164
165 let user_data = &*(user_data as *const State);
166
167 let panic_result = std::panic::catch_unwind(|| {
168 let buffer = slice::from_raw_parts_mut(data, len);
169 user_data.0.fill_random(buffer)
170 });
171
172 match panic_result {
173 Ok(Ok(_)) => sys::SG_SUCCESS as c_int,
174 Ok(Err(e)) => {
175 log::error!("Unable to generate random data: {}", e);
176 InternalError::Unknown.code()
177 },
178 Err(e) => {
179 let msg = if let Some(m) = e.downcast_ref::<&str>() {
180 m
181 } else if let Some(m) = e.downcast_ref::<String>() {
182 m.as_str()
183 } else {
184 "Unknown panic"
185 };
186 log::error!("Panic encountered while trying to generate {} random bytes at {}#{}: {}",
187 len, file!(), line!(), msg);
188
189 InternalError::Unknown.code()
190 },
191 }
192}
193
194unsafe extern "C" fn hmac_sha256_cleanup_func(
195 hmac_context: *mut c_void,
196 _user_data: *mut c_void,
197) {
198 if hmac_context.is_null() {
199 return;
200 }
201
202 let hmac_context: Box<HmacContext> =
203 Box::from_raw(hmac_context as *mut HmacContext);
204 drop(hmac_context);
205}
206
207unsafe extern "C" fn hmac_sha256_final_func(
208 hmac_context: *mut c_void,
209 output: *mut *mut signal_buffer,
210 _user_data: *mut c_void,
211) -> i32 {
212 signal_assert!(!output.is_null());
214 signal_assert!(!hmac_context.is_null());
215
216 let hmac_context = &*(hmac_context as *const HmacContext);
217
218 match signal_catch_unwind!(hmac_context.0.lock().unwrap().finalize()) {
219 Ok(hmac) => {
220 let buffer = Buffer::from(hmac);
221 *output = buffer.into_raw();
222 sys::SG_SUCCESS as c_int
223 },
224 Err(e) => e.code(),
225 }
226}
227
228unsafe extern "C" fn hmac_sha256_init_func(
229 hmac_context: *mut *mut c_void,
230 key: *const u8,
231 key_len: usize,
232 user_data: *mut c_void,
233) -> i32 {
234 signal_assert!(!key.is_null());
235 signal_assert!(!user_data.is_null());
236
237 let state = &*(user_data as *const State);
238 let key = slice::from_raw_parts(key, key_len);
239
240 let hasher = match signal_catch_unwind!(state.0.hmac_sha256(key)) {
241 Ok(h) => h,
242 Err(e) => {
243 *hmac_context = ptr::null_mut();
244 return e.code();
245 },
246 };
247
248 *hmac_context =
249 Box::into_raw(Box::new(HmacContext(Mutex::new(hasher)))) as *mut c_void;
250 sys::SG_SUCCESS as c_int
251}
252
253unsafe extern "C" fn hmac_sha256_update_func(
254 hmac_context: *mut c_void,
255 data: *const u8,
256 data_len: usize,
257 _user_data: *mut c_void,
258) -> i32 {
259 signal_assert!(!data.is_null());
260 signal_assert!(!hmac_context.is_null());
261
262 let hmac_context = &*(hmac_context as *const HmacContext);
263
264 let data = slice::from_raw_parts(data, data_len);
265
266 signal_catch_unwind!(hmac_context.0.lock().unwrap().update(data))
267 .into_code()
268}
269
270unsafe extern "C" fn sha512_digest_init_func(
271 digest_context: *mut *mut c_void,
272 user_data: *mut c_void,
273) -> c_int {
274 signal_assert!(!user_data.is_null());
275
276 let user_data = &*(user_data as *const State);
277 let hasher = match signal_catch_unwind!(user_data.0.sha512_digest()) {
278 Ok(h) => h,
279 Err(e) => {
280 *digest_context = ptr::null_mut();
281 return e.code();
282 },
283 };
284
285 let dc = Box::new(DigestContext(Mutex::new(hasher)));
286 *digest_context = Box::into_raw(Box::new(dc)) as *mut c_void;
287
288 sys::SG_SUCCESS as c_int
289}
290
291unsafe extern "C" fn sha512_digest_update_func(
292 digest_context: *mut c_void,
293 data: *const u8,
294 data_len: usize,
295 _user_data: *mut c_void,
296) -> c_int {
297 signal_assert!(!data.is_null());
298 signal_assert!(!digest_context.is_null());
299
300 let hasher = &*(digest_context as *const DigestContext);
301
302 let buffer = slice::from_raw_parts(data, data_len);
303 signal_catch_unwind!(hasher.0.lock().unwrap().update(buffer)).into_code()
304}
305
306unsafe extern "C" fn sha512_digest_final_func(
307 digest_context: *mut c_void,
308 output: *mut *mut signal_buffer,
309 _user_data: *mut c_void,
310) -> c_int {
311 signal_assert!(!output.is_null());
313 signal_assert!(!digest_context.is_null());
314
315 let hasher = &*(digest_context as *const DigestContext);
316
317 match signal_catch_unwind!(hasher.0.lock().unwrap().finalize()) {
318 Ok(buf) => {
319 let buffer = Buffer::from(buf);
320 *output = buffer.into_raw();
321 sys::SG_SUCCESS as c_int
322 },
323 Err(e) => e.code(),
324 }
325}
326
327unsafe extern "C" fn sha512_digest_cleanup_func(
328 digest_context: *mut c_void,
329 _user_data: *mut c_void,
330) {
331 if digest_context.is_null() {
332 return;
333 }
334
335 let digest_context: Box<DigestContext> =
336 Box::from_raw(digest_context as *mut DigestContext);
337 drop(digest_context);
338}
339
340unsafe extern "C" fn encrypt_func(
341 output: *mut *mut signal_buffer,
342 cipher: c_int,
343 key: *const u8,
344 key_len: usize,
345 iv: *const u8,
346 iv_len: usize,
347 plaintext: *const u8,
348 plaintext_len: usize,
349 user_data: *mut c_void,
350) -> c_int {
351 internal_cipher(
352 CipherMode::Encrypt,
353 output,
354 cipher,
355 key,
356 key_len,
357 iv,
358 iv_len,
359 plaintext,
360 plaintext_len,
361 user_data,
362 )
363}
364
365unsafe extern "C" fn decrypt_func(
366 output: *mut *mut signal_buffer,
367 cipher: c_int,
368 key: *const u8,
369 key_len: usize,
370 iv: *const u8,
371 iv_len: usize,
372 ciphertext: *const u8,
373 ciphertext_len: usize,
374 user_data: *mut c_void,
375) -> c_int {
376 internal_cipher(
377 CipherMode::Decrypt,
378 output,
379 cipher,
380 key,
381 key_len,
382 iv,
383 iv_len,
384 ciphertext,
385 ciphertext_len,
386 user_data,
387 )
388}
389
390#[allow(clippy::cognitive_complexity)]
391unsafe extern "C" fn internal_cipher(
392 mode: CipherMode,
393 output: *mut *mut signal_buffer,
394 cipher: c_int,
395 key: *const u8,
396 key_len: usize,
397 iv: *const u8,
398 iv_len: usize,
399 data: *const u8,
400 data_len: usize,
401 user_data: *mut c_void,
402) -> c_int {
403 use self::CipherMode::*;
404 signal_assert!(!output.is_null());
406 signal_assert!(!user_data.is_null());
407 signal_assert!(!key.is_null());
408 signal_assert!(!iv.is_null());
409 signal_assert!(!data.is_null());
410
411 let signal_cipher_type = match SignalCipherType::try_from(cipher) {
412 Ok(ty) => ty,
413 Err(_) => return InternalError::InvalidArgument.code(),
416 };
417 let key = slice::from_raw_parts(key, key_len);
418 let iv = slice::from_raw_parts(iv, iv_len);
419 let data = slice::from_raw_parts(data, data_len);
420
421 let user_data = &*(user_data as *const State);
422
423 let result = match mode {
424 Encrypt => signal_catch_unwind!(user_data.0.encrypt(
425 signal_cipher_type,
426 key,
427 iv,
428 data
429 )),
430 Decrypt => signal_catch_unwind!(user_data.0.decrypt(
431 signal_cipher_type,
432 key,
433 iv,
434 data
435 )),
436 };
437
438 match result {
439 Ok(buf) => {
440 let buffer = Buffer::from(buf);
441 *output = buffer.into_raw();
442 sys::SG_SUCCESS as c_int
443 },
444 Err(e) => e.code(),
445 }
446}
447
448#[cfg(test)]
449mod crypto_tests {
450 #[allow(unused_imports)]
451 use super::*;
452
453 #[cfg(all(feature = "crypto-native", feature = "crypto-openssl"))]
454 #[test]
455 fn test_crypter_cbc() {
456 let native_crypto = DefaultCrypto::default();
458 let openssl_crypto = OpenSSLCrypto::default();
459 let data = [1, 2, 3, 4, 5, 6, 7];
460 let mut key = [0u8; 16];
461 let mut iv = [0u8; 16];
462 native_crypto.fill_random(&mut key).unwrap();
463 native_crypto.fill_random(&mut iv).unwrap();
464
465 let cipher_text_native = native_crypto
466 .encrypt(SignalCipherType::AesCbcPkcs5, &key, &iv, &data)
467 .unwrap();
468
469 let cipher_text_openssl = openssl_crypto
470 .encrypt(SignalCipherType::AesCbcPkcs5, &key, &iv, &data)
471 .unwrap();
472 assert_eq!(cipher_text_native, cipher_text_openssl);
473 let plain_text_native = native_crypto
474 .decrypt(
475 SignalCipherType::AesCbcPkcs5,
476 &key,
477 &iv,
478 &cipher_text_openssl,
479 )
480 .unwrap();
481 let plain_text_openssl = openssl_crypto
482 .decrypt(
483 SignalCipherType::AesCbcPkcs5,
484 &key,
485 &iv,
486 &cipher_text_native,
487 )
488 .unwrap();
489 assert_eq!(plain_text_native, data);
490 assert_eq!(plain_text_openssl, data);
491 }
492
493 #[cfg(all(feature = "crypto-native", feature = "crypto-openssl"))]
494 #[test]
495 fn test_crypter_ctr() {
496 let native_crypto = DefaultCrypto::default();
498 let openssl_crypto = OpenSSLCrypto::default();
499 let data = [1, 2, 3, 4, 5, 6, 7];
500 let mut key = [0u8; 16];
501 let mut iv = [0u8; 16];
502 native_crypto.fill_random(&mut key).unwrap();
503 native_crypto.fill_random(&mut iv).unwrap();
504
505 let cipher_text_native = native_crypto
506 .encrypt(SignalCipherType::AesCtrNoPadding, &key, &iv, &data)
507 .unwrap();
508
509 let cipher_text_openssl = openssl_crypto
510 .encrypt(SignalCipherType::AesCtrNoPadding, &key, &iv, &data)
511 .unwrap();
512 assert_eq!(cipher_text_native, cipher_text_openssl);
513 let plain_text_native = native_crypto
514 .decrypt(
515 SignalCipherType::AesCtrNoPadding,
516 &key,
517 &iv,
518 &cipher_text_openssl,
519 )
520 .unwrap();
521 let plain_text_openssl = openssl_crypto
522 .decrypt(
523 SignalCipherType::AesCtrNoPadding,
524 &key,
525 &iv,
526 &cipher_text_native,
527 )
528 .unwrap();
529 assert_eq!(plain_text_native, data);
530 assert_eq!(plain_text_openssl, data);
531 }
532}