gam_problem/row_measure.rs
1//! Row-subsample mask handle for trust-region invariant enforcement.
2//!
3//! A `RowSubsampleMask` is the explicit identity of the set of rows + per-row
4//! weights used to evaluate any one of {Hessian, gradient, objective}
5//! during a single inner trust-region iteration. The trust-region
6//! globalization computes
7//!
8//! ρ = actual_reduction / predicted_reduction
9//! = [F(β) − F(β + δ)] / [−g·δ − ½·δᵀHδ]
10//!
11//! and accepts/rejects the step from ρ. All four quantities (F(β),
12//! F(β + δ), g, H) MUST be evaluated against the same row measure for
13//! ρ to be meaningful; otherwise the numerator and denominator estimate
14//! different objectives and ρ can take any sign, producing the observed
15//! ρ = -0.05 with predicted_reduction = +7.378e6 sign flip.
16//!
17//! `RowSubsampleMask::id` is a stable 64-bit content hash: equal masks
18//! (`Arc<OuterScoreSubsample>` pointer equality OR identical mask
19//! contents) ⇒ equal ids; differing masks ⇒ differing ids with high
20//! probability. The TR loop captures one `RowSubsampleMask` at the top of an
21//! iteration and hard-asserts that the id observed by each of the four
22//! quantities matches before computing ρ.
23//!
24//! The `BlockwiseFitOptions`-coupled `from_options` constructor stays up in
25//! `gam-solve` (it depends on the options type, which lives above this tier);
26//! the data type and its pure data methods live here so lower tiers can
27//! consume the measure without depending on `gam-solve`.
28
29use std::sync::Arc;
30
31use crate::outer_subsample::OuterScoreSubsample;
32
33/// Identifier-carrying handle for a single row subsample mask.
34///
35/// The handle is `Clone` and cheap to copy; the `Arc` is shared, not
36/// duplicated.
37#[derive(Clone, Debug)]
38pub struct RowSubsampleMask {
39 /// Stable 64-bit content hash. Same `mask` (by Arc pointer OR by
40 /// row content) ⇒ same id; different `mask` ⇒ different id.
41 pub id: u64,
42 /// `None` means full data (`0..n`, weight 1.0 per row).
43 /// `Some(_)` means the rows and HT weights inside the subsample.
44 pub mask: Option<Arc<OuterScoreSubsample>>,
45}
46
47impl RowSubsampleMask {
48 /// Full-data measure: walk `0..n` with weight 1.0 per row.
49 pub fn full_data(n: usize) -> Self {
50 Self {
51 id: hash_full(n),
52 mask: None,
53 }
54 }
55
56 /// Subsample measure: walk the mask's rows with their per-row HT
57 /// weights. Id is derived from the Arc pointer (cheap and stable
58 /// for the lifetime of the Arc) combined with mask metadata.
59 pub fn subsample(mask: Arc<OuterScoreSubsample>) -> Self {
60 let id = hash_subsample(&mask);
61 Self {
62 id,
63 mask: Some(mask),
64 }
65 }
66
67 /// Materialize the row indices and per-row weights this measure
68 /// implies. `full_data(n)` returns `(0..n collected, [1.0; n])`,
69 /// preserving the full-data semantics of any caller that walked
70 /// `0..self.n` unconditionally with weight 1.0.
71 pub fn indices_and_weights(&self, n: usize) -> (Vec<usize>, Vec<f64>) {
72 match self.mask.as_ref() {
73 Some(m) => {
74 assert_eq!(
75 m.n_full, n,
76 "RowSubsampleMask n_full ({}) must match caller n ({})",
77 m.n_full, n
78 );
79 let indices: Vec<usize> = m.mask.as_ref().clone();
80 let mut weights = vec![1.0_f64; n];
81 for r in m.rows.iter() {
82 if r.index < n {
83 weights[r.index] = r.weight;
84 }
85 }
86 (indices, weights)
87 }
88 None => ((0..n).collect(), vec![1.0_f64; n]),
89 }
90 }
91}
92
93/// Thin wrapper over the canonical SplitMix64 hash in
94/// [`gam_linalg::utils::splitmix64_hash`].
95fn splitmix64(x: u64) -> u64 {
96 gam_linalg::utils::splitmix64_hash(x)
97}
98
99const FULL_DATA_ROW_SUBSAMPLE_SENTINEL: u64 = 0xA5A5_5A5A_DEAD_BEEF;
100
101fn hash_full(n: usize) -> u64 {
102 let mut h = splitmix64(FULL_DATA_ROW_SUBSAMPLE_SENTINEL ^ (n as u64));
103 if h == 0 {
104 h = 0x1234_5678_9ABC_DEF0;
105 }
106 h
107}
108
109fn hash_subsample(mask: &Arc<OuterScoreSubsample>) -> u64 {
110 let ptr = Arc::as_ptr(mask) as u64;
111 let mut h = splitmix64(ptr);
112 h ^= splitmix64(mask.n_full as u64);
113 h ^= splitmix64(mask.len() as u64);
114 h ^= splitmix64(mask.seed);
115 h ^= splitmix64((mask.weight_scale.to_bits()) ^ 0xC0FF_EE00_0000_0000);
116 if h == 0 {
117 h = 0xDEAD_BEEF_FEED_FACE;
118 }
119 h
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::outer_subsample::OuterScoreSubsample;
126
127 #[test]
128 fn full_data_id_is_stable_per_n() {
129 let a = RowSubsampleMask::full_data(100);
130 let b = RowSubsampleMask::full_data(100);
131 let c = RowSubsampleMask::full_data(101);
132 assert_eq!(a.id, b.id);
133 assert_ne!(a.id, c.id);
134 assert!(a.mask.is_none());
135 }
136
137 #[test]
138 fn subsample_id_matches_for_same_arc() {
139 let s = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
140 vec![1, 3, 5],
141 10,
142 42,
143 ));
144 let a = RowSubsampleMask::subsample(Arc::clone(&s));
145 let b = RowSubsampleMask::subsample(Arc::clone(&s));
146 assert_eq!(a.id, b.id);
147 }
148
149 #[test]
150 fn subsample_id_differs_for_different_arcs() {
151 let s1 = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
152 vec![1, 3, 5],
153 10,
154 42,
155 ));
156 let s2 = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
157 vec![1, 3, 5],
158 10,
159 42,
160 ));
161 let a = RowSubsampleMask::subsample(s1);
162 let b = RowSubsampleMask::subsample(s2);
163 // Different Arc allocations ⇒ different ids; this is intentional
164 // so the TR invariant catches mid-iteration mask rebuilds even
165 // when the resulting mask happens to be content-equal.
166 assert_ne!(a.id, b.id);
167 }
168
169 #[test]
170 fn indices_and_weights_full_data() {
171 let rm = RowSubsampleMask::full_data(4);
172 let (idx, w) = rm.indices_and_weights(4);
173 assert_eq!(idx, vec![0, 1, 2, 3]);
174 assert_eq!(w, vec![1.0, 1.0, 1.0, 1.0]);
175 }
176
177 #[test]
178 fn indices_and_weights_subsample() {
179 let s = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
180 vec![0, 2],
181 4,
182 7,
183 ));
184 let rm = RowSubsampleMask::subsample(s);
185 let (idx, w) = rm.indices_and_weights(4);
186 assert_eq!(idx, vec![0, 2]);
187 assert_eq!(w.len(), 4);
188 assert!(w[0] > 0.0 && w[2] > 0.0);
189 }
190}