use std::{
convert::TryFrom,
fmt::{self, Debug, Formatter},
os::raw::{c_char, c_int, c_void},
panic::RefUnwindSafe,
pin::Pin,
ptr,
rc::Rc,
sync::Mutex,
time::SystemTime,
};
use failure::Error;
use lock_api::RawMutex as _;
use log::Level;
use parking_lot::RawMutex;
#[cfg(feature = "crypto-native")]
use crate::crypto::DefaultCrypto;
use crate::{
crypto::{Crypto, CryptoProvider},
errors::{FromInternalErrorCode, InternalError},
hkdf::HMACBasedKeyDerivationFunction,
keys::{
IdentityKeyPair, KeyPair, PreKeyList, PrivateKey, SessionSignedPreKey,
},
raw_ptr::Raw,
session_builder::SessionBuilder,
stores::{
identity_key_store::{self as iks, IdentityKeyStore},
pre_key_store::{self as pks, PreKeyStore},
session_store::{self as sess, SessionStore},
signed_pre_key_store::{self as spks, SignedPreKeyStore},
},
Address, Buffer, StoreContext,
};
#[allow(unused_imports)]
use crate::keys::{PreKey, PublicKey};
pub fn generate_identity_key_pair(
ctx: &Context,
) -> Result<IdentityKeyPair, Error> {
unsafe {
let mut key_pair = ptr::null_mut();
sys::signal_protocol_key_helper_generate_identity_key_pair(
&mut key_pair,
ctx.raw(),
)
.into_result()?;
Ok(IdentityKeyPair {
raw: Raw::from_ptr(key_pair),
})
}
}
pub fn generate_key_pair(ctx: &Context) -> Result<KeyPair, Error> {
unsafe {
let mut key_pair = ptr::null_mut();
sys::curve_generate_key_pair(ctx.raw(), &mut key_pair).into_result()?;
Ok(KeyPair {
raw: Raw::from_ptr(key_pair),
})
}
}
pub fn calculate_signature(
ctx: &Context,
private: &PrivateKey,
message: &[u8],
) -> Result<Buffer, Error> {
unsafe {
let mut buffer = ptr::null_mut();
sys::curve_calculate_signature(
ctx.raw(),
&mut buffer,
private.raw.as_const_ptr(),
message.as_ptr(),
message.len(),
)
.into_result()?;
Ok(Buffer::from_raw(buffer))
}
}
pub fn generate_registration_id(
ctx: &Context,
extended_range: i32,
) -> Result<u32, Error> {
let mut id = 0;
unsafe {
sys::signal_protocol_key_helper_generate_registration_id(
&mut id,
extended_range,
ctx.raw(),
)
.into_result()?;
}
Ok(id)
}
pub fn generate_pre_keys(
ctx: &Context,
start: u32,
count: u32,
) -> Result<PreKeyList, Error> {
unsafe {
let mut pre_keys_head = ptr::null_mut();
sys::signal_protocol_key_helper_generate_pre_keys(
&mut pre_keys_head,
start,
count,
ctx.raw(),
)
.into_result()?;
Ok(PreKeyList::from_raw(pre_keys_head))
}
}
pub fn generate_signed_pre_key(
ctx: &Context,
identity_key_pair: &IdentityKeyPair,
id: u32,
timestamp: SystemTime,
) -> Result<SessionSignedPreKey, Error> {
unsafe {
let mut raw = ptr::null_mut();
let unix_time = timestamp.duration_since(SystemTime::UNIX_EPOCH)?;
sys::signal_protocol_key_helper_generate_signed_pre_key(
&mut raw,
identity_key_pair.raw.as_const_ptr(),
id,
unix_time.as_secs(),
ctx.raw(),
)
.into_result()?;
if raw.is_null() {
Err(failure::err_msg("Unable to generate a signed pre key"))
} else {
Ok(SessionSignedPreKey {
raw: Raw::from_ptr(raw),
})
}
}
}
pub fn store_context<P, K, S, I>(
ctx: &Context,
pre_key_store: P,
signed_pre_key_store: K,
session_store: S,
identity_key_store: I,
) -> Result<StoreContext, Error>
where
P: PreKeyStore + 'static,
K: SignedPreKeyStore + 'static,
S: SessionStore + 'static,
I: IdentityKeyStore + 'static,
{
unsafe {
let mut store_ctx = ptr::null_mut();
sys::signal_protocol_store_context_create(&mut store_ctx, ctx.raw())
.into_result()?;
let pre_key_store = pks::new_vtable(pre_key_store);
sys::signal_protocol_store_context_set_pre_key_store(
store_ctx,
&pre_key_store,
)
.into_result()?;
let signed_pre_key_store = spks::new_vtable(signed_pre_key_store);
sys::signal_protocol_store_context_set_signed_pre_key_store(
store_ctx,
&signed_pre_key_store,
)
.into_result()?;
let session_store = sess::new_vtable(session_store);
sys::signal_protocol_store_context_set_session_store(
store_ctx,
&session_store,
)
.into_result()?;
let identity_key_store = iks::new_vtable(identity_key_store);
sys::signal_protocol_store_context_set_identity_key_store(
store_ctx,
&identity_key_store,
)
.into_result()?;
Ok(StoreContext::new(store_ctx, &ctx.0))
}
}
pub fn create_hkdf(
ctx: &Context,
version: i32,
) -> Result<HMACBasedKeyDerivationFunction, Error> {
HMACBasedKeyDerivationFunction::new(version, ctx)
}
pub fn session_builder(
ctx: &Context,
store_context: &StoreContext,
address: &Address,
) -> SessionBuilder {
SessionBuilder::new(ctx, store_context, address)
}
#[derive(Debug, Clone)]
pub struct Context(pub(crate) Rc<ContextInner>);
impl Context {
pub fn new<C: Crypto + 'static>(crypto: C) -> Result<Context, Error> {
ContextInner::new(crypto)
.map(|c| Context(Rc::new(c)))
.map_err(Error::from)
}
pub fn crypto(&self) -> &dyn Crypto { self.0.crypto.state() }
pub(crate) fn raw(&self) -> *mut sys::signal_context { self.0.raw() }
pub fn set_log_func<F>(&self, log_func: F)
where
F: Fn(Level, &str) + RefUnwindSafe + 'static,
{
let mut lf = self.0.state.log_func.lock().unwrap();
*lf = Box::new(log_func);
}
}
#[cfg(feature = "crypto-native")]
impl Default for Context {
fn default() -> Context {
match Context::new(DefaultCrypto::default()) {
Ok(c) => c,
Err(e) => {
panic!("Unable to create a context using the defaults: {}", e)
},
}
}
}
#[allow(dead_code)]
pub(crate) struct ContextInner {
raw: *mut sys::signal_context,
crypto: CryptoProvider,
state: Pin<Box<State>>,
}
impl ContextInner {
pub(crate) fn new<C: Crypto + 'static>(
crypto: C,
) -> Result<ContextInner, InternalError> {
unsafe {
let mut global_context: *mut sys::signal_context = ptr::null_mut();
let crypto = CryptoProvider::new(crypto);
let mut state = Pin::new(Box::new(State {
mux: RawMutex::INIT,
log_func: Mutex::new(Box::new(default_log_func)),
}));
let user_data =
state.as_mut().get_mut() as *mut State as *mut c_void;
sys::signal_context_create(&mut global_context, user_data)
.into_result()?;
sys::signal_context_set_crypto_provider(
global_context,
&crypto.vtable,
)
.into_result()?;
sys::signal_context_set_locking_functions(
global_context,
Some(lock_function),
Some(unlock_function),
)
.into_result()?;
sys::signal_context_set_log_function(
global_context,
Some(log_trampoline),
)
.into_result()?;
Ok(ContextInner {
raw: global_context,
crypto,
state,
})
}
}
pub(crate) const fn raw(&self) -> *mut sys::signal_context { self.raw }
}
impl Drop for ContextInner {
fn drop(&mut self) {
unsafe {
sys::signal_context_destroy(self.raw());
}
}
}
impl Debug for ContextInner {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_tuple("ContextInner").finish()
}
}
fn default_log_func(level: Level, message: &str) {
log::log!(level, "{}", message);
if level == Level::Error && std::env::var("RUST_BACKTRACE").is_ok() {
log::error!("{}", failure::Backtrace::new());
}
}
unsafe extern "C" fn log_trampoline(
level: c_int,
msg: *const c_char,
len: usize,
user_data: *mut c_void,
) {
signal_assert!(!msg.is_null(), ());
signal_assert!(!user_data.is_null(), ());
let state = &*(user_data as *const State);
let buffer = std::slice::from_raw_parts(msg as *const u8, len);
let level = translate_log_level(level);
if let Ok(message) = std::str::from_utf8(buffer) {
let _ = std::panic::catch_unwind(|| {
let log_func = state.log_func.lock().unwrap();
log_func(level, message);
});
}
}
fn translate_log_level(raw: c_int) -> Level {
match u32::try_from(raw) {
Ok(sys::SG_LOG_ERROR) => Level::Error,
Ok(sys::SG_LOG_WARNING) => Level::Warn,
Ok(sys::SG_LOG_INFO) => Level::Info,
Ok(sys::SG_LOG_DEBUG) => Level::Debug,
Ok(sys::SG_LOG_NOTICE) => Level::Trace,
_ => Level::Info,
}
}
unsafe extern "C" fn lock_function(user_data: *mut c_void) {
let state = &*(user_data as *const State);
state.mux.lock();
}
unsafe extern "C" fn unlock_function(user_data: *mut c_void) {
let state = &*(user_data as *const State);
state.mux.unlock();
}
struct State {
mux: RawMutex,
log_func: Mutex<Box<dyn Fn(Level, &str) + RefUnwindSafe>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "crypto-native")]
#[test]
fn library_initialization_example_from_readme_native() {
let ctx = Context::default();
drop(ctx);
}
#[cfg(feature = "crypto-openssl")]
#[test]
fn library_initialization_example_from_readme_openssl() {
use crate::crypto::OpenSSLCrypto;
let ctx = Context::new(OpenSSLCrypto::default()).unwrap();
drop(ctx);
}
}