Skip to main content

gam_problem/
lib.rs

1//! Shared REML/LAML contract types.
2//!
3//! These are the family-facing interfaces for REML outer assembly. They live
4//! below `solver` so families can construct operator-backed derivative payloads
5//! without importing `solver::estimate::reml::reml_outer_engine`.
6
7use std::any::Any;
8use std::collections::HashMap;
9use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
10use std::sync::{Arc, Condvar, Mutex};
11
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14
15#[macro_use]
16mod macros;
17
18pub mod basis_error;
19pub mod block_count_error;
20pub mod block_role;
21pub mod block_spec;
22pub mod coefficient_prior_mean;
23pub mod custom_family_blockwise;
24pub mod custom_family_error;
25pub mod diagnostics;
26pub mod dispersion;
27pub mod dispersion_cov;
28pub mod estimation_error;
29pub mod execution_path;
30pub mod family_options;
31pub mod finite_validation;
32pub mod fisher_rao;
33pub mod gauge;
34pub mod identifiability_audit;
35pub mod joint_penalty;
36mod linalg_helpers;
37mod linear_constraints;
38pub mod monotone_root_error;
39pub mod outer_subsample;
40pub mod penalty_coordinate;
41pub mod penalty_matrix;
42mod pseudo_logdet;
43pub mod psi_design_contract;
44pub mod psi_terms;
45pub mod riemannian_retraction;
46// `ρ`-posterior certificate/escalation DATA types contract-downed (#1521) so
47// gam-solve can store/return them without a back-edge into gam-inference; the
48// computation stays UP in the monolith `inference::rho_posterior`.
49pub mod rho_posterior;
50pub mod row_measure;
51pub mod row_metric;
52pub mod schedule;
53// #1521 contract-downs: pure-data carriers + caller-supplied sampler/verdict
54// traits so gam-solve can call up-tier work (NUTS sampling, topology verdicts)
55// without a back-edge into gam-inference/gam-sae; computation stays UP.
56pub mod laplace_sampler_contract;
57mod seeding;
58pub mod solver_contract;
59pub mod topology_certificates;
60pub mod types;
61
62pub use riemannian_retraction::LatentRetractionRegistry;
63pub use row_measure::RowSubsampleMask;
64
65mod gpu {
66    pub(crate) mod linalg_dispatch {
67        use ndarray::{Array2, ArrayView2};
68
69        pub(crate) fn try_fast_atb(
70            a: ArrayView2<'_, f64>,
71            b: ArrayView2<'_, f64>,
72        ) -> Option<Array2<f64>> {
73            let (n_a, p) = a.dim();
74            let (n_b, q) = b.dim();
75            assert_eq!(n_a, n_b, "A and B must have same number of rows");
76            if !crate::linalg_helpers::should_use_faer_matmul(p, q, n_a) {
77                return None;
78            }
79            Some(crate::linalg_helpers::fast_atb_with_parallelism(
80                &a,
81                &b,
82                crate::linalg_helpers::matmul_parallelism(p, q, n_a),
83            ))
84        }
85    }
86}
87
88pub use basis_error::BasisError;
89pub use block_count_error::BlockCountMismatch;
90pub use block_role::BlockRole;
91pub use block_spec::{
92    AdditiveBlockJacobian, BlockEffectiveJacobian, BlockGeometryDirectionalDerivative,
93    BlockWorkingSet, FamilyChannelHessian, FamilyLinearizationState, GaugeComposedJacobian,
94    ParameterBlockSpec, ParameterBlockState, RowScaledJacobian, TensorChannelHessian,
95};
96pub use coefficient_prior_mean::{CoefficientPriorMean, PriorMeanError};
97pub use custom_family_blockwise::{
98    CUSTOM_FAMILY_RIDGE_FLOOR, CUSTOM_FAMILY_WEIGHT_FLOOR, ExactNewtonOuterCurvature,
99    validate_blockspec_consistency,
100};
101pub use custom_family_error::CustomFamilyError;
102pub use dispersion::Dispersion;
103pub use dispersion_cov::{
104    DispersionExt, PhiScaledCovariance, UnscaledPrecision, se_from_covariance,
105};
106pub use estimation_error::EstimationError;
107pub use execution_path::ExecutionPath;
108pub use family_options::{ExactNewtonOuterObjective, ExactOuterDerivativeOrder};
109pub use finite_validation::{
110    bail_if_cached_beta_non_finite, ensure_finite_scalar, ensure_finite_scalar_estimation,
111    validate_all_finite, validate_all_finite_estimation,
112};
113pub use fisher_rao::{
114    FisherRaoDefiniteness, normalize_fisher_rao_blocks, normalize_fisher_rao_blocks_pd,
115};
116pub use gam_linalg::faer_ndarray::{in_nested_parallel_region, with_nested_parallel};
117pub use gauge::Gauge;
118pub use identifiability_audit::{
119    AliasedPair, BlockIdentity, DroppedColumn, IdentifiabilityAudit, MapUniquenessError,
120};
121pub use joint_penalty::{JointPenaltyBundle, JointPenaltyError, JointPenaltySpec};
122use linalg_helpers::{dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into};
123pub use linear_constraints::LinearInequalityConstraints;
124pub use monotone_root_error::MonotoneRootError;
125pub use penalty_coordinate::PenaltyCoordinate;
126pub use penalty_matrix::PenaltyMatrix;
127pub use pseudo_logdet::PseudoLogdetMode;
128pub use psi_design_contract::{
129    CustomFamilyBlockPsiDerivative, CustomFamilyPsiDerivativeOperator,
130    JointHessianSourcePreference, MaterializablePsiDerivativeOperator, MaterializationIntent,
131    SharedDerivativeBlocks,
132};
133pub use psi_terms::{
134    ExactNewtonJointPsiSecondOrderContracted, ExactNewtonJointPsiSecondOrderTerms,
135    ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace,
136};
137pub use row_metric::{MetricProvenance, RowMetric, WeightField, pack_probe_factors};
138pub use schedule::{GumbelTemperatureSchedule, ScheduleKind, SearchStrategy};
139pub use seeding::{SeedConfig, SeedRiskProfile, clamp_seed_rho_to_bounds, normalize_seed_bounds};
140pub use solver_contract::{
141    DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterEval,
142    OuterHessianMaterialization, OuterHessianOperator, OuterStrategyError,
143};
144pub use types::*;
145
146#[cold]
147fn reml_contract_panic(message: impl Into<String>) -> ! {
148    std::panic::panic_any(message.into())
149}
150
151/// Evaluation mode for the unified evaluator.
152#[derive(Clone, Copy, Debug, PartialEq, Eq)]
153pub enum EvalMode {
154    /// Compute cost only (e.g., for line search).
155    ValueOnly,
156    /// Compute cost and gradient (the common case).
157    ValueAndGradient,
158    /// Compute cost, gradient, and outer Hessian.
159    ValueGradientHessian,
160}
161
162/// Trait for operators that can compute a hyper-derivative matrix-vector product
163/// without necessarily materializing the full matrix.
164struct NonDowncastableHyperOperator;
165
166static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
167
168pub trait HyperOperator: Send + Sync {
169    /// Operator dimension `p` such that `B · v` consumes a `p`-vector and
170    /// produces a `p`-vector.
171    fn dim(&self) -> usize;
172
173    /// Compute B · v (matrix-vector product). v and result are p-vectors.
174    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
175
176    /// Expose the concrete type for solver-local downcast helpers when the
177    /// implementor has a `'static` concrete type. Borrowing adapters may keep
178    /// the default, which simply cannot downcast.
179    fn as_any(&self) -> &(dyn Any + 'static) {
180        &NON_DOWNCASTABLE_HYPER_OPERATOR
181    }
182
183    /// Compute B · v from a vector view.
184    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
185        self.mul_vec(&v.to_owned())
186    }
187
188    /// Compute B · v into caller-owned storage.
189    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
190        out.assign(&self.mul_vec_view(v));
191    }
192
193    /// Compute B · F where F is (p × k). Default dispatches per-column in
194    /// parallel unless already inside a rayon worker.
195    fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
196        let p = factor.nrows();
197        let k = factor.ncols();
198        let mut out = Array2::<f64>::zeros((p, k));
199        if rayon::current_thread_index().is_some() {
200            for col in 0..k {
201                let bv = out.column_mut(col);
202                self.mul_vec_into(factor.column(col), bv);
203            }
204            return out;
205        }
206        let cols: Vec<Array1<f64>> = (0..k)
207            .into_par_iter()
208            .map(|col| {
209                let mut bv = Array1::<f64>::zeros(p);
210                self.mul_vec_into(factor.column(col), bv.view_mut());
211                bv
212            })
213            .collect();
214        for (col, bv) in cols.into_iter().enumerate() {
215            out.column_mut(col).assign(&bv);
216        }
217        out
218    }
219
220    /// Compute `trace(F^T B F)` for a `(p x k)` factor matrix `F`.
221    fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
222        let op_factor = self.mul_mat(factor);
223        factor
224            .iter()
225            .zip(op_factor.iter())
226            .map(|(&f, &bf)| f * bf)
227            .sum()
228    }
229
230    /// Optional stable identity for this operator's action `B`. When `Some`,
231    /// the default cached trace / projected-matrix paths memoize the `B · F`
232    /// product in the shared [`ProjectedFactorCache`] under a
233    /// `(design_id, factor)` key, so repeated projections of the same factor
234    /// against the same operator within one outer iteration build `B · F`
235    /// once. `None` (the default) disables that reuse: an operator with no
236    /// design factor stable across calls cannot key the cache without risking
237    /// a stale `B · F`, so it recomputes every time.
238    fn projection_design_id(&self) -> Option<usize> {
239        None
240    }
241
242    fn trace_projected_factor_cached(
243        &self,
244        factor: &Array2<f64>,
245        factor_cache: &ProjectedFactorCache,
246    ) -> f64 {
247        // The default implementation has no use for the caller-owned cache;
248        // verify the cache object carries a positive-size allocation before
249        // delegating to the exact path.
250        assert!(std::mem::size_of_val(factor_cache) > 0);
251        match self.projection_design_id() {
252            Some(design_id) => {
253                let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
254                let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
255                factor
256                    .iter()
257                    .zip(projected.iter())
258                    .map(|(&f, &bf)| f * bf)
259                    .sum()
260            }
261            None => self.trace_projected_factor(factor),
262        }
263    }
264
265    /// Compute the exact projected matrix `F^T B F`.
266    fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
267        let op_factor = self.mul_mat(factor);
268        crate::linalg_helpers::fast_atb(factor, &op_factor)
269    }
270
271    /// Compute the exact projected matrix `F^T B F`, reusing caller-owned
272    /// projection caches when the operator has a shared row/design factor.
273    fn projected_matrix_cached(
274        &self,
275        factor: &Array2<f64>,
276        factor_cache: &ProjectedFactorCache,
277    ) -> Array2<f64> {
278        assert!(std::mem::size_of_val(factor_cache) > 0);
279        match self.projection_design_id() {
280            Some(design_id) => {
281                let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
282                let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
283                crate::linalg_helpers::fast_atb(factor, projected.as_ref())
284            }
285            None => self.projected_matrix(factor),
286        }
287    }
288
289    /// Fill columns `[start, start + out.ncols())` of `B` into `out`.
290    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
291        let cols = out.ncols();
292        let dim = out.nrows();
293        assert!(start + cols <= dim);
294        let mut basis = Array1::<f64>::zeros(dim);
295        for local_col in 0..cols {
296            let global_col = start + local_col;
297            basis[global_col] = 1.0;
298            self.mul_vec_into(basis.view(), out.column_mut(local_col));
299            basis[global_col] = 0.0;
300        }
301    }
302
303    /// Accumulate `scale * B · v` into caller-owned storage.
304    fn scaled_add_mul_vec(
305        &self,
306        v: ArrayView1<'_, f64>,
307        scale: f64,
308        mut out: ArrayViewMut1<'_, f64>,
309    ) {
310        if scale == 0.0 {
311            return;
312        }
313        let mut work = Array1::<f64>::zeros(out.len());
314        self.mul_vec_into(v, work.view_mut());
315        out.scaled_add(scale, &work);
316    }
317
318    /// Compute v^T · B · u (bilinear form).
319    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
320        let mut bv = Array1::<f64>::zeros(v.len());
321        self.mul_vec_into(v.view(), bv.view_mut());
322        u.dot(&bv)
323    }
324
325    /// Compute v^T · B · u without requiring owned vector inputs.
326    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
327        let mut bv = Array1::<f64>::zeros(v.len());
328        self.mul_vec_into(v, bv.view_mut());
329        u.dot(&bv)
330    }
331
332    /// Whether `bilinear_view` is implemented as a direct scalar contraction.
333    fn has_fast_bilinear_view(&self) -> bool {
334        false
335    }
336
337    /// Full dense materialization.
338    fn to_dense(&self) -> Array2<f64> {
339        let p = self.dim();
340        let mut out = Array2::<f64>::zeros((p, p));
341        let mut basis = Array1::<f64>::zeros(p);
342        for j in 0..p {
343            basis[j] = 1.0;
344            self.mul_vec_into(basis.view(), out.column_mut(j));
345            basis[j] = 0.0;
346        }
347        out
348    }
349
350    /// Whether this operator uses implicit (non-materialized) storage.
351    fn is_implicit(&self) -> bool;
352
353    /// If this operator is block-local, returns the block range and local matrix.
354    fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
355        None
356    }
357}
358
359#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
360pub struct ProjectedFactorKey {
361    pub(crate) design_id: usize,
362    pub(crate) factor_ptr: usize,
363    pub(crate) rows: usize,
364    pub(crate) cols: usize,
365    pub(crate) row_stride: isize,
366    pub(crate) col_stride: isize,
367    pub(crate) value_hash: u64,
368    pub(crate) value_hash2: u64,
369}
370
371impl ProjectedFactorKey {
372    pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
373        let strides = factor.strides();
374        let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
375        Self {
376            design_id,
377            factor_ptr: factor.as_ptr() as usize,
378            rows: factor.nrows(),
379            cols: factor.ncols(),
380            row_stride: strides[0],
381            col_stride: strides[1],
382            value_hash,
383            value_hash2,
384        }
385    }
386
387    /// Construct a synthetic, unique-by-`seed` key without going through
388    /// [`Self::from_factor_view`]. Used by cache tests that need to inject
389    /// fingerprints directly (and deterministically) rather than relying on
390    /// ndarray pointer aliasing, which the real constructor keys on.
391    pub fn synthetic(seed: u64) -> Self {
392        Self {
393            design_id: 1,
394            factor_ptr: seed as usize,
395            rows: 1,
396            cols: 1,
397            row_stride: 1,
398            col_stride: 1,
399            value_hash: seed,
400            value_hash2: seed.wrapping_mul(31),
401        }
402    }
403}
404
405pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
406    let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
407    let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
408    for (idx, value) in factor.iter().enumerate() {
409        let bits = value.to_bits();
410        let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
411        h1 ^= mixed;
412        h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
413        h2 ^= bits.rotate_left((idx & 63) as u32);
414        h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
415    }
416    (h1, h2)
417}
418
419/// Memoizer for projected factor products keyed on a `(design, factor)` fingerprint.
420pub struct ProjectedFactorCache {
421    pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
422}
423
424pub(crate) struct ProjectedFactorCacheInner {
425    pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
426    pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
427    pub(crate) next_seq: u64,
428    pub(crate) total_bytes: usize,
429    pub(crate) budget_bytes: usize,
430}
431
432pub(crate) struct ProjectedFactorInProgress {
433    pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
434    pub(crate) ready: Condvar,
435    pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
436    pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
437}
438
439pub(crate) enum ProjectedFactorInProgressState {
440    Ready(Arc<Array2<f64>>),
441    Failed,
442}
443
444pub(crate) struct ProjectedFactorEntry {
445    pub(crate) value: Arc<Array2<f64>>,
446    pub(crate) bytes: usize,
447    pub(crate) last_used: u64,
448}
449
450impl Default for ProjectedFactorCache {
451    fn default() -> Self {
452        Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
453    }
454}
455
456impl ProjectedFactorCache {
457    pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
458
459    pub fn with_budget(budget_bytes: usize) -> Self {
460        Self {
461            inner: Mutex::new(ProjectedFactorCacheInner {
462                entries: HashMap::new(),
463                in_progress: HashMap::new(),
464                next_seq: 0,
465                total_bytes: 0,
466                budget_bytes,
467            }),
468        }
469    }
470
471    pub fn get_or_insert_with(
472        &self,
473        key: ProjectedFactorKey,
474        compute: impl FnOnce() -> Array2<f64>,
475    ) -> Arc<Array2<f64>> {
476        enum CacheLookup {
477            Hit(Arc<Array2<f64>>),
478            Wait(Arc<ProjectedFactorInProgress>),
479            Compute(Arc<ProjectedFactorInProgress>),
480        }
481
482        let lookup = {
483            let mut inner = self
484                .inner
485                .lock()
486                .expect("projected factor cache lock poisoned");
487            inner.next_seq += 1;
488            let now = inner.next_seq;
489            if let Some(entry) = inner.entries.get_mut(&key) {
490                entry.last_used = now;
491                CacheLookup::Hit(entry.value.clone())
492            } else if let Some(waiter) = inner.in_progress.get(&key) {
493                CacheLookup::Wait(waiter.clone())
494            } else {
495                let marker = Arc::new(ProjectedFactorInProgress {
496                    state: Mutex::new(None),
497                    ready: Condvar::new(),
498                    waiter_count: std::sync::atomic::AtomicUsize::new(0),
499                    subscriber_arrived: (Mutex::new(()), Condvar::new()),
500                });
501                inner.in_progress.insert(key, marker.clone());
502                CacheLookup::Compute(marker)
503            }
504        };
505
506        match lookup {
507            CacheLookup::Hit(value) => value,
508            CacheLookup::Wait(marker) => {
509                marker
510                    .waiter_count
511                    .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
512                let (lock, cv) = &marker.subscriber_arrived;
513                drop(
514                    lock.lock()
515                        .expect("subscriber-arrived notification lock poisoned"),
516                );
517                cv.notify_all();
518                let mut guard = marker
519                    .state
520                    .lock()
521                    .expect("projected factor in-progress lock poisoned");
522                let result = loop {
523                    match guard.as_ref() {
524                        Some(ProjectedFactorInProgressState::Ready(value)) => {
525                            break value.clone();
526                        }
527                        Some(ProjectedFactorInProgressState::Failed) => {
528                            marker
529                                .waiter_count
530                                .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
531                            reml_contract_panic("projected factor cache producer panicked")
532                        }
533                        None => {
534                            guard = marker
535                                .ready
536                                .wait(guard)
537                                .expect("projected factor in-progress wait poisoned");
538                        }
539                    }
540                };
541                marker
542                    .waiter_count
543                    .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
544                result
545            }
546            CacheLookup::Compute(marker) => {
547                let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
548                    Ok(value) => value,
549                    Err(payload) => {
550                        let mut inner = self
551                            .inner
552                            .lock()
553                            .expect("projected factor cache lock poisoned");
554                        inner.in_progress.remove(&key);
555                        drop(inner);
556
557                        let mut guard = marker
558                            .state
559                            .lock()
560                            .expect("projected factor in-progress lock poisoned");
561                        *guard = Some(ProjectedFactorInProgressState::Failed);
562                        marker.ready.notify_all();
563                        resume_unwind(payload);
564                    }
565                };
566                let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
567                let mut inner = self
568                    .inner
569                    .lock()
570                    .expect("projected factor cache lock poisoned");
571                inner.next_seq += 1;
572                let now = inner.next_seq;
573
574                if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
575                    while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
576                        && !inner.entries.is_empty()
577                    {
578                        let Some(oldest_key) = inner
579                            .entries
580                            .iter()
581                            .min_by_key(|(_, e)| e.last_used)
582                            .map(|(k, _)| *k)
583                        else {
584                            break;
585                        };
586                        if let Some(removed) = inner.entries.remove(&oldest_key) {
587                            inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
588                        }
589                    }
590                }
591
592                let value = if let Some(entry) = inner.entries.get_mut(&key) {
593                    entry.last_used = now;
594                    entry.value.clone()
595                } else {
596                    inner.entries.insert(
597                        key,
598                        ProjectedFactorEntry {
599                            value: computed.clone(),
600                            bytes,
601                            last_used: now,
602                        },
603                    );
604                    inner.total_bytes = inner.total_bytes.saturating_add(bytes);
605                    computed
606                };
607                inner.in_progress.remove(&key);
608                drop(inner);
609
610                let mut guard = marker
611                    .state
612                    .lock()
613                    .expect("projected factor in-progress lock poisoned");
614                *guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
615                marker.ready.notify_all();
616                value
617            }
618        }
619    }
620
621    pub fn len(&self) -> usize {
622        self.inner
623            .lock()
624            .map(|inner| inner.entries.len())
625            .unwrap_or(0)
626    }
627
628    pub fn total_bytes(&self) -> usize {
629        self.inner
630            .lock()
631            .map(|inner| inner.total_bytes)
632            .unwrap_or(0)
633    }
634
635    pub fn is_empty(&self) -> bool {
636        self.len() == 0
637    }
638
639    /// Test/diagnostic affordance: block until a consumer has subscribed to the
640    /// in-progress slot for `key` (i.e. is waiting on the producer), or until
641    /// `timeout` elapses. Returns `true` if a subscriber arrived, `false` if the
642    /// key has no in-progress slot or the wait timed out.
643    ///
644    /// This lives on the cache because it reaches into the per-key subscriber
645    /// condvar and waiter counter, which are private synchronization internals;
646    /// exposing it as a method keeps those fields encapsulated while still
647    /// letting downstream tests deterministically order producer/consumer
648    /// interleavings.
649    pub fn wait_for_subscriber(
650        &self,
651        key: ProjectedFactorKey,
652        timeout: std::time::Duration,
653    ) -> bool {
654        let marker = {
655            let inner = self
656                .inner
657                .lock()
658                .expect("projected factor cache lock poisoned");
659            let Some(m) = inner.in_progress.get(&key) else {
660                return false;
661            };
662            Arc::clone(m)
663        };
664        if marker
665            .waiter_count
666            .load(std::sync::atomic::Ordering::Acquire)
667            > 0
668        {
669            return true;
670        }
671        let (lock, cv) = &marker.subscriber_arrived;
672        let mut guard = lock
673            .lock()
674            .expect("subscriber-arrived notification lock poisoned");
675        let deadline = std::time::Instant::now() + timeout;
676        loop {
677            if marker
678                .waiter_count
679                .load(std::sync::atomic::Ordering::Acquire)
680                > 0
681            {
682                return true;
683            }
684            let now = std::time::Instant::now();
685            if now >= deadline {
686                return false;
687            }
688            let (next_guard, result) = cv
689                .wait_timeout(guard, deadline - now)
690                .expect("subscriber-arrived wait poisoned");
691            guard = next_guard;
692            if result.timed_out()
693                && marker
694                    .waiter_count
695                    .load(std::sync::atomic::Ordering::Acquire)
696                    == 0
697            {
698                return false;
699            }
700        }
701    }
702}
703
704#[derive(Clone)]
705pub struct DenseMatrixHyperOperator {
706    pub matrix: Array2<f64>,
707}
708
709impl HyperOperator for DenseMatrixHyperOperator {
710    fn dim(&self) -> usize {
711        self.matrix.nrows()
712    }
713
714    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
715        self.matrix.dot(v)
716    }
717
718    fn as_any(&self) -> &(dyn Any + 'static) {
719        self
720    }
721
722    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
723        self.matrix.dot(&v)
724    }
725
726    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
727        assert_eq!(self.matrix.ncols(), v.len());
728        assert_eq!(self.matrix.nrows(), out.len());
729        for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
730            *out_value = row.dot(&v);
731        }
732    }
733
734    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
735        let end = start + out.ncols();
736        assert!(end <= self.matrix.ncols());
737        out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
738    }
739
740    fn scaled_add_mul_vec(
741        &self,
742        v: ArrayView1<'_, f64>,
743        scale: f64,
744        mut out: ArrayViewMut1<'_, f64>,
745    ) {
746        assert_eq!(self.matrix.ncols(), v.len());
747        assert_eq!(self.matrix.nrows(), out.len());
748        if scale == 0.0 {
749            return;
750        }
751        for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
752            *out_value += scale * row.dot(&v);
753        }
754    }
755
756    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
757        dense_bilinear(&self.matrix, v.view(), u.view())
758    }
759
760    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
761        dense_bilinear(&self.matrix, v, u)
762    }
763
764    fn to_dense(&self) -> Array2<f64> {
765        self.matrix.clone()
766    }
767
768    fn is_implicit(&self) -> bool {
769        false
770    }
771}
772
773#[derive(Clone)]
774pub struct BlockLocalDrift {
775    pub local: Array2<f64>,
776    pub start: usize,
777    pub end: usize,
778    pub total_dim: usize,
779}
780
781impl HyperOperator for BlockLocalDrift {
782    fn dim(&self) -> usize {
783        self.total_dim
784    }
785
786    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
787        assert_eq!(v.len(), self.total_dim);
788        let mut out = Array1::zeros(self.total_dim);
789        self.mul_vec_into(v.view(), out.view_mut());
790        out
791    }
792
793    fn as_any(&self) -> &(dyn Any + 'static) {
794        self
795    }
796
797    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
798        assert_eq!(v.len(), self.total_dim);
799        assert_eq!(out.len(), self.total_dim);
800        out.fill(0.0);
801        let v_block = v.slice(ndarray::s![self.start..self.end]);
802        let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
803        dense_matvec_into(&self.local, v_block, out_block.view_mut());
804    }
805
806    fn scaled_add_mul_vec(
807        &self,
808        v: ArrayView1<'_, f64>,
809        scale: f64,
810        mut out: ArrayViewMut1<'_, f64>,
811    ) {
812        assert_eq!(v.len(), self.total_dim);
813        assert_eq!(out.len(), self.total_dim);
814        if scale == 0.0 {
815            return;
816        }
817        let v_block = v.slice(ndarray::s![self.start..self.end]);
818        let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
819        dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
820    }
821
822    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
823        self.bilinear_view(v.view(), u.view())
824    }
825
826    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
827        assert_eq!(v.len(), self.total_dim);
828        assert_eq!(u.len(), self.total_dim);
829        let v_block = v.slice(ndarray::s![self.start..self.end]);
830        let u_block = u.slice(ndarray::s![self.start..self.end]);
831        dense_bilinear(&self.local, v_block, u_block)
832    }
833
834    fn to_dense(&self) -> Array2<f64> {
835        let p = self.total_dim;
836        let mut out = Array2::zeros((p, p));
837        out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
838            .assign(&self.local);
839        out
840    }
841
842    fn is_implicit(&self) -> bool {
843        false
844    }
845
846    fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
847        Some((&self.local, self.start, self.end))
848    }
849}
850
851#[derive(Clone)]
852pub struct HyperCoordDrift {
853    pub dense: Option<Array2<f64>>,
854    pub block_local: Option<BlockLocalDrift>,
855    pub operator: Option<Arc<dyn HyperOperator>>,
856}
857
858impl HyperCoordDrift {
859    pub fn none() -> Self {
860        Self {
861            dense: None,
862            block_local: None,
863            operator: None,
864        }
865    }
866
867    pub fn from_dense(dense: Array2<f64>) -> Self {
868        Self {
869            dense: Some(dense),
870            block_local: None,
871            operator: None,
872        }
873    }
874
875    pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
876        Self {
877            dense: None,
878            block_local: None,
879            operator: Some(operator),
880        }
881    }
882
883    pub fn from_parts(
884        dense: Option<Array2<f64>>,
885        operator: Option<Arc<dyn HyperOperator>>,
886    ) -> Self {
887        let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
888        Self {
889            dense,
890            block_local: None,
891            operator,
892        }
893    }
894
895    pub fn from_block_local_and_operator(
896        local: Array2<f64>,
897        start: usize,
898        end: usize,
899        total_dim: usize,
900        operator: Option<Arc<dyn HyperOperator>>,
901    ) -> Self {
902        Self {
903            dense: None,
904            block_local: Some(BlockLocalDrift {
905                local,
906                start,
907                end,
908                total_dim,
909            }),
910            operator,
911        }
912    }
913
914    pub fn has_operator(&self) -> bool {
915        self.operator.is_some()
916    }
917
918    pub fn uses_operator_fast_path(&self) -> bool {
919        self.operator.is_some() || self.block_local.is_some()
920    }
921
922    pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
923        self.operator.as_ref().map(Arc::as_ref)
924    }
925
926    pub fn materialize(&self) -> Array2<f64> {
927        let p = self.infer_dim();
928        if p == 0 {
929            return Array2::zeros((0, 0));
930        }
931        let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
932        if let Some(bl) = &self.block_local {
933            out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
934                .scaled_add(1.0, &bl.local);
935        }
936        if let Some(op) = &self.operator {
937            out += &op.to_dense();
938        }
939        out
940    }
941
942    pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
943        let mut out = Array1::zeros(v.len());
944        self.scaled_add_apply(v.view(), 1.0, &mut out);
945        out
946    }
947
948    pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
949        assert_eq!(v.len(), out.len());
950        if scale == 0.0 {
951            return;
952        }
953        if let Some(dense) = &self.dense {
954            dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
955        }
956        if let Some(bl) = &self.block_local {
957            let v_block = v.slice(ndarray::s![bl.start..bl.end]);
958            let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
959            dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
960        }
961        if let Some(op) = &self.operator {
962            op.scaled_add_mul_vec(v, scale, out.view_mut());
963        }
964    }
965
966    pub(crate) fn infer_dim(&self) -> usize {
967        if let Some(d) = &self.dense {
968            return d.nrows();
969        }
970        if let Some(op) = &self.operator {
971            return op.dim();
972        }
973        if let Some(bl) = &self.block_local {
974            return bl.total_dim;
975        }
976        0
977    }
978}
979
980#[derive(Clone)]
981pub struct HyperCoord {
982    pub a: f64,
983    pub g: Array1<f64>,
984    pub drift: HyperCoordDrift,
985    pub ld_s: f64,
986    pub b_depends_on_beta: bool,
987    pub is_penalty_like: bool,
988    pub firth_g: Option<Array1<f64>>,
989    pub tk_eta_fixed: Option<Array1<f64>>,
990    pub tk_x_fixed: Option<Array2<f64>>,
991}
992
993pub struct HyperCoordPair {
994    pub a: f64,
995    pub g: Array1<f64>,
996    pub b_mat: Array2<f64>,
997    pub b_operator: Option<Box<dyn HyperOperator>>,
998    pub ld_s: f64,
999}
1000
1001/// Shared-ownership callback computing a second-order fixed-β
1002/// [`HyperCoordPair`] for a coordinate pair `(i, j)`.
1003///
1004/// `Arc` (not `Box`) so the same callback can be cloned into a derived
1005/// `InnerSolution` — notably the tangent-projected solution built under active
1006/// inequality constraints, which must carry the very same pair callbacks
1007/// through to `ValueGradientHessian` outer-Hessian assembly. The pair objects
1008/// are p-space; every consumer contracts them through the (possibly
1009/// tangent-wrapped) Hessian operator, which applies the `ZᵀMZ` / `Z H_T⁻¹ Zᵀ`
1010/// projection internally, so a clone-through is mathematically exact.
1011pub type HyperCoordPairFn = Arc<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>;
1012
1013impl HyperCoordPair {
1014    pub fn zero() -> Self {
1015        Self {
1016            a: 0.0,
1017            g: Array1::zeros(0),
1018            b_mat: Array2::zeros((0, 0)),
1019            b_operator: None,
1020            ld_s: 0.0,
1021        }
1022    }
1023}
1024
1025#[derive(Clone)]
1026pub enum DriftDerivResult {
1027    Dense(Array2<f64>),
1028    Operator(Arc<dyn HyperOperator>),
1029}
1030
1031impl std::fmt::Debug for DriftDerivResult {
1032    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1033        match self {
1034            Self::Dense(matrix) => f
1035                .debug_tuple("Dense")
1036                .field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
1037                .finish(),
1038            Self::Operator(_) => f
1039                .debug_tuple("Operator")
1040                .field(&"<hyper-operator>")
1041                .finish(),
1042        }
1043    }
1044}
1045
1046impl DriftDerivResult {
1047    pub fn into_operator(self) -> Arc<dyn HyperOperator> {
1048        match self {
1049            Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
1050            Self::Operator(operator) => operator,
1051        }
1052    }
1053
1054    pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
1055        match self {
1056            Self::Dense(matrix) => matrix.dot(v),
1057            Self::Operator(operator) => operator.mul_vec(v),
1058        }
1059    }
1060}
1061
1062pub type FixedDriftDerivFn =
1063    Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
1064
1065/// Shared-ownership form of [`FixedDriftDerivFn`] used for `InnerSolution`
1066/// storage, so the same `M_i[u] = D_β B_i[u]` callback can be cloned into a
1067/// derived (tangent-projected) solution. Construction sites still hand back a
1068/// `Box` ([`FixedDriftDerivFn`]); storage re-tags it via `Arc::from` (free).
1069/// The drift `M` is a p-space matrix that every consumer contracts through the
1070/// (tangent-wrapped) Hessian operator's `trace_logdet_*`, so the clone-through
1071/// is exact under projection.
1072pub type SharedFixedDriftDerivFn =
1073    Arc<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
1074
1075pub struct ContractedPsiSecondOrder {
1076    pub objective: Array1<f64>,
1077    pub score: Array2<f64>,
1078    pub hessian: Vec<DriftDerivResult>,
1079    pub ld_s: Array1<f64>,
1080}
1081
1082pub type ContractedPsiSecondOrderFn =
1083    Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;