#![allow(unsafe_code)]
use std::sync::Once;
use libc::{self, size_t};
#[cfg(not(feature = "use-libsodium-sys"))]
use libc::{c_void, c_int};
#[cfg(feature = "use-libsodium-sys")]
use libsodium_sys::{
randombytes_buf, sodium_allocarray, sodium_free, sodium_init,
sodium_memcmp, sodium_memzero, sodium_mlock, sodium_mprotect_noaccess,
sodium_mprotect_readonly, sodium_mprotect_readwrite, sodium_munlock,
};
static INIT: Once = Once::new();
static mut INITIALIZED: bool = false;
#[cfg(test)]
thread_local! {
static FAIL: std::cell::Cell<bool> = std::cell::Cell::new(false);
}
#[cfg(not(feature = "use-libsodium-sys"))]
unsafe extern "C" {
fn sodium_init() -> c_int;
fn sodium_allocarray(count: size_t, size: size_t) -> *mut c_void;
fn sodium_free(ptr: *mut c_void);
fn sodium_mlock(ptr: *mut c_void, len: size_t) -> c_int;
fn sodium_munlock(ptr: *mut c_void, len: size_t) -> c_int;
fn sodium_mprotect_noaccess(ptr: *mut c_void) -> c_int;
fn sodium_mprotect_readonly(ptr: *mut c_void) -> c_int;
fn sodium_mprotect_readwrite(ptr: *mut c_void) -> c_int;
fn sodium_memcmp(l: *const c_void, r: *const c_void, len: size_t) -> c_int;
fn sodium_memzero(ptr: *mut c_void, len: size_t);
fn randombytes_buf(ptr: *mut c_void, len: size_t);
}
#[cfg(test)]
pub(crate) fn fail() {
FAIL.with(|f| f.set(true))
}
pub(crate) fn init() -> bool {
unsafe {
#[cfg(test)]
{ if FAIL.with(|f| f.replace(false)) { return false }; let _x = 0; };
INIT.call_once(|| {
#[allow(clippy::useless_transmute)]
let _ = std::mem::transmute::<usize, size_t>(0);
let mut failure = false;
#[cfg(unix)]
#[cfg(any(profile = "release", profile = "coverage"))]
#[cfg(not(feature = "allow-coredumps"))]
{
failure |= libc::setrlimit(libc::RLIMIT_CORE, &libc::rlimit {
rlim_cur: 0,
rlim_max: 0,
}) == -1;
}
failure |= sodium_init() == -1;
INITIALIZED = !failure;
});
INITIALIZED
}
}
pub(crate) unsafe fn allocarray<T>(count: usize) -> *mut T {
unsafe { sodium_allocarray(count, size_of::<T>()).cast() }
}
pub(crate) unsafe fn free<T>(ptr: *mut T) {
unsafe { sodium_free(ptr.cast()) };
}
pub(crate) unsafe fn mlock<T>(ptr: *mut T) -> bool {
#[cfg(test)]
{ if FAIL.with(|f| f.replace(false)) { return false }; let _x = 0; };
unsafe { sodium_mlock(ptr.cast(), size_of::<T>()) == 0 }
}
pub(crate) unsafe fn munlock<T>(ptr: *mut T) -> bool {
#[cfg(test)]
{ if FAIL.with(|f| f.replace(false)) { return false }; let _x = 0; };
unsafe { sodium_munlock(ptr.cast(), size_of::<T>()) == 0 }
}
pub(crate) unsafe fn mprotect_noaccess<T>(ptr: *mut T) -> bool {
#[cfg(test)]
{ if FAIL.with(|f| f.replace(false)) { return false }; let _x = 0; };
unsafe { sodium_mprotect_noaccess(ptr.cast()) == 0 }
}
pub(crate) unsafe fn mprotect_readonly<T>(ptr: *mut T) -> bool {
#[cfg(test)]
{ if FAIL.with(|f| f.replace(false)) { return false }; let _x = 0; };
unsafe { sodium_mprotect_readonly(ptr.cast()) == 0 }
}
pub(crate) unsafe fn mprotect_readwrite<T>(ptr: *mut T) -> bool {
#[cfg(test)]
{ if FAIL.with(|f| f.replace(false)) { return false }; let _x = 0; };
unsafe { sodium_mprotect_readwrite(ptr.cast()) == 0 }
}
pub(crate) fn memcmp(l: &[u8], r: &[u8]) -> bool {
if l.len() != r.len() {
return false;
}
unsafe {
sodium_memcmp(
l.as_ptr().cast(),
r.as_ptr().cast(),
r.len(),
) == 0
}
}
pub(crate) unsafe fn memtransfer(src: &mut [u8], dst: &mut [u8]) {
never!(src.len() > dst.len(),
"secrets: may not transfer a larger `src` into a smaller `dst`");
proven!(
unsafe {
(src.as_ptr() < dst.as_ptr() && src.as_ptr().add(src.len()) <= dst.as_ptr()) ||
(dst.as_ptr() < src.as_ptr() && dst.as_ptr().add(dst.len()) <= src.as_ptr())
},
"secrets: may not transfer overlapping slices into one-another"
);
unsafe { src.as_ptr().copy_to_nonoverlapping(dst.as_mut_ptr(), src.len()) };
memzero(src);
}
pub(crate) fn memzero(bytes: &mut [u8]) {
unsafe { sodium_memzero(bytes.as_mut_ptr().cast(), bytes.len()) }
}
pub(crate) fn memrandom(bytes: &mut [u8]) {
unsafe { randombytes_buf(bytes.as_mut_ptr().cast(), bytes.len()) }
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn memcmp_compares_equality() {
let a = [0xfd, 0xa1, 0x92, 0x4b];
let b = a;
assert!(memcmp(&a, &b));
}
#[test]
fn memcmp_compares_inequality_for_different_lengths() {
let a = [0xb8, 0xa4, 0x06, 0xd1];
let b = [0xb8, 0xa4, 0x06];
let c = [0xb8, 0xa4, 0x06, 0xd1, 0x3a];
assert!(!memcmp(&a, &b));
assert!(!memcmp(&b, &a));
assert!(!memcmp(&a, &c));
assert!(!memcmp(&c, &a));
}
}