use std::ops::{Deref, DerefMut};
use std::ptr;
use zeroize::Zeroize;
use crate::error::SignerError;
pub struct SecureBuffer {
data: Vec<u8>,
is_locked: bool,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LockingMode {
Strict,
Permissive,
}
impl SecureBuffer {
pub fn new(capacity: usize) -> Result<Self, SignerError> {
Self::with_mode(capacity, LockingMode::Permissive)
}
pub fn with_mode(capacity: usize, mode: LockingMode) -> Result<Self, SignerError> {
let data = vec![0u8; capacity];
let locked = lock_memory(&data);
if mode == LockingMode::Strict && !locked {
return Err(SignerError::MemoryLockFailed(
"mlock failed - memory may be swapped to disk. \
Check ulimit -l or run with CAP_IPC_LOCK capability."
.to_string(),
));
}
if !locked {
eprintln!(
"Warning: Memory locking failed. Private keys may be swapped to disk. \
Consider running with elevated privileges or increasing ulimit -l."
);
}
Ok(Self {
data,
is_locked: locked,
})
}
pub fn new_permissive(capacity: usize) -> Result<Self, SignerError> {
Self::with_mode(capacity, LockingMode::Permissive)
}
pub fn from_slice(source: &[u8]) -> Result<Self, SignerError> {
Self::from_slice_with_mode(source, LockingMode::Permissive)
}
pub fn from_slice_with_mode(source: &[u8], mode: LockingMode) -> Result<Self, SignerError> {
let mut buffer = Self::with_mode(source.len(), mode)?;
buffer.data.copy_from_slice(source);
Ok(buffer)
}
pub fn from_slice_permissive(source: &[u8]) -> Result<Self, SignerError> {
Self::from_slice_with_mode(source, LockingMode::Permissive)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SignerError> {
Self::from_slice(bytes)
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn is_locked(&self) -> bool {
self.is_locked
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn as_mut_bytes(&mut self) -> &mut [u8] {
&mut self.data
}
pub fn zeroize(&mut self) {
self.data.zeroize();
}
pub fn resize(&mut self, new_len: usize) -> Result<(), SignerError> {
self.resize_with_mode(new_len, LockingMode::Strict)
}
pub fn resize_with_mode(
&mut self,
new_len: usize,
mode: LockingMode,
) -> Result<(), SignerError> {
if new_len > self.data.len() {
let mut new_data = vec![0u8; new_len];
let new_locked = lock_memory(&new_data);
if mode == LockingMode::Strict && !new_locked {
return Err(SignerError::MemoryLockFailed(
"mlock failed on resized buffer".to_string(),
));
}
if self.is_locked {
unlock_memory(&self.data);
}
new_data[..self.data.len()].copy_from_slice(&self.data);
self.data.zeroize();
self.is_locked = new_locked;
self.data = new_data;
} else {
for byte in &mut self.data[new_len..] {
*byte = 0;
}
self.data.truncate(new_len);
}
Ok(())
}
}
impl Drop for SecureBuffer {
fn drop(&mut self) {
self.data.zeroize();
if self.is_locked {
unlock_memory(&self.data);
}
}
}
impl Deref for SecureBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl DerefMut for SecureBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl std::fmt::Debug for SecureBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecureBuffer")
.field("len", &self.data.len())
.field("is_locked", &self.is_locked)
.field("data", &"[REDACTED]")
.finish()
}
}
#[cfg(unix)]
fn lock_memory(data: &[u8]) -> bool {
use std::ffi::c_void;
if data.is_empty() {
return true;
}
unsafe {
let ptr = data.as_ptr() as *const c_void;
let len = data.len();
libc::mlock(ptr, len) == 0
}
}
#[cfg(unix)]
fn unlock_memory(data: &[u8]) {
use std::ffi::c_void;
if data.is_empty() {
return;
}
unsafe {
let ptr = data.as_ptr() as *const c_void;
let len = data.len();
libc::munlock(ptr, len);
}
}
#[cfg(windows)]
fn lock_memory(data: &[u8]) -> bool {
if data.is_empty() {
return true;
}
unsafe {
use std::ffi::c_void;
extern "system" {
fn VirtualLock(lpAddress: *const c_void, dwSize: usize) -> i32;
}
VirtualLock(data.as_ptr() as *const c_void, data.len()) != 0
}
}
#[cfg(windows)]
fn unlock_memory(data: &[u8]) {
if data.is_empty() {
return;
}
unsafe {
use std::ffi::c_void;
extern "system" {
fn VirtualUnlock(lpAddress: *const c_void, dwSize: usize) -> i32;
}
VirtualUnlock(data.as_ptr() as *const c_void, data.len());
}
}
#[cfg(not(any(unix, windows)))]
fn lock_memory(_data: &[u8]) -> bool {
eprintln!("Warning: Memory locking not supported on this platform");
false
}
#[cfg(not(any(unix, windows)))]
fn unlock_memory(_data: &[u8]) {
}
pub struct SecureGuard<'a> {
data: &'a mut [u8],
}
impl<'a> SecureGuard<'a> {
pub fn new(data: &'a mut [u8]) -> Self {
Self { data }
}
}
impl<'a> Deref for SecureGuard<'a> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<'a> DerefMut for SecureGuard<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
impl<'a> Drop for SecureGuard<'a> {
fn drop(&mut self) {
for byte in self.data.iter_mut() {
unsafe {
ptr::write_volatile(byte, 0);
}
}
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secure_buffer_creation_permissive() {
let buffer = SecureBuffer::new_permissive(32).unwrap();
assert_eq!(buffer.len(), 32);
assert!(buffer.as_slice().iter().all(|&b| b == 0));
}
#[test]
fn test_secure_buffer_from_slice_permissive() {
let data = [1u8, 2, 3, 4, 5];
let buffer = SecureBuffer::from_slice_permissive(&data).unwrap();
assert_eq!(buffer.as_slice(), &data);
}
#[test]
fn test_secure_buffer_from_bytes_compat() {
let data = b"supersecretkey!!";
let buf = SecureBuffer::from_bytes(data).unwrap();
assert_eq!(buf.as_bytes(), data);
assert_eq!(buf.len(), 16);
}
#[test]
fn test_secure_buffer_zeroize() {
let mut buffer = SecureBuffer::from_slice_permissive(&[1, 2, 3, 4]).unwrap();
buffer.zeroize();
assert!(buffer.as_slice().iter().all(|&b| b == 0));
}
#[test]
fn test_debug_redacts_data() {
let buffer = SecureBuffer::from_slice_permissive(&[0xDE, 0xAD, 0xBE, 0xEF]).unwrap();
let debug_str = format!("{:?}", buffer);
assert!(debug_str.contains("[REDACTED]"));
assert!(!debug_str.contains("DEAD"));
assert!(!debug_str.contains("BEEF"));
}
#[test]
fn test_empty_buffer() {
let buf = SecureBuffer::new(0).unwrap();
assert!(buf.is_empty());
}
#[test]
fn test_strict_mode_checks_locking() {
let result = SecureBuffer::with_mode(32, LockingMode::Strict);
match result {
Ok(buf) => assert!(buf.is_locked(), "Strict mode should only succeed if locked"),
Err(SignerError::MemoryLockFailed(_)) => {
}
Err(e) => panic!("Unexpected error: {}", e),
}
}
}