use blake3::Hasher;
use crate::error::A1Error;
const DOMAIN: &[u8] = b"a1::dyolo::narrowing::v2.8.0";
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NarrowingMatrix {
mask: [u8; 32],
}
impl NarrowingMatrix {
pub const EMPTY: Self = Self { mask: [0u8; 32] };
pub const FULL: Self = Self { mask: [0xFF; 32] };
pub fn from_capabilities<S: AsRef<str>>(caps: &[S]) -> Self {
let mut mask = [0u8; 32];
for cap in caps {
let (byte_idx, bit_idx) = capability_to_bit(cap.as_ref());
mask[byte_idx] |= 1u8 << bit_idx;
}
Self { mask }
}
pub fn from_csv(csv: &str) -> Self {
let caps: Vec<&str> = csv
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.collect();
Self::from_capabilities(&caps)
}
pub(crate) fn from_raw(mask: [u8; 32]) -> Self {
Self { mask }
}
pub fn is_subset_of(&self, parent: &NarrowingMatrix) -> bool {
let read_u64 = |bytes: &[u8; 32], i: usize| -> u64 {
u64::from_le_bytes(
bytes[i * 8..(i + 1) * 8]
.try_into()
.expect("slice is 8 bytes"),
)
};
(0..4).all(|i| {
let s = read_u64(&self.mask, i);
let p = read_u64(&parent.mask, i);
s & p == s
})
}
pub fn enforce_narrowing(&self, parent: &NarrowingMatrix) -> Result<(), A1Error> {
if self.is_subset_of(parent) {
Ok(())
} else {
Err(A1Error::PassportNarrowingViolation)
}
}
pub fn intersect(&self, other: &NarrowingMatrix) -> NarrowingMatrix {
let mut mask = [0u8; 32];
for (i, item) in mask.iter_mut().enumerate() {
*item = self.mask[i] & other.mask[i];
}
NarrowingMatrix { mask }
}
pub fn commitment(&self) -> [u8; 32] {
let mut h = Hasher::new_derive_key(std::str::from_utf8(DOMAIN).unwrap());
h.update(&self.mask);
*h.finalize().as_bytes()
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.mask
}
pub fn to_hex(&self) -> String {
hex::encode(self.mask)
}
pub fn from_hex(s: &str) -> Result<Self, A1Error> {
let bytes = hex::decode(s)
.map_err(|_| A1Error::WireFormatError("invalid narrowing matrix hex".into()))?;
if bytes.len() != 32 {
return Err(A1Error::WireFormatError(
"narrowing matrix must be exactly 32 bytes".into(),
));
}
let mut mask = [0u8; 32];
mask.copy_from_slice(&bytes);
Ok(Self { mask })
}
pub fn is_empty(&self) -> bool {
self.mask.iter().all(|&b| b == 0)
}
pub fn capacity_count(&self) -> u32 {
self.mask.iter().map(|b| b.count_ones()).sum()
}
}
impl Default for NarrowingMatrix {
fn default() -> Self {
Self::EMPTY
}
}
impl std::fmt::Display for NarrowingMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
fn capability_to_bit(name: &str) -> (usize, usize) {
let mut h = Hasher::new_derive_key(std::str::from_utf8(DOMAIN).unwrap());
h.update(name.as_bytes());
let out = h.finalize();
let b = out.as_bytes();
let byte_idx = (b[0] as usize) % 32;
let bit_idx = (b[1] as usize) % 8;
(byte_idx, bit_idx)
}
#[derive(Debug, Clone)]
pub struct CapabilityRegistry {
slots: std::collections::HashMap<String, u8>,
next: u8,
count: usize,
}
impl CapabilityRegistry {
pub fn new() -> Self {
Self {
slots: std::collections::HashMap::new(),
next: 0,
count: 0,
}
}
pub fn register(&mut self, name: impl Into<String>) -> Result<u8, A1Error> {
let name = name.into();
if let Some(&slot) = self.slots.get(&name) {
return Ok(slot);
}
if self.count >= 256 {
return Err(A1Error::WireFormatError(
"CapabilityRegistry is full: maximum 256 capabilities per registry".into(),
));
}
let slot = self.next;
self.slots.insert(name, slot);
self.next = self.next.wrapping_add(1);
self.count += 1;
Ok(slot)
}
pub fn register_all<S: AsRef<str>>(&mut self, names: &[S]) -> Result<(), A1Error> {
for name in names {
self.register(name.as_ref())?;
}
Ok(())
}
pub fn build_mask<S: AsRef<str>>(
&self,
capabilities: &[S],
) -> Result<NarrowingMatrix, A1Error> {
let mut mask = [0u8; 32];
for cap in capabilities {
let name = cap.as_ref();
let slot = self.slots.get(name).ok_or_else(|| {
A1Error::WireFormatError(format!(
"capability '{}' is not registered; call register() first",
name
))
})?;
let byte_idx = (*slot as usize) / 8;
let bit_idx = (*slot as usize) % 8;
mask[byte_idx] |= 1u8 << bit_idx;
}
Ok(NarrowingMatrix::from_raw(mask))
}
pub fn build_full_mask(&self) -> NarrowingMatrix {
let mut mask = [0u8; 32];
for slot in self.slots.values() {
let byte_idx = (*slot as usize) / 8;
let bit_idx = (*slot as usize) % 8;
mask[byte_idx] |= 1u8 << bit_idx;
}
NarrowingMatrix::from_raw(mask)
}
pub fn slot_of(&self, name: &str) -> Option<u8> {
self.slots.get(name).copied()
}
pub fn names_in_order(&self) -> Vec<&str> {
let mut pairs: Vec<(&str, u8)> = self.slots.iter().map(|(k, &v)| (k.as_str(), v)).collect();
pairs.sort_by_key(|&(_, slot)| slot);
pairs.into_iter().map(|(name, _)| name).collect()
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_is_subset_of_full() {
assert!(NarrowingMatrix::EMPTY.is_subset_of(&NarrowingMatrix::FULL));
}
#[test]
fn full_is_not_subset_of_empty() {
assert!(!NarrowingMatrix::FULL.is_subset_of(&NarrowingMatrix::EMPTY));
}
#[test]
fn subset_of_itself() {
let m = NarrowingMatrix::from_capabilities(&["trade.equity", "portfolio.read"]);
assert!(m.is_subset_of(&m));
}
#[test]
fn sub_is_subset_of_parent() {
let parent = NarrowingMatrix::from_capabilities(&[
"trade.equity",
"portfolio.read",
"portfolio.write",
]);
let child = NarrowingMatrix::from_capabilities(&["trade.equity"]);
assert!(child.is_subset_of(&parent));
assert!(!parent.is_subset_of(&child));
}
#[test]
fn escalation_detected() {
let parent = NarrowingMatrix::from_capabilities(&["portfolio.read"]);
let child = NarrowingMatrix::from_capabilities(&["trade.equity"]);
assert!(child.enforce_narrowing(&parent).is_err());
}
#[test]
fn commitment_is_stable() {
let m = NarrowingMatrix::from_capabilities(&["trade.equity"]);
let c1 = m.commitment();
let c2 = m.commitment();
assert_eq!(c1, c2);
}
#[test]
fn commitment_differs_across_masks() {
let a = NarrowingMatrix::from_capabilities(&["trade.equity"]);
let b = NarrowingMatrix::from_capabilities(&["portfolio.write"]);
assert_ne!(a.commitment(), b.commitment());
}
#[test]
fn roundtrip_hex() {
let m = NarrowingMatrix::from_capabilities(&["trade.equity", "audit.read"]);
let hex = m.to_hex();
let m2 = NarrowingMatrix::from_hex(&hex).unwrap();
assert_eq!(m, m2);
}
#[test]
fn csv_parsing() {
let m = NarrowingMatrix::from_csv("trade.equity , portfolio.read, audit.read");
let expected =
NarrowingMatrix::from_capabilities(&["trade.equity", "portfolio.read", "audit.read"]);
assert_eq!(m, expected);
}
#[test]
fn intersect_produces_common_bits() {
let a = NarrowingMatrix::from_capabilities(&["trade.equity", "portfolio.read"]);
let b = NarrowingMatrix::from_capabilities(&["trade.equity", "audit.read"]);
let common = a.intersect(&b);
let expected = NarrowingMatrix::from_capabilities(&["trade.equity"]);
assert_eq!(common, expected);
}
#[test]
fn registry_sequential_slots() {
let mut reg = CapabilityRegistry::new();
let s0 = reg.register("alpha").unwrap();
let s1 = reg.register("beta").unwrap();
let s2 = reg.register("gamma").unwrap();
assert_eq!(s0, 0);
assert_eq!(s1, 1);
assert_eq!(s2, 2);
}
#[test]
fn registry_idempotent_register() {
let mut reg = CapabilityRegistry::new();
let s0 = reg.register("alpha").unwrap();
let s1 = reg.register("alpha").unwrap();
assert_eq!(s0, s1);
assert_eq!(reg.len(), 1);
}
#[test]
fn registry_build_mask_subset() {
let mut reg = CapabilityRegistry::new();
reg.register_all(&["trade.equity", "portfolio.read", "audit.read"])
.unwrap();
let parent = reg.build_mask(&["trade.equity", "portfolio.read"]).unwrap();
let child = reg.build_mask(&["trade.equity"]).unwrap();
assert!(child.is_subset_of(&parent));
assert!(!parent.is_subset_of(&child));
}
#[test]
fn registry_rejects_unknown_capability() {
let mut reg = CapabilityRegistry::new();
reg.register("trade.equity").unwrap();
let result = reg.build_mask(&["portfolio.write"]);
assert!(result.is_err());
}
#[test]
fn registry_full_mask_covers_all() {
let mut reg = CapabilityRegistry::new();
reg.register_all(&["a", "b", "c"]).unwrap();
let full = reg.build_full_mask();
let a = reg.build_mask(&["a"]).unwrap();
let b = reg.build_mask(&["b"]).unwrap();
let c = reg.build_mask(&["c"]).unwrap();
assert!(a.is_subset_of(&full));
assert!(b.is_subset_of(&full));
assert!(c.is_subset_of(&full));
}
#[test]
fn registry_no_collisions_across_256_caps() {
let mut reg = CapabilityRegistry::new();
let caps: Vec<String> = (0..256).map(|i| format!("cap.{}", i)).collect();
let cap_refs: Vec<&str> = caps.iter().map(String::as_str).collect();
reg.register_all(&cap_refs).unwrap();
for cap in &caps {
let mask = reg.build_mask(&[cap.as_str()]).unwrap();
assert_eq!(
mask.capacity_count(),
1,
"cap '{}' must occupy exactly one bit",
cap
);
}
}
#[test]
fn registry_over_256_returns_error() {
let mut reg = CapabilityRegistry::new();
let caps: Vec<String> = (0..256).map(|i| format!("cap.{}", i)).collect();
let cap_refs: Vec<&str> = caps.iter().map(String::as_str).collect();
reg.register_all(&cap_refs).unwrap();
let result = reg.register("one.too.many");
assert!(result.is_err());
}
#[test]
fn registry_names_in_order_matches_registration() {
let mut reg = CapabilityRegistry::new();
let names = ["gamma", "alpha", "beta", "delta"];
reg.register_all(&names).unwrap();
let ordered = reg.names_in_order();
assert_eq!(ordered, names.as_slice());
}
#[test]
fn registry_no_collision_where_hash_would_collide() {
let mut reg = CapabilityRegistry::new();
reg.register_all(&["cap.0", "cap.1"]).unwrap();
let m0 = reg.build_mask(&["cap.0"]).unwrap();
let m1 = reg.build_mask(&["cap.1"]).unwrap();
assert_eq!(m0.intersect(&m1), NarrowingMatrix::EMPTY);
}
}