use crate::derivation::DerivationTree;
use crate::error::{CapError, CapResult};
use crate::grant::{validate_grant, GrantRequest};
use crate::revoke::{validate_revoke, RevokeRequest, RevokeResult};
#[cfg(feature = "alloc")]
use crate::revoke::RevokeStats;
use crate::table::{CapTableEntry, CapabilityTable};
use crate::{DEFAULT_CAP_TABLE_CAPACITY, DEFAULT_MAX_DELEGATION_DEPTH};
use ruvix_types::{CapHandle, CapRights, Capability, ObjectType, TaskHandle};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CapManagerConfig {
pub max_delegation_depth: u8,
pub track_derivation: bool,
pub initial_epoch: u64,
}
impl CapManagerConfig {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
max_delegation_depth: DEFAULT_MAX_DELEGATION_DEPTH,
track_derivation: true,
initial_epoch: 0,
}
}
#[inline]
#[must_use]
pub const fn with_max_depth(mut self, depth: u8) -> Self {
self.max_delegation_depth = depth;
self
}
#[inline]
#[must_use]
pub const fn without_derivation_tracking(mut self) -> Self {
self.track_derivation = false;
self
}
}
impl Default for CapManagerConfig {
fn default() -> Self {
Self::new()
}
}
pub struct CapabilityManager<const N: usize = DEFAULT_CAP_TABLE_CAPACITY> {
table: CapabilityTable<N>,
derivation: DerivationTree<N>,
config: CapManagerConfig,
epoch: u64,
stats: ManagerStats,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ManagerStats {
pub caps_created: u64,
pub caps_granted: u64,
pub caps_revoked: u64,
pub revoke_operations: u64,
pub max_depth_reached: u8,
}
impl<const N: usize> CapabilityManager<N> {
#[inline]
#[must_use]
pub const fn new(config: CapManagerConfig) -> Self {
Self {
table: CapabilityTable::new(),
derivation: DerivationTree::new(),
epoch: config.initial_epoch,
config,
stats: ManagerStats {
caps_created: 0,
caps_granted: 0,
caps_revoked: 0,
revoke_operations: 0,
max_depth_reached: 0,
},
}
}
#[inline]
#[must_use]
pub const fn with_defaults() -> Self {
Self::new(CapManagerConfig::new())
}
#[inline]
#[must_use]
pub const fn config(&self) -> &CapManagerConfig {
&self.config
}
#[inline]
#[must_use]
pub const fn stats(&self) -> &ManagerStats {
&self.stats
}
#[inline]
#[must_use]
pub const fn epoch(&self) -> u64 {
self.epoch
}
#[inline]
#[must_use]
pub const fn len(&self) -> usize {
self.table.len()
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.table.is_empty()
}
#[inline]
pub fn increment_epoch(&mut self) {
self.epoch = self.epoch.wrapping_add(1);
}
pub fn create_root_capability(
&mut self,
object_id: u64,
object_type: ObjectType,
badge: u64,
owner: TaskHandle,
) -> CapResult<CapHandle> {
let capability = Capability::new(
object_id,
object_type,
CapRights::ALL,
badge,
self.epoch,
);
let handle = self.table.allocate_root(capability, owner)?;
if self.config.track_derivation {
self.derivation.add_root(handle)?;
}
self.stats.caps_created += 1;
Ok(handle)
}
pub fn create_root_capability_with_rights(
&mut self,
object_id: u64,
object_type: ObjectType,
rights: CapRights,
badge: u64,
owner: TaskHandle,
) -> CapResult<CapHandle> {
let capability = Capability::new(
object_id,
object_type,
rights,
badge,
self.epoch,
);
let handle = self.table.allocate_root(capability, owner)?;
if self.config.track_derivation {
self.derivation.add_root(handle)?;
}
self.stats.caps_created += 1;
Ok(handle)
}
pub fn grant(
&mut self,
source_handle: CapHandle,
rights: CapRights,
badge: u64,
_caller: TaskHandle,
target: TaskHandle,
) -> CapResult<CapHandle> {
let source_entry = self.table.lookup(source_handle)?;
let request = GrantRequest::new(source_handle, rights, badge)
.with_max_depth(self.config.max_delegation_depth);
let grant_result = validate_grant(source_entry, &request)?;
let derived_handle = self.table.allocate_derived(
grant_result.capability,
target,
grant_result.depth,
source_handle,
)?;
if self.config.track_derivation {
self.derivation.add_child(source_handle, derived_handle, grant_result.depth)?;
}
self.stats.caps_granted += 1;
if grant_result.depth > self.stats.max_depth_reached {
self.stats.max_depth_reached = grant_result.depth;
}
Ok(derived_handle)
}
pub fn revoke(&mut self, handle: CapHandle, _request: RevokeRequest) -> CapResult<RevokeResult> {
let entry = self.table.lookup(handle)?;
validate_revoke(entry)?;
let revoked_count = if self.config.track_derivation {
self.derivation.revoke(handle)?
} else {
self.table.deallocate(handle)?;
1
};
self.deallocate_revoked_caps(handle)?;
self.stats.caps_revoked += revoked_count as u64;
self.stats.revoke_operations += 1;
Ok(RevokeResult::new(revoked_count))
}
fn deallocate_revoked_caps(&mut self, handle: CapHandle) -> CapResult<()> {
let _ = self.table.deallocate(handle);
Ok(())
}
pub fn revoke_single(&mut self, handle: CapHandle) -> CapResult<()> {
let entry = self.table.lookup(handle)?;
validate_revoke(entry)?;
self.table.deallocate(handle)?;
self.stats.caps_revoked += 1;
self.stats.revoke_operations += 1;
Ok(())
}
pub fn lookup(&self, handle: CapHandle) -> CapResult<&CapTableEntry> {
let entry = self.table.lookup(handle)?;
if entry.capability.epoch != self.epoch {
return Err(CapError::Revoked);
}
if self.config.track_derivation && !self.derivation.is_valid(handle) {
return Err(CapError::Revoked);
}
Ok(entry)
}
pub fn is_valid(&self, handle: CapHandle) -> bool {
self.lookup(handle).is_ok()
}
pub fn has_right(&self, handle: CapHandle, right: CapRights) -> CapResult<bool> {
let entry = self.lookup(handle)?;
Ok(entry.capability.rights.contains(right))
}
pub fn has_rights(&self, handle: CapHandle, rights: CapRights) -> CapResult<bool> {
let entry = self.lookup(handle)?;
Ok(entry.capability.rights.contains(rights))
}
pub fn depth(&self, handle: CapHandle) -> CapResult<u8> {
let entry = self.lookup(handle)?;
Ok(entry.depth)
}
#[cfg(feature = "alloc")]
pub fn preview_revoke(&self, handle: CapHandle) -> CapResult<RevokeStats> {
let entry = self.table.lookup(handle)?;
if !self.config.track_derivation {
return Ok(RevokeStats {
roots_revoked: 1,
derived_revoked: 0,
max_depth_reached: entry.depth,
});
}
let descendants = self.derivation.collect_descendants(handle);
let derived_count = descendants.len().saturating_sub(1);
Ok(RevokeStats {
roots_revoked: 1,
derived_revoked: derived_count,
max_depth_reached: entry.depth,
})
}
pub fn iter(&self) -> impl Iterator<Item = (CapHandle, &CapTableEntry)> {
self.table.iter().filter(|(h, _)| {
!self.config.track_derivation || self.derivation.is_valid(*h)
})
}
}
impl<const N: usize> Default for CapabilityManager<N> {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_root_capability() {
let mut manager = CapabilityManager::<64>::with_defaults();
let owner = TaskHandle::new(1, 0);
let handle = manager.create_root_capability(
0x1000,
ObjectType::VectorStore,
42,
owner,
).unwrap();
assert_eq!(manager.len(), 1);
let entry = manager.lookup(handle).unwrap();
assert_eq!(entry.capability.object_id, 0x1000);
assert_eq!(entry.capability.rights, CapRights::ALL);
assert_eq!(entry.depth, 0);
}
#[test]
fn test_grant_capability() {
let mut manager = CapabilityManager::<64>::with_defaults();
let owner = TaskHandle::new(1, 0);
let target = TaskHandle::new(2, 0);
let root_handle = manager.create_root_capability(
0x1000,
ObjectType::Region,
0,
owner,
).unwrap();
let derived_handle = manager.grant(
root_handle,
CapRights::READ | CapRights::WRITE,
100,
owner,
target,
).unwrap();
assert_eq!(manager.len(), 2);
let derived_entry = manager.lookup(derived_handle).unwrap();
assert!(derived_entry.capability.rights.contains(CapRights::READ));
assert!(derived_entry.capability.rights.contains(CapRights::WRITE));
assert!(!derived_entry.capability.rights.contains(CapRights::GRANT));
assert_eq!(derived_entry.depth, 1);
}
#[test]
fn test_revoke_propagation() {
let mut manager = CapabilityManager::<64>::with_defaults();
let owner = TaskHandle::new(1, 0);
let target1 = TaskHandle::new(2, 0);
let target2 = TaskHandle::new(3, 0);
let root = manager.create_root_capability(
0x1000,
ObjectType::Queue,
0,
owner,
).unwrap();
let child1 = manager.grant(
root,
CapRights::READ | CapRights::GRANT,
1,
owner,
target1,
).unwrap();
let grandchild = manager.grant(
child1,
CapRights::READ,
2,
target1,
target2,
).unwrap();
assert_eq!(manager.len(), 3);
let result = manager.revoke(root, RevokeRequest::new()).unwrap();
assert_eq!(result.revoked_count, 3);
assert!(!manager.is_valid(root));
assert!(!manager.is_valid(child1));
assert!(!manager.is_valid(grandchild));
}
#[test]
fn test_delegation_depth_limit() {
let config = CapManagerConfig::new().with_max_depth(2);
let mut manager = CapabilityManager::<64>::new(config);
let owner = TaskHandle::new(1, 0);
let root = manager.create_root_capability(
0x1000,
ObjectType::Timer,
0,
owner,
).unwrap();
let d1 = manager.grant(
root,
CapRights::READ | CapRights::GRANT,
1,
owner,
owner,
).unwrap();
let d2 = manager.grant(
d1,
CapRights::READ | CapRights::GRANT,
2,
owner,
owner,
).unwrap();
let result = manager.grant(
d2,
CapRights::READ,
3,
owner,
owner,
);
assert_eq!(result, Err(CapError::DelegationDepthExceeded));
}
#[test]
fn test_has_right() {
let mut manager = CapabilityManager::<64>::with_defaults();
let owner = TaskHandle::new(1, 0);
let handle = manager.create_root_capability_with_rights(
0x1000,
ObjectType::Region,
CapRights::READ | CapRights::WRITE,
0,
owner,
).unwrap();
assert!(manager.has_right(handle, CapRights::READ).unwrap());
assert!(manager.has_right(handle, CapRights::WRITE).unwrap());
assert!(!manager.has_right(handle, CapRights::EXECUTE).unwrap());
}
#[test]
fn test_epoch_invalidation() {
let mut manager = CapabilityManager::<64>::with_defaults();
let owner = TaskHandle::new(1, 0);
let handle = manager.create_root_capability(
0x1000,
ObjectType::VectorStore,
0,
owner,
).unwrap();
assert!(manager.is_valid(handle));
manager.increment_epoch();
assert!(!manager.is_valid(handle));
}
}