use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone)]
pub struct ResolvedGroup<C: Ord + Clone> {
pub label: String,
pub parent: Option<String>,
pub coordinates: BTreeSet<C>,
}
pub struct ResolvedGroupHierarchy<C: Ord + Clone> {
groups: Vec<ResolvedGroup<C>>,
coordinates_by_label: BTreeMap<String, BTreeSet<C>>,
children_by_parent: BTreeMap<String, Vec<String>>,
}
impl<C: Ord + Clone> std::fmt::Debug for ResolvedGroupHierarchy<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResolvedGroupHierarchy")
.field("group_count", &self.groups.len())
.field(
"labels",
&self.coordinates_by_label.keys().collect::<Vec<_>>(),
)
.finish()
}
}
impl<C: Ord + Clone> ResolvedGroupHierarchy<C> {
pub fn build(groups: Vec<ResolvedGroup<C>>) -> Result<Self, String> {
let mut seen = BTreeSet::<String>::new();
for group in &groups {
if group.label.trim().is_empty() {
return Err("coefficient group label must not be empty".to_string());
}
if !seen.insert(group.label.clone()) {
return Err(format!(
"duplicate coefficient group label '{}'",
group.label
));
}
if group.coordinates.is_empty() {
return Err(format!(
"coefficient group '{}' contains no coefficients",
group.label
));
}
}
let coordinates_by_label: BTreeMap<String, BTreeSet<C>> = groups
.iter()
.map(|group| (group.label.clone(), group.coordinates.clone()))
.collect();
let parent_by_label: BTreeMap<String, Option<String>> = groups
.iter()
.map(|group| (group.label.clone(), group.parent.clone()))
.collect();
let mut children_by_parent = BTreeMap::<String, Vec<String>>::new();
for group in &groups {
if let Some(parent) = group.parent.as_ref() {
children_by_parent
.entry(parent.clone())
.or_default()
.push(group.label.clone());
}
}
for group in &groups {
let mut path = BTreeSet::<String>::new();
let mut cursor = Some(group.label.as_str());
while let Some(label) = cursor {
if !path.insert(label.to_string()) {
return Err(format!(
"coefficient group hierarchy contains a cycle involving '{label}'"
));
}
cursor = parent_by_label
.get(label)
.ok_or_else(|| {
format!("coefficient group hierarchy references unknown group '{label}'")
})?
.as_deref();
}
if let Some(parent) = group.parent.as_ref() {
let parent_set = coordinates_by_label.get(parent).ok_or_else(|| {
format!(
"coefficient group '{}' references unknown parent group '{parent}'",
group.label
)
})?;
let child_set = coordinates_by_label
.get(&group.label)
.expect("resolved group coordinates should exist");
if !child_set.is_subset(parent_set) {
return Err(format!(
"coefficient group '{}' is not a subset of parent group '{parent}'",
group.label
));
}
}
if let Some(children) = children_by_parent.get(&group.label) {
let mut child_union = BTreeSet::<C>::new();
for child in children {
let child_set = coordinates_by_label
.get(child)
.expect("child group coordinates should exist after resolution");
child_union.extend(child_set.iter().cloned());
}
let parent_set = coordinates_by_label
.get(&group.label)
.expect("parent group coordinates should exist after resolution");
if &child_union != parent_set {
return Err(format!(
"coefficient group '{}' has children but its coefficients are not exactly the union of its child groups; nested supergroups concatenate child coefficients",
group.label
));
}
}
}
Ok(Self {
groups,
coordinates_by_label,
children_by_parent,
})
}
pub fn groups(&self) -> &[ResolvedGroup<C>] {
&self.groups
}
pub fn concatenated_penalty_components(&self, label: &str) -> Vec<BTreeSet<C>> {
let Some(children) = self.children_by_parent.get(label) else {
return vec![
self.coordinates_by_label
.get(label)
.expect("coefficient group coordinates should exist")
.clone(),
];
};
let mut components = Vec::new();
for child in children {
components.extend(self.concatenated_penalty_components(child));
}
components
}
}
#[cfg(test)]
mod tests {
use super::*;
fn group<C: Ord + Clone>(
label: &str,
parent: Option<&str>,
coords: impl IntoIterator<Item = C>,
) -> ResolvedGroup<C> {
ResolvedGroup {
label: label.to_string(),
parent: parent.map(str::to_string),
coordinates: coords.into_iter().collect(),
}
}
#[test]
fn carriers_produce_matching_concatenated_components() {
const BLOCK_WIDTH: usize = 4;
let to_pair = |c: usize| (c / BLOCK_WIDTH, c % BLOCK_WIDTH);
let column_groups = vec![
group::<usize>("leaf_a", Some("left"), [0usize, 1]),
group::<usize>("leaf_b", Some("left"), [2usize, 3]),
group::<usize>("left", Some("root"), [0usize, 1, 2, 3]),
group::<usize>("right", Some("root"), [4usize, 5]),
group::<usize>("root", None, [0usize, 1, 2, 3, 4, 5]),
];
let pair_groups: Vec<ResolvedGroup<(usize, usize)>> = column_groups
.iter()
.map(|g| ResolvedGroup {
label: g.label.clone(),
parent: g.parent.clone(),
coordinates: g.coordinates.iter().copied().map(to_pair).collect(),
})
.collect();
let column_hierarchy =
ResolvedGroupHierarchy::build(column_groups).expect("column carrier valid");
let pair_hierarchy =
ResolvedGroupHierarchy::build(pair_groups).expect("pair carrier valid");
for label in ["leaf_a", "leaf_b", "left", "right", "root"] {
let column_components = column_hierarchy.concatenated_penalty_components(label);
let pair_components = pair_hierarchy.concatenated_penalty_components(label);
let mapped: Vec<BTreeSet<(usize, usize)>> = column_components
.iter()
.map(|component| component.iter().copied().map(to_pair).collect())
.collect();
assert_eq!(
mapped, pair_components,
"carrier components diverged for group '{label}'"
);
}
let root_components = column_hierarchy.concatenated_penalty_components("root");
assert_eq!(
root_components,
vec![
BTreeSet::from([0usize, 1]),
BTreeSet::from([2usize, 3]),
BTreeSet::from([4usize, 5]),
],
"interior node must concatenate recursively expanded child components"
);
assert_eq!(
column_hierarchy.concatenated_penalty_components("leaf_a"),
vec![BTreeSet::from([0usize, 1])]
);
}
#[test]
fn policy_violations_are_carrier_agnostic() {
let err = ResolvedGroupHierarchy::build(vec![
group::<usize>("child", Some("parent"), [0usize, 9]),
group::<usize>("parent", None, [0usize, 1]),
])
.unwrap_err();
assert_eq!(
err,
"coefficient group 'child' is not a subset of parent group 'parent'"
);
let err = ResolvedGroupHierarchy::build(vec![
group::<usize>("child", Some("parent"), [0usize]),
group::<usize>("parent", None, [0usize, 1]),
])
.unwrap_err();
assert!(
err.contains("not exactly the union of its child groups"),
"unexpected message: {err}"
);
let err = ResolvedGroupHierarchy::build(vec![
group::<usize>("a", Some("b"), [0usize]),
group::<usize>("b", Some("a"), [0usize]),
])
.unwrap_err();
assert!(
err.contains("contains a cycle involving"),
"unexpected message: {err}"
);
let err = ResolvedGroupHierarchy::build(vec![
group::<usize>("g", None, [0usize]),
group::<usize>("g", None, [1usize]),
])
.unwrap_err();
assert_eq!(err, "duplicate coefficient group label 'g'");
let err = ResolvedGroupHierarchy::build(vec![group::<usize>("g", None, [])]).unwrap_err();
assert_eq!(err, "coefficient group 'g' contains no coefficients");
let err =
ResolvedGroupHierarchy::build(vec![group::<usize>("g", Some("missing"), [0usize])])
.unwrap_err();
assert_eq!(
err,
"coefficient group hierarchy references unknown group 'missing'"
);
let err = ResolvedGroupHierarchy::build(vec![
group::<(usize, usize)>("child", Some("parent"), [(0usize, 0usize), (1, 1)]),
group::<(usize, usize)>("parent", None, [(0usize, 0usize)]),
])
.unwrap_err();
assert_eq!(
err,
"coefficient group 'child' is not a subset of parent group 'parent'"
);
}
}