use super::{CryptoError, CryptoResult};
use std::alloc::Layout;
use std::io;
use subtle::{Choice, ConstantTimeEq};
use zeroize::Zeroize;
struct MemoryProtector;
impl MemoryProtector {
#[inline]
fn lock_memory(buffer: &mut [u8]) -> Result<(), String> {
if buffer.is_empty() {
return Ok(());
}
#[cfg(unix)]
{
let ptr = buffer.as_mut_ptr() as *mut libc::c_void;
let len = buffer.len();
let result = unsafe { libc::mlock(ptr, len) };
if result != 0 {
return Err(io::Error::last_os_error().to_string());
}
}
#[cfg(windows)]
{
use windows_sys::Win32::System::Memory::VirtualLock;
let ptr = buffer.as_mut_ptr() as *mut core::ffi::c_void;
let len = buffer.len();
let ok = unsafe { VirtualLock(ptr, len) };
if ok == 0 {
return Err(io::Error::last_os_error().to_string());
}
}
#[cfg(not(any(unix, windows)))]
{
return Ok(());
}
Ok(())
}
#[inline]
fn unlock_memory(buffer: &mut [u8]) -> Result<(), String> {
if buffer.is_empty() {
return Ok(());
}
#[cfg(unix)]
{
let ptr = buffer.as_mut_ptr() as *mut libc::c_void;
let len = buffer.len();
let result = unsafe { libc::munlock(ptr, len) };
if result != 0 {
return Err(io::Error::last_os_error().to_string());
}
}
#[cfg(windows)]
{
use windows_sys::Win32::System::Memory::VirtualUnlock;
let ptr = buffer.as_mut_ptr() as *mut core::ffi::c_void;
let len = buffer.len();
let ok = unsafe { VirtualUnlock(ptr, len) };
if ok == 0 {
return Err(io::Error::last_os_error().to_string());
}
}
#[cfg(not(any(unix, windows)))]
{
return Ok(());
}
Ok(())
}
#[inline]
fn additional_protection(buffer: &mut [u8]) -> Result<(), String> {
if buffer.is_empty() {
return Ok(());
}
#[cfg(all(unix, target_os = "linux"))]
{
let ptr = buffer.as_mut_ptr() as *mut libc::c_void;
let len = buffer.len();
let result = unsafe { libc::madvise(ptr, len, libc::MADV_DONTDUMP) };
if result != 0 {
return Err(io::Error::last_os_error().to_string());
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MemoryProtectionStatus {
additional_protection_error: Option<String>,
}
impl MemoryProtectionStatus {
pub fn additional_protection_error(&self) -> Option<&str> {
self.additional_protection_error.as_deref()
}
}
pub struct SecureMemory {
data: Vec<u8>,
protection: MemoryProtectionStatus,
}
impl SecureMemory {
pub fn new(len: usize) -> CryptoResult<Self> {
if len == 0 {
return Err(CryptoError::MemoryError(
"Zero-length memory allocation".to_string(),
));
}
Layout::array::<u8>(len)
.map_err(|_| CryptoError::MemoryError("Layout error".to_string()))?;
let mut data = Vec::new();
if data.try_reserve_exact(len).is_err() {
return Err(CryptoError::MemoryError(
"Memory allocation failed".to_string(),
));
}
data.resize(len, 0u8);
if let Err(e) = MemoryProtector::lock_memory(&mut data) {
return Err(CryptoError::MemoryError(format!(
"Memory protection failed: {e}"
)));
}
let additional_protection_error = MemoryProtector::additional_protection(&mut data).err();
Ok(Self {
data,
protection: MemoryProtectionStatus {
additional_protection_error,
},
})
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn capacity(&self) -> usize {
self.data.len()
}
pub fn protection_status(&self) -> &MemoryProtectionStatus {
&self.protection
}
pub fn try_unlock(&mut self) -> CryptoResult<()> {
MemoryProtector::unlock_memory(&mut self.data)
.map_err(|e| CryptoError::MemoryError(format!("Memory unlock failed: {e}")))
}
}
impl Drop for SecureMemory {
fn drop(&mut self) {
self.data.zeroize();
let _ = MemoryProtector::unlock_memory(&mut self.data);
}
}
pub struct SecureString {
data: SecureMemory,
len: usize,
}
impl SecureString {
pub fn new(s: &str) -> CryptoResult<Self> {
let bytes = s.as_bytes();
let mut memory = SecureMemory::new(bytes.len())?;
memory.as_mut_slice().copy_from_slice(bytes);
Ok(Self {
data: memory,
len: bytes.len(),
})
}
pub fn with_capacity(capacity: usize) -> CryptoResult<Self> {
let memory = SecureMemory::new(capacity)?;
Ok(Self {
data: memory,
len: 0,
})
}
pub fn as_str(&self) -> Option<&str> {
let slice = &self.data.as_slice()[..self.len];
std::str::from_utf8(slice).ok()
}
pub fn try_as_str(&self) -> Result<&str, std::str::Utf8Error> {
let slice = &self.data.as_slice()[..self.len];
std::str::from_utf8(slice)
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn capacity(&self) -> usize {
self.data.capacity()
}
}
impl Drop for SecureString {
fn drop(&mut self) {
if self.len > 0 {
let slice = &mut self.data.as_mut_slice()[..self.len];
slice.zeroize();
}
}
}
impl std::fmt::Display for SecureString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SECURE STRING: {} bytes", self.len)
}
}
impl std::fmt::Debug for SecureString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecureString")
.field("len", &self.len)
.field("capacity", &self.capacity())
.finish()
}
}
pub struct MemoryProtection;
impl MemoryProtection {
#[inline(always)]
pub fn secure_clear<T: Zeroize>(data: &mut T) {
data.zeroize();
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
pub fn secure_compare(a: &[u8], b: &[u8]) -> bool {
let len = a.len().max(b.len());
let mut result = Choice::from(1u8);
result &= a.len().ct_eq(&b.len());
for i in 0..len {
let a_byte = if i < a.len() { a[i] } else { 0 };
let b_byte = if i < b.len() { b[i] } else { 0 };
result &= a_byte.ct_eq(&b_byte);
}
result.unwrap_u8() == 1
}
}