use wasm_dbms_api::prelude::{
DEFAULT_ALIGNMENT, DataSize, Encode, MSize, MemoryError, MemoryResult, PageOffset,
};
use crate::{MemoryAccess, MemoryManager, MemoryProvider};
pub trait AccessControl: Default {
type Id;
fn load<M>(mm: &mut MemoryManager<M>) -> MemoryResult<Self>
where
M: MemoryProvider,
Self: Sized;
fn is_allowed(&self, identity: &Self::Id) -> bool;
fn allowed_identities(&self) -> Vec<Self::Id>;
fn add_identity<M>(
&mut self,
identity: Self::Id,
mm: &mut MemoryManager<M>,
) -> MemoryResult<()>
where
M: MemoryProvider;
fn remove_identity<M>(
&mut self,
identity: &Self::Id,
mm: &mut MemoryManager<M>,
) -> MemoryResult<()>
where
M: MemoryProvider;
}
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct NoAccessControl;
impl AccessControl for NoAccessControl {
type Id = ();
fn load<M>(_mm: &mut MemoryManager<M>) -> MemoryResult<Self>
where
M: MemoryProvider,
{
Ok(Self)
}
fn is_allowed(&self, _identity: &Self::Id) -> bool {
true
}
fn allowed_identities(&self) -> Vec<Self::Id> {
vec![]
}
fn add_identity<M>(
&mut self,
_identity: Self::Id,
_mm: &mut MemoryManager<M>,
) -> MemoryResult<()>
where
M: MemoryProvider,
{
Ok(())
}
fn remove_identity<M>(
&mut self,
_identity: &Self::Id,
_mm: &mut MemoryManager<M>,
) -> MemoryResult<()>
where
M: MemoryProvider,
{
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct AccessControlList {
allowed: Vec<Vec<u8>>,
}
impl AccessControlList {
fn save<M>(&self, mm: &mut MemoryManager<M>) -> MemoryResult<()>
where
M: MemoryProvider,
{
mm.write_at(mm.acl_page(), 0, self)
}
}
impl AccessControl for AccessControlList {
type Id = Vec<u8>;
fn load<M>(mm: &mut MemoryManager<M>) -> MemoryResult<Self>
where
M: MemoryProvider,
{
mm.read_at(mm.acl_page(), 0)
}
fn is_allowed(&self, identity: &Self::Id) -> bool {
self.allowed
.iter()
.any(|a| a.as_slice() == identity.as_slice())
}
fn allowed_identities(&self) -> Vec<Self::Id> {
self.allowed.clone()
}
fn add_identity<M>(&mut self, identity: Self::Id, mm: &mut MemoryManager<M>) -> MemoryResult<()>
where
M: MemoryProvider,
{
if !self.is_allowed(&identity) {
self.allowed.push(identity);
self.save(mm)?;
}
Ok(())
}
fn remove_identity<M>(
&mut self,
identity: &Self::Id,
mm: &mut MemoryManager<M>,
) -> MemoryResult<()>
where
M: MemoryProvider,
{
if let Some(pos) = self
.allowed
.iter()
.position(|p| p.as_slice() == identity.as_slice())
{
if self.allowed.len() == 1 {
return Err(MemoryError::ConstraintViolation(
"ACL must contain at least one identity".to_string(),
));
}
self.allowed.swap_remove(pos);
self.save(mm)?;
}
Ok(())
}
}
impl Encode for AccessControlList {
const SIZE: DataSize = DataSize::Dynamic;
const ALIGNMENT: PageOffset = DEFAULT_ALIGNMENT;
fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
let mut bytes = Vec::with_capacity(self.size() as usize);
let len = self.allowed.len() as u32;
bytes.extend_from_slice(&len.to_le_bytes());
for identity in &self.allowed {
let identity_len = identity.len() as u8;
bytes.extend_from_slice(&identity_len.to_le_bytes());
bytes.extend_from_slice(identity);
}
std::borrow::Cow::Owned(bytes)
}
fn decode(data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
where
Self: Sized,
{
let mut offset = 0;
let len_bytes = &data[offset..offset + 4];
offset += 4;
let len = u32::from_le_bytes(len_bytes.try_into()?) as usize;
let mut allowed = Vec::with_capacity(len);
for _ in 0..len {
let identity_len_bytes = &data[offset..offset + 1];
offset += 1;
let identity_len = u8::from_le_bytes(identity_len_bytes.try_into()?) as usize;
let identity_bytes = data[offset..offset + identity_len].to_vec();
offset += identity_len;
allowed.push(identity_bytes);
}
Ok(AccessControlList { allowed })
}
fn size(&self) -> MSize {
4 + self
.allowed
.iter()
.map(|p| 1 + p.len() as MSize)
.sum::<MSize>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HeapMemoryProvider;
fn make_mm() -> MemoryManager<HeapMemoryProvider> {
MemoryManager::init(HeapMemoryProvider::default())
}
#[test]
fn test_acl_encode_decode() {
let acl = AccessControlList {
allowed: vec![
vec![0x04], vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01],
vec![0xDE, 0xAD, 0xBE, 0xEF],
vec![0x01, 0x02, 0x03, 0x04, 0x05],
],
};
let encoded = acl.encode();
let decoded = AccessControlList::decode(encoded).unwrap();
assert_eq!(acl, decoded);
}
#[test]
fn test_acl_add_remove_identity() {
let mut mm = make_mm();
let mut acl = AccessControlList::default();
let identity = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
assert!(!acl.is_allowed(&identity));
acl.add_identity(identity.clone(), &mut mm).unwrap();
let other = vec![0xDE, 0xAD, 0xBE, 0xEF];
acl.add_identity(other.clone(), &mut mm).unwrap();
assert!(acl.is_allowed(&identity));
assert!(acl.is_allowed(&other));
assert_eq!(acl.allowed_identities().len(), 2);
acl.remove_identity(&other, &mut mm).unwrap();
}
#[test]
fn test_remove_last_identity_returns_error() {
let mut mm = make_mm();
let mut acl = AccessControlList::default();
let identity = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
acl.add_identity(identity.clone(), &mut mm).unwrap();
assert!(acl.is_allowed(&identity));
let result = acl.remove_identity(&identity, &mut mm);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MemoryError::ConstraintViolation(_)
));
}
#[test]
fn test_should_add_more_identities() {
let mut mm = make_mm();
let mut acl = AccessControlList::default();
let identity1 = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
let identity2 = vec![0xDE, 0xAD, 0xBE, 0xEF];
acl.add_identity(identity1.clone(), &mut mm).unwrap();
acl.add_identity(identity2.clone(), &mut mm).unwrap();
assert!(acl.is_allowed(&identity1));
assert!(acl.is_allowed(&identity2));
assert_eq!(
acl.allowed_identities(),
vec![identity1.clone(), identity2.clone()]
);
}
#[test]
fn test_add_identity_should_write_to_memory() {
let mut mm = make_mm();
let mut acl = AccessControlList::default();
let identity = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
acl.add_identity(identity.clone(), &mut mm).unwrap();
let loaded_acl = AccessControlList::load(&mut mm).unwrap();
assert!(loaded_acl.is_allowed(&identity));
}
#[test]
fn test_no_access_control_allows_everything() {
let acl = NoAccessControl;
assert!(acl.is_allowed(&()));
assert!(acl.allowed_identities().is_empty());
}
#[test]
fn test_add_duplicate_identity_is_idempotent() {
let mut mm = make_mm();
let mut acl = AccessControlList::default();
let identity = vec![0x01, 0x02, 0x03];
acl.add_identity(identity.clone(), &mut mm).unwrap();
acl.add_identity(identity.clone(), &mut mm).unwrap();
assert_eq!(acl.allowed_identities().len(), 1);
assert!(acl.is_allowed(&identity));
}
#[test]
fn test_remove_nonexistent_identity_is_noop() {
let mut mm = make_mm();
let mut acl = AccessControlList::default();
let identity_a = vec![0x01, 0x02];
let identity_b = vec![0x03, 0x04];
acl.add_identity(identity_a.clone(), &mut mm).unwrap();
acl.add_identity(identity_b.clone(), &mut mm).unwrap();
let nonexistent = vec![0xFF, 0xFF];
acl.remove_identity(&nonexistent, &mut mm).unwrap();
assert_eq!(acl.allowed_identities().len(), 2);
}
#[test]
fn test_no_access_control_load() {
let mut mm = make_mm();
let acl = NoAccessControl::load(&mut mm).unwrap();
assert!(acl.is_allowed(&()));
}
#[test]
fn test_no_access_control_add_and_remove_identity() {
let mut mm = make_mm();
let mut acl = NoAccessControl;
acl.add_identity((), &mut mm).unwrap();
acl.remove_identity(&(), &mut mm).unwrap();
assert!(acl.is_allowed(&()));
}
#[test]
fn test_empty_acl_encode_decode() {
let acl = AccessControlList::default();
let encoded = acl.encode();
let decoded = AccessControlList::decode(encoded).unwrap();
assert_eq!(acl, decoded);
assert!(decoded.allowed_identities().is_empty());
}
#[test]
fn test_acl_size() {
let acl = AccessControlList {
allowed: vec![vec![0x01, 0x02], vec![0x03]],
};
assert_eq!(acl.size(), 4 + 3 + 2);
}
#[test]
fn test_empty_acl_size() {
let acl = AccessControlList::default();
assert_eq!(acl.size(), 4);
}
}