libsignal_protocol/crypto/
mod.rs

1//! Underlying cryptographic routines.
2
3use std::{
4    convert::TryFrom,
5    os::raw::{c_int, c_void},
6    panic::RefUnwindSafe,
7    pin::Pin,
8    ptr, slice,
9    sync::Mutex,
10};
11
12use sys::{signal_buffer, signal_crypto_provider};
13
14use crate::{
15    buffer::Buffer,
16    errors::{InternalError, IntoInternalErrorCode},
17};
18
19#[cfg(feature = "crypto-native")]
20pub use self::native::DefaultCrypto;
21#[cfg(feature = "crypto-openssl")]
22pub use self::openssl::OpenSSLCrypto;
23
24#[cfg(feature = "crypto-native")]
25mod native;
26#[cfg(feature = "crypto-openssl")]
27mod openssl;
28
29/// The error returned from a failed conversion to [`SignalCipherType`].
30#[derive(Debug, Copy, Clone)]
31pub struct SignalCipherTypeError(i32);
32
33#[derive(Debug, Copy, Clone)]
34enum CipherMode {
35    Encrypt,
36    Decrypt,
37}
38
39/// The type of AES cipher.
40#[derive(Debug, Copy, Clone)]
41#[allow(missing_docs)]
42pub enum SignalCipherType {
43    AesCtrNoPadding,
44    AesCbcPkcs5,
45}
46
47impl TryFrom<i32> for SignalCipherType {
48    type Error = SignalCipherTypeError;
49
50    fn try_from(v: i32) -> Result<Self, Self::Error> {
51        match v as u32 {
52            sys::SG_CIPHER_AES_CTR_NOPADDING => {
53                Ok(SignalCipherType::AesCtrNoPadding)
54            },
55            sys::SG_CIPHER_AES_CBC_PKCS5 => Ok(SignalCipherType::AesCbcPkcs5),
56            _ => Err(SignalCipherTypeError(v)),
57        }
58    }
59}
60
61/// Something which can calculate a SHA-256 HMAC.
62pub trait Sha256Hmac {
63    /// Update the HMAC context with the provided data.
64    fn update(&mut self, data: &[u8]) -> Result<(), InternalError>;
65    /// Return the HMAC result.
66    ///
67    /// # Note
68    ///
69    /// This method should prepare the context for reuse.
70    fn finalize(&mut self) -> Result<Vec<u8>, InternalError>;
71}
72
73/// Something which can generate a SHA-512 hash.
74pub trait Sha512Digest {
75    /// Update the digest context with the provided data.
76    fn update(&mut self, data: &[u8]) -> Result<(), InternalError>;
77    /// Return the digest result.
78    ///
79    /// # Note
80    ///
81    /// This method should prepare the context for reuse.
82    fn finalize(&mut self) -> Result<Vec<u8>, InternalError>;
83}
84
85/// Cryptography routines used in the signal protocol.
86pub trait Crypto: RefUnwindSafe {
87    /// Fill the provided buffer with some random bytes.
88    fn fill_random(&self, buffer: &mut [u8]) -> Result<(), InternalError>;
89
90    /// Start to calculate a SHA-256 HMAC using the provided key.
91    fn hmac_sha256(
92        &self,
93        key: &[u8],
94    ) -> Result<Box<dyn Sha256Hmac>, InternalError>;
95
96    /// Start to generate a SHA-512 digest.
97    fn sha512_digest(&self) -> Result<Box<dyn Sha512Digest>, InternalError>;
98
99    /// Encrypt the provided data using AES.
100    fn encrypt(
101        &self,
102        cipher: SignalCipherType,
103        key: &[u8],
104        iv: &[u8],
105        data: &[u8],
106    ) -> Result<Vec<u8>, InternalError>;
107
108    /// Decrypt the provided data using AES.
109    fn decrypt(
110        &self,
111        cipher: SignalCipherType,
112        key: &[u8],
113        iv: &[u8],
114        data: &[u8],
115    ) -> Result<Vec<u8>, InternalError>;
116}
117
118/// A simple vtable ([`signal_crypto_provider`]) and set of trampolines to let C
119/// use our [`Crypto`] trait object.
120pub(crate) struct CryptoProvider {
121    pub(crate) vtable: signal_crypto_provider,
122    state: Pin<Box<State>>,
123}
124
125impl CryptoProvider {
126    pub fn new<C: Crypto + 'static>(crypto: C) -> CryptoProvider {
127        // we need a double-pointer because C doesn't do fat pointers
128        let mut state: Pin<Box<State>> = Box::pin(State(Box::new(crypto)));
129
130        let vtable = signal_crypto_provider {
131            user_data: state.as_mut().get_mut() as *mut State as *mut c_void,
132            random_func: Some(random_func),
133            hmac_sha256_init_func: Some(hmac_sha256_init_func),
134            hmac_sha256_update_func: Some(hmac_sha256_update_func),
135            hmac_sha256_final_func: Some(hmac_sha256_final_func),
136            hmac_sha256_cleanup_func: Some(hmac_sha256_cleanup_func),
137            sha512_digest_init_func: Some(sha512_digest_init_func),
138            sha512_digest_update_func: Some(sha512_digest_update_func),
139            sha512_digest_final_func: Some(sha512_digest_final_func),
140            sha512_digest_cleanup_func: Some(sha512_digest_cleanup_func),
141            encrypt_func: Some(encrypt_func),
142            decrypt_func: Some(decrypt_func),
143        };
144
145        CryptoProvider { vtable, state }
146    }
147
148    pub fn state(&self) -> &dyn Crypto { &*self.state.0 }
149}
150
151struct State(Box<dyn Crypto>);
152
153struct HmacContext(Mutex<Box<dyn Sha256Hmac>>);
154
155struct DigestContext(Mutex<Box<dyn Sha512Digest>>);
156
157unsafe extern "C" fn random_func(
158    data: *mut u8,
159    len: usize,
160    user_data: *mut c_void,
161) -> c_int {
162    signal_assert!(!data.is_null());
163    signal_assert!(!user_data.is_null());
164
165    let user_data = &*(user_data as *const State);
166
167    let panic_result = std::panic::catch_unwind(|| {
168        let buffer = slice::from_raw_parts_mut(data, len);
169        user_data.0.fill_random(buffer)
170    });
171
172    match panic_result {
173        Ok(Ok(_)) => sys::SG_SUCCESS as c_int,
174        Ok(Err(e)) => {
175            log::error!("Unable to generate random data: {}", e);
176            InternalError::Unknown.code()
177        },
178        Err(e) => {
179            let msg = if let Some(m) = e.downcast_ref::<&str>() {
180                m
181            } else if let Some(m) = e.downcast_ref::<String>() {
182                m.as_str()
183            } else {
184                "Unknown panic"
185            };
186            log::error!("Panic encountered while trying to generate {} random bytes at {}#{}: {}",
187            len, file!(), line!(), msg);
188
189            InternalError::Unknown.code()
190        },
191    }
192}
193
194unsafe extern "C" fn hmac_sha256_cleanup_func(
195    hmac_context: *mut c_void,
196    _user_data: *mut c_void,
197) {
198    if hmac_context.is_null() {
199        return;
200    }
201
202    let hmac_context: Box<HmacContext> =
203        Box::from_raw(hmac_context as *mut HmacContext);
204    drop(hmac_context);
205}
206
207unsafe extern "C" fn hmac_sha256_final_func(
208    hmac_context: *mut c_void,
209    output: *mut *mut signal_buffer,
210    _user_data: *mut c_void,
211) -> i32 {
212    // just to make sure that the c ffi gave us a valid buffer to write to.
213    signal_assert!(!output.is_null());
214    signal_assert!(!hmac_context.is_null());
215
216    let hmac_context = &*(hmac_context as *const HmacContext);
217
218    match signal_catch_unwind!(hmac_context.0.lock().unwrap().finalize()) {
219        Ok(hmac) => {
220            let buffer = Buffer::from(hmac);
221            *output = buffer.into_raw();
222            sys::SG_SUCCESS as c_int
223        },
224        Err(e) => e.code(),
225    }
226}
227
228unsafe extern "C" fn hmac_sha256_init_func(
229    hmac_context: *mut *mut c_void,
230    key: *const u8,
231    key_len: usize,
232    user_data: *mut c_void,
233) -> i32 {
234    signal_assert!(!key.is_null());
235    signal_assert!(!user_data.is_null());
236
237    let state = &*(user_data as *const State);
238    let key = slice::from_raw_parts(key, key_len);
239
240    let hasher = match signal_catch_unwind!(state.0.hmac_sha256(key)) {
241        Ok(h) => h,
242        Err(e) => {
243            *hmac_context = ptr::null_mut();
244            return e.code();
245        },
246    };
247
248    *hmac_context =
249        Box::into_raw(Box::new(HmacContext(Mutex::new(hasher)))) as *mut c_void;
250    sys::SG_SUCCESS as c_int
251}
252
253unsafe extern "C" fn hmac_sha256_update_func(
254    hmac_context: *mut c_void,
255    data: *const u8,
256    data_len: usize,
257    _user_data: *mut c_void,
258) -> i32 {
259    signal_assert!(!data.is_null());
260    signal_assert!(!hmac_context.is_null());
261
262    let hmac_context = &*(hmac_context as *const HmacContext);
263
264    let data = slice::from_raw_parts(data, data_len);
265
266    signal_catch_unwind!(hmac_context.0.lock().unwrap().update(data))
267        .into_code()
268}
269
270unsafe extern "C" fn sha512_digest_init_func(
271    digest_context: *mut *mut c_void,
272    user_data: *mut c_void,
273) -> c_int {
274    signal_assert!(!user_data.is_null());
275
276    let user_data = &*(user_data as *const State);
277    let hasher = match signal_catch_unwind!(user_data.0.sha512_digest()) {
278        Ok(h) => h,
279        Err(e) => {
280            *digest_context = ptr::null_mut();
281            return e.code();
282        },
283    };
284
285    let dc = Box::new(DigestContext(Mutex::new(hasher)));
286    *digest_context = Box::into_raw(Box::new(dc)) as *mut c_void;
287
288    sys::SG_SUCCESS as c_int
289}
290
291unsafe extern "C" fn sha512_digest_update_func(
292    digest_context: *mut c_void,
293    data: *const u8,
294    data_len: usize,
295    _user_data: *mut c_void,
296) -> c_int {
297    signal_assert!(!data.is_null());
298    signal_assert!(!digest_context.is_null());
299
300    let hasher = &*(digest_context as *const DigestContext);
301
302    let buffer = slice::from_raw_parts(data, data_len);
303    signal_catch_unwind!(hasher.0.lock().unwrap().update(buffer)).into_code()
304}
305
306unsafe extern "C" fn sha512_digest_final_func(
307    digest_context: *mut c_void,
308    output: *mut *mut signal_buffer,
309    _user_data: *mut c_void,
310) -> c_int {
311    // just to make sure that the c ffi gave us a valid buffer to write to.
312    signal_assert!(!output.is_null());
313    signal_assert!(!digest_context.is_null());
314
315    let hasher = &*(digest_context as *const DigestContext);
316
317    match signal_catch_unwind!(hasher.0.lock().unwrap().finalize()) {
318        Ok(buf) => {
319            let buffer = Buffer::from(buf);
320            *output = buffer.into_raw();
321            sys::SG_SUCCESS as c_int
322        },
323        Err(e) => e.code(),
324    }
325}
326
327unsafe extern "C" fn sha512_digest_cleanup_func(
328    digest_context: *mut c_void,
329    _user_data: *mut c_void,
330) {
331    if digest_context.is_null() {
332        return;
333    }
334
335    let digest_context: Box<DigestContext> =
336        Box::from_raw(digest_context as *mut DigestContext);
337    drop(digest_context);
338}
339
340unsafe extern "C" fn encrypt_func(
341    output: *mut *mut signal_buffer,
342    cipher: c_int,
343    key: *const u8,
344    key_len: usize,
345    iv: *const u8,
346    iv_len: usize,
347    plaintext: *const u8,
348    plaintext_len: usize,
349    user_data: *mut c_void,
350) -> c_int {
351    internal_cipher(
352        CipherMode::Encrypt,
353        output,
354        cipher,
355        key,
356        key_len,
357        iv,
358        iv_len,
359        plaintext,
360        plaintext_len,
361        user_data,
362    )
363}
364
365unsafe extern "C" fn decrypt_func(
366    output: *mut *mut signal_buffer,
367    cipher: c_int,
368    key: *const u8,
369    key_len: usize,
370    iv: *const u8,
371    iv_len: usize,
372    ciphertext: *const u8,
373    ciphertext_len: usize,
374    user_data: *mut c_void,
375) -> c_int {
376    internal_cipher(
377        CipherMode::Decrypt,
378        output,
379        cipher,
380        key,
381        key_len,
382        iv,
383        iv_len,
384        ciphertext,
385        ciphertext_len,
386        user_data,
387    )
388}
389
390#[allow(clippy::cognitive_complexity)]
391unsafe extern "C" fn internal_cipher(
392    mode: CipherMode,
393    output: *mut *mut signal_buffer,
394    cipher: c_int,
395    key: *const u8,
396    key_len: usize,
397    iv: *const u8,
398    iv_len: usize,
399    data: *const u8,
400    data_len: usize,
401    user_data: *mut c_void,
402) -> c_int {
403    use self::CipherMode::*;
404    // just to make sure that the c ffi gave us a valid buffer to write to.
405    signal_assert!(!output.is_null());
406    signal_assert!(!user_data.is_null());
407    signal_assert!(!key.is_null());
408    signal_assert!(!iv.is_null());
409    signal_assert!(!data.is_null());
410
411    let signal_cipher_type = match SignalCipherType::try_from(cipher) {
412        Ok(ty) => ty,
413        // return early from the function with invalid arg instead of unknown
414        // error, cuz we know it xD
415        Err(_) => return InternalError::InvalidArgument.code(),
416    };
417    let key = slice::from_raw_parts(key, key_len);
418    let iv = slice::from_raw_parts(iv, iv_len);
419    let data = slice::from_raw_parts(data, data_len);
420
421    let user_data = &*(user_data as *const State);
422
423    let result = match mode {
424        Encrypt => signal_catch_unwind!(user_data.0.encrypt(
425            signal_cipher_type,
426            key,
427            iv,
428            data
429        )),
430        Decrypt => signal_catch_unwind!(user_data.0.decrypt(
431            signal_cipher_type,
432            key,
433            iv,
434            data
435        )),
436    };
437
438    match result {
439        Ok(buf) => {
440            let buffer = Buffer::from(buf);
441            *output = buffer.into_raw();
442            sys::SG_SUCCESS as c_int
443        },
444        Err(e) => e.code(),
445    }
446}
447
448#[cfg(test)]
449mod crypto_tests {
450    #[allow(unused_imports)]
451    use super::*;
452
453    #[cfg(all(feature = "crypto-native", feature = "crypto-openssl"))]
454    #[test]
455    fn test_crypter_cbc() {
456        // Here is a test to see the behavior of DefaultCrypto vs OpenSSLCrypto
457        let native_crypto = DefaultCrypto::default();
458        let openssl_crypto = OpenSSLCrypto::default();
459        let data = [1, 2, 3, 4, 5, 6, 7];
460        let mut key = [0u8; 16];
461        let mut iv = [0u8; 16];
462        native_crypto.fill_random(&mut key).unwrap();
463        native_crypto.fill_random(&mut iv).unwrap();
464
465        let cipher_text_native = native_crypto
466            .encrypt(SignalCipherType::AesCbcPkcs5, &key, &iv, &data)
467            .unwrap();
468
469        let cipher_text_openssl = openssl_crypto
470            .encrypt(SignalCipherType::AesCbcPkcs5, &key, &iv, &data)
471            .unwrap();
472        assert_eq!(cipher_text_native, cipher_text_openssl);
473        let plain_text_native = native_crypto
474            .decrypt(
475                SignalCipherType::AesCbcPkcs5,
476                &key,
477                &iv,
478                &cipher_text_openssl,
479            )
480            .unwrap();
481        let plain_text_openssl = openssl_crypto
482            .decrypt(
483                SignalCipherType::AesCbcPkcs5,
484                &key,
485                &iv,
486                &cipher_text_native,
487            )
488            .unwrap();
489        assert_eq!(plain_text_native, data);
490        assert_eq!(plain_text_openssl, data);
491    }
492
493    #[cfg(all(feature = "crypto-native", feature = "crypto-openssl"))]
494    #[test]
495    fn test_crypter_ctr() {
496        // Here is a test to see the behavior of DefaultCrypto vs OpenSSLCrypto
497        let native_crypto = DefaultCrypto::default();
498        let openssl_crypto = OpenSSLCrypto::default();
499        let data = [1, 2, 3, 4, 5, 6, 7];
500        let mut key = [0u8; 16];
501        let mut iv = [0u8; 16];
502        native_crypto.fill_random(&mut key).unwrap();
503        native_crypto.fill_random(&mut iv).unwrap();
504
505        let cipher_text_native = native_crypto
506            .encrypt(SignalCipherType::AesCtrNoPadding, &key, &iv, &data)
507            .unwrap();
508
509        let cipher_text_openssl = openssl_crypto
510            .encrypt(SignalCipherType::AesCtrNoPadding, &key, &iv, &data)
511            .unwrap();
512        assert_eq!(cipher_text_native, cipher_text_openssl);
513        let plain_text_native = native_crypto
514            .decrypt(
515                SignalCipherType::AesCtrNoPadding,
516                &key,
517                &iv,
518                &cipher_text_openssl,
519            )
520            .unwrap();
521        let plain_text_openssl = openssl_crypto
522            .decrypt(
523                SignalCipherType::AesCtrNoPadding,
524                &key,
525                &iv,
526                &cipher_text_native,
527            )
528            .unwrap();
529        assert_eq!(plain_text_native, data);
530        assert_eq!(plain_text_openssl, data);
531    }
532}