use crate::error::{CapError, CapResult};
use crate::DEFAULT_CAP_TABLE_CAPACITY;
use ruvix_types::CapHandle;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DerivationNode {
pub handle: CapHandle,
pub first_child: CapHandle,
pub next_sibling: CapHandle,
pub is_valid: bool,
pub depth: u8,
}
impl DerivationNode {
#[inline]
#[must_use]
pub const fn new_root(handle: CapHandle) -> Self {
Self {
handle,
first_child: CapHandle::null(),
next_sibling: CapHandle::null(),
is_valid: true,
depth: 0,
}
}
#[inline]
#[must_use]
pub const fn new_child(handle: CapHandle, depth: u8) -> Self {
Self {
handle,
first_child: CapHandle::null(),
next_sibling: CapHandle::null(),
is_valid: true,
depth,
}
}
#[inline]
#[must_use]
pub const fn has_children(&self) -> bool {
!self.first_child.is_null()
}
#[inline]
#[must_use]
pub const fn has_sibling(&self) -> bool {
!self.next_sibling.is_null()
}
}
impl DerivationNode {
#[inline]
#[must_use]
pub const fn empty() -> Self {
Self {
handle: CapHandle::null(),
first_child: CapHandle::null(),
next_sibling: CapHandle::null(),
is_valid: false,
depth: 0,
}
}
}
impl Default for DerivationNode {
fn default() -> Self {
Self::empty()
}
}
pub struct DerivationTree<const N: usize = DEFAULT_CAP_TABLE_CAPACITY> {
nodes: [DerivationNode; N],
count: usize,
}
impl<const N: usize> DerivationTree<N> {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
nodes: [DerivationNode::empty(); N],
count: 0,
}
}
#[inline]
#[must_use]
pub const fn len(&self) -> usize {
self.count
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.count == 0
}
pub fn add_root(&mut self, handle: CapHandle) -> CapResult<()> {
let index = handle.raw().id as usize;
if index >= N {
return Err(CapError::InvalidHandle);
}
self.nodes[index] = DerivationNode::new_root(handle);
self.count += 1;
Ok(())
}
pub fn add_child(
&mut self,
parent_handle: CapHandle,
child_handle: CapHandle,
depth: u8,
) -> CapResult<()> {
let parent_index = parent_handle.raw().id as usize;
let child_index = child_handle.raw().id as usize;
if parent_index >= N || child_index >= N {
return Err(CapError::InvalidHandle);
}
if !self.nodes[parent_index].is_valid {
return Err(CapError::Revoked);
}
let mut child_node = DerivationNode::new_child(child_handle, depth);
child_node.next_sibling = self.nodes[parent_index].first_child;
self.nodes[parent_index].first_child = child_handle;
self.nodes[child_index] = child_node;
self.count += 1;
Ok(())
}
pub fn revoke(&mut self, handle: CapHandle) -> CapResult<usize> {
let index = handle.raw().id as usize;
if index >= N {
return Err(CapError::InvalidHandle);
}
if !self.nodes[index].is_valid {
return Err(CapError::Revoked);
}
let count = self.revoke_subtree(handle);
Ok(count)
}
fn revoke_subtree(&mut self, handle: CapHandle) -> usize {
let index = handle.raw().id as usize;
if index >= N || !self.nodes[index].is_valid {
return 0;
}
let mut count = 1;
self.nodes[index].is_valid = false;
self.count = self.count.saturating_sub(1);
let mut child = self.nodes[index].first_child;
while !child.is_null() {
let child_index = child.raw().id as usize;
if child_index < N {
let next = self.nodes[child_index].next_sibling;
count += self.revoke_subtree(child);
child = next;
} else {
break;
}
}
count
}
pub fn lookup(&self, handle: CapHandle) -> CapResult<&DerivationNode> {
let index = handle.raw().id as usize;
if index >= N {
return Err(CapError::InvalidHandle);
}
let node = &self.nodes[index];
if !node.is_valid {
return Err(CapError::Revoked);
}
Ok(node)
}
pub fn is_valid(&self, handle: CapHandle) -> bool {
let index = handle.raw().id as usize;
index < N && self.nodes[index].is_valid
}
pub fn depth(&self, handle: CapHandle) -> CapResult<u8> {
self.lookup(handle).map(|n| n.depth)
}
#[cfg(feature = "alloc")]
pub fn collect_descendants(&self, handle: CapHandle) -> alloc::vec::Vec<CapHandle> {
let mut result = alloc::vec::Vec::new();
self.collect_descendants_recursive(handle, &mut result);
result
}
#[cfg(feature = "alloc")]
fn collect_descendants_recursive(
&self,
handle: CapHandle,
result: &mut alloc::vec::Vec<CapHandle>,
) {
let index = handle.raw().id as usize;
if index >= N || !self.nodes[index].is_valid {
return;
}
result.push(handle);
let mut child = self.nodes[index].first_child;
while !child.is_null() {
let child_index = child.raw().id as usize;
if child_index < N {
let next = self.nodes[child_index].next_sibling;
self.collect_descendants_recursive(child, result);
child = next;
} else {
break;
}
}
}
}
impl<const N: usize> Default for DerivationTree<N> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derivation_tree_root() {
let mut tree = DerivationTree::<64>::new();
let handle = CapHandle::new(0, 0);
tree.add_root(handle).unwrap();
assert_eq!(tree.len(), 1);
assert!(tree.is_valid(handle));
let node = tree.lookup(handle).unwrap();
assert_eq!(node.depth, 0);
assert!(!node.has_children());
}
#[test]
fn test_derivation_tree_child() {
let mut tree = DerivationTree::<64>::new();
let parent = CapHandle::new(0, 0);
let child = CapHandle::new(1, 0);
tree.add_root(parent).unwrap();
tree.add_child(parent, child, 1).unwrap();
assert_eq!(tree.len(), 2);
assert!(tree.is_valid(child));
assert_eq!(tree.depth(child).unwrap(), 1);
let parent_node = tree.lookup(parent).unwrap();
assert!(parent_node.has_children());
}
#[test]
fn test_derivation_tree_revoke() {
let mut tree = DerivationTree::<64>::new();
let root = CapHandle::new(0, 0);
let child1 = CapHandle::new(1, 0);
let child2 = CapHandle::new(2, 0);
let grandchild = CapHandle::new(3, 0);
tree.add_root(root).unwrap();
tree.add_child(root, child1, 1).unwrap();
tree.add_child(root, child2, 1).unwrap();
tree.add_child(child1, grandchild, 2).unwrap();
assert_eq!(tree.len(), 4);
let revoked = tree.revoke(root).unwrap();
assert_eq!(revoked, 4);
assert_eq!(tree.len(), 0);
assert!(!tree.is_valid(root));
assert!(!tree.is_valid(child1));
assert!(!tree.is_valid(child2));
assert!(!tree.is_valid(grandchild));
}
#[test]
fn test_derivation_tree_partial_revoke() {
let mut tree = DerivationTree::<64>::new();
let root = CapHandle::new(0, 0);
let child1 = CapHandle::new(1, 0);
let child2 = CapHandle::new(2, 0);
let grandchild = CapHandle::new(3, 0);
tree.add_root(root).unwrap();
tree.add_child(root, child1, 1).unwrap();
tree.add_child(root, child2, 1).unwrap();
tree.add_child(child1, grandchild, 2).unwrap();
let revoked = tree.revoke(child1).unwrap();
assert_eq!(revoked, 2);
assert!(tree.is_valid(root));
assert!(!tree.is_valid(child1));
assert!(tree.is_valid(child2));
assert!(!tree.is_valid(grandchild));
}
}