Skip to main content

gam_terms/structure/
coefficient_group_resolver.rs

1//! Carrier-agnostic coefficient-group resolver.
2//!
3//! User-declared coefficient groups are realized into penalty components and
4//! rho-prior entries in two places: standard term collections (carrier =
5//! columns of the realized design matrix, [`crate::smooth`]) and custom
6//! families (carrier = `(block, column)` coordinates of parameter blocks,
7//! `crate::families::custom_family`). The carrier differs but the *policy* is
8//! identical: validate labels, build the parent/child hierarchy, reject cycles,
9//! require each child to be a subset of its parent, require an interior node's
10//! coefficients to be exactly the union of its children, and expand interior
11//! nodes into the concatenation of their recursively resolved child components.
12//!
13//! This module hosts that shared policy once, generic over the coordinate type
14//! `C` (the carrier). Each caller resolves its own selectors into `C` sets and
15//! lays out its own penalty matrices and rho coordinates from the resolved
16//! components — only the carrier-specific layout stays on the caller side.
17
18use std::collections::{BTreeMap, BTreeSet};
19
20/// A coefficient group after its selectors have been resolved by the carrier
21/// into a concrete set of carrier coordinates `C`.
22///
23/// `coordinates` is the full coordinate set the group selects; the resolver
24/// validates hierarchy relationships against it and never reinterprets the
25/// coordinates themselves.
26#[derive(Debug, Clone)]
27pub struct ResolvedGroup<C: Ord + Clone> {
28    pub label: String,
29    pub parent: Option<String>,
30    pub coordinates: BTreeSet<C>,
31}
32
33/// Validated coefficient-group hierarchy over carrier coordinates `C`.
34///
35/// Construction enforces the full group policy (unique non-empty labels,
36/// non-empty coordinate sets, acyclic parent chains terminating at known
37/// groups, child ⊆ parent, interior coordinates == union of children). After
38/// construction, callers walk `groups` in their original order and request the
39/// concatenated penalty components per group.
40pub struct ResolvedGroupHierarchy<C: Ord + Clone> {
41    groups: Vec<ResolvedGroup<C>>,
42    coordinates_by_label: BTreeMap<String, BTreeSet<C>>,
43    children_by_parent: BTreeMap<String, Vec<String>>,
44}
45
46impl<C: Ord + Clone> std::fmt::Debug for ResolvedGroupHierarchy<C> {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("ResolvedGroupHierarchy")
49            .field("group_count", &self.groups.len())
50            .field(
51                "labels",
52                &self.coordinates_by_label.keys().collect::<Vec<_>>(),
53            )
54            .finish()
55    }
56}
57
58impl<C: Ord + Clone> ResolvedGroupHierarchy<C> {
59    /// Validate the carrier-resolved groups and build the hierarchy.
60    ///
61    /// `groups` must already have had each selector resolved into `coordinates`
62    /// by the carrier. The order of `groups` is preserved for [`Self::groups`].
63    pub fn build(groups: Vec<ResolvedGroup<C>>) -> Result<Self, String> {
64        let mut seen = BTreeSet::<String>::new();
65        for group in &groups {
66            if group.label.trim().is_empty() {
67                return Err("coefficient group label must not be empty".to_string());
68            }
69            if !seen.insert(group.label.clone()) {
70                return Err(format!(
71                    "duplicate coefficient group label '{}'",
72                    group.label
73                ));
74            }
75            if group.coordinates.is_empty() {
76                return Err(format!(
77                    "coefficient group '{}' contains no coefficients",
78                    group.label
79                ));
80            }
81        }
82
83        let coordinates_by_label: BTreeMap<String, BTreeSet<C>> = groups
84            .iter()
85            .map(|group| (group.label.clone(), group.coordinates.clone()))
86            .collect();
87        let parent_by_label: BTreeMap<String, Option<String>> = groups
88            .iter()
89            .map(|group| (group.label.clone(), group.parent.clone()))
90            .collect();
91        let mut children_by_parent = BTreeMap::<String, Vec<String>>::new();
92        for group in &groups {
93            if let Some(parent) = group.parent.as_ref() {
94                children_by_parent
95                    .entry(parent.clone())
96                    .or_default()
97                    .push(group.label.clone());
98            }
99        }
100
101        for group in &groups {
102            let mut path = BTreeSet::<String>::new();
103            let mut cursor = Some(group.label.as_str());
104            while let Some(label) = cursor {
105                if !path.insert(label.to_string()) {
106                    return Err(format!(
107                        "coefficient group hierarchy contains a cycle involving '{label}'"
108                    ));
109                }
110                cursor = parent_by_label
111                    .get(label)
112                    .ok_or_else(|| {
113                        format!("coefficient group hierarchy references unknown group '{label}'")
114                    })?
115                    .as_deref();
116            }
117            if let Some(parent) = group.parent.as_ref() {
118                let parent_set = coordinates_by_label.get(parent).ok_or_else(|| {
119                    format!(
120                        "coefficient group '{}' references unknown parent group '{parent}'",
121                        group.label
122                    )
123                })?;
124                let child_set = coordinates_by_label
125                    .get(&group.label)
126                    .expect("resolved group coordinates should exist");
127                if !child_set.is_subset(parent_set) {
128                    return Err(format!(
129                        "coefficient group '{}' is not a subset of parent group '{parent}'",
130                        group.label
131                    ));
132                }
133            }
134            if let Some(children) = children_by_parent.get(&group.label) {
135                let mut child_union = BTreeSet::<C>::new();
136                for child in children {
137                    let child_set = coordinates_by_label
138                        .get(child)
139                        .expect("child group coordinates should exist after resolution");
140                    child_union.extend(child_set.iter().cloned());
141                }
142                let parent_set = coordinates_by_label
143                    .get(&group.label)
144                    .expect("parent group coordinates should exist after resolution");
145                if &child_union != parent_set {
146                    return Err(format!(
147                        "coefficient group '{}' has children but its coefficients are not exactly the union of its child groups; nested supergroups concatenate child coefficients",
148                        group.label
149                    ));
150                }
151            }
152        }
153
154        Ok(Self {
155            groups,
156            coordinates_by_label,
157            children_by_parent,
158        })
159    }
160
161    /// The resolved groups, in their original declaration order.
162    pub fn groups(&self) -> &[ResolvedGroup<C>] {
163        &self.groups
164    }
165
166    /// Concatenated penalty components for `label`, recursively expanded.
167    ///
168    /// A leaf group yields a single component (its own coordinate set). An
169    /// interior node yields the concatenation of its children's components,
170    /// expanding recursively when a child is itself interior. This realizes the
171    /// hierarchical-Gamma identity in which an interior node's coefficient
172    /// vector is the concatenation of its child vectors under one precision:
173    /// overlapping children stay separate factors so their log normalizers and
174    /// quadratic contributions both add — it is not a block-sum shortcut.
175    pub fn concatenated_penalty_components(&self, label: &str) -> Vec<BTreeSet<C>> {
176        let Some(children) = self.children_by_parent.get(label) else {
177            return vec![
178                self.coordinates_by_label
179                    .get(label)
180                    .expect("coefficient group coordinates should exist")
181                    .clone(),
182            ];
183        };
184        let mut components = Vec::new();
185        for child in children {
186            components.extend(self.concatenated_penalty_components(child));
187        }
188        components
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn group<C: Ord + Clone>(
197        label: &str,
198        parent: Option<&str>,
199        coords: impl IntoIterator<Item = C>,
200    ) -> ResolvedGroup<C> {
201        ResolvedGroup {
202            label: label.to_string(),
203            parent: parent.map(str::to_string),
204            coordinates: coords.into_iter().collect(),
205        }
206    }
207
208    /// Equivalent group declarations on the two real carriers — standard-term
209    /// columns (`usize`) and custom-family `(block, column)` coordinates — must
210    /// produce identical resolved components and hierarchy structure. We encode
211    /// the same declaration on both carriers under the bijection
212    /// `col <-> (col / BLOCK_WIDTH, col % BLOCK_WIDTH)` and assert the
213    /// concatenated penalty components correspond column-for-column.
214    #[test]
215    fn carriers_produce_matching_concatenated_components() {
216        const BLOCK_WIDTH: usize = 4;
217        let to_pair = |c: usize| (c / BLOCK_WIDTH, c % BLOCK_WIDTH);
218
219        // Hierarchy: `root` = union of `left` and `right`; `left` is itself an
220        // interior node = union of `leaf_a` and `leaf_b`. This exercises
221        // recursive child expansion and concatenation (overlap-free here, so
222        // the parent is the disjoint union of its leaves).
223        let column_groups = vec![
224            group::<usize>("leaf_a", Some("left"), [0usize, 1]),
225            group::<usize>("leaf_b", Some("left"), [2usize, 3]),
226            group::<usize>("left", Some("root"), [0usize, 1, 2, 3]),
227            group::<usize>("right", Some("root"), [4usize, 5]),
228            group::<usize>("root", None, [0usize, 1, 2, 3, 4, 5]),
229        ];
230        let pair_groups: Vec<ResolvedGroup<(usize, usize)>> = column_groups
231            .iter()
232            .map(|g| ResolvedGroup {
233                label: g.label.clone(),
234                parent: g.parent.clone(),
235                coordinates: g.coordinates.iter().copied().map(to_pair).collect(),
236            })
237            .collect();
238
239        let column_hierarchy =
240            ResolvedGroupHierarchy::build(column_groups).expect("column carrier valid");
241        let pair_hierarchy =
242            ResolvedGroupHierarchy::build(pair_groups).expect("pair carrier valid");
243
244        for label in ["leaf_a", "leaf_b", "left", "right", "root"] {
245            let column_components = column_hierarchy.concatenated_penalty_components(label);
246            let pair_components = pair_hierarchy.concatenated_penalty_components(label);
247            // Map the column components through the carrier bijection and
248            // require exact equality with the pair-carrier components.
249            let mapped: Vec<BTreeSet<(usize, usize)>> = column_components
250                .iter()
251                .map(|component| component.iter().copied().map(to_pair).collect())
252                .collect();
253            assert_eq!(
254                mapped, pair_components,
255                "carrier components diverged for group '{label}'"
256            );
257        }
258
259        // `root` is interior: it must expand into the four leaf columns of
260        // `left` plus the two columns of `right`, as separate components per
261        // child node (not a single merged set), proving recursive expansion.
262        let root_components = column_hierarchy.concatenated_penalty_components("root");
263        assert_eq!(
264            root_components,
265            vec![
266                BTreeSet::from([0usize, 1]),
267                BTreeSet::from([2usize, 3]),
268                BTreeSet::from([4usize, 5]),
269            ],
270            "interior node must concatenate recursively expanded child components"
271        );
272        // A leaf yields exactly one component: its own coordinate set.
273        assert_eq!(
274            column_hierarchy.concatenated_penalty_components("leaf_a"),
275            vec![BTreeSet::from([0usize, 1])]
276        );
277    }
278
279    /// The shared policy must reject the same malformed declarations on every
280    /// carrier with identical messages, so standard terms and custom families
281    /// cannot diverge on hierarchy rules.
282    #[test]
283    fn policy_violations_are_carrier_agnostic() {
284        // Child not a subset of parent.
285        let err = ResolvedGroupHierarchy::build(vec![
286            group::<usize>("child", Some("parent"), [0usize, 9]),
287            group::<usize>("parent", None, [0usize, 1]),
288        ])
289        .unwrap_err();
290        assert_eq!(
291            err,
292            "coefficient group 'child' is not a subset of parent group 'parent'"
293        );
294
295        // Interior node whose coefficients are not exactly the child union.
296        let err = ResolvedGroupHierarchy::build(vec![
297            group::<usize>("child", Some("parent"), [0usize]),
298            group::<usize>("parent", None, [0usize, 1]),
299        ])
300        .unwrap_err();
301        assert!(
302            err.contains("not exactly the union of its child groups"),
303            "unexpected message: {err}"
304        );
305
306        // Cycle in the parent chain.
307        let err = ResolvedGroupHierarchy::build(vec![
308            group::<usize>("a", Some("b"), [0usize]),
309            group::<usize>("b", Some("a"), [0usize]),
310        ])
311        .unwrap_err();
312        assert!(
313            err.contains("contains a cycle involving"),
314            "unexpected message: {err}"
315        );
316
317        // Duplicate label.
318        let err = ResolvedGroupHierarchy::build(vec![
319            group::<usize>("g", None, [0usize]),
320            group::<usize>("g", None, [1usize]),
321        ])
322        .unwrap_err();
323        assert_eq!(err, "duplicate coefficient group label 'g'");
324
325        // Empty coordinate set.
326        let err = ResolvedGroupHierarchy::build(vec![group::<usize>("g", None, [])]).unwrap_err();
327        assert_eq!(err, "coefficient group 'g' contains no coefficients");
328
329        // Unknown parent reference.
330        let err =
331            ResolvedGroupHierarchy::build(vec![group::<usize>("g", Some("missing"), [0usize])])
332                .unwrap_err();
333        assert_eq!(
334            err,
335            "coefficient group hierarchy references unknown group 'missing'"
336        );
337
338        // Identical violation, identical message on the `(block, column)`
339        // carrier — confirming the policy is genuinely carrier-agnostic.
340        let err = ResolvedGroupHierarchy::build(vec![
341            group::<(usize, usize)>("child", Some("parent"), [(0usize, 0usize), (1, 1)]),
342            group::<(usize, usize)>("parent", None, [(0usize, 0usize)]),
343        ])
344        .unwrap_err();
345        assert_eq!(
346            err,
347            "coefficient group 'child' is not a subset of parent group 'parent'"
348        );
349    }
350}