1use std::collections::{BTreeMap, BTreeSet};
19
20#[derive(Debug, Clone)]
27pub struct ResolvedGroup<C: Ord + Clone> {
28 pub label: String,
29 pub parent: Option<String>,
30 pub coordinates: BTreeSet<C>,
31}
32
33pub struct ResolvedGroupHierarchy<C: Ord + Clone> {
41 groups: Vec<ResolvedGroup<C>>,
42 coordinates_by_label: BTreeMap<String, BTreeSet<C>>,
43 children_by_parent: BTreeMap<String, Vec<String>>,
44}
45
46impl<C: Ord + Clone> std::fmt::Debug for ResolvedGroupHierarchy<C> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("ResolvedGroupHierarchy")
49 .field("group_count", &self.groups.len())
50 .field(
51 "labels",
52 &self.coordinates_by_label.keys().collect::<Vec<_>>(),
53 )
54 .finish()
55 }
56}
57
58impl<C: Ord + Clone> ResolvedGroupHierarchy<C> {
59 pub fn build(groups: Vec<ResolvedGroup<C>>) -> Result<Self, String> {
64 let mut seen = BTreeSet::<String>::new();
65 for group in &groups {
66 if group.label.trim().is_empty() {
67 return Err("coefficient group label must not be empty".to_string());
68 }
69 if !seen.insert(group.label.clone()) {
70 return Err(format!(
71 "duplicate coefficient group label '{}'",
72 group.label
73 ));
74 }
75 if group.coordinates.is_empty() {
76 return Err(format!(
77 "coefficient group '{}' contains no coefficients",
78 group.label
79 ));
80 }
81 }
82
83 let coordinates_by_label: BTreeMap<String, BTreeSet<C>> = groups
84 .iter()
85 .map(|group| (group.label.clone(), group.coordinates.clone()))
86 .collect();
87 let parent_by_label: BTreeMap<String, Option<String>> = groups
88 .iter()
89 .map(|group| (group.label.clone(), group.parent.clone()))
90 .collect();
91 let mut children_by_parent = BTreeMap::<String, Vec<String>>::new();
92 for group in &groups {
93 if let Some(parent) = group.parent.as_ref() {
94 children_by_parent
95 .entry(parent.clone())
96 .or_default()
97 .push(group.label.clone());
98 }
99 }
100
101 for group in &groups {
102 let mut path = BTreeSet::<String>::new();
103 let mut cursor = Some(group.label.as_str());
104 while let Some(label) = cursor {
105 if !path.insert(label.to_string()) {
106 return Err(format!(
107 "coefficient group hierarchy contains a cycle involving '{label}'"
108 ));
109 }
110 cursor = parent_by_label
111 .get(label)
112 .ok_or_else(|| {
113 format!("coefficient group hierarchy references unknown group '{label}'")
114 })?
115 .as_deref();
116 }
117 if let Some(parent) = group.parent.as_ref() {
118 let parent_set = coordinates_by_label.get(parent).ok_or_else(|| {
119 format!(
120 "coefficient group '{}' references unknown parent group '{parent}'",
121 group.label
122 )
123 })?;
124 let child_set = coordinates_by_label
125 .get(&group.label)
126 .expect("resolved group coordinates should exist");
127 if !child_set.is_subset(parent_set) {
128 return Err(format!(
129 "coefficient group '{}' is not a subset of parent group '{parent}'",
130 group.label
131 ));
132 }
133 }
134 if let Some(children) = children_by_parent.get(&group.label) {
135 let mut child_union = BTreeSet::<C>::new();
136 for child in children {
137 let child_set = coordinates_by_label
138 .get(child)
139 .expect("child group coordinates should exist after resolution");
140 child_union.extend(child_set.iter().cloned());
141 }
142 let parent_set = coordinates_by_label
143 .get(&group.label)
144 .expect("parent group coordinates should exist after resolution");
145 if &child_union != parent_set {
146 return Err(format!(
147 "coefficient group '{}' has children but its coefficients are not exactly the union of its child groups; nested supergroups concatenate child coefficients",
148 group.label
149 ));
150 }
151 }
152 }
153
154 Ok(Self {
155 groups,
156 coordinates_by_label,
157 children_by_parent,
158 })
159 }
160
161 pub fn groups(&self) -> &[ResolvedGroup<C>] {
163 &self.groups
164 }
165
166 pub fn concatenated_penalty_components(&self, label: &str) -> Vec<BTreeSet<C>> {
176 let Some(children) = self.children_by_parent.get(label) else {
177 return vec![
178 self.coordinates_by_label
179 .get(label)
180 .expect("coefficient group coordinates should exist")
181 .clone(),
182 ];
183 };
184 let mut components = Vec::new();
185 for child in children {
186 components.extend(self.concatenated_penalty_components(child));
187 }
188 components
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 fn group<C: Ord + Clone>(
197 label: &str,
198 parent: Option<&str>,
199 coords: impl IntoIterator<Item = C>,
200 ) -> ResolvedGroup<C> {
201 ResolvedGroup {
202 label: label.to_string(),
203 parent: parent.map(str::to_string),
204 coordinates: coords.into_iter().collect(),
205 }
206 }
207
208 #[test]
215 fn carriers_produce_matching_concatenated_components() {
216 const BLOCK_WIDTH: usize = 4;
217 let to_pair = |c: usize| (c / BLOCK_WIDTH, c % BLOCK_WIDTH);
218
219 let column_groups = vec![
224 group::<usize>("leaf_a", Some("left"), [0usize, 1]),
225 group::<usize>("leaf_b", Some("left"), [2usize, 3]),
226 group::<usize>("left", Some("root"), [0usize, 1, 2, 3]),
227 group::<usize>("right", Some("root"), [4usize, 5]),
228 group::<usize>("root", None, [0usize, 1, 2, 3, 4, 5]),
229 ];
230 let pair_groups: Vec<ResolvedGroup<(usize, usize)>> = column_groups
231 .iter()
232 .map(|g| ResolvedGroup {
233 label: g.label.clone(),
234 parent: g.parent.clone(),
235 coordinates: g.coordinates.iter().copied().map(to_pair).collect(),
236 })
237 .collect();
238
239 let column_hierarchy =
240 ResolvedGroupHierarchy::build(column_groups).expect("column carrier valid");
241 let pair_hierarchy =
242 ResolvedGroupHierarchy::build(pair_groups).expect("pair carrier valid");
243
244 for label in ["leaf_a", "leaf_b", "left", "right", "root"] {
245 let column_components = column_hierarchy.concatenated_penalty_components(label);
246 let pair_components = pair_hierarchy.concatenated_penalty_components(label);
247 let mapped: Vec<BTreeSet<(usize, usize)>> = column_components
250 .iter()
251 .map(|component| component.iter().copied().map(to_pair).collect())
252 .collect();
253 assert_eq!(
254 mapped, pair_components,
255 "carrier components diverged for group '{label}'"
256 );
257 }
258
259 let root_components = column_hierarchy.concatenated_penalty_components("root");
263 assert_eq!(
264 root_components,
265 vec![
266 BTreeSet::from([0usize, 1]),
267 BTreeSet::from([2usize, 3]),
268 BTreeSet::from([4usize, 5]),
269 ],
270 "interior node must concatenate recursively expanded child components"
271 );
272 assert_eq!(
274 column_hierarchy.concatenated_penalty_components("leaf_a"),
275 vec![BTreeSet::from([0usize, 1])]
276 );
277 }
278
279 #[test]
283 fn policy_violations_are_carrier_agnostic() {
284 let err = ResolvedGroupHierarchy::build(vec![
286 group::<usize>("child", Some("parent"), [0usize, 9]),
287 group::<usize>("parent", None, [0usize, 1]),
288 ])
289 .unwrap_err();
290 assert_eq!(
291 err,
292 "coefficient group 'child' is not a subset of parent group 'parent'"
293 );
294
295 let err = ResolvedGroupHierarchy::build(vec![
297 group::<usize>("child", Some("parent"), [0usize]),
298 group::<usize>("parent", None, [0usize, 1]),
299 ])
300 .unwrap_err();
301 assert!(
302 err.contains("not exactly the union of its child groups"),
303 "unexpected message: {err}"
304 );
305
306 let err = ResolvedGroupHierarchy::build(vec![
308 group::<usize>("a", Some("b"), [0usize]),
309 group::<usize>("b", Some("a"), [0usize]),
310 ])
311 .unwrap_err();
312 assert!(
313 err.contains("contains a cycle involving"),
314 "unexpected message: {err}"
315 );
316
317 let err = ResolvedGroupHierarchy::build(vec![
319 group::<usize>("g", None, [0usize]),
320 group::<usize>("g", None, [1usize]),
321 ])
322 .unwrap_err();
323 assert_eq!(err, "duplicate coefficient group label 'g'");
324
325 let err = ResolvedGroupHierarchy::build(vec![group::<usize>("g", None, [])]).unwrap_err();
327 assert_eq!(err, "coefficient group 'g' contains no coefficients");
328
329 let err =
331 ResolvedGroupHierarchy::build(vec![group::<usize>("g", Some("missing"), [0usize])])
332 .unwrap_err();
333 assert_eq!(
334 err,
335 "coefficient group hierarchy references unknown group 'missing'"
336 );
337
338 let err = ResolvedGroupHierarchy::build(vec![
341 group::<(usize, usize)>("child", Some("parent"), [(0usize, 0usize), (1, 1)]),
342 group::<(usize, usize)>("parent", None, [(0usize, 0usize)]),
343 ])
344 .unwrap_err();
345 assert_eq!(
346 err,
347 "coefficient group 'child' is not a subset of parent group 'parent'"
348 );
349 }
350}