libsignal_protocol/
context.rs

1use std::{
2    convert::TryFrom,
3    fmt::{self, Debug, Formatter},
4    os::raw::{c_char, c_int, c_void},
5    panic::RefUnwindSafe,
6    pin::Pin,
7    ptr,
8    rc::Rc,
9    sync::Mutex,
10    time::SystemTime,
11};
12
13use failure::Error;
14use lock_api::RawMutex as _;
15use log::Level;
16use parking_lot::RawMutex;
17
18#[cfg(feature = "crypto-native")]
19use crate::crypto::DefaultCrypto;
20use crate::{
21    crypto::{Crypto, CryptoProvider},
22    errors::{FromInternalErrorCode, InternalError},
23    hkdf::HMACBasedKeyDerivationFunction,
24    keys::{
25        IdentityKeyPair, KeyPair, PreKeyList, PrivateKey, SessionSignedPreKey,
26    },
27    raw_ptr::Raw,
28    session_builder::SessionBuilder,
29    stores::{
30        identity_key_store::{self as iks, IdentityKeyStore},
31        pre_key_store::{self as pks, PreKeyStore},
32        session_store::{self as sess, SessionStore},
33        signed_pre_key_store::{self as spks, SignedPreKeyStore},
34    },
35    Address, Buffer, StoreContext,
36};
37// for rustdoc link resolution
38#[allow(unused_imports)]
39use crate::keys::{PreKey, PublicKey};
40
41/// A helper function for generating a new [`IdentityKeyPair`].
42pub fn generate_identity_key_pair(
43    ctx: &Context,
44) -> Result<IdentityKeyPair, Error> {
45    unsafe {
46        let mut key_pair = ptr::null_mut();
47        sys::signal_protocol_key_helper_generate_identity_key_pair(
48            &mut key_pair,
49            ctx.raw(),
50        )
51        .into_result()?;
52        Ok(IdentityKeyPair {
53            raw: Raw::from_ptr(key_pair),
54        })
55    }
56}
57
58/// Generate a normal elliptic curve key pair.
59pub fn generate_key_pair(ctx: &Context) -> Result<KeyPair, Error> {
60    unsafe {
61        let mut key_pair = ptr::null_mut();
62        sys::curve_generate_key_pair(ctx.raw(), &mut key_pair).into_result()?;
63
64        Ok(KeyPair {
65            raw: Raw::from_ptr(key_pair),
66        })
67    }
68}
69
70/// Calculate the signature for a message.
71///
72/// # Examples
73///
74/// This is the counterpart to [`PublicKey::verify_signature`].
75///
76/// ```rust
77/// # use libsignal_protocol::{keys::PublicKey, Context};
78/// # use failure::Error;
79/// # use cfg_if::cfg_if;
80/// # fn main() -> Result<(), Error> {
81/// # cfg_if::cfg_if! {
82/// #  if #[cfg(feature = "crypto-native")] {
83/// #      type Crypto = libsignal_protocol::crypto::DefaultCrypto;
84/// #  } else if #[cfg(feature = "crypto-openssl")] {
85/// #      type Crypto = libsignal_protocol::crypto::OpenSSLCrypto;
86/// #  } else {
87/// #      compile_error!("These tests require one of the crypto features to be enabled");
88/// #  }
89/// # }
90/// // the `Crypto` here is a type alias to one of `OpenSSLCrypto` or `DefaultCrypto`.
91/// let ctx = Context::new(Crypto::default()).unwrap();
92/// let key_pair = libsignal_protocol::generate_key_pair(&ctx)?;
93///
94/// let msg = "Hello, World!";
95/// let private_key = key_pair.private();
96/// let signature = libsignal_protocol::calculate_signature(
97///     &ctx,
98///     &private_key,
99///     msg.as_bytes(),
100/// )?;
101///
102/// let public = key_pair.public();
103/// let got = public.verify_signature(msg.as_bytes(), signature.as_slice());
104/// assert!(got.is_ok());
105/// # Ok(())
106/// # }
107/// ```
108pub fn calculate_signature(
109    ctx: &Context,
110    private: &PrivateKey,
111    message: &[u8],
112) -> Result<Buffer, Error> {
113    unsafe {
114        let mut buffer = ptr::null_mut();
115        sys::curve_calculate_signature(
116            ctx.raw(),
117            &mut buffer,
118            private.raw.as_const_ptr(),
119            message.as_ptr(),
120            message.len(),
121        )
122        .into_result()?;
123
124        Ok(Buffer::from_raw(buffer))
125    }
126}
127
128/// Generate a new registration ID.
129pub fn generate_registration_id(
130    ctx: &Context,
131    extended_range: i32,
132) -> Result<u32, Error> {
133    let mut id = 0;
134    unsafe {
135        sys::signal_protocol_key_helper_generate_registration_id(
136            &mut id,
137            extended_range,
138            ctx.raw(),
139        )
140        .into_result()?;
141    }
142
143    Ok(id)
144}
145
146/// Generate a list of [`PreKey`]s. Clients should do this at install time, and
147/// subsequently any time the list of [`PreKey`]s stored on the server runs low.
148///
149/// Pre key IDs are shorts, so they will eventually be repeated. Clients should
150/// store pre keys in a circular buffer, so that they are repeated as
151/// infrequently as possible.
152pub fn generate_pre_keys(
153    ctx: &Context,
154    start: u32,
155    count: u32,
156) -> Result<PreKeyList, Error> {
157    unsafe {
158        let mut pre_keys_head = ptr::null_mut();
159        sys::signal_protocol_key_helper_generate_pre_keys(
160            &mut pre_keys_head,
161            start,
162            count,
163            ctx.raw(),
164        )
165        .into_result()?;
166
167        Ok(PreKeyList::from_raw(pre_keys_head))
168    }
169}
170
171/// Generate a signed pre-key.
172pub fn generate_signed_pre_key(
173    ctx: &Context,
174    identity_key_pair: &IdentityKeyPair,
175    id: u32,
176    timestamp: SystemTime,
177) -> Result<SessionSignedPreKey, Error> {
178    unsafe {
179        let mut raw = ptr::null_mut();
180        let unix_time = timestamp.duration_since(SystemTime::UNIX_EPOCH)?;
181
182        sys::signal_protocol_key_helper_generate_signed_pre_key(
183            &mut raw,
184            identity_key_pair.raw.as_const_ptr(),
185            id,
186            unix_time.as_secs(),
187            ctx.raw(),
188        )
189        .into_result()?;
190
191        if raw.is_null() {
192            Err(failure::err_msg("Unable to generate a signed pre key"))
193        } else {
194            Ok(SessionSignedPreKey {
195                raw: Raw::from_ptr(raw),
196            })
197        }
198    }
199}
200
201/// Create a container for the state used by the signal protocol.
202pub fn store_context<P, K, S, I>(
203    ctx: &Context,
204    pre_key_store: P,
205    signed_pre_key_store: K,
206    session_store: S,
207    identity_key_store: I,
208) -> Result<StoreContext, Error>
209where
210    P: PreKeyStore + 'static,
211    K: SignedPreKeyStore + 'static,
212    S: SessionStore + 'static,
213    I: IdentityKeyStore + 'static,
214{
215    unsafe {
216        let mut store_ctx = ptr::null_mut();
217        sys::signal_protocol_store_context_create(&mut store_ctx, ctx.raw())
218            .into_result()?;
219
220        let pre_key_store = pks::new_vtable(pre_key_store);
221        sys::signal_protocol_store_context_set_pre_key_store(
222            store_ctx,
223            &pre_key_store,
224        )
225        .into_result()?;
226
227        let signed_pre_key_store = spks::new_vtable(signed_pre_key_store);
228        sys::signal_protocol_store_context_set_signed_pre_key_store(
229            store_ctx,
230            &signed_pre_key_store,
231        )
232        .into_result()?;
233
234        let session_store = sess::new_vtable(session_store);
235        sys::signal_protocol_store_context_set_session_store(
236            store_ctx,
237            &session_store,
238        )
239        .into_result()?;
240
241        let identity_key_store = iks::new_vtable(identity_key_store);
242        sys::signal_protocol_store_context_set_identity_key_store(
243            store_ctx,
244            &identity_key_store,
245        )
246        .into_result()?;
247
248        Ok(StoreContext::new(store_ctx, &ctx.0))
249    }
250}
251
252/// Create a new HMAC-based key derivation function.
253pub fn create_hkdf(
254    ctx: &Context,
255    version: i32,
256) -> Result<HMACBasedKeyDerivationFunction, Error> {
257    HMACBasedKeyDerivationFunction::new(version, ctx)
258}
259
260/// Create a new session builder for communication with the user with the
261/// specified address.
262pub fn session_builder(
263    ctx: &Context,
264    store_context: &StoreContext,
265    address: &Address,
266) -> SessionBuilder {
267    SessionBuilder::new(ctx, store_context, address)
268}
269
270/// Global state and callbacks used by the library.
271///
272/// Most functions which require access to the global context (e.g. for crypto
273/// functions or locking) will accept a `&Context` as their first argument.
274#[derive(Debug, Clone)]
275pub struct Context(pub(crate) Rc<ContextInner>);
276
277impl Context {
278    /// Create a new [`Context`] using the provided cryptographic functions.
279    pub fn new<C: Crypto + 'static>(crypto: C) -> Result<Context, Error> {
280        ContextInner::new(crypto)
281            .map(|c| Context(Rc::new(c)))
282            .map_err(Error::from)
283    }
284
285    /// Access the original [`Crypto`] object.
286    pub fn crypto(&self) -> &dyn Crypto { self.0.crypto.state() }
287
288    pub(crate) fn raw(&self) -> *mut sys::signal_context { self.0.raw() }
289
290    /// Se the function to use when `libsignal-protocol-c` emits a log message.
291    pub fn set_log_func<F>(&self, log_func: F)
292    where
293        F: Fn(Level, &str) + RefUnwindSafe + 'static,
294    {
295        let mut lf = self.0.state.log_func.lock().unwrap();
296        *lf = Box::new(log_func);
297    }
298}
299
300#[cfg(feature = "crypto-native")]
301impl Default for Context {
302    fn default() -> Context {
303        match Context::new(DefaultCrypto::default()) {
304            Ok(c) => c,
305            Err(e) => {
306                panic!("Unable to create a context using the defaults: {}", e)
307            },
308        }
309    }
310}
311
312/// Our Rust wrapper around the [`sys::signal_context`].
313///
314/// # Safety
315///
316/// This **must** outlive any data created by the `libsignal-protocol-c`
317/// library. You'll usually do this by adding a `Rc<ContextInner>` to any
318/// wrapper types.
319#[allow(dead_code)]
320pub(crate) struct ContextInner {
321    raw: *mut sys::signal_context,
322    crypto: CryptoProvider,
323    // A pointer to our [`State`] has been passed to `libsignal-protocol-c`, so
324    // we need to make sure it is never moved.
325    state: Pin<Box<State>>,
326}
327
328impl ContextInner {
329    pub(crate) fn new<C: Crypto + 'static>(
330        crypto: C,
331    ) -> Result<ContextInner, InternalError> {
332        unsafe {
333            let mut global_context: *mut sys::signal_context = ptr::null_mut();
334            let crypto = CryptoProvider::new(crypto);
335            let mut state = Pin::new(Box::new(State {
336                mux: RawMutex::INIT,
337                log_func: Mutex::new(Box::new(default_log_func)),
338            }));
339
340            let user_data =
341                state.as_mut().get_mut() as *mut State as *mut c_void;
342            sys::signal_context_create(&mut global_context, user_data)
343                .into_result()?;
344            sys::signal_context_set_crypto_provider(
345                global_context,
346                &crypto.vtable,
347            )
348            .into_result()?;
349            sys::signal_context_set_locking_functions(
350                global_context,
351                Some(lock_function),
352                Some(unlock_function),
353            )
354            .into_result()?;
355            sys::signal_context_set_log_function(
356                global_context,
357                Some(log_trampoline),
358            )
359            .into_result()?;
360
361            Ok(ContextInner {
362                raw: global_context,
363                crypto,
364                state,
365            })
366        }
367    }
368
369    pub(crate) const fn raw(&self) -> *mut sys::signal_context { self.raw }
370}
371
372impl Drop for ContextInner {
373    fn drop(&mut self) {
374        unsafe {
375            sys::signal_context_destroy(self.raw());
376        }
377    }
378}
379
380impl Debug for ContextInner {
381    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
382        f.debug_tuple("ContextInner").finish()
383    }
384}
385
386fn default_log_func(level: Level, message: &str) {
387    log::log!(level, "{}", message);
388
389    if level == Level::Error && std::env::var("RUST_BACKTRACE").is_ok() {
390        log::error!("{}", failure::Backtrace::new());
391    }
392}
393
394unsafe extern "C" fn log_trampoline(
395    level: c_int,
396    msg: *const c_char,
397    len: usize,
398    user_data: *mut c_void,
399) {
400    signal_assert!(!msg.is_null(), ());
401    signal_assert!(!user_data.is_null(), ());
402
403    let state = &*(user_data as *const State);
404    let buffer = std::slice::from_raw_parts(msg as *const u8, len);
405    let level = translate_log_level(level);
406
407    if let Ok(message) = std::str::from_utf8(buffer) {
408        // we can't log the errors that occur while logging errors, so just
409        // drop them on the floor...
410        let _ = std::panic::catch_unwind(|| {
411            let log_func = state.log_func.lock().unwrap();
412            log_func(level, message);
413        });
414    }
415}
416
417fn translate_log_level(raw: c_int) -> Level {
418    match u32::try_from(raw) {
419        Ok(sys::SG_LOG_ERROR) => Level::Error,
420        Ok(sys::SG_LOG_WARNING) => Level::Warn,
421        Ok(sys::SG_LOG_INFO) => Level::Info,
422        Ok(sys::SG_LOG_DEBUG) => Level::Debug,
423        Ok(sys::SG_LOG_NOTICE) => Level::Trace,
424        _ => Level::Info,
425    }
426}
427
428unsafe extern "C" fn lock_function(user_data: *mut c_void) {
429    let state = &*(user_data as *const State);
430    state.mux.lock();
431}
432
433unsafe extern "C" fn unlock_function(user_data: *mut c_void) {
434    let state = &*(user_data as *const State);
435    state.mux.unlock();
436}
437
438/// The "user state" we pass to `libsignal-protocol-c` as part of the global
439/// context.
440///
441/// # Safety
442///
443/// A pointer to this [`State`] will be shared throughout the
444/// `libsignal-protocol-c` library, so any mutation **must** be done using the
445/// appropriate synchronisation mechanisms (i.e. `RefCell` or atomics).
446struct State {
447    mux: RawMutex,
448    log_func: Mutex<Box<dyn Fn(Level, &str) + RefUnwindSafe>>,
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[cfg(feature = "crypto-native")]
456    #[test]
457    fn library_initialization_example_from_readme_native() {
458        let ctx = Context::default();
459
460        drop(ctx);
461    }
462
463    #[cfg(feature = "crypto-openssl")]
464    #[test]
465    fn library_initialization_example_from_readme_openssl() {
466        use crate::crypto::OpenSSLCrypto;
467        let ctx = Context::new(OpenSSLCrypto::default()).unwrap();
468
469        drop(ctx);
470    }
471}