1use blake3::Hasher;
2
3use crate::error::A1Error;
4
5const DOMAIN: &[u8] = b"a1::dyolo::narrowing::v2.8.0";
6
7#[derive(Debug, Clone, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub struct NarrowingMatrix {
40 mask: [u8; 32],
41}
42
43impl NarrowingMatrix {
44 pub const EMPTY: Self = Self { mask: [0u8; 32] };
46
47 pub const FULL: Self = Self { mask: [0xFF; 32] };
49
50 pub fn from_capabilities<S: AsRef<str>>(caps: &[S]) -> Self {
58 let mut mask = [0u8; 32];
59 for cap in caps {
60 let (byte_idx, bit_idx) = capability_to_bit(cap.as_ref());
61 mask[byte_idx] |= 1u8 << bit_idx;
62 }
63 Self { mask }
64 }
65
66 pub fn from_csv(csv: &str) -> Self {
68 let caps: Vec<&str> = csv
69 .split(',')
70 .map(str::trim)
71 .filter(|s| !s.is_empty())
72 .collect();
73 Self::from_capabilities(&caps)
74 }
75
76 pub(crate) fn from_raw(mask: [u8; 32]) -> Self {
81 Self { mask }
82 }
83
84 pub fn is_subset_of(&self, parent: &NarrowingMatrix) -> bool {
89 let read_u64 = |bytes: &[u8; 32], i: usize| -> u64 {
90 u64::from_le_bytes(
91 bytes[i * 8..(i + 1) * 8]
92 .try_into()
93 .expect("slice is 8 bytes"),
94 )
95 };
96 (0..4).all(|i| {
97 let s = read_u64(&self.mask, i);
98 let p = read_u64(&parent.mask, i);
99 s & p == s
100 })
101 }
102
103 pub fn enforce_narrowing(&self, parent: &NarrowingMatrix) -> Result<(), A1Error> {
105 if self.is_subset_of(parent) {
106 Ok(())
107 } else {
108 Err(A1Error::PassportNarrowingViolation)
109 }
110 }
111
112 pub fn intersect(&self, other: &NarrowingMatrix) -> NarrowingMatrix {
116 let mut mask = [0u8; 32];
117 for (i, item) in mask.iter_mut().enumerate() {
118 *item = self.mask[i] & other.mask[i];
119 }
120 NarrowingMatrix { mask }
121 }
122
123 pub fn commitment(&self) -> [u8; 32] {
128 let mut h = Hasher::new_derive_key(std::str::from_utf8(DOMAIN).unwrap());
129 h.update(&self.mask);
130 *h.finalize().as_bytes()
131 }
132
133 pub fn as_bytes(&self) -> &[u8; 32] {
135 &self.mask
136 }
137
138 pub fn to_hex(&self) -> String {
140 hex::encode(self.mask)
141 }
142
143 pub fn from_hex(s: &str) -> Result<Self, A1Error> {
145 let bytes = hex::decode(s)
146 .map_err(|_| A1Error::WireFormatError("invalid narrowing matrix hex".into()))?;
147 if bytes.len() != 32 {
148 return Err(A1Error::WireFormatError(
149 "narrowing matrix must be exactly 32 bytes".into(),
150 ));
151 }
152 let mut mask = [0u8; 32];
153 mask.copy_from_slice(&bytes);
154 Ok(Self { mask })
155 }
156
157 pub fn is_empty(&self) -> bool {
159 self.mask.iter().all(|&b| b == 0)
160 }
161
162 pub fn capacity_count(&self) -> u32 {
164 self.mask.iter().map(|b| b.count_ones()).sum()
165 }
166}
167
168impl Default for NarrowingMatrix {
169 fn default() -> Self {
170 Self::EMPTY
171 }
172}
173
174impl std::fmt::Display for NarrowingMatrix {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 write!(f, "{}", self.to_hex())
177 }
178}
179
180fn capability_to_bit(name: &str) -> (usize, usize) {
186 let mut h = Hasher::new_derive_key(std::str::from_utf8(DOMAIN).unwrap());
187 h.update(name.as_bytes());
188 let out = h.finalize();
189 let b = out.as_bytes();
190 let byte_idx = (b[0] as usize) % 32;
191 let bit_idx = (b[1] as usize) % 8;
192 (byte_idx, bit_idx)
193}
194
195#[derive(Debug, Clone)]
232pub struct CapabilityRegistry {
233 slots: std::collections::HashMap<String, u8>,
235 next: u8,
237 count: usize,
239}
240
241impl CapabilityRegistry {
242 pub fn new() -> Self {
244 Self {
245 slots: std::collections::HashMap::new(),
246 next: 0,
247 count: 0,
248 }
249 }
250
251 pub fn register(&mut self, name: impl Into<String>) -> Result<u8, A1Error> {
257 let name = name.into();
258 if let Some(&slot) = self.slots.get(&name) {
259 return Ok(slot);
260 }
261 if self.count >= 256 {
262 return Err(A1Error::WireFormatError(
263 "CapabilityRegistry is full: maximum 256 capabilities per registry".into(),
264 ));
265 }
266 let slot = self.next;
267 self.slots.insert(name, slot);
268 self.next = self.next.wrapping_add(1);
269 self.count += 1;
270 Ok(slot)
271 }
272
273 pub fn register_all<S: AsRef<str>>(&mut self, names: &[S]) -> Result<(), A1Error> {
275 for name in names {
276 self.register(name.as_ref())?;
277 }
278 Ok(())
279 }
280
281 pub fn build_mask<S: AsRef<str>>(
288 &self,
289 capabilities: &[S],
290 ) -> Result<NarrowingMatrix, A1Error> {
291 let mut mask = [0u8; 32];
292 for cap in capabilities {
293 let name = cap.as_ref();
294 let slot = self.slots.get(name).ok_or_else(|| {
295 A1Error::WireFormatError(format!(
296 "capability '{}' is not registered; call register() first",
297 name
298 ))
299 })?;
300 let byte_idx = (*slot as usize) / 8;
301 let bit_idx = (*slot as usize) % 8;
302 mask[byte_idx] |= 1u8 << bit_idx;
303 }
304 Ok(NarrowingMatrix::from_raw(mask))
305 }
306
307 pub fn build_full_mask(&self) -> NarrowingMatrix {
309 let mut mask = [0u8; 32];
310 for slot in self.slots.values() {
311 let byte_idx = (*slot as usize) / 8;
312 let bit_idx = (*slot as usize) % 8;
313 mask[byte_idx] |= 1u8 << bit_idx;
314 }
315 NarrowingMatrix::from_raw(mask)
316 }
317
318 pub fn slot_of(&self, name: &str) -> Option<u8> {
320 self.slots.get(name).copied()
321 }
322
323 pub fn names_in_order(&self) -> Vec<&str> {
325 let mut pairs: Vec<(&str, u8)> = self.slots.iter().map(|(k, &v)| (k.as_str(), v)).collect();
326 pairs.sort_by_key(|&(_, slot)| slot);
327 pairs.into_iter().map(|(name, _)| name).collect()
328 }
329
330 pub fn len(&self) -> usize {
332 self.count
333 }
334
335 pub fn is_empty(&self) -> bool {
337 self.count == 0
338 }
339}
340
341impl Default for CapabilityRegistry {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn empty_is_subset_of_full() {
353 assert!(NarrowingMatrix::EMPTY.is_subset_of(&NarrowingMatrix::FULL));
354 }
355
356 #[test]
357 fn full_is_not_subset_of_empty() {
358 assert!(!NarrowingMatrix::FULL.is_subset_of(&NarrowingMatrix::EMPTY));
359 }
360
361 #[test]
362 fn subset_of_itself() {
363 let m = NarrowingMatrix::from_capabilities(&["trade.equity", "portfolio.read"]);
364 assert!(m.is_subset_of(&m));
365 }
366
367 #[test]
368 fn sub_is_subset_of_parent() {
369 let parent = NarrowingMatrix::from_capabilities(&[
370 "trade.equity",
371 "portfolio.read",
372 "portfolio.write",
373 ]);
374 let child = NarrowingMatrix::from_capabilities(&["trade.equity"]);
375 assert!(child.is_subset_of(&parent));
376 assert!(!parent.is_subset_of(&child));
377 }
378
379 #[test]
380 fn escalation_detected() {
381 let parent = NarrowingMatrix::from_capabilities(&["portfolio.read"]);
382 let child = NarrowingMatrix::from_capabilities(&["trade.equity"]);
383 assert!(child.enforce_narrowing(&parent).is_err());
384 }
385
386 #[test]
387 fn commitment_is_stable() {
388 let m = NarrowingMatrix::from_capabilities(&["trade.equity"]);
389 let c1 = m.commitment();
390 let c2 = m.commitment();
391 assert_eq!(c1, c2);
392 }
393
394 #[test]
395 fn commitment_differs_across_masks() {
396 let a = NarrowingMatrix::from_capabilities(&["trade.equity"]);
397 let b = NarrowingMatrix::from_capabilities(&["portfolio.write"]);
398 assert_ne!(a.commitment(), b.commitment());
399 }
400
401 #[test]
402 fn roundtrip_hex() {
403 let m = NarrowingMatrix::from_capabilities(&["trade.equity", "audit.read"]);
404 let hex = m.to_hex();
405 let m2 = NarrowingMatrix::from_hex(&hex).unwrap();
406 assert_eq!(m, m2);
407 }
408
409 #[test]
410 fn csv_parsing() {
411 let m = NarrowingMatrix::from_csv("trade.equity , portfolio.read, audit.read");
412 let expected =
413 NarrowingMatrix::from_capabilities(&["trade.equity", "portfolio.read", "audit.read"]);
414 assert_eq!(m, expected);
415 }
416
417 #[test]
418 fn intersect_produces_common_bits() {
419 let a = NarrowingMatrix::from_capabilities(&["trade.equity", "portfolio.read"]);
420 let b = NarrowingMatrix::from_capabilities(&["trade.equity", "audit.read"]);
421 let common = a.intersect(&b);
422 let expected = NarrowingMatrix::from_capabilities(&["trade.equity"]);
423 assert_eq!(common, expected);
424 }
425
426 #[test]
429 fn registry_sequential_slots() {
430 let mut reg = CapabilityRegistry::new();
431 let s0 = reg.register("alpha").unwrap();
432 let s1 = reg.register("beta").unwrap();
433 let s2 = reg.register("gamma").unwrap();
434 assert_eq!(s0, 0);
435 assert_eq!(s1, 1);
436 assert_eq!(s2, 2);
437 }
438
439 #[test]
440 fn registry_idempotent_register() {
441 let mut reg = CapabilityRegistry::new();
442 let s0 = reg.register("alpha").unwrap();
443 let s1 = reg.register("alpha").unwrap();
444 assert_eq!(s0, s1);
445 assert_eq!(reg.len(), 1);
446 }
447
448 #[test]
449 fn registry_build_mask_subset() {
450 let mut reg = CapabilityRegistry::new();
451 reg.register_all(&["trade.equity", "portfolio.read", "audit.read"])
452 .unwrap();
453
454 let parent = reg.build_mask(&["trade.equity", "portfolio.read"]).unwrap();
455 let child = reg.build_mask(&["trade.equity"]).unwrap();
456
457 assert!(child.is_subset_of(&parent));
458 assert!(!parent.is_subset_of(&child));
459 }
460
461 #[test]
462 fn registry_rejects_unknown_capability() {
463 let mut reg = CapabilityRegistry::new();
464 reg.register("trade.equity").unwrap();
465 let result = reg.build_mask(&["portfolio.write"]);
466 assert!(result.is_err());
467 }
468
469 #[test]
470 fn registry_full_mask_covers_all() {
471 let mut reg = CapabilityRegistry::new();
472 reg.register_all(&["a", "b", "c"]).unwrap();
473
474 let full = reg.build_full_mask();
475 let a = reg.build_mask(&["a"]).unwrap();
476 let b = reg.build_mask(&["b"]).unwrap();
477 let c = reg.build_mask(&["c"]).unwrap();
478
479 assert!(a.is_subset_of(&full));
480 assert!(b.is_subset_of(&full));
481 assert!(c.is_subset_of(&full));
482 }
483
484 #[test]
485 fn registry_no_collisions_across_256_caps() {
486 let mut reg = CapabilityRegistry::new();
487 let caps: Vec<String> = (0..256).map(|i| format!("cap.{}", i)).collect();
488 let cap_refs: Vec<&str> = caps.iter().map(String::as_str).collect();
489 reg.register_all(&cap_refs).unwrap();
490
491 for cap in &caps {
492 let mask = reg.build_mask(&[cap.as_str()]).unwrap();
493 assert_eq!(
494 mask.capacity_count(),
495 1,
496 "cap '{}' must occupy exactly one bit",
497 cap
498 );
499 }
500 }
501
502 #[test]
503 fn registry_over_256_returns_error() {
504 let mut reg = CapabilityRegistry::new();
505 let caps: Vec<String> = (0..256).map(|i| format!("cap.{}", i)).collect();
506 let cap_refs: Vec<&str> = caps.iter().map(String::as_str).collect();
507 reg.register_all(&cap_refs).unwrap();
508
509 let result = reg.register("one.too.many");
510 assert!(result.is_err());
511 }
512
513 #[test]
514 fn registry_names_in_order_matches_registration() {
515 let mut reg = CapabilityRegistry::new();
516 let names = ["gamma", "alpha", "beta", "delta"];
517 reg.register_all(&names).unwrap();
518 let ordered = reg.names_in_order();
519 assert_eq!(ordered, names.as_slice());
520 }
521
522 #[test]
523 fn registry_no_collision_where_hash_would_collide() {
524 let mut reg = CapabilityRegistry::new();
527 reg.register_all(&["cap.0", "cap.1"]).unwrap();
528 let m0 = reg.build_mask(&["cap.0"]).unwrap();
529 let m1 = reg.build_mask(&["cap.1"]).unwrap();
530 assert_eq!(m0.intersect(&m1), NarrowingMatrix::EMPTY);
532 }
533}