use bytes::BytesMut;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::{Deref, DerefMut};
pub struct SecBuffer {
inner: BytesMut,
}
impl SecBuffer {
pub fn empty() -> Self {
Self::with_capacity(0)
}
pub fn with_capacity(cap: usize) -> Self {
Self::from(BytesMut::with_capacity(cap))
}
pub fn into_buffer(mut self) -> BytesMut {
self.unlock();
std::mem::take(&mut self.inner)
}
pub fn handle(&mut self) -> SecureBufMutHandle {
SecureBufMutHandle::new(self)
}
pub fn len(&self) -> usize {
self.inner.len()
}
fn lock(&self) {
unsafe { crate::misc::mlock(self.slice().as_ptr(), self.inner.len()) }
}
fn unlock(&self) {
unsafe { crate::misc::munlock(self.slice().as_ptr(), self.inner.len()) }
}
fn zeroize(&mut self) {
unsafe { crate::misc::zeroize(self.slice().as_ptr(), self.inner.len()) }
}
fn slice(&self) -> &[u8] {
&self.inner[..]
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct SecureBufMutHandle<'a> {
inner: &'a mut SecBuffer,
}
impl<'a> SecureBufMutHandle<'a> {
fn new(inner: &'a mut SecBuffer) -> SecureBufMutHandle<'a> {
inner.unlock();
Self { inner }
}
}
impl Deref for SecureBufMutHandle<'_> {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
&self.inner.inner
}
}
impl DerefMut for SecureBufMutHandle<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner.inner
}
}
impl Drop for SecureBufMutHandle<'_> {
fn drop(&mut self) {
self.inner.lock()
}
}
impl AsRef<[u8]> for SecBuffer {
fn as_ref(&self) -> &[u8] {
&self.inner[..]
}
}
impl AsMut<[u8]> for SecBuffer {
fn as_mut(&mut self) -> &mut [u8] {
self.inner.as_mut()
}
}
impl From<Vec<u8>> for SecBuffer {
fn from(inner: Vec<u8>) -> Self {
Self::from(&inner[..])
}
}
impl From<BytesMut> for SecBuffer {
fn from(inner: BytesMut) -> Self {
let this = Self { inner };
this.lock();
this
}
}
impl<const N: usize> From<[u8; N]> for SecBuffer {
fn from(this: [u8; N]) -> Self {
Self::from(&this as &[u8])
}
}
impl From<&[u8]> for SecBuffer {
fn from(this: &[u8]) -> Self {
Self::from(BytesMut::from(this))
}
}
impl From<&str> for SecBuffer {
fn from(this: &str) -> Self {
Self::from(BytesMut::from(this))
}
}
impl Drop for SecBuffer {
fn drop(&mut self) {
self.unlock();
self.zeroize();
}
}
impl Debug for SecBuffer {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "***SECRET***")
}
}
impl<T: AsRef<[u8]>> PartialEq<T> for SecBuffer {
fn eq(&self, other: &T) -> bool {
let this = self.as_ref();
let other = other.as_ref();
super::const_time_compare(this, other)
}
}
impl Clone for SecBuffer {
fn clone(&self) -> Self {
self.unlock();
let ret = SecBuffer::from(self.inner.clone());
self.lock();
ret
}
}
impl Serialize for SecBuffer {
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
self.unlock();
let ret = self.inner.serialize(serializer);
self.lock();
ret
}
}
impl<'de> Deserialize<'de> for SecBuffer {
fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
where
D: Deserializer<'de>,
{
Ok(Self::from(BytesMut::deserialize(deserializer)?))
}
}
#[cfg(test)]
mod test {
use crate::prelude::SecBuffer;
#[test]
fn test_secbuffer_cmp_same() {
let a = SecBuffer::from("Hello");
let b = SecBuffer::from("Hello");
assert_eq!(a, b);
}
#[test]
fn test_secbuffer_cmp_diff() {
let a = SecBuffer::from("Hello");
let b = SecBuffer::from("World");
assert_ne!(a, b);
}
#[test]
fn test_secbuffer_cmp_diff2() {
let a = SecBuffer::from("Hello");
let b = SecBuffer::from("World................");
assert_ne!(a, b);
}
#[test]
fn test_secbuffer_cmp_diff3() {
let a = SecBuffer::from("Hello................");
let b = SecBuffer::from("World");
assert_ne!(a, b);
}
}