Skip to main content

gam_problem/
row_measure.rs

1//! Row-subsample mask handle for trust-region invariant enforcement.
2//!
3//! A `RowSubsampleMask` is the explicit identity of the set of rows + per-row
4//! weights used to evaluate any one of {Hessian, gradient, objective}
5//! during a single inner trust-region iteration. The trust-region
6//! globalization computes
7//!
8//!   ρ = actual_reduction / predicted_reduction
9//!     = [F(β) − F(β + δ)] / [−g·δ − ½·δᵀHδ]
10//!
11//! and accepts/rejects the step from ρ. All four quantities (F(β),
12//! F(β + δ), g, H) MUST be evaluated against the same row measure for
13//! ρ to be meaningful; otherwise the numerator and denominator estimate
14//! different objectives and ρ can take any sign, producing the observed
15//! ρ = -0.05 with predicted_reduction = +7.378e6 sign flip.
16//!
17//! `RowSubsampleMask::id` is a stable 64-bit content hash: equal masks
18//! (`Arc<OuterScoreSubsample>` pointer equality OR identical mask
19//! contents) ⇒ equal ids; differing masks ⇒ differing ids with high
20//! probability. The TR loop captures one `RowSubsampleMask` at the top of an
21//! iteration and hard-asserts that the id observed by each of the four
22//! quantities matches before computing ρ.
23//!
24//! The `BlockwiseFitOptions`-coupled `from_options` constructor stays up in
25//! `gam-solve` (it depends on the options type, which lives above this tier);
26//! the data type and its pure data methods live here so lower tiers can
27//! consume the measure without depending on `gam-solve`.
28
29use std::sync::Arc;
30
31use crate::outer_subsample::OuterScoreSubsample;
32
33/// Identifier-carrying handle for a single row subsample mask.
34///
35/// The handle is `Clone` and cheap to copy; the `Arc` is shared, not
36/// duplicated.
37#[derive(Clone, Debug)]
38pub struct RowSubsampleMask {
39    /// Stable 64-bit content hash. Same `mask` (by Arc pointer OR by
40    /// row content) ⇒ same id; different `mask` ⇒ different id.
41    pub id: u64,
42    /// `None` means full data (`0..n`, weight 1.0 per row).
43    /// `Some(_)` means the rows and HT weights inside the subsample.
44    pub mask: Option<Arc<OuterScoreSubsample>>,
45}
46
47impl RowSubsampleMask {
48    /// Full-data measure: walk `0..n` with weight 1.0 per row.
49    pub fn full_data(n: usize) -> Self {
50        Self {
51            id: hash_full(n),
52            mask: None,
53        }
54    }
55
56    /// Subsample measure: walk the mask's rows with their per-row HT
57    /// weights. Id is derived from the Arc pointer (cheap and stable
58    /// for the lifetime of the Arc) combined with mask metadata.
59    pub fn subsample(mask: Arc<OuterScoreSubsample>) -> Self {
60        let id = hash_subsample(&mask);
61        Self {
62            id,
63            mask: Some(mask),
64        }
65    }
66
67    /// Materialize the row indices and per-row weights this measure
68    /// implies. `full_data(n)` returns `(0..n collected, [1.0; n])`,
69    /// preserving the full-data semantics of any caller that walked
70    /// `0..self.n` unconditionally with weight 1.0.
71    pub fn indices_and_weights(&self, n: usize) -> (Vec<usize>, Vec<f64>) {
72        match self.mask.as_ref() {
73            Some(m) => {
74                assert_eq!(
75                    m.n_full, n,
76                    "RowSubsampleMask n_full ({}) must match caller n ({})",
77                    m.n_full, n
78                );
79                let indices: Vec<usize> = m.mask.as_ref().clone();
80                let mut weights = vec![1.0_f64; n];
81                for r in m.rows.iter() {
82                    if r.index < n {
83                        weights[r.index] = r.weight;
84                    }
85                }
86                (indices, weights)
87            }
88            None => ((0..n).collect(), vec![1.0_f64; n]),
89        }
90    }
91}
92
93/// Thin wrapper over the canonical SplitMix64 hash in
94/// [`gam_linalg::utils::splitmix64_hash`].
95fn splitmix64(x: u64) -> u64 {
96    gam_linalg::utils::splitmix64_hash(x)
97}
98
99const FULL_DATA_ROW_SUBSAMPLE_SENTINEL: u64 = 0xA5A5_5A5A_DEAD_BEEF;
100
101fn hash_full(n: usize) -> u64 {
102    let mut h = splitmix64(FULL_DATA_ROW_SUBSAMPLE_SENTINEL ^ (n as u64));
103    if h == 0 {
104        h = 0x1234_5678_9ABC_DEF0;
105    }
106    h
107}
108
109fn hash_subsample(mask: &Arc<OuterScoreSubsample>) -> u64 {
110    let ptr = Arc::as_ptr(mask) as u64;
111    let mut h = splitmix64(ptr);
112    h ^= splitmix64(mask.n_full as u64);
113    h ^= splitmix64(mask.len() as u64);
114    h ^= splitmix64(mask.seed);
115    h ^= splitmix64((mask.weight_scale.to_bits()) ^ 0xC0FF_EE00_0000_0000);
116    if h == 0 {
117        h = 0xDEAD_BEEF_FEED_FACE;
118    }
119    h
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::outer_subsample::OuterScoreSubsample;
126
127    #[test]
128    fn full_data_id_is_stable_per_n() {
129        let a = RowSubsampleMask::full_data(100);
130        let b = RowSubsampleMask::full_data(100);
131        let c = RowSubsampleMask::full_data(101);
132        assert_eq!(a.id, b.id);
133        assert_ne!(a.id, c.id);
134        assert!(a.mask.is_none());
135    }
136
137    #[test]
138    fn subsample_id_matches_for_same_arc() {
139        let s = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
140            vec![1, 3, 5],
141            10,
142            42,
143        ));
144        let a = RowSubsampleMask::subsample(Arc::clone(&s));
145        let b = RowSubsampleMask::subsample(Arc::clone(&s));
146        assert_eq!(a.id, b.id);
147    }
148
149    #[test]
150    fn subsample_id_differs_for_different_arcs() {
151        let s1 = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
152            vec![1, 3, 5],
153            10,
154            42,
155        ));
156        let s2 = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
157            vec![1, 3, 5],
158            10,
159            42,
160        ));
161        let a = RowSubsampleMask::subsample(s1);
162        let b = RowSubsampleMask::subsample(s2);
163        // Different Arc allocations ⇒ different ids; this is intentional
164        // so the TR invariant catches mid-iteration mask rebuilds even
165        // when the resulting mask happens to be content-equal.
166        assert_ne!(a.id, b.id);
167    }
168
169    #[test]
170    fn indices_and_weights_full_data() {
171        let rm = RowSubsampleMask::full_data(4);
172        let (idx, w) = rm.indices_and_weights(4);
173        assert_eq!(idx, vec![0, 1, 2, 3]);
174        assert_eq!(w, vec![1.0, 1.0, 1.0, 1.0]);
175    }
176
177    #[test]
178    fn indices_and_weights_subsample() {
179        let s = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
180            vec![0, 2],
181            4,
182            7,
183        ));
184        let rm = RowSubsampleMask::subsample(s);
185        let (idx, w) = rm.indices_and_weights(4);
186        assert_eq!(idx, vec![0, 2]);
187        assert_eq!(w.len(), 4);
188        assert!(w[0] > 0.0 && w[2] > 0.0);
189    }
190}