gam_problem/outer_subsample.rs
1//! Outer-loop row subsampling primitive shared across the solver and the
2//! family-specific outer-score evaluators.
3//!
4//! [`OuterScoreSubsample`] and its per-row [`WeightedOuterRow`] are the
5//! Horvitz–Thompson row subsample consumed on outer-loop hot paths. They live
6//! in the solver layer (below `families`) so that both the solver's row-measure
7//! machinery and the family outer-score builders can depend on them downward,
8//! without `solver` reaching up into `families`. The stratified *builders* that
9//! construct these (`build_outer_score_subsample`, `auto_outer_score_subsample`)
10//! remain in `families::marginal_slope_shared`, since they depend on
11//! family-specific fit options; they import this type downward.
12
13use std::sync::Arc;
14
15/// Stratified row index subsample shared across outer-loop evaluations.
16///
17/// `mask` is sorted, deduplicated, and never empty in practice (enforced by
18/// `build_outer_score_subsample`).
19///
20/// Per-row inverse-inclusion weights `w_i = N_h / k_h` (where `h` is the row's
21/// stratum) are stored alongside the mask in `rows`. The Horvitz–Thompson
22/// estimator for any linear-in-row functional T = Σ_i f_i is
23/// T̂ = Σ_{i ∈ mask} w_i · f_i,
24/// which is unbiased even when per-stratum sampling fractions differ
25/// (the `ceil(k * N_h / n).max(1)` rule in the stratified builder makes
26/// rare strata oversample relative to the bulk, so a single global rescale
27/// `n_full / |mask|` is biased in those strata).
28///
29/// `weight_scale` is retained as a *diagnostic* (mean of `w_i` across the
30/// mask). It equals `n_full / |mask|` when all rows share a uniform inclusion
31/// probability (the caller-supplied-mask case represented by
32/// [`OuterScoreSubsample::from_uniform_inclusion_mask`]); it can drift from
33/// that value under the stratified builder's rare-stratum boost. It is not the
34/// per-row scaling factor — consumers must read `rows[i].weight` for HT
35/// correctness.
36///
37/// # Horvitz–Thompson contract
38///
39/// Per-row weight `rows[i].weight = 1 / π_i`, where `π_i` is the
40/// inclusion probability of row `i` under the stratified sampler. Any
41/// outer-only score/gradient routine that consumes this subsample must
42/// form `Σ_{i ∈ mask} w_i · f_i` so the resulting estimator is unbiased:
43///
44/// ```text
45/// E[ score_subsample ] = score_full.
46/// ```
47///
48/// The following families consume this subsample on their outer-loop hot
49/// paths: Gaussian-LS, Binomial-LS, the Wiggle variants, CTN, and
50/// Survival-LS. Each routes the `rows[i].weight` factor through its
51/// per-row accumulator (gradient, Hessian-action, trace probes).
52///
53/// # Convergence warning
54///
55/// Subsampled gradients are noisy by construction. The outer driver must
56/// **never** declare convergence on a subsampled gradient — near
57/// convergence it switches back to the full-data score so that the KKT
58/// stopping test sees the unbiased, low-variance signal. New consumers
59/// adding subsampled paths must preserve this invariant.
60#[derive(Debug, Clone)]
61pub struct OuterScoreSubsample {
62 pub mask: Arc<Vec<usize>>,
63 pub rows: Arc<Vec<WeightedOuterRow>>,
64 pub n_full: usize,
65 pub weight_scale: f64,
66 pub seed: u64,
67}
68
69impl OuterScoreSubsample {
70 /// Wrap a precomputed mask sampled with a uniform inclusion probability,
71 /// assigning each selected row the inverse-inclusion weight `n_full / m`.
72 /// The caller is responsible for sortedness and uniqueness;
73 /// `build_outer_score_subsample` remains the stratified per-row HT builder.
74 pub fn from_uniform_inclusion_mask(mask: Vec<usize>, n_full: usize, seed: u64) -> Self {
75 let m = mask.len();
76 let w = if m == 0 {
77 1.0
78 } else {
79 n_full as f64 / m as f64
80 };
81 Self::with_uniform_weight(mask, n_full, seed, w)
82 }
83
84 /// Wrap a precomputed mask with an explicit uniform per-row weight.
85 /// Useful for tests that need the unrescaled (`weight = 1.0`) sum over a
86 /// custom mask, and for callers that already know the desired
87 /// rescaling factor and don't want the constructor to derive it from
88 /// `n_full / |mask|`.
89 pub fn with_uniform_weight(mask: Vec<usize>, n_full: usize, seed: u64, weight: f64) -> Self {
90 let rows: Vec<WeightedOuterRow> = mask
91 .iter()
92 .map(|&index| WeightedOuterRow {
93 index,
94 weight,
95 stratum: 0,
96 })
97 .collect();
98 let weight_scale = if rows.is_empty() { 1.0 } else { weight };
99 Self {
100 mask: Arc::new(mask),
101 rows: Arc::new(rows),
102 n_full,
103 weight_scale,
104 seed,
105 }
106 }
107
108 /// Wrap a vector of `(index, weight, stratum)` triples. The mask is
109 /// derived as the sorted/dedup'd index list. Used by the stratified
110 /// builder to install per-row HT weights.
111 pub fn from_weighted_rows(mut rows: Vec<WeightedOuterRow>, n_full: usize, seed: u64) -> Self {
112 rows.sort_by_key(|r| r.index);
113 rows.dedup_by_key(|r| r.index);
114 let mask: Vec<usize> = rows.iter().map(|r| r.index).collect();
115 let weight_scale = if rows.is_empty() {
116 1.0
117 } else {
118 rows.iter().map(|r| r.weight).sum::<f64>() / rows.len() as f64
119 };
120 Self {
121 mask: Arc::new(mask),
122 rows: Arc::new(rows),
123 n_full,
124 weight_scale,
125 seed,
126 }
127 }
128
129 #[inline]
130 pub fn len(&self) -> usize {
131 self.mask.len()
132 }
133
134 #[inline]
135 pub fn is_empty(&self) -> bool {
136 self.mask.is_empty()
137 }
138
139 /// True when at least two retained rows have different per-row weights.
140 /// Consumers that previously applied a single post-sum scalar must
141 /// switch to per-row weighting whenever this returns true.
142 pub fn has_variable_weights(&self) -> bool {
143 let mut iter = self.rows.iter();
144 let Some(first) = iter.next() else {
145 return false;
146 };
147 iter.any(|r| (r.weight - first.weight).abs() > 0.0)
148 }
149}
150
151#[derive(Debug, Clone, Copy)]
152pub struct WeightedOuterRow {
153 pub index: usize,
154 pub weight: f64,
155 /// Stratum identifier the row was drawn from. Pure diagnostic — consumers
156 /// must use `weight` for any aggregation.
157 pub stratum: u32,
158}
159
160/// Deterministic row-block tiling constant for the parallel reduction paths.
161///
162/// All cross-row summations chunk the rows into `ARROW_ROW_CHUNK`-sized tiles
163/// and reduce the per-tile partials in tile-index order on the caller thread,
164/// so the floating-point reduction tree is fixed across Rayon worker counts and
165/// work-stealing decisions. Consumers that require deterministic associativity
166/// must keep their tiling a multiple of this constant.
167pub const ARROW_ROW_CHUNK: usize = 256;
168
169/// Number of `ARROW_ROW_CHUNK`-sized tiles covering `n_rows`.
170#[inline]
171pub fn arrow_row_chunk_count(n_rows: usize) -> usize {
172 if n_rows == 0 {
173 0
174 } else {
175 (n_rows - 1) / ARROW_ROW_CHUNK + 1
176 }
177}
178
179/// Row selection for an outer-loop evaluation: either the full data (`All`) or
180/// a Horvitz–Thompson [`WeightedOuterRow`] subsample.
181///
182/// `All` walks rows `0..n_total` with unit weight; `Subsample` walks the stored
183/// rows applying each row's inverse-inclusion scale `1/π_i`, so any partial sum
184/// `Σ_i w_i · f(row_i)` is an unbiased estimator of the corresponding full-data
185/// sum `Σ_{i=1..n_full} f(row_i)`. Inner-PIRLS and final-covariance passes
186/// always run with `All`; only outer score / gradient hot loops consume a
187/// non-`All` variant.
188///
189/// Lives in this lower layer (below `families`/`terms`) so the row-kernel
190/// consumers and the term hot-paths can name it without the `Subsample` field
191/// reaching up into `solver` (#1135). The family-specific constructor
192/// (`families::row_kernel::RowSet::from_options`, which reads
193/// `custom_family::BlockwiseFitOptions`) stays in `families` as an extension
194/// `impl` block.
195#[derive(Clone)]
196pub enum RowSet {
197 All,
198 Subsample {
199 rows: Arc<Vec<WeightedOuterRow>>,
200 n_full: usize,
201 },
202}
203
204impl RowSet {
205 /// Parallel fold-reduce over the row set. `init` produces a fresh
206 /// accumulator, `fold` is the per-row update, `reduce` combines two
207 /// accumulators.
208 ///
209 /// Returns the reduced result. Both branches process fixed-size row chunks
210 /// in parallel, then combine the chunk accumulators in chunk-index order on
211 /// the caller thread. The resulting floating-point reduction tree is fixed
212 /// across Rayon worker counts and work-stealing decisions.
213 #[inline]
214 pub fn par_reduce_fold<T, I, F, R>(&self, n_total: usize, init: I, fold: F, reduce: R) -> T
215 where
216 T: Send,
217 I: Fn() -> T + Send + Sync,
218 F: Fn(T, usize, f64) -> T + Send + Sync,
219 R: Fn(T, T) -> T + Send + Sync,
220 {
221 use rayon::iter::{IntoParallelIterator, ParallelIterator};
222 use rayon::slice::ParallelSlice;
223 match self {
224 Self::All => {
225 let chunk_accumulators: Vec<T> = (0..arrow_row_chunk_count(n_total))
226 .into_par_iter()
227 .map(|chunk_idx| {
228 let start = chunk_idx * ARROW_ROW_CHUNK;
229 let end = (start + ARROW_ROW_CHUNK).min(n_total);
230 let mut acc = init();
231 for i in start..end {
232 acc = fold(acc, i, 1.0);
233 }
234 acc
235 })
236 .collect();
237 let mut total = init();
238 for acc in chunk_accumulators {
239 total = reduce(total, acc);
240 }
241 total
242 }
243 Self::Subsample { rows, .. } => {
244 let chunk_accumulators: Vec<T> = rows
245 .par_chunks(ARROW_ROW_CHUNK)
246 .map(|chunk| {
247 let mut acc = init();
248 for r in chunk {
249 acc = fold(acc, r.index, r.weight);
250 }
251 acc
252 })
253 .collect();
254 let mut total = init();
255 for acc in chunk_accumulators {
256 total = reduce(total, acc);
257 }
258 total
259 }
260 }
261 }
262
263 /// Parallel try-fold over fixed-size row chunks, followed by deterministic
264 /// chunk-index-order reduction on the caller thread.
265 #[inline]
266 pub fn par_try_reduce_fold<T, E, I, F, R>(
267 &self,
268 n_total: usize,
269 init: I,
270 fold: F,
271 reduce: R,
272 ) -> Result<T, E>
273 where
274 T: Send,
275 E: Send,
276 I: Fn() -> T + Send + Sync,
277 F: Fn(T, usize, f64) -> Result<T, E> + Send + Sync,
278 R: Fn(T, T) -> Result<T, E> + Send + Sync,
279 {
280 use rayon::iter::{IntoParallelIterator, ParallelIterator};
281 use rayon::slice::ParallelSlice;
282 match self {
283 Self::All => {
284 let chunk_accumulators: Vec<Result<T, E>> = (0..arrow_row_chunk_count(n_total))
285 .into_par_iter()
286 .map(|chunk_idx| {
287 let start = chunk_idx * ARROW_ROW_CHUNK;
288 let end = (start + ARROW_ROW_CHUNK).min(n_total);
289 let mut acc = init();
290 for i in start..end {
291 acc = fold(acc, i, 1.0)?;
292 }
293 Ok(acc)
294 })
295 .collect();
296 let mut total = init();
297 for acc in chunk_accumulators {
298 total = reduce(total, acc?)?;
299 }
300 Ok(total)
301 }
302 Self::Subsample { rows, .. } => {
303 let chunk_accumulators: Vec<Result<T, E>> = rows
304 .par_chunks(ARROW_ROW_CHUNK)
305 .map(|chunk| {
306 let mut acc = init();
307 for r in chunk {
308 acc = fold(acc, r.index, r.weight)?;
309 }
310 Ok(acc)
311 })
312 .collect();
313 let mut total = init();
314 for acc in chunk_accumulators {
315 total = reduce(total, acc?)?;
316 }
317 Ok(total)
318 }
319 }
320 }
321}