Skip to main content

gam_problem/
lib.rs

1//! Shared REML/LAML contract types.
2//!
3//! These are the family-facing interfaces for REML outer assembly. They live
4//! below `solver` so families can construct operator-backed derivative payloads
5//! without importing `solver::estimate::reml::reml_outer_engine`.
6
7use std::any::Any;
8use std::collections::HashMap;
9use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
10use std::sync::{Arc, Condvar, Mutex};
11
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14
15#[macro_use]
16mod macros;
17
18pub mod basis_error;
19pub mod block_count_error;
20pub mod block_role;
21pub mod block_spec;
22pub mod coefficient_prior_mean;
23pub mod custom_family_blockwise;
24pub mod custom_family_error;
25pub mod diagnostics;
26pub mod dispersion;
27pub mod dispersion_cov;
28pub mod estimation_error;
29pub mod execution_path;
30pub mod family_options;
31pub mod finite_validation;
32pub mod fisher_rao;
33pub mod gauge;
34pub mod identifiability_audit;
35pub mod joint_penalty;
36mod linalg_helpers;
37mod linear_constraints;
38pub mod monotone_root_error;
39pub mod outer_subsample;
40pub mod penalty_coordinate;
41pub mod penalty_matrix;
42mod pseudo_logdet;
43pub mod psi_design_contract;
44pub mod psi_terms;
45pub mod riemannian_retraction;
46// `ρ`-posterior certificate/escalation DATA types contract-downed (#1521) so
47// gam-solve can store/return them without a back-edge into gam-inference; the
48// computation stays UP in the monolith `inference::rho_posterior`.
49pub mod rho_posterior;
50pub mod row_measure;
51pub mod row_metric;
52pub mod schedule;
53// #1521 contract-downs: pure-data carriers + caller-supplied sampler/verdict
54// traits so gam-solve can call up-tier work (NUTS sampling, topology verdicts)
55// without a back-edge into gam-inference/gam-sae; computation stays UP.
56pub mod laplace_sampler_contract;
57pub mod topology_certificates;
58mod seeding;
59pub mod solver_contract;
60pub mod types;
61
62pub use riemannian_retraction::LatentRetractionRegistry;
63pub use row_measure::RowSubsampleMask;
64
65mod gpu {
66    pub(crate) mod linalg_dispatch {
67        use ndarray::{Array2, ArrayView2};
68
69        pub(crate) fn try_fast_atb(
70            a: ArrayView2<'_, f64>,
71            b: ArrayView2<'_, f64>,
72        ) -> Option<Array2<f64>> {
73            let (n_a, p) = a.dim();
74            let (n_b, q) = b.dim();
75            assert_eq!(n_a, n_b, "A and B must have same number of rows");
76            if !crate::linalg_helpers::should_use_faer_matmul(p, q, n_a) {
77                return None;
78            }
79            Some(crate::linalg_helpers::fast_atb_with_parallelism(
80                &a,
81                &b,
82                crate::linalg_helpers::matmul_parallelism(p, q, n_a),
83            ))
84        }
85    }
86}
87
88pub use basis_error::BasisError;
89pub use block_count_error::BlockCountMismatch;
90pub use block_role::BlockRole;
91pub use block_spec::{
92    AdditiveBlockJacobian, BlockEffectiveJacobian, BlockGeometryDirectionalDerivative,
93    BlockWorkingSet, FamilyChannelHessian, FamilyLinearizationState, GaugeComposedJacobian,
94    ParameterBlockSpec, ParameterBlockState, RowScaledJacobian, TensorChannelHessian,
95};
96pub use coefficient_prior_mean::{CoefficientPriorMean, PriorMeanError};
97pub use custom_family_blockwise::{
98    CUSTOM_FAMILY_RIDGE_FLOOR, CUSTOM_FAMILY_WEIGHT_FLOOR, ExactNewtonOuterCurvature,
99    validate_blockspec_consistency,
100};
101pub use custom_family_error::CustomFamilyError;
102pub use dispersion::Dispersion;
103pub use dispersion_cov::{DispersionExt, PhiScaledCovariance, UnscaledPrecision, se_from_covariance};
104pub use estimation_error::EstimationError;
105pub use execution_path::ExecutionPath;
106pub use family_options::{ExactNewtonOuterObjective, ExactOuterDerivativeOrder};
107pub use finite_validation::{
108    bail_if_cached_beta_non_finite, ensure_finite_scalar, ensure_finite_scalar_estimation,
109    validate_all_finite, validate_all_finite_estimation,
110};
111pub use fisher_rao::{
112    FisherRaoDefiniteness, normalize_fisher_rao_blocks, normalize_fisher_rao_blocks_pd,
113};
114pub use gam_linalg::faer_ndarray::{in_nested_parallel_region, with_nested_parallel};
115pub use gauge::Gauge;
116pub use identifiability_audit::{
117    AliasedPair, BlockIdentity, DroppedColumn, IdentifiabilityAudit, MapUniquenessError,
118};
119pub use joint_penalty::{JointPenaltyBundle, JointPenaltyError, JointPenaltySpec};
120use linalg_helpers::{dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into};
121pub use linear_constraints::LinearInequalityConstraints;
122pub use monotone_root_error::MonotoneRootError;
123pub use penalty_coordinate::PenaltyCoordinate;
124pub use penalty_matrix::PenaltyMatrix;
125pub use pseudo_logdet::PseudoLogdetMode;
126pub use psi_design_contract::{
127    CustomFamilyBlockPsiDerivative, CustomFamilyPsiDerivativeOperator, JointHessianSourcePreference,
128    MaterializablePsiDerivativeOperator, MaterializationIntent, SharedDerivativeBlocks,
129};
130pub use psi_terms::{
131    ExactNewtonJointPsiSecondOrderContracted, ExactNewtonJointPsiSecondOrderTerms,
132    ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace,
133};
134pub use row_metric::{MetricProvenance, RowMetric, WeightField};
135pub use schedule::{GumbelTemperatureSchedule, ScheduleKind, SearchStrategy};
136pub use seeding::{SeedConfig, SeedRiskProfile, clamp_seed_rho_to_bounds, normalize_seed_bounds};
137pub use solver_contract::{
138    DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterEval,
139    OuterHessianMaterialization, OuterHessianOperator, OuterStrategyError,
140};
141pub use types::*;
142
143#[cold]
144fn reml_contract_panic(message: impl Into<String>) -> ! {
145    std::panic::panic_any(message.into())
146}
147
148/// Evaluation mode for the unified evaluator.
149#[derive(Clone, Copy, Debug, PartialEq, Eq)]
150pub enum EvalMode {
151    /// Compute cost only (e.g., for line search).
152    ValueOnly,
153    /// Compute cost and gradient (the common case).
154    ValueAndGradient,
155    /// Compute cost, gradient, and outer Hessian.
156    ValueGradientHessian,
157}
158
159/// Trait for operators that can compute a hyper-derivative matrix-vector product
160/// without necessarily materializing the full matrix.
161struct NonDowncastableHyperOperator;
162
163static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
164
165pub trait HyperOperator: Send + Sync {
166    /// Operator dimension `p` such that `B · v` consumes a `p`-vector and
167    /// produces a `p`-vector.
168    fn dim(&self) -> usize;
169
170    /// Compute B · v (matrix-vector product). v and result are p-vectors.
171    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
172
173    /// Expose the concrete type for solver-local downcast helpers when the
174    /// implementor has a `'static` concrete type. Borrowing adapters may keep
175    /// the default, which simply cannot downcast.
176    fn as_any(&self) -> &(dyn Any + 'static) {
177        &NON_DOWNCASTABLE_HYPER_OPERATOR
178    }
179
180    /// Compute B · v from a vector view.
181    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
182        self.mul_vec(&v.to_owned())
183    }
184
185    /// Compute B · v into caller-owned storage.
186    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
187        out.assign(&self.mul_vec_view(v));
188    }
189
190    /// Compute B · F where F is (p × k). Default dispatches per-column in
191    /// parallel unless already inside a rayon worker.
192    fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
193        let p = factor.nrows();
194        let k = factor.ncols();
195        let mut out = Array2::<f64>::zeros((p, k));
196        if rayon::current_thread_index().is_some() {
197            for col in 0..k {
198                let bv = out.column_mut(col);
199                self.mul_vec_into(factor.column(col), bv);
200            }
201            return out;
202        }
203        let cols: Vec<Array1<f64>> = (0..k)
204            .into_par_iter()
205            .map(|col| {
206                let mut bv = Array1::<f64>::zeros(p);
207                self.mul_vec_into(factor.column(col), bv.view_mut());
208                bv
209            })
210            .collect();
211        for (col, bv) in cols.into_iter().enumerate() {
212            out.column_mut(col).assign(&bv);
213        }
214        out
215    }
216
217    /// Compute `trace(F^T B F)` for a `(p x k)` factor matrix `F`.
218    fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
219        let op_factor = self.mul_mat(factor);
220        factor
221            .iter()
222            .zip(op_factor.iter())
223            .map(|(&f, &bf)| f * bf)
224            .sum()
225    }
226
227    /// Optional stable identity for this operator's action `B`. When `Some`,
228    /// the default cached trace / projected-matrix paths memoize the `B · F`
229    /// product in the shared [`ProjectedFactorCache`] under a
230    /// `(design_id, factor)` key, so repeated projections of the same factor
231    /// against the same operator within one outer iteration build `B · F`
232    /// once. `None` (the default) disables that reuse: an operator with no
233    /// design factor stable across calls cannot key the cache without risking
234    /// a stale `B · F`, so it recomputes every time.
235    fn projection_design_id(&self) -> Option<usize> {
236        None
237    }
238
239    fn trace_projected_factor_cached(
240        &self,
241        factor: &Array2<f64>,
242        factor_cache: &ProjectedFactorCache,
243    ) -> f64 {
244        // The default implementation has no use for the caller-owned cache;
245        // verify the cache object carries a positive-size allocation before
246        // delegating to the exact path.
247        assert!(std::mem::size_of_val(factor_cache) > 0);
248        match self.projection_design_id() {
249            Some(design_id) => {
250                let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
251                let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
252                factor
253                    .iter()
254                    .zip(projected.iter())
255                    .map(|(&f, &bf)| f * bf)
256                    .sum()
257            }
258            None => self.trace_projected_factor(factor),
259        }
260    }
261
262    /// Compute the exact projected matrix `F^T B F`.
263    fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
264        let op_factor = self.mul_mat(factor);
265        crate::linalg_helpers::fast_atb(factor, &op_factor)
266    }
267
268    /// Compute the exact projected matrix `F^T B F`, reusing caller-owned
269    /// projection caches when the operator has a shared row/design factor.
270    fn projected_matrix_cached(
271        &self,
272        factor: &Array2<f64>,
273        factor_cache: &ProjectedFactorCache,
274    ) -> Array2<f64> {
275        assert!(std::mem::size_of_val(factor_cache) > 0);
276        match self.projection_design_id() {
277            Some(design_id) => {
278                let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
279                let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
280                crate::linalg_helpers::fast_atb(factor, projected.as_ref())
281            }
282            None => self.projected_matrix(factor),
283        }
284    }
285
286    /// Fill columns `[start, start + out.ncols())` of `B` into `out`.
287    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
288        let cols = out.ncols();
289        let dim = out.nrows();
290        assert!(start + cols <= dim);
291        let mut basis = Array1::<f64>::zeros(dim);
292        for local_col in 0..cols {
293            let global_col = start + local_col;
294            basis[global_col] = 1.0;
295            self.mul_vec_into(basis.view(), out.column_mut(local_col));
296            basis[global_col] = 0.0;
297        }
298    }
299
300    /// Accumulate `scale * B · v` into caller-owned storage.
301    fn scaled_add_mul_vec(
302        &self,
303        v: ArrayView1<'_, f64>,
304        scale: f64,
305        mut out: ArrayViewMut1<'_, f64>,
306    ) {
307        if scale == 0.0 {
308            return;
309        }
310        let mut work = Array1::<f64>::zeros(out.len());
311        self.mul_vec_into(v, work.view_mut());
312        out.scaled_add(scale, &work);
313    }
314
315    /// Compute v^T · B · u (bilinear form).
316    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
317        let mut bv = Array1::<f64>::zeros(v.len());
318        self.mul_vec_into(v.view(), bv.view_mut());
319        u.dot(&bv)
320    }
321
322    /// Compute v^T · B · u without requiring owned vector inputs.
323    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
324        let mut bv = Array1::<f64>::zeros(v.len());
325        self.mul_vec_into(v, bv.view_mut());
326        u.dot(&bv)
327    }
328
329    /// Whether `bilinear_view` is implemented as a direct scalar contraction.
330    fn has_fast_bilinear_view(&self) -> bool {
331        false
332    }
333
334    /// Full dense materialization.
335    fn to_dense(&self) -> Array2<f64> {
336        let p = self.dim();
337        let mut out = Array2::<f64>::zeros((p, p));
338        let mut basis = Array1::<f64>::zeros(p);
339        for j in 0..p {
340            basis[j] = 1.0;
341            self.mul_vec_into(basis.view(), out.column_mut(j));
342            basis[j] = 0.0;
343        }
344        out
345    }
346
347    /// Whether this operator uses implicit (non-materialized) storage.
348    fn is_implicit(&self) -> bool;
349
350    /// If this operator is block-local, returns the block range and local matrix.
351    fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
352        None
353    }
354}
355
356#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
357pub struct ProjectedFactorKey {
358    pub(crate) design_id: usize,
359    pub(crate) factor_ptr: usize,
360    pub(crate) rows: usize,
361    pub(crate) cols: usize,
362    pub(crate) row_stride: isize,
363    pub(crate) col_stride: isize,
364    pub(crate) value_hash: u64,
365    pub(crate) value_hash2: u64,
366}
367
368impl ProjectedFactorKey {
369    pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
370        let strides = factor.strides();
371        let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
372        Self {
373            design_id,
374            factor_ptr: factor.as_ptr() as usize,
375            rows: factor.nrows(),
376            cols: factor.ncols(),
377            row_stride: strides[0],
378            col_stride: strides[1],
379            value_hash,
380            value_hash2,
381        }
382    }
383
384    /// Construct a synthetic, unique-by-`seed` key without going through
385    /// [`Self::from_factor_view`]. Used by cache tests that need to inject
386    /// fingerprints directly (and deterministically) rather than relying on
387    /// ndarray pointer aliasing, which the real constructor keys on.
388    pub fn synthetic(seed: u64) -> Self {
389        Self {
390            design_id: 1,
391            factor_ptr: seed as usize,
392            rows: 1,
393            cols: 1,
394            row_stride: 1,
395            col_stride: 1,
396            value_hash: seed,
397            value_hash2: seed.wrapping_mul(31),
398        }
399    }
400}
401
402pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
403    let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
404    let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
405    for (idx, value) in factor.iter().enumerate() {
406        let bits = value.to_bits();
407        let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
408        h1 ^= mixed;
409        h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
410        h2 ^= bits.rotate_left((idx & 63) as u32);
411        h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
412    }
413    (h1, h2)
414}
415
416/// Memoizer for projected factor products keyed on a `(design, factor)` fingerprint.
417pub struct ProjectedFactorCache {
418    pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
419}
420
421pub(crate) struct ProjectedFactorCacheInner {
422    pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
423    pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
424    pub(crate) next_seq: u64,
425    pub(crate) total_bytes: usize,
426    pub(crate) budget_bytes: usize,
427}
428
429pub(crate) struct ProjectedFactorInProgress {
430    pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
431    pub(crate) ready: Condvar,
432    pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
433    pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
434}
435
436pub(crate) enum ProjectedFactorInProgressState {
437    Ready(Arc<Array2<f64>>),
438    Failed,
439}
440
441pub(crate) struct ProjectedFactorEntry {
442    pub(crate) value: Arc<Array2<f64>>,
443    pub(crate) bytes: usize,
444    pub(crate) last_used: u64,
445}
446
447impl Default for ProjectedFactorCache {
448    fn default() -> Self {
449        Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
450    }
451}
452
453impl ProjectedFactorCache {
454    pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
455
456    pub fn with_budget(budget_bytes: usize) -> Self {
457        Self {
458            inner: Mutex::new(ProjectedFactorCacheInner {
459                entries: HashMap::new(),
460                in_progress: HashMap::new(),
461                next_seq: 0,
462                total_bytes: 0,
463                budget_bytes,
464            }),
465        }
466    }
467
468    pub fn get_or_insert_with(
469        &self,
470        key: ProjectedFactorKey,
471        compute: impl FnOnce() -> Array2<f64>,
472    ) -> Arc<Array2<f64>> {
473        enum CacheLookup {
474            Hit(Arc<Array2<f64>>),
475            Wait(Arc<ProjectedFactorInProgress>),
476            Compute(Arc<ProjectedFactorInProgress>),
477        }
478
479        let lookup = {
480            let mut inner = self
481                .inner
482                .lock()
483                .expect("projected factor cache lock poisoned");
484            inner.next_seq += 1;
485            let now = inner.next_seq;
486            if let Some(entry) = inner.entries.get_mut(&key) {
487                entry.last_used = now;
488                CacheLookup::Hit(entry.value.clone())
489            } else if let Some(waiter) = inner.in_progress.get(&key) {
490                CacheLookup::Wait(waiter.clone())
491            } else {
492                let marker = Arc::new(ProjectedFactorInProgress {
493                    state: Mutex::new(None),
494                    ready: Condvar::new(),
495                    waiter_count: std::sync::atomic::AtomicUsize::new(0),
496                    subscriber_arrived: (Mutex::new(()), Condvar::new()),
497                });
498                inner.in_progress.insert(key, marker.clone());
499                CacheLookup::Compute(marker)
500            }
501        };
502
503        match lookup {
504            CacheLookup::Hit(value) => value,
505            CacheLookup::Wait(marker) => {
506                marker
507                    .waiter_count
508                    .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
509                let (lock, cv) = &marker.subscriber_arrived;
510                drop(
511                    lock.lock()
512                        .expect("subscriber-arrived notification lock poisoned"),
513                );
514                cv.notify_all();
515                let mut guard = marker
516                    .state
517                    .lock()
518                    .expect("projected factor in-progress lock poisoned");
519                let result = loop {
520                    match guard.as_ref() {
521                        Some(ProjectedFactorInProgressState::Ready(value)) => {
522                            break value.clone();
523                        }
524                        Some(ProjectedFactorInProgressState::Failed) => {
525                            marker
526                                .waiter_count
527                                .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
528                            reml_contract_panic("projected factor cache producer panicked")
529                        }
530                        None => {
531                            guard = marker
532                                .ready
533                                .wait(guard)
534                                .expect("projected factor in-progress wait poisoned");
535                        }
536                    }
537                };
538                marker
539                    .waiter_count
540                    .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
541                result
542            }
543            CacheLookup::Compute(marker) => {
544                let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
545                    Ok(value) => value,
546                    Err(payload) => {
547                        let mut inner = self
548                            .inner
549                            .lock()
550                            .expect("projected factor cache lock poisoned");
551                        inner.in_progress.remove(&key);
552                        drop(inner);
553
554                        let mut guard = marker
555                            .state
556                            .lock()
557                            .expect("projected factor in-progress lock poisoned");
558                        *guard = Some(ProjectedFactorInProgressState::Failed);
559                        marker.ready.notify_all();
560                        resume_unwind(payload);
561                    }
562                };
563                let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
564                let mut inner = self
565                    .inner
566                    .lock()
567                    .expect("projected factor cache lock poisoned");
568                inner.next_seq += 1;
569                let now = inner.next_seq;
570
571                if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
572                    while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
573                        && !inner.entries.is_empty()
574                    {
575                        let Some(oldest_key) = inner
576                            .entries
577                            .iter()
578                            .min_by_key(|(_, e)| e.last_used)
579                            .map(|(k, _)| *k)
580                        else {
581                            break;
582                        };
583                        if let Some(removed) = inner.entries.remove(&oldest_key) {
584                            inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
585                        }
586                    }
587                }
588
589                let value = if let Some(entry) = inner.entries.get_mut(&key) {
590                    entry.last_used = now;
591                    entry.value.clone()
592                } else {
593                    inner.entries.insert(
594                        key,
595                        ProjectedFactorEntry {
596                            value: computed.clone(),
597                            bytes,
598                            last_used: now,
599                        },
600                    );
601                    inner.total_bytes = inner.total_bytes.saturating_add(bytes);
602                    computed
603                };
604                inner.in_progress.remove(&key);
605                drop(inner);
606
607                let mut guard = marker
608                    .state
609                    .lock()
610                    .expect("projected factor in-progress lock poisoned");
611                *guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
612                marker.ready.notify_all();
613                value
614            }
615        }
616    }
617
618    pub fn len(&self) -> usize {
619        self.inner
620            .lock()
621            .map(|inner| inner.entries.len())
622            .unwrap_or(0)
623    }
624
625    pub fn total_bytes(&self) -> usize {
626        self.inner
627            .lock()
628            .map(|inner| inner.total_bytes)
629            .unwrap_or(0)
630    }
631
632    pub fn is_empty(&self) -> bool {
633        self.len() == 0
634    }
635
636    /// Test/diagnostic affordance: block until a consumer has subscribed to the
637    /// in-progress slot for `key` (i.e. is waiting on the producer), or until
638    /// `timeout` elapses. Returns `true` if a subscriber arrived, `false` if the
639    /// key has no in-progress slot or the wait timed out.
640    ///
641    /// This lives on the cache because it reaches into the per-key subscriber
642    /// condvar and waiter counter, which are private synchronization internals;
643    /// exposing it as a method keeps those fields encapsulated while still
644    /// letting downstream tests deterministically order producer/consumer
645    /// interleavings.
646    pub fn wait_for_subscriber(
647        &self,
648        key: ProjectedFactorKey,
649        timeout: std::time::Duration,
650    ) -> bool {
651        let marker = {
652            let inner = self
653                .inner
654                .lock()
655                .expect("projected factor cache lock poisoned");
656            let Some(m) = inner.in_progress.get(&key) else {
657                return false;
658            };
659            Arc::clone(m)
660        };
661        if marker
662            .waiter_count
663            .load(std::sync::atomic::Ordering::Acquire)
664            > 0
665        {
666            return true;
667        }
668        let (lock, cv) = &marker.subscriber_arrived;
669        let mut guard = lock
670            .lock()
671            .expect("subscriber-arrived notification lock poisoned");
672        let deadline = std::time::Instant::now() + timeout;
673        loop {
674            if marker
675                .waiter_count
676                .load(std::sync::atomic::Ordering::Acquire)
677                > 0
678            {
679                return true;
680            }
681            let now = std::time::Instant::now();
682            if now >= deadline {
683                return false;
684            }
685            let (next_guard, result) = cv
686                .wait_timeout(guard, deadline - now)
687                .expect("subscriber-arrived wait poisoned");
688            guard = next_guard;
689            if result.timed_out()
690                && marker
691                    .waiter_count
692                    .load(std::sync::atomic::Ordering::Acquire)
693                    == 0
694            {
695                return false;
696            }
697        }
698    }
699}
700
701#[derive(Clone)]
702pub struct DenseMatrixHyperOperator {
703    pub matrix: Array2<f64>,
704}
705
706impl HyperOperator for DenseMatrixHyperOperator {
707    fn dim(&self) -> usize {
708        self.matrix.nrows()
709    }
710
711    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
712        self.matrix.dot(v)
713    }
714
715    fn as_any(&self) -> &(dyn Any + 'static) {
716        self
717    }
718
719    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
720        self.matrix.dot(&v)
721    }
722
723    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
724        assert_eq!(self.matrix.ncols(), v.len());
725        assert_eq!(self.matrix.nrows(), out.len());
726        for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
727            *out_value = row.dot(&v);
728        }
729    }
730
731    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
732        let end = start + out.ncols();
733        assert!(end <= self.matrix.ncols());
734        out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
735    }
736
737    fn scaled_add_mul_vec(
738        &self,
739        v: ArrayView1<'_, f64>,
740        scale: f64,
741        mut out: ArrayViewMut1<'_, f64>,
742    ) {
743        assert_eq!(self.matrix.ncols(), v.len());
744        assert_eq!(self.matrix.nrows(), out.len());
745        if scale == 0.0 {
746            return;
747        }
748        for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
749            *out_value += scale * row.dot(&v);
750        }
751    }
752
753    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
754        dense_bilinear(&self.matrix, v.view(), u.view())
755    }
756
757    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
758        dense_bilinear(&self.matrix, v, u)
759    }
760
761    fn to_dense(&self) -> Array2<f64> {
762        self.matrix.clone()
763    }
764
765    fn is_implicit(&self) -> bool {
766        false
767    }
768}
769
770#[derive(Clone)]
771pub struct BlockLocalDrift {
772    pub local: Array2<f64>,
773    pub start: usize,
774    pub end: usize,
775    pub total_dim: usize,
776}
777
778impl HyperOperator for BlockLocalDrift {
779    fn dim(&self) -> usize {
780        self.total_dim
781    }
782
783    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
784        assert_eq!(v.len(), self.total_dim);
785        let mut out = Array1::zeros(self.total_dim);
786        self.mul_vec_into(v.view(), out.view_mut());
787        out
788    }
789
790    fn as_any(&self) -> &(dyn Any + 'static) {
791        self
792    }
793
794    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
795        assert_eq!(v.len(), self.total_dim);
796        assert_eq!(out.len(), self.total_dim);
797        out.fill(0.0);
798        let v_block = v.slice(ndarray::s![self.start..self.end]);
799        let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
800        dense_matvec_into(&self.local, v_block, out_block.view_mut());
801    }
802
803    fn scaled_add_mul_vec(
804        &self,
805        v: ArrayView1<'_, f64>,
806        scale: f64,
807        mut out: ArrayViewMut1<'_, f64>,
808    ) {
809        assert_eq!(v.len(), self.total_dim);
810        assert_eq!(out.len(), self.total_dim);
811        if scale == 0.0 {
812            return;
813        }
814        let v_block = v.slice(ndarray::s![self.start..self.end]);
815        let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
816        dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
817    }
818
819    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
820        self.bilinear_view(v.view(), u.view())
821    }
822
823    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
824        assert_eq!(v.len(), self.total_dim);
825        assert_eq!(u.len(), self.total_dim);
826        let v_block = v.slice(ndarray::s![self.start..self.end]);
827        let u_block = u.slice(ndarray::s![self.start..self.end]);
828        dense_bilinear(&self.local, v_block, u_block)
829    }
830
831    fn to_dense(&self) -> Array2<f64> {
832        let p = self.total_dim;
833        let mut out = Array2::zeros((p, p));
834        out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
835            .assign(&self.local);
836        out
837    }
838
839    fn is_implicit(&self) -> bool {
840        false
841    }
842
843    fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
844        Some((&self.local, self.start, self.end))
845    }
846}
847
848#[derive(Clone)]
849pub struct HyperCoordDrift {
850    pub dense: Option<Array2<f64>>,
851    pub block_local: Option<BlockLocalDrift>,
852    pub operator: Option<Arc<dyn HyperOperator>>,
853}
854
855impl HyperCoordDrift {
856    pub fn none() -> Self {
857        Self {
858            dense: None,
859            block_local: None,
860            operator: None,
861        }
862    }
863
864    pub fn from_dense(dense: Array2<f64>) -> Self {
865        Self {
866            dense: Some(dense),
867            block_local: None,
868            operator: None,
869        }
870    }
871
872    pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
873        Self {
874            dense: None,
875            block_local: None,
876            operator: Some(operator),
877        }
878    }
879
880    pub fn from_parts(
881        dense: Option<Array2<f64>>,
882        operator: Option<Arc<dyn HyperOperator>>,
883    ) -> Self {
884        let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
885        Self {
886            dense,
887            block_local: None,
888            operator,
889        }
890    }
891
892    pub fn from_block_local_and_operator(
893        local: Array2<f64>,
894        start: usize,
895        end: usize,
896        total_dim: usize,
897        operator: Option<Arc<dyn HyperOperator>>,
898    ) -> Self {
899        Self {
900            dense: None,
901            block_local: Some(BlockLocalDrift {
902                local,
903                start,
904                end,
905                total_dim,
906            }),
907            operator,
908        }
909    }
910
911    pub fn has_operator(&self) -> bool {
912        self.operator.is_some()
913    }
914
915    pub fn uses_operator_fast_path(&self) -> bool {
916        self.operator.is_some() || self.block_local.is_some()
917    }
918
919    pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
920        self.operator.as_ref().map(Arc::as_ref)
921    }
922
923    pub fn materialize(&self) -> Array2<f64> {
924        let p = self.infer_dim();
925        if p == 0 {
926            return Array2::zeros((0, 0));
927        }
928        let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
929        if let Some(bl) = &self.block_local {
930            out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
931                .scaled_add(1.0, &bl.local);
932        }
933        if let Some(op) = &self.operator {
934            out += &op.to_dense();
935        }
936        out
937    }
938
939    pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
940        let mut out = Array1::zeros(v.len());
941        self.scaled_add_apply(v.view(), 1.0, &mut out);
942        out
943    }
944
945    pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
946        assert_eq!(v.len(), out.len());
947        if scale == 0.0 {
948            return;
949        }
950        if let Some(dense) = &self.dense {
951            dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
952        }
953        if let Some(bl) = &self.block_local {
954            let v_block = v.slice(ndarray::s![bl.start..bl.end]);
955            let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
956            dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
957        }
958        if let Some(op) = &self.operator {
959            op.scaled_add_mul_vec(v, scale, out.view_mut());
960        }
961    }
962
963    pub(crate) fn infer_dim(&self) -> usize {
964        if let Some(d) = &self.dense {
965            return d.nrows();
966        }
967        if let Some(op) = &self.operator {
968            return op.dim();
969        }
970        if let Some(bl) = &self.block_local {
971            return bl.total_dim;
972        }
973        0
974    }
975}
976
977#[derive(Clone)]
978pub struct HyperCoord {
979    pub a: f64,
980    pub g: Array1<f64>,
981    pub drift: HyperCoordDrift,
982    pub ld_s: f64,
983    pub b_depends_on_beta: bool,
984    pub is_penalty_like: bool,
985    pub firth_g: Option<Array1<f64>>,
986    pub tk_eta_fixed: Option<Array1<f64>>,
987    pub tk_x_fixed: Option<Array2<f64>>,
988}
989
990pub struct HyperCoordPair {
991    pub a: f64,
992    pub g: Array1<f64>,
993    pub b_mat: Array2<f64>,
994    pub b_operator: Option<Box<dyn HyperOperator>>,
995    pub ld_s: f64,
996}
997
998impl HyperCoordPair {
999    pub fn zero() -> Self {
1000        Self {
1001            a: 0.0,
1002            g: Array1::zeros(0),
1003            b_mat: Array2::zeros((0, 0)),
1004            b_operator: None,
1005            ld_s: 0.0,
1006        }
1007    }
1008}
1009
1010#[derive(Clone)]
1011pub enum DriftDerivResult {
1012    Dense(Array2<f64>),
1013    Operator(Arc<dyn HyperOperator>),
1014}
1015
1016impl std::fmt::Debug for DriftDerivResult {
1017    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1018        match self {
1019            Self::Dense(matrix) => f
1020                .debug_tuple("Dense")
1021                .field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
1022                .finish(),
1023            Self::Operator(_) => f
1024                .debug_tuple("Operator")
1025                .field(&"<hyper-operator>")
1026                .finish(),
1027        }
1028    }
1029}
1030
1031impl DriftDerivResult {
1032    pub fn into_operator(self) -> Arc<dyn HyperOperator> {
1033        match self {
1034            Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
1035            Self::Operator(operator) => operator,
1036        }
1037    }
1038
1039    pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
1040        match self {
1041            Self::Dense(matrix) => matrix.dot(v),
1042            Self::Operator(operator) => operator.mul_vec(v),
1043        }
1044    }
1045}
1046
1047pub type FixedDriftDerivFn =
1048    Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
1049
1050pub struct ContractedPsiSecondOrder {
1051    pub objective: Array1<f64>,
1052    pub score: Array2<f64>,
1053    pub hessian: Vec<DriftDerivResult>,
1054    pub ld_s: Array1<f64>,
1055}
1056
1057pub type ContractedPsiSecondOrderFn =
1058    Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;