Skip to main content

gam_problem/
outer_subsample.rs

1//! Outer-loop row subsampling primitive shared across the solver and the
2//! family-specific outer-score evaluators.
3//!
4//! [`OuterScoreSubsample`] and its per-row [`WeightedOuterRow`] are the
5//! Horvitz–Thompson row subsample consumed on outer-loop hot paths. They live
6//! in the solver layer (below `families`) so that both the solver's row-measure
7//! machinery and the family outer-score builders can depend on them downward,
8//! without `solver` reaching up into `families`. The stratified *builders* that
9//! construct these (`build_outer_score_subsample`, `auto_outer_score_subsample`)
10//! remain in `families::marginal_slope_shared`, since they depend on
11//! family-specific fit options; they import this type downward.
12
13use std::sync::Arc;
14
15/// Stratified row index subsample shared across outer-loop evaluations.
16///
17/// `mask` is sorted, deduplicated, and never empty in practice (enforced by
18/// `build_outer_score_subsample`).
19///
20/// Per-row inverse-inclusion weights `w_i = N_h / k_h` (where `h` is the row's
21/// stratum) are stored alongside the mask in `rows`. The Horvitz–Thompson
22/// estimator for any linear-in-row functional T = Σ_i f_i is
23///   T̂ = Σ_{i ∈ mask} w_i · f_i,
24/// which is unbiased even when per-stratum sampling fractions differ
25/// (the `ceil(k * N_h / n).max(1)` rule in the stratified builder makes
26/// rare strata oversample relative to the bulk, so a single global rescale
27/// `n_full / |mask|` is biased in those strata).
28///
29/// `weight_scale` is retained as a *diagnostic* (mean of `w_i` across the
30/// mask). It equals `n_full / |mask|` when all rows share a uniform inclusion
31/// probability (the caller-supplied-mask case represented by
32/// [`OuterScoreSubsample::from_uniform_inclusion_mask`]); it can drift from
33/// that value under the stratified builder's rare-stratum boost. It is not the
34/// per-row scaling factor — consumers must read `rows[i].weight` for HT
35/// correctness.
36///
37/// # Horvitz–Thompson contract
38///
39/// Per-row weight `rows[i].weight = 1 / π_i`, where `π_i` is the
40/// inclusion probability of row `i` under the stratified sampler. Any
41/// outer-only score/gradient routine that consumes this subsample must
42/// form `Σ_{i ∈ mask} w_i · f_i` so the resulting estimator is unbiased:
43///
44/// ```text
45///   E[ score_subsample ]  =  score_full.
46/// ```
47///
48/// The following families consume this subsample on their outer-loop hot
49/// paths: Gaussian-LS, Binomial-LS, the Wiggle variants, CTN, and
50/// Survival-LS. Each routes the `rows[i].weight` factor through its
51/// per-row accumulator (gradient, Hessian-action, trace probes).
52///
53/// # Convergence warning
54///
55/// Subsampled gradients are noisy by construction. The outer driver must
56/// **never** declare convergence on a subsampled gradient — near
57/// convergence it switches back to the full-data score so that the KKT
58/// stopping test sees the unbiased, low-variance signal. New consumers
59/// adding subsampled paths must preserve this invariant.
60#[derive(Debug, Clone)]
61pub struct OuterScoreSubsample {
62    pub mask: Arc<Vec<usize>>,
63    pub rows: Arc<Vec<WeightedOuterRow>>,
64    pub n_full: usize,
65    pub weight_scale: f64,
66    pub seed: u64,
67}
68
69impl OuterScoreSubsample {
70    /// Wrap a precomputed mask sampled with a uniform inclusion probability,
71    /// assigning each selected row the inverse-inclusion weight `n_full / m`.
72    /// The caller is responsible for sortedness and uniqueness;
73    /// `build_outer_score_subsample` remains the stratified per-row HT builder.
74    pub fn from_uniform_inclusion_mask(mask: Vec<usize>, n_full: usize, seed: u64) -> Self {
75        let m = mask.len();
76        let w = if m == 0 {
77            1.0
78        } else {
79            n_full as f64 / m as f64
80        };
81        Self::with_uniform_weight(mask, n_full, seed, w)
82    }
83
84    /// Wrap a precomputed mask with an explicit uniform per-row weight.
85    /// Useful for tests that need the unrescaled (`weight = 1.0`) sum over a
86    /// custom mask, and for callers that already know the desired
87    /// rescaling factor and don't want the constructor to derive it from
88    /// `n_full / |mask|`.
89    pub fn with_uniform_weight(mask: Vec<usize>, n_full: usize, seed: u64, weight: f64) -> Self {
90        let rows: Vec<WeightedOuterRow> = mask
91            .iter()
92            .map(|&index| WeightedOuterRow {
93                index,
94                weight,
95                stratum: 0,
96            })
97            .collect();
98        let weight_scale = if rows.is_empty() { 1.0 } else { weight };
99        Self {
100            mask: Arc::new(mask),
101            rows: Arc::new(rows),
102            n_full,
103            weight_scale,
104            seed,
105        }
106    }
107
108    /// Wrap a vector of `(index, weight, stratum)` triples. The mask is
109    /// derived as the sorted/dedup'd index list. Used by the stratified
110    /// builder to install per-row HT weights.
111    pub fn from_weighted_rows(mut rows: Vec<WeightedOuterRow>, n_full: usize, seed: u64) -> Self {
112        rows.sort_by_key(|r| r.index);
113        rows.dedup_by_key(|r| r.index);
114        let mask: Vec<usize> = rows.iter().map(|r| r.index).collect();
115        let weight_scale = if rows.is_empty() {
116            1.0
117        } else {
118            rows.iter().map(|r| r.weight).sum::<f64>() / rows.len() as f64
119        };
120        Self {
121            mask: Arc::new(mask),
122            rows: Arc::new(rows),
123            n_full,
124            weight_scale,
125            seed,
126        }
127    }
128
129    #[inline]
130    pub fn len(&self) -> usize {
131        self.mask.len()
132    }
133
134    #[inline]
135    pub fn is_empty(&self) -> bool {
136        self.mask.is_empty()
137    }
138
139    /// True when at least two retained rows have different per-row weights.
140    /// Consumers that previously applied a single post-sum scalar must
141    /// switch to per-row weighting whenever this returns true.
142    pub fn has_variable_weights(&self) -> bool {
143        let mut iter = self.rows.iter();
144        let Some(first) = iter.next() else {
145            return false;
146        };
147        iter.any(|r| (r.weight - first.weight).abs() > 0.0)
148    }
149}
150
151#[derive(Debug, Clone, Copy)]
152pub struct WeightedOuterRow {
153    pub index: usize,
154    pub weight: f64,
155    /// Stratum identifier the row was drawn from. Pure diagnostic — consumers
156    /// must use `weight` for any aggregation.
157    pub stratum: u32,
158}
159
160/// Deterministic row-block tiling constant for the parallel reduction paths.
161///
162/// All cross-row summations chunk the rows into `ARROW_ROW_CHUNK`-sized tiles
163/// and reduce the per-tile partials in tile-index order on the caller thread,
164/// so the floating-point reduction tree is fixed across Rayon worker counts and
165/// work-stealing decisions. Consumers that require deterministic associativity
166/// must keep their tiling a multiple of this constant.
167pub const ARROW_ROW_CHUNK: usize = 256;
168
169/// Number of `ARROW_ROW_CHUNK`-sized tiles covering `n_rows`.
170#[inline]
171pub fn arrow_row_chunk_count(n_rows: usize) -> usize {
172    if n_rows == 0 {
173        0
174    } else {
175        (n_rows - 1) / ARROW_ROW_CHUNK + 1
176    }
177}
178
179/// Row selection for an outer-loop evaluation: either the full data (`All`) or
180/// a Horvitz–Thompson [`WeightedOuterRow`] subsample.
181///
182/// `All` walks rows `0..n_total` with unit weight; `Subsample` walks the stored
183/// rows applying each row's inverse-inclusion scale `1/π_i`, so any partial sum
184/// `Σ_i w_i · f(row_i)` is an unbiased estimator of the corresponding full-data
185/// sum `Σ_{i=1..n_full} f(row_i)`. Inner-PIRLS and final-covariance passes
186/// always run with `All`; only outer score / gradient hot loops consume a
187/// non-`All` variant.
188///
189/// Lives in this lower layer (below `families`/`terms`) so the row-kernel
190/// consumers and the term hot-paths can name it without the `Subsample` field
191/// reaching up into `solver` (#1135). The family-specific constructor
192/// (`families::row_kernel::RowSet::from_options`, which reads
193/// `custom_family::BlockwiseFitOptions`) stays in `families` as an extension
194/// `impl` block.
195#[derive(Clone)]
196pub enum RowSet {
197    All,
198    Subsample {
199        rows: Arc<Vec<WeightedOuterRow>>,
200        n_full: usize,
201    },
202}
203
204impl RowSet {
205    /// Parallel fold-reduce over the row set. `init` produces a fresh
206    /// accumulator, `fold` is the per-row update, `reduce` combines two
207    /// accumulators.
208    ///
209    /// Returns the reduced result. Both branches process fixed-size row chunks
210    /// in parallel, then combine the chunk accumulators in chunk-index order on
211    /// the caller thread. The resulting floating-point reduction tree is fixed
212    /// across Rayon worker counts and work-stealing decisions.
213    #[inline]
214    pub fn par_reduce_fold<T, I, F, R>(&self, n_total: usize, init: I, fold: F, reduce: R) -> T
215    where
216        T: Send,
217        I: Fn() -> T + Send + Sync,
218        F: Fn(T, usize, f64) -> T + Send + Sync,
219        R: Fn(T, T) -> T + Send + Sync,
220    {
221        use rayon::iter::{IntoParallelIterator, ParallelIterator};
222        use rayon::slice::ParallelSlice;
223        match self {
224            Self::All => {
225                let chunk_accumulators: Vec<T> = (0..arrow_row_chunk_count(n_total))
226                    .into_par_iter()
227                    .map(|chunk_idx| {
228                        let start = chunk_idx * ARROW_ROW_CHUNK;
229                        let end = (start + ARROW_ROW_CHUNK).min(n_total);
230                        let mut acc = init();
231                        for i in start..end {
232                            acc = fold(acc, i, 1.0);
233                        }
234                        acc
235                    })
236                    .collect();
237                let mut total = init();
238                for acc in chunk_accumulators {
239                    total = reduce(total, acc);
240                }
241                total
242            }
243            Self::Subsample { rows, .. } => {
244                let chunk_accumulators: Vec<T> = rows
245                    .par_chunks(ARROW_ROW_CHUNK)
246                    .map(|chunk| {
247                        let mut acc = init();
248                        for r in chunk {
249                            acc = fold(acc, r.index, r.weight);
250                        }
251                        acc
252                    })
253                    .collect();
254                let mut total = init();
255                for acc in chunk_accumulators {
256                    total = reduce(total, acc);
257                }
258                total
259            }
260        }
261    }
262
263    /// Parallel try-fold over fixed-size row chunks, followed by deterministic
264    /// chunk-index-order reduction on the caller thread.
265    #[inline]
266    pub fn par_try_reduce_fold<T, E, I, F, R>(
267        &self,
268        n_total: usize,
269        init: I,
270        fold: F,
271        reduce: R,
272    ) -> Result<T, E>
273    where
274        T: Send,
275        E: Send,
276        I: Fn() -> T + Send + Sync,
277        F: Fn(T, usize, f64) -> Result<T, E> + Send + Sync,
278        R: Fn(T, T) -> Result<T, E> + Send + Sync,
279    {
280        use rayon::iter::{IntoParallelIterator, ParallelIterator};
281        use rayon::slice::ParallelSlice;
282        match self {
283            Self::All => {
284                let chunk_accumulators: Vec<Result<T, E>> = (0..arrow_row_chunk_count(n_total))
285                    .into_par_iter()
286                    .map(|chunk_idx| {
287                        let start = chunk_idx * ARROW_ROW_CHUNK;
288                        let end = (start + ARROW_ROW_CHUNK).min(n_total);
289                        let mut acc = init();
290                        for i in start..end {
291                            acc = fold(acc, i, 1.0)?;
292                        }
293                        Ok(acc)
294                    })
295                    .collect();
296                let mut total = init();
297                for acc in chunk_accumulators {
298                    total = reduce(total, acc?)?;
299                }
300                Ok(total)
301            }
302            Self::Subsample { rows, .. } => {
303                let chunk_accumulators: Vec<Result<T, E>> = rows
304                    .par_chunks(ARROW_ROW_CHUNK)
305                    .map(|chunk| {
306                        let mut acc = init();
307                        for r in chunk {
308                            acc = fold(acc, r.index, r.weight)?;
309                        }
310                        Ok(acc)
311                    })
312                    .collect();
313                let mut total = init();
314                for acc in chunk_accumulators {
315                    total = reduce(total, acc?)?;
316                }
317                Ok(total)
318            }
319        }
320    }
321}