use super::lattice::Lattice;
use super::symbol::SymbolId;
use bitflags::bitflags;
use std::collections::{HashMap, HashSet};
bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ResourceLifecycle: u8 {
const UNINIT = 0b0001;
const OPEN = 0b0010;
const CLOSED = 0b0100;
const MOVED = 0b1000;
}
}
impl Lattice for ResourceLifecycle {
fn bot() -> Self {
ResourceLifecycle::empty()
}
fn join(&self, other: &Self) -> Self {
*self | *other
}
fn leq(&self, other: &Self) -> bool {
self.intersection(*other) == *self
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ResourceDomainState {
pub vars: HashMap<SymbolId, ResourceLifecycle>,
}
impl ResourceDomainState {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, sym: SymbolId) -> ResourceLifecycle {
self.vars
.get(&sym)
.copied()
.unwrap_or(ResourceLifecycle::empty())
}
pub fn set(&mut self, sym: SymbolId, state: ResourceLifecycle) {
self.vars.insert(sym, state);
}
}
impl Lattice for ResourceDomainState {
fn bot() -> Self {
Self::new()
}
fn join(&self, other: &Self) -> Self {
let mut merged = self.clone();
for (&sym, &other_lc) in &other.vars {
let entry = merged.vars.entry(sym).or_insert(ResourceLifecycle::empty());
*entry = entry.join(&other_lc);
}
merged
}
fn leq(&self, other: &Self) -> bool {
for (&sym, &self_lc) in &self.vars {
let other_lc = other.get(sym);
if !self_lc.leq(&other_lc) {
return false;
}
}
true
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum AuthLevel {
Unauthed,
Authed,
Admin,
}
impl Lattice for AuthLevel {
fn bot() -> Self {
AuthLevel::Unauthed
}
fn join(&self, other: &Self) -> Self {
(*self).min(*other)
}
fn leq(&self, other: &Self) -> bool {
*self >= *other
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AuthDomainState {
pub auth_level: AuthLevel,
pub validated: HashSet<SymbolId>,
}
impl Default for AuthDomainState {
fn default() -> Self {
Self {
auth_level: AuthLevel::Unauthed,
validated: HashSet::new(),
}
}
}
impl AuthDomainState {
pub fn new() -> Self {
Self::default()
}
}
impl Lattice for AuthDomainState {
fn bot() -> Self {
Self::new()
}
fn join(&self, other: &Self) -> Self {
Self {
auth_level: self.auth_level.join(&other.auth_level),
validated: self
.validated
.intersection(&other.validated)
.copied()
.collect(),
}
}
fn leq(&self, other: &Self) -> bool {
self.auth_level.leq(&other.auth_level) && self.validated.is_superset(&other.validated)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ChainProxyState {
pub lifecycle: ResourceLifecycle,
pub class_group: crate::cfg::BodyId,
pub acquire_span: (usize, usize),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ProductState {
pub resource: ResourceDomainState,
pub auth: AuthDomainState,
pub receiver_class_group: HashMap<SymbolId, crate::cfg::BodyId>,
pub proxy_acquire_spans: HashMap<SymbolId, (usize, usize)>,
pub chain_proxies: HashMap<String, ChainProxyState>,
}
impl ProductState {
pub fn initial() -> Self {
Self {
resource: ResourceDomainState::new(),
auth: AuthDomainState::new(),
receiver_class_group: HashMap::new(),
proxy_acquire_spans: HashMap::new(),
chain_proxies: HashMap::new(),
}
}
}
impl Lattice for ProductState {
fn bot() -> Self {
Self {
resource: ResourceDomainState::bot(),
auth: AuthDomainState::bot(),
receiver_class_group: HashMap::new(),
proxy_acquire_spans: HashMap::new(),
chain_proxies: HashMap::new(),
}
}
fn join(&self, other: &Self) -> Self {
let mut class_group = self.receiver_class_group.clone();
class_group.extend(other.receiver_class_group.iter());
let mut proxy_spans = self.proxy_acquire_spans.clone();
proxy_spans.extend(other.proxy_acquire_spans.iter());
let mut chain = self.chain_proxies.clone();
for (key, other_state) in &other.chain_proxies {
chain
.entry(key.clone())
.and_modify(|e| {
e.lifecycle = e.lifecycle.join(&other_state.lifecycle);
})
.or_insert_with(|| other_state.clone());
}
Self {
resource: self.resource.join(&other.resource),
auth: self.auth.join(&other.auth),
receiver_class_group: class_group,
proxy_acquire_spans: proxy_spans,
chain_proxies: chain,
}
}
fn leq(&self, other: &Self) -> bool {
self.resource.leq(&other.resource) && self.auth.leq(&other.auth)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resource_lifecycle_join_is_or() {
let a = ResourceLifecycle::OPEN;
let b = ResourceLifecycle::CLOSED;
assert_eq!(
a.join(&b),
ResourceLifecycle::OPEN | ResourceLifecycle::CLOSED
);
}
#[test]
fn resource_lifecycle_bot_identity() {
let a = ResourceLifecycle::OPEN;
assert_eq!(a.join(&ResourceLifecycle::bot()), a);
}
#[test]
fn resource_lifecycle_leq() {
let a = ResourceLifecycle::OPEN;
let b = ResourceLifecycle::OPEN | ResourceLifecycle::CLOSED;
assert!(a.leq(&b));
assert!(!b.leq(&a));
}
#[test]
fn resource_domain_join_merges_keys() {
let mut a = ResourceDomainState::new();
let mut b = ResourceDomainState::new();
let sym_x = SymbolId(0);
let sym_y = SymbolId(1);
a.set(sym_x, ResourceLifecycle::OPEN);
b.set(sym_x, ResourceLifecycle::CLOSED);
b.set(sym_y, ResourceLifecycle::OPEN);
let joined = a.join(&b);
assert_eq!(
joined.get(sym_x),
ResourceLifecycle::OPEN | ResourceLifecycle::CLOSED
);
assert_eq!(joined.get(sym_y), ResourceLifecycle::OPEN);
}
#[test]
fn auth_level_join_is_min() {
assert_eq!(
AuthLevel::Admin.join(&AuthLevel::Unauthed),
AuthLevel::Unauthed
);
assert_eq!(AuthLevel::Authed.join(&AuthLevel::Admin), AuthLevel::Authed);
assert_eq!(
AuthLevel::Authed.join(&AuthLevel::Authed),
AuthLevel::Authed
);
}
#[test]
fn auth_domain_join_intersects_validated() {
let sym_a = SymbolId(0);
let sym_b = SymbolId(1);
let sym_c = SymbolId(2);
let a = AuthDomainState {
auth_level: AuthLevel::Authed,
validated: [sym_a, sym_b].into_iter().collect(),
};
let b = AuthDomainState {
auth_level: AuthLevel::Admin,
validated: [sym_b, sym_c].into_iter().collect(),
};
let joined = a.join(&b);
assert_eq!(joined.auth_level, AuthLevel::Authed);
assert_eq!(joined.validated, [sym_b].into_iter().collect());
}
#[test]
fn product_state_join() {
let a = ProductState::initial();
let b = ProductState::initial();
let joined = a.join(&b);
assert_eq!(joined, ProductState::initial());
}
#[test]
fn may_must_leak_semantics() {
let must_leak = ResourceLifecycle::OPEN;
assert!(must_leak.contains(ResourceLifecycle::OPEN));
assert!(!must_leak.contains(ResourceLifecycle::CLOSED));
assert!(!must_leak.contains(ResourceLifecycle::MOVED));
let may_leak = ResourceLifecycle::OPEN | ResourceLifecycle::CLOSED;
assert!(may_leak.contains(ResourceLifecycle::OPEN));
assert!(may_leak.contains(ResourceLifecycle::CLOSED));
let no_leak = ResourceLifecycle::CLOSED;
assert!(!no_leak.contains(ResourceLifecycle::OPEN));
assert!(no_leak.contains(ResourceLifecycle::CLOSED));
}
#[test]
fn symbol_id_is_copy() {
let s = SymbolId(0);
let s2 = s;
assert_eq!(s, s2);
}
#[test]
fn resource_lifecycle_join_laws() {
let vals = [
ResourceLifecycle::empty(),
ResourceLifecycle::UNINIT,
ResourceLifecycle::OPEN,
ResourceLifecycle::CLOSED,
ResourceLifecycle::MOVED,
ResourceLifecycle::OPEN | ResourceLifecycle::CLOSED,
ResourceLifecycle::all(),
];
for a in &vals {
assert_eq!(a.join(a), *a, "idempotence broken on {a:?}");
assert_eq!(a.join(&ResourceLifecycle::bot()), *a);
for b in &vals {
assert_eq!(a.join(b), b.join(a), "commutativity broken ({a:?},{b:?})");
let consistent = a.leq(b) == (a.join(b) == *b);
assert!(consistent, "leq/join consistency broken ({a:?} ⊑ {b:?})");
for c in &vals {
assert_eq!(
a.join(b).join(c),
a.join(&b.join(c)),
"associativity broken ({a:?},{b:?},{c:?})"
);
}
}
}
}
#[test]
fn auth_level_join_associative_commutative_idempotent() {
let vals = [AuthLevel::Unauthed, AuthLevel::Authed, AuthLevel::Admin];
for a in &vals {
assert_eq!(a.join(a), *a, "AuthLevel idempotence broken on {a:?}");
for b in &vals {
assert_eq!(
a.join(b),
b.join(a),
"AuthLevel commutativity ({a:?},{b:?})"
);
for c in &vals {
assert_eq!(
a.join(b).join(c),
a.join(&b.join(c)),
"AuthLevel associativity ({a:?},{b:?},{c:?})"
);
}
}
}
}
#[test]
fn auth_level_bot_is_absorbing_not_identity() {
assert_eq!(AuthLevel::bot(), AuthLevel::Unauthed);
assert_eq!(
AuthLevel::Admin.join(&AuthLevel::Unauthed),
AuthLevel::Unauthed,
"Unauthed must absorb Admin under min-join (conservative security)"
);
assert_ne!(
AuthLevel::Admin.join(&AuthLevel::bot()),
AuthLevel::Admin,
"if this passes, AuthLevel::bot() was changed — re-audit security implications"
);
}
#[test]
fn auth_level_leq_consistent_with_join() {
let vals = [AuthLevel::Unauthed, AuthLevel::Authed, AuthLevel::Admin];
for a in &vals {
for b in &vals {
assert_eq!(
a.leq(b),
a.join(b) == *b,
"leq/join inconsistent on ({a:?}, {b:?})"
);
}
}
}
#[test]
fn auth_domain_join_drops_partially_validated() {
let sym_only_a = SymbolId(10);
let sym_only_b = SymbolId(11);
let a = AuthDomainState {
auth_level: AuthLevel::Authed,
validated: [sym_only_a].into_iter().collect(),
};
let b = AuthDomainState {
auth_level: AuthLevel::Authed,
validated: [sym_only_b].into_iter().collect(),
};
let j = a.join(&b);
assert!(
j.validated.is_empty(),
"validated set must drop vars not validated on all paths"
);
}
#[test]
fn product_state_join_non_trivial() {
let sym_x = SymbolId(20);
let sym_y = SymbolId(21);
let mut a = ProductState::initial();
a.resource.set(sym_x, ResourceLifecycle::OPEN);
a.auth.auth_level = AuthLevel::Admin;
a.auth.validated.insert(sym_y);
let mut b = ProductState::initial();
b.resource.set(sym_x, ResourceLifecycle::CLOSED);
b.auth.auth_level = AuthLevel::Authed;
b.auth.validated.insert(sym_y);
let j = a.join(&b);
assert_eq!(
j.resource.get(sym_x),
ResourceLifecycle::OPEN | ResourceLifecycle::CLOSED,
"may-leak: OPEN on one path, CLOSED on the other"
);
assert_eq!(j.auth.auth_level, AuthLevel::Authed, "join takes min auth");
assert!(
j.auth.validated.contains(&sym_y),
"var validated on both paths must survive"
);
}
}