1use std::{
2 convert::TryFrom,
3 fmt::{self, Debug, Formatter},
4 os::raw::{c_char, c_int, c_void},
5 panic::RefUnwindSafe,
6 pin::Pin,
7 ptr,
8 rc::Rc,
9 sync::Mutex,
10 time::SystemTime,
11};
12
13use failure::Error;
14use lock_api::RawMutex as _;
15use log::Level;
16use parking_lot::RawMutex;
17
18#[cfg(feature = "crypto-native")]
19use crate::crypto::DefaultCrypto;
20use crate::{
21 crypto::{Crypto, CryptoProvider},
22 errors::{FromInternalErrorCode, InternalError},
23 hkdf::HMACBasedKeyDerivationFunction,
24 keys::{
25 IdentityKeyPair, KeyPair, PreKeyList, PrivateKey, SessionSignedPreKey,
26 },
27 raw_ptr::Raw,
28 session_builder::SessionBuilder,
29 stores::{
30 identity_key_store::{self as iks, IdentityKeyStore},
31 pre_key_store::{self as pks, PreKeyStore},
32 session_store::{self as sess, SessionStore},
33 signed_pre_key_store::{self as spks, SignedPreKeyStore},
34 },
35 Address, Buffer, StoreContext,
36};
37#[allow(unused_imports)]
39use crate::keys::{PreKey, PublicKey};
40
41pub fn generate_identity_key_pair(
43 ctx: &Context,
44) -> Result<IdentityKeyPair, Error> {
45 unsafe {
46 let mut key_pair = ptr::null_mut();
47 sys::signal_protocol_key_helper_generate_identity_key_pair(
48 &mut key_pair,
49 ctx.raw(),
50 )
51 .into_result()?;
52 Ok(IdentityKeyPair {
53 raw: Raw::from_ptr(key_pair),
54 })
55 }
56}
57
58pub fn generate_key_pair(ctx: &Context) -> Result<KeyPair, Error> {
60 unsafe {
61 let mut key_pair = ptr::null_mut();
62 sys::curve_generate_key_pair(ctx.raw(), &mut key_pair).into_result()?;
63
64 Ok(KeyPair {
65 raw: Raw::from_ptr(key_pair),
66 })
67 }
68}
69
70pub fn calculate_signature(
109 ctx: &Context,
110 private: &PrivateKey,
111 message: &[u8],
112) -> Result<Buffer, Error> {
113 unsafe {
114 let mut buffer = ptr::null_mut();
115 sys::curve_calculate_signature(
116 ctx.raw(),
117 &mut buffer,
118 private.raw.as_const_ptr(),
119 message.as_ptr(),
120 message.len(),
121 )
122 .into_result()?;
123
124 Ok(Buffer::from_raw(buffer))
125 }
126}
127
128pub fn generate_registration_id(
130 ctx: &Context,
131 extended_range: i32,
132) -> Result<u32, Error> {
133 let mut id = 0;
134 unsafe {
135 sys::signal_protocol_key_helper_generate_registration_id(
136 &mut id,
137 extended_range,
138 ctx.raw(),
139 )
140 .into_result()?;
141 }
142
143 Ok(id)
144}
145
146pub fn generate_pre_keys(
153 ctx: &Context,
154 start: u32,
155 count: u32,
156) -> Result<PreKeyList, Error> {
157 unsafe {
158 let mut pre_keys_head = ptr::null_mut();
159 sys::signal_protocol_key_helper_generate_pre_keys(
160 &mut pre_keys_head,
161 start,
162 count,
163 ctx.raw(),
164 )
165 .into_result()?;
166
167 Ok(PreKeyList::from_raw(pre_keys_head))
168 }
169}
170
171pub fn generate_signed_pre_key(
173 ctx: &Context,
174 identity_key_pair: &IdentityKeyPair,
175 id: u32,
176 timestamp: SystemTime,
177) -> Result<SessionSignedPreKey, Error> {
178 unsafe {
179 let mut raw = ptr::null_mut();
180 let unix_time = timestamp.duration_since(SystemTime::UNIX_EPOCH)?;
181
182 sys::signal_protocol_key_helper_generate_signed_pre_key(
183 &mut raw,
184 identity_key_pair.raw.as_const_ptr(),
185 id,
186 unix_time.as_secs(),
187 ctx.raw(),
188 )
189 .into_result()?;
190
191 if raw.is_null() {
192 Err(failure::err_msg("Unable to generate a signed pre key"))
193 } else {
194 Ok(SessionSignedPreKey {
195 raw: Raw::from_ptr(raw),
196 })
197 }
198 }
199}
200
201pub fn store_context<P, K, S, I>(
203 ctx: &Context,
204 pre_key_store: P,
205 signed_pre_key_store: K,
206 session_store: S,
207 identity_key_store: I,
208) -> Result<StoreContext, Error>
209where
210 P: PreKeyStore + 'static,
211 K: SignedPreKeyStore + 'static,
212 S: SessionStore + 'static,
213 I: IdentityKeyStore + 'static,
214{
215 unsafe {
216 let mut store_ctx = ptr::null_mut();
217 sys::signal_protocol_store_context_create(&mut store_ctx, ctx.raw())
218 .into_result()?;
219
220 let pre_key_store = pks::new_vtable(pre_key_store);
221 sys::signal_protocol_store_context_set_pre_key_store(
222 store_ctx,
223 &pre_key_store,
224 )
225 .into_result()?;
226
227 let signed_pre_key_store = spks::new_vtable(signed_pre_key_store);
228 sys::signal_protocol_store_context_set_signed_pre_key_store(
229 store_ctx,
230 &signed_pre_key_store,
231 )
232 .into_result()?;
233
234 let session_store = sess::new_vtable(session_store);
235 sys::signal_protocol_store_context_set_session_store(
236 store_ctx,
237 &session_store,
238 )
239 .into_result()?;
240
241 let identity_key_store = iks::new_vtable(identity_key_store);
242 sys::signal_protocol_store_context_set_identity_key_store(
243 store_ctx,
244 &identity_key_store,
245 )
246 .into_result()?;
247
248 Ok(StoreContext::new(store_ctx, &ctx.0))
249 }
250}
251
252pub fn create_hkdf(
254 ctx: &Context,
255 version: i32,
256) -> Result<HMACBasedKeyDerivationFunction, Error> {
257 HMACBasedKeyDerivationFunction::new(version, ctx)
258}
259
260pub fn session_builder(
263 ctx: &Context,
264 store_context: &StoreContext,
265 address: &Address,
266) -> SessionBuilder {
267 SessionBuilder::new(ctx, store_context, address)
268}
269
270#[derive(Debug, Clone)]
275pub struct Context(pub(crate) Rc<ContextInner>);
276
277impl Context {
278 pub fn new<C: Crypto + 'static>(crypto: C) -> Result<Context, Error> {
280 ContextInner::new(crypto)
281 .map(|c| Context(Rc::new(c)))
282 .map_err(Error::from)
283 }
284
285 pub fn crypto(&self) -> &dyn Crypto { self.0.crypto.state() }
287
288 pub(crate) fn raw(&self) -> *mut sys::signal_context { self.0.raw() }
289
290 pub fn set_log_func<F>(&self, log_func: F)
292 where
293 F: Fn(Level, &str) + RefUnwindSafe + 'static,
294 {
295 let mut lf = self.0.state.log_func.lock().unwrap();
296 *lf = Box::new(log_func);
297 }
298}
299
300#[cfg(feature = "crypto-native")]
301impl Default for Context {
302 fn default() -> Context {
303 match Context::new(DefaultCrypto::default()) {
304 Ok(c) => c,
305 Err(e) => {
306 panic!("Unable to create a context using the defaults: {}", e)
307 },
308 }
309 }
310}
311
312#[allow(dead_code)]
320pub(crate) struct ContextInner {
321 raw: *mut sys::signal_context,
322 crypto: CryptoProvider,
323 state: Pin<Box<State>>,
326}
327
328impl ContextInner {
329 pub(crate) fn new<C: Crypto + 'static>(
330 crypto: C,
331 ) -> Result<ContextInner, InternalError> {
332 unsafe {
333 let mut global_context: *mut sys::signal_context = ptr::null_mut();
334 let crypto = CryptoProvider::new(crypto);
335 let mut state = Pin::new(Box::new(State {
336 mux: RawMutex::INIT,
337 log_func: Mutex::new(Box::new(default_log_func)),
338 }));
339
340 let user_data =
341 state.as_mut().get_mut() as *mut State as *mut c_void;
342 sys::signal_context_create(&mut global_context, user_data)
343 .into_result()?;
344 sys::signal_context_set_crypto_provider(
345 global_context,
346 &crypto.vtable,
347 )
348 .into_result()?;
349 sys::signal_context_set_locking_functions(
350 global_context,
351 Some(lock_function),
352 Some(unlock_function),
353 )
354 .into_result()?;
355 sys::signal_context_set_log_function(
356 global_context,
357 Some(log_trampoline),
358 )
359 .into_result()?;
360
361 Ok(ContextInner {
362 raw: global_context,
363 crypto,
364 state,
365 })
366 }
367 }
368
369 pub(crate) const fn raw(&self) -> *mut sys::signal_context { self.raw }
370}
371
372impl Drop for ContextInner {
373 fn drop(&mut self) {
374 unsafe {
375 sys::signal_context_destroy(self.raw());
376 }
377 }
378}
379
380impl Debug for ContextInner {
381 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
382 f.debug_tuple("ContextInner").finish()
383 }
384}
385
386fn default_log_func(level: Level, message: &str) {
387 log::log!(level, "{}", message);
388
389 if level == Level::Error && std::env::var("RUST_BACKTRACE").is_ok() {
390 log::error!("{}", failure::Backtrace::new());
391 }
392}
393
394unsafe extern "C" fn log_trampoline(
395 level: c_int,
396 msg: *const c_char,
397 len: usize,
398 user_data: *mut c_void,
399) {
400 signal_assert!(!msg.is_null(), ());
401 signal_assert!(!user_data.is_null(), ());
402
403 let state = &*(user_data as *const State);
404 let buffer = std::slice::from_raw_parts(msg as *const u8, len);
405 let level = translate_log_level(level);
406
407 if let Ok(message) = std::str::from_utf8(buffer) {
408 let _ = std::panic::catch_unwind(|| {
411 let log_func = state.log_func.lock().unwrap();
412 log_func(level, message);
413 });
414 }
415}
416
417fn translate_log_level(raw: c_int) -> Level {
418 match u32::try_from(raw) {
419 Ok(sys::SG_LOG_ERROR) => Level::Error,
420 Ok(sys::SG_LOG_WARNING) => Level::Warn,
421 Ok(sys::SG_LOG_INFO) => Level::Info,
422 Ok(sys::SG_LOG_DEBUG) => Level::Debug,
423 Ok(sys::SG_LOG_NOTICE) => Level::Trace,
424 _ => Level::Info,
425 }
426}
427
428unsafe extern "C" fn lock_function(user_data: *mut c_void) {
429 let state = &*(user_data as *const State);
430 state.mux.lock();
431}
432
433unsafe extern "C" fn unlock_function(user_data: *mut c_void) {
434 let state = &*(user_data as *const State);
435 state.mux.unlock();
436}
437
438struct State {
447 mux: RawMutex,
448 log_func: Mutex<Box<dyn Fn(Level, &str) + RefUnwindSafe>>,
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[cfg(feature = "crypto-native")]
456 #[test]
457 fn library_initialization_example_from_readme_native() {
458 let ctx = Context::default();
459
460 drop(ctx);
461 }
462
463 #[cfg(feature = "crypto-openssl")]
464 #[test]
465 fn library_initialization_example_from_readme_openssl() {
466 use crate::crypto::OpenSSLCrypto;
467 let ctx = Context::new(OpenSSLCrypto::default()).unwrap();
468
469 drop(ctx);
470 }
471}