1use crate::schubert::{IntersectionResult, SchubertCalculus, SchubertClass};
28use crate::EnumerativeResult;
29use std::sync::Arc;
30
31#[cfg(feature = "parallel")]
32use rayon::prelude::*;
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
42pub struct CapabilityId(pub Arc<str>);
43
44impl CapabilityId {
45 #[must_use]
54 pub fn new(name: impl Into<String>) -> Self {
55 Self(Arc::from(name.into()))
56 }
57
58 #[must_use]
60 pub fn as_str(&self) -> &str {
61 &self.0
62 }
63}
64
65impl std::fmt::Display for CapabilityId {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 write!(f, "{}", self.0)
68 }
69}
70
71impl From<&str> for CapabilityId {
72 fn from(s: &str) -> Self {
73 Self::new(s)
74 }
75}
76
77impl From<String> for CapabilityId {
78 fn from(s: String) -> Self {
79 Self::new(s)
80 }
81}
82
83#[derive(Debug, Clone)]
95pub struct Capability {
96 pub id: CapabilityId,
98 pub name: String,
100 pub schubert_class: SchubertClass,
103 pub requires: Vec<CapabilityId>,
105 pub conflicts: Vec<CapabilityId>,
107}
108
109impl Capability {
110 pub fn new(
119 id: impl Into<String>,
120 name: impl Into<String>,
121 partition: Vec<usize>,
122 grassmannian: (usize, usize),
123 ) -> EnumerativeResult<Self> {
124 let id_str = id.into();
125 Ok(Self {
126 id: CapabilityId::new(id_str),
127 name: name.into(),
128 schubert_class: SchubertClass::new(partition, grassmannian)?,
129 requires: Vec::new(),
130 conflicts: Vec::new(),
131 })
132 }
133
134 #[must_use]
136 pub fn requires(mut self, cap_id: CapabilityId) -> Self {
137 self.requires.push(cap_id);
138 self
139 }
140
141 #[must_use]
143 pub fn requires_all(mut self, cap_ids: impl IntoIterator<Item = CapabilityId>) -> Self {
144 self.requires.extend(cap_ids);
145 self
146 }
147
148 #[must_use]
150 pub fn conflicts_with(mut self, cap_id: CapabilityId) -> Self {
151 self.conflicts.push(cap_id);
152 self
153 }
154
155 #[must_use]
157 pub fn conflicts_with_all(mut self, cap_ids: impl IntoIterator<Item = CapabilityId>) -> Self {
158 self.conflicts.extend(cap_ids);
159 self
160 }
161
162 #[must_use]
170 pub fn codimension(&self) -> usize {
171 self.schubert_class.partition.iter().sum()
172 }
173
174 #[must_use]
176 pub fn has_dependencies(&self) -> bool {
177 !self.requires.is_empty()
178 }
179
180 #[must_use]
182 pub fn has_conflicts(&self) -> bool {
183 !self.conflicts.is_empty()
184 }
185}
186
187#[derive(Debug, Clone)]
199pub struct Namespace {
200 pub grassmannian: (usize, usize),
202 pub position: SchubertClass,
204 pub capabilities: Vec<Capability>,
206 pub name: String,
208}
209
210impl Namespace {
211 #[must_use]
213 pub fn new(name: impl Into<String>, position: SchubertClass) -> Self {
214 Self {
215 grassmannian: position.grassmannian_dim,
216 position,
217 capabilities: Vec::new(),
218 name: name.into(),
219 }
220 }
221
222 pub fn full(name: impl Into<String>, k: usize, n: usize) -> EnumerativeResult<Self> {
232 let position = SchubertClass::new(vec![], (k, n))?;
233 Ok(Self::new(name, position))
234 }
235
236 pub fn grant(&mut self, capability: Capability) -> Result<(), NamespaceError> {
252 for existing in &self.capabilities {
254 if capability.conflicts.contains(&existing.id) {
255 return Err(NamespaceError::Conflict {
256 new: capability.id,
257 existing: existing.id.clone(),
258 });
259 }
260 if existing.conflicts.contains(&capability.id) {
261 return Err(NamespaceError::Conflict {
262 new: capability.id,
263 existing: existing.id.clone(),
264 });
265 }
266 }
267
268 for req in &capability.requires {
270 if !self.capabilities.iter().any(|c| &c.id == req) {
271 return Err(NamespaceError::MissingDependency {
272 capability: capability.id,
273 required: req.clone(),
274 });
275 }
276 }
277
278 self.capabilities.push(capability);
279 Ok(())
280 }
281
282 pub fn grant_all(&mut self, capabilities: Vec<Capability>) -> Result<(), NamespaceError> {
287 let mut remaining = capabilities;
289 let mut progress = true;
290
291 while !remaining.is_empty() && progress {
292 progress = false;
293 let mut still_remaining = Vec::new();
294
295 for cap in remaining {
296 let deps_satisfied = cap.requires.iter().all(|dep| self.has_capability(dep));
297
298 if deps_satisfied {
299 self.grant(cap)?;
300 progress = true;
301 } else {
302 still_remaining.push(cap);
303 }
304 }
305
306 remaining = still_remaining;
307 }
308
309 if !remaining.is_empty() {
310 let first = remaining.into_iter().next().unwrap();
311 return Err(NamespaceError::MissingDependency {
312 capability: first.id,
313 required: first
314 .requires
315 .into_iter()
316 .next()
317 .unwrap_or_else(|| CapabilityId::new("unknown")),
318 });
319 }
320
321 Ok(())
322 }
323
324 pub fn revoke(&mut self, id: &CapabilityId) -> bool {
336 if let Some(pos) = self.capabilities.iter().position(|c| &c.id == id) {
337 let dependents: Vec<CapabilityId> = self
339 .capabilities
340 .iter()
341 .filter(|c| c.requires.contains(id))
342 .map(|c| c.id.clone())
343 .collect();
344
345 if dependents.is_empty() {
346 self.capabilities.remove(pos);
347 return true;
348 }
349 }
350 false
351 }
352
353 #[must_use]
355 pub fn has_capability(&self, id: &CapabilityId) -> bool {
356 self.capabilities.iter().any(|c| &c.id == id)
357 }
358
359 #[must_use]
361 pub fn capability_ids(&self) -> Vec<CapabilityId> {
362 self.capabilities.iter().map(|c| c.id.clone()).collect()
363 }
364
365 #[must_use]
367 pub fn capability_count(&self) -> usize {
368 self.capabilities.len()
369 }
370
371 #[must_use]
385 pub fn count_configurations(&self) -> IntersectionResult {
386 let mut calc = SchubertCalculus::new(self.grassmannian);
387
388 let mut classes = vec![self.position.clone()];
389 for cap in &self.capabilities {
390 classes.push(cap.schubert_class.clone());
391 }
392
393 calc.multi_intersect(&classes)
394 }
395
396 #[must_use]
398 pub fn total_codimension(&self) -> usize {
399 let position_codim: usize = self.position.partition.iter().sum();
400 let cap_codim: usize = self.capabilities.iter().map(|c| c.codimension()).sum();
401 position_codim + cap_codim
402 }
403
404 #[must_use]
406 pub fn would_overdetermine(&self, capability: &Capability) -> bool {
407 let (k, n) = self.grassmannian;
408 let grassmannian_dim = k * (n - k);
409 let new_total = self.total_codimension() + capability.codimension();
410 new_total > grassmannian_dim
411 }
412}
413
414#[derive(Debug, Clone, PartialEq, Eq)]
416pub enum NamespaceError {
417 Conflict {
419 new: CapabilityId,
421 existing: CapabilityId,
423 },
424 MissingDependency {
426 capability: CapabilityId,
428 required: CapabilityId,
430 },
431 InvalidConfiguration,
433}
434
435impl std::fmt::Display for NamespaceError {
436 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437 match self {
438 NamespaceError::Conflict { new, existing } => {
439 write!(f, "Capability {} conflicts with existing {}", new, existing)
440 }
441 NamespaceError::MissingDependency {
442 capability,
443 required,
444 } => {
445 write!(
446 f,
447 "Capability {} requires {} which is not present",
448 capability, required
449 )
450 }
451 NamespaceError::InvalidConfiguration => {
452 write!(f, "Invalid namespace configuration")
453 }
454 }
455 }
456}
457
458impl std::error::Error for NamespaceError {}
459
460pub fn namespace_intersection(
470 ns1: &Namespace,
471 ns2: &Namespace,
472) -> EnumerativeResult<NamespaceIntersection> {
473 if ns1.grassmannian != ns2.grassmannian {
474 return Ok(NamespaceIntersection::Incompatible);
475 }
476
477 let mut calc = SchubertCalculus::new(ns1.grassmannian);
478 let result = calc.multi_intersect(&[ns1.position.clone(), ns2.position.clone()]);
479
480 Ok(match result {
481 IntersectionResult::Empty => NamespaceIntersection::Disjoint,
482 IntersectionResult::Finite(1) => NamespaceIntersection::SinglePoint,
483 IntersectionResult::Finite(n) => NamespaceIntersection::FinitePoints(n),
484 IntersectionResult::PositiveDimensional { dimension, .. } => {
485 NamespaceIntersection::Subspace { dimension }
486 }
487 })
488}
489
490#[derive(Debug, Clone, PartialEq, Eq, Default)]
492pub enum NamespaceIntersection {
493 #[default]
495 Incompatible,
496 Disjoint,
498 SinglePoint,
500 FinitePoints(u64),
502 Subspace {
504 dimension: usize,
506 },
507}
508
509pub fn capability_accessible(
518 namespace: &Namespace,
519 capability: &Capability,
520) -> EnumerativeResult<bool> {
521 if namespace.grassmannian != capability.schubert_class.grassmannian_dim {
522 return Ok(false);
523 }
524
525 let mut calc = SchubertCalculus::new(namespace.grassmannian);
526 let result = calc.multi_intersect(&[
527 namespace.position.clone(),
528 capability.schubert_class.clone(),
529 ]);
530
531 Ok(!matches!(result, IntersectionResult::Empty))
532}
533
534#[derive(Debug)]
536pub struct NamespaceBuilder {
537 name: String,
538 grassmannian: (usize, usize),
539 position: Vec<usize>,
540 capabilities: Vec<Capability>,
541}
542
543impl NamespaceBuilder {
544 #[must_use]
546 pub fn new(name: impl Into<String>, k: usize, n: usize) -> Self {
547 Self {
548 name: name.into(),
549 grassmannian: (k, n),
550 position: vec![],
551 capabilities: vec![],
552 }
553 }
554
555 #[must_use]
557 pub fn position(mut self, partition: Vec<usize>) -> Self {
558 self.position = partition;
559 self
560 }
561
562 #[must_use]
564 pub fn with_capability(mut self, capability: Capability) -> Self {
565 self.capabilities.push(capability);
566 self
567 }
568
569 #[must_use]
571 pub fn with_capabilities(mut self, capabilities: impl IntoIterator<Item = Capability>) -> Self {
572 self.capabilities.extend(capabilities);
573 self
574 }
575
576 pub fn build(self) -> EnumerativeResult<Namespace> {
578 let position = SchubertClass::new(self.position, self.grassmannian)?;
579 let mut ns = Namespace::new(self.name, position);
580
581 for cap in self.capabilities {
582 ns.grant(cap).map_err(|e| {
583 crate::EnumerativeError::SchubertError(format!("Failed to grant capability: {}", e))
584 })?;
585 }
586
587 Ok(ns)
588 }
589}
590
591#[cfg(feature = "parallel")]
597pub fn count_configurations_batch(namespaces: &[Namespace]) -> Vec<IntersectionResult> {
598 namespaces
599 .par_iter()
600 .map(|ns| ns.count_configurations())
601 .collect()
602}
603
604#[cfg(feature = "parallel")]
606pub fn capability_accessible_batch(
607 pairs: &[(&Namespace, &Capability)],
608) -> EnumerativeResult<Vec<bool>> {
609 pairs
610 .par_iter()
611 .map(|(ns, cap)| capability_accessible(ns, cap))
612 .collect()
613}
614
615#[cfg(feature = "parallel")]
617pub fn namespace_intersection_batch(
618 pairs: &[(&Namespace, &Namespace)],
619) -> EnumerativeResult<Vec<NamespaceIntersection>> {
620 pairs
621 .par_iter()
622 .map(|(ns1, ns2)| namespace_intersection(ns1, ns2))
623 .collect()
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_capability_creation() {
632 let cap = Capability::new("read", "Read Access", vec![1], (2, 4)).unwrap();
634 assert_eq!(cap.codimension(), 1);
635 assert_eq!(cap.id, CapabilityId::new("read"));
636 }
637
638 #[test]
639 fn test_capability_id_from() {
640 let id1: CapabilityId = "read".into();
641 let id2: CapabilityId = String::from("write").into();
642 assert_eq!(id1.as_str(), "read");
643 assert_eq!(id2.as_str(), "write");
644 }
645
646 #[test]
647 fn test_namespace_full() {
648 let ns = Namespace::full("test", 2, 4).unwrap();
649 assert_eq!(ns.grassmannian, (2, 4));
650 assert!(ns.position.partition.is_empty());
651 }
652
653 #[test]
654 fn test_namespace_grant() {
655 let mut ns = Namespace::full("test", 2, 4).unwrap();
656 let cap = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
657
658 assert!(ns.grant(cap).is_ok());
659 assert!(ns.has_capability(&CapabilityId::new("read")));
660 }
661
662 #[test]
663 fn test_capability_conflict() {
664 let mut ns = Namespace::full("test", 2, 4).unwrap();
665
666 let read = Capability::new("read", "Read", vec![1], (2, 4))
667 .unwrap()
668 .conflicts_with(CapabilityId::new("write"));
669 let write = Capability::new("write", "Write", vec![1], (2, 4)).unwrap();
670
671 ns.grant(read).unwrap();
672
673 let result = ns.grant(write);
674 assert!(matches!(result, Err(NamespaceError::Conflict { .. })));
675 }
676
677 #[test]
678 fn test_capability_dependency() {
679 let mut ns = Namespace::full("test", 2, 4).unwrap();
680
681 let write = Capability::new("write", "Write", vec![1], (2, 4))
682 .unwrap()
683 .requires(CapabilityId::new("read"));
684
685 let result = ns.grant(write.clone());
687 assert!(matches!(
688 result,
689 Err(NamespaceError::MissingDependency { .. })
690 ));
691
692 let read = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
694 ns.grant(read).unwrap();
695 assert!(ns.grant(write).is_ok());
696 }
697
698 #[test]
699 fn test_namespace_intersection() {
700 let ns1 = Namespace::full("ns1", 2, 4).unwrap();
701 let ns2 = Namespace::full("ns2", 2, 4).unwrap();
702
703 let result = namespace_intersection(&ns1, &ns2).unwrap();
704 assert!(matches!(
706 result,
707 NamespaceIntersection::Subspace { dimension: 4 }
708 ));
709 }
710
711 #[test]
712 fn test_namespace_incompatible() {
713 let ns1 = Namespace::full("ns1", 2, 4).unwrap();
714 let ns2 = Namespace::full("ns2", 3, 6).unwrap();
715
716 let result = namespace_intersection(&ns1, &ns2).unwrap();
717 assert_eq!(result, NamespaceIntersection::Incompatible);
718 }
719
720 #[test]
721 fn test_count_configurations() {
722 let mut ns = Namespace::full("agent", 2, 4).unwrap();
723
724 for i in 0..4 {
727 let cap = Capability::new(format!("cap{}", i), format!("Cap {}", i), vec![1], (2, 4))
728 .unwrap();
729 ns.grant(cap).unwrap();
730 }
731
732 let count = ns.count_configurations();
733 assert_eq!(count, IntersectionResult::Finite(2));
734 }
735
736 #[test]
737 fn test_namespace_builder() {
738 let read = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
739
740 let ns = NamespaceBuilder::new("test", 2, 4)
741 .position(vec![])
742 .with_capability(read)
743 .build()
744 .unwrap();
745
746 assert!(ns.has_capability(&CapabilityId::new("read")));
747 }
748
749 #[test]
750 fn test_capability_accessible() {
751 let ns = Namespace::full("test", 2, 4).unwrap();
752 let cap = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
753
754 assert!(capability_accessible(&ns, &cap).unwrap());
755 }
756
757 #[test]
758 fn test_revoke_capability() {
759 let mut ns = Namespace::full("test", 2, 4).unwrap();
760 let cap = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
761
762 ns.grant(cap).unwrap();
763 assert!(ns.has_capability(&CapabilityId::new("read")));
764
765 assert!(ns.revoke(&CapabilityId::new("read")));
766 assert!(!ns.has_capability(&CapabilityId::new("read")));
767 }
768
769 #[test]
770 fn test_would_overdetermine() {
771 let mut ns = Namespace::full("test", 2, 4).unwrap();
772
773 for i in 0..4 {
775 let cap = Capability::new(format!("cap{}", i), format!("Cap {}", i), vec![1], (2, 4))
776 .unwrap();
777 ns.grant(cap).unwrap();
778 }
779
780 let extra = Capability::new("extra", "Extra", vec![1], (2, 4)).unwrap();
782 assert!(ns.would_overdetermine(&extra));
783 }
784
785 #[test]
786 fn test_grant_all() {
787 let mut ns = Namespace::full("test", 2, 4).unwrap();
788
789 let read = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
790 let write = Capability::new("write", "Write", vec![1], (2, 4))
791 .unwrap()
792 .requires(CapabilityId::new("read"));
793
794 ns.grant_all(vec![write, read]).unwrap();
796
797 assert!(ns.has_capability(&CapabilityId::new("read")));
798 assert!(ns.has_capability(&CapabilityId::new("write")));
799 }
800
801 #[test]
802 fn test_namespace_intersection_default() {
803 let default = NamespaceIntersection::default();
804 assert_eq!(default, NamespaceIntersection::Incompatible);
805 }
806}
807
808#[cfg(all(test, feature = "parallel"))]
813mod parallel_tests {
814 use super::*;
815
816 #[test]
817 fn test_count_configurations_batch() {
818 let mut ns1 = Namespace::full("ns1", 2, 4).unwrap();
820 let mut ns2 = Namespace::full("ns2", 2, 4).unwrap();
821 let ns3 = Namespace::full("ns3", 2, 4).unwrap();
822
823 for i in 0..4 {
825 ns1.grant(
826 Capability::new(format!("c{}", i), format!("Cap {}", i), vec![1], (2, 4)).unwrap(),
827 )
828 .unwrap();
829 }
830
831 for i in 0..2 {
833 ns2.grant(
834 Capability::new(format!("d{}", i), format!("Cap {}", i), vec![1], (2, 4)).unwrap(),
835 )
836 .unwrap();
837 }
838
839 let namespaces = vec![ns1, ns2, ns3];
842 let results = count_configurations_batch(&namespaces);
843
844 assert_eq!(results.len(), 3);
845 assert_eq!(results[0], IntersectionResult::Finite(2));
846 assert!(matches!(
847 results[1],
848 IntersectionResult::PositiveDimensional { dimension: 2, .. }
849 ));
850 assert!(matches!(
851 results[2],
852 IntersectionResult::PositiveDimensional { dimension: 4, .. }
853 ));
854 }
855
856 #[test]
857 fn test_capability_accessible_batch() {
858 let ns1 = Namespace::full("ns1", 2, 4).unwrap();
859 let ns2 = Namespace::full("ns2", 3, 6).unwrap();
860
861 let cap_24 = Capability::new("read", "Read", vec![1], (2, 4)).unwrap();
862 let cap_36 = Capability::new("write", "Write", vec![1], (3, 6)).unwrap();
863
864 let pairs: Vec<(&Namespace, &Capability)> = vec![
865 (&ns1, &cap_24), (&ns1, &cap_36), (&ns2, &cap_36), (&ns2, &cap_24), ];
870
871 let results = capability_accessible_batch(&pairs).unwrap();
872
873 assert_eq!(results.len(), 4);
874 assert!(results[0]); assert!(!results[1]); assert!(results[2]); assert!(!results[3]); }
879
880 #[test]
881 fn test_namespace_intersection_batch() {
882 let ns1 = Namespace::full("ns1", 2, 4).unwrap();
883 let ns2 = Namespace::full("ns2", 2, 4).unwrap();
884 let ns3 = Namespace::full("ns3", 3, 6).unwrap();
885
886 let pairs: Vec<(&Namespace, &Namespace)> = vec![
887 (&ns1, &ns2), (&ns1, &ns3), (&ns2, &ns3), ];
891
892 let results = namespace_intersection_batch(&pairs).unwrap();
893
894 assert_eq!(results.len(), 3);
895 assert!(matches!(
896 results[0],
897 NamespaceIntersection::Subspace { dimension: 4 }
898 ));
899 assert_eq!(results[1], NamespaceIntersection::Incompatible);
900 assert_eq!(results[2], NamespaceIntersection::Incompatible);
901 }
902}