Skip to main content

gam_problem/
lib.rs

1//! Shared REML/LAML contract types.
2//!
3//! These are the family-facing interfaces for REML outer assembly. They live
4//! below `solver` so families can construct operator-backed derivative payloads
5//! without importing `solver::estimate::reml::reml_outer_engine`.
6
7use std::any::Any;
8use std::collections::HashMap;
9use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
10use std::sync::{Arc, Condvar, Mutex};
11
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14
15#[macro_use]
16mod macros;
17
18mod linalg_helpers;
19mod linear_constraints;
20mod pseudo_logdet;
21mod seeding;
22pub mod solver_contract;
23
24mod gpu {
25    pub(crate) mod linalg_dispatch {
26        use ndarray::{Array2, ArrayView2};
27
28        pub(crate) fn try_fast_atb(
29            a: ArrayView2<'_, f64>,
30            b: ArrayView2<'_, f64>,
31        ) -> Option<Array2<f64>> {
32            let (n_a, p) = a.dim();
33            let (n_b, q) = b.dim();
34            assert_eq!(n_a, n_b, "A and B must have same number of rows");
35            if !crate::linalg_helpers::should_use_faer_matmul(p, q, n_a) {
36                return None;
37            }
38            Some(crate::linalg_helpers::fast_atb_with_parallelism(
39                &a,
40                &b,
41                crate::linalg_helpers::matmul_parallelism(p, q, n_a),
42            ))
43        }
44    }
45}
46
47pub use gam_linalg::faer_ndarray::{in_nested_parallel_region, with_nested_parallel};
48use linalg_helpers::{dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into};
49pub use linear_constraints::LinearInequalityConstraints;
50pub use pseudo_logdet::PseudoLogdetMode;
51pub use seeding::{SeedConfig, SeedRiskProfile, clamp_seed_rho_to_bounds, normalize_seed_bounds};
52pub use solver_contract::{
53    DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterEval,
54    OuterHessianMaterialization, OuterHessianOperator, OuterStrategyError,
55};
56
57#[cold]
58fn reml_contract_panic(message: impl Into<String>) -> ! {
59    std::panic::panic_any(message.into())
60}
61
62/// Evaluation mode for the unified evaluator.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum EvalMode {
65    /// Compute cost only (e.g., for line search).
66    ValueOnly,
67    /// Compute cost and gradient (the common case).
68    ValueAndGradient,
69    /// Compute cost, gradient, and outer Hessian.
70    ValueGradientHessian,
71}
72
73/// Trait for operators that can compute a hyper-derivative matrix-vector product
74/// without necessarily materializing the full matrix.
75struct NonDowncastableHyperOperator;
76
77static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
78
79pub trait HyperOperator: Send + Sync {
80    /// Operator dimension `p` such that `B · v` consumes a `p`-vector and
81    /// produces a `p`-vector.
82    fn dim(&self) -> usize;
83
84    /// Compute B · v (matrix-vector product). v and result are p-vectors.
85    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
86
87    /// Expose the concrete type for solver-local downcast helpers when the
88    /// implementor has a `'static` concrete type. Borrowing adapters may keep
89    /// the default, which simply cannot downcast.
90    fn as_any(&self) -> &(dyn Any + 'static) {
91        &NON_DOWNCASTABLE_HYPER_OPERATOR
92    }
93
94    /// Compute B · v from a vector view.
95    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
96        self.mul_vec(&v.to_owned())
97    }
98
99    /// Compute B · v into caller-owned storage.
100    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
101        out.assign(&self.mul_vec_view(v));
102    }
103
104    /// Compute B · F where F is (p × k). Default dispatches per-column in
105    /// parallel unless already inside a rayon worker.
106    fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
107        let p = factor.nrows();
108        let k = factor.ncols();
109        let mut out = Array2::<f64>::zeros((p, k));
110        if rayon::current_thread_index().is_some() {
111            for col in 0..k {
112                let bv = out.column_mut(col);
113                self.mul_vec_into(factor.column(col), bv);
114            }
115            return out;
116        }
117        let cols: Vec<Array1<f64>> = (0..k)
118            .into_par_iter()
119            .map(|col| {
120                let mut bv = Array1::<f64>::zeros(p);
121                self.mul_vec_into(factor.column(col), bv.view_mut());
122                bv
123            })
124            .collect();
125        for (col, bv) in cols.into_iter().enumerate() {
126            out.column_mut(col).assign(&bv);
127        }
128        out
129    }
130
131    /// Compute `trace(F^T B F)` for a `(p x k)` factor matrix `F`.
132    fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
133        let op_factor = self.mul_mat(factor);
134        factor
135            .iter()
136            .zip(op_factor.iter())
137            .map(|(&f, &bf)| f * bf)
138            .sum()
139    }
140
141    /// Optional stable identity for this operator's action `B`. When `Some`,
142    /// the default cached trace / projected-matrix paths memoize the `B · F`
143    /// product in the shared [`ProjectedFactorCache`] under a
144    /// `(design_id, factor)` key, so repeated projections of the same factor
145    /// against the same operator within one outer iteration build `B · F`
146    /// once. `None` (the default) disables that reuse: an operator with no
147    /// design factor stable across calls cannot key the cache without risking
148    /// a stale `B · F`, so it recomputes every time.
149    fn projection_design_id(&self) -> Option<usize> {
150        None
151    }
152
153    fn trace_projected_factor_cached(
154        &self,
155        factor: &Array2<f64>,
156        factor_cache: &ProjectedFactorCache,
157    ) -> f64 {
158        // The default implementation has no use for the caller-owned cache;
159        // verify the cache object carries a positive-size allocation before
160        // delegating to the exact path.
161        assert!(std::mem::size_of_val(factor_cache) > 0);
162        match self.projection_design_id() {
163            Some(design_id) => {
164                let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
165                let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
166                factor
167                    .iter()
168                    .zip(projected.iter())
169                    .map(|(&f, &bf)| f * bf)
170                    .sum()
171            }
172            None => self.trace_projected_factor(factor),
173        }
174    }
175
176    /// Compute the exact projected matrix `F^T B F`.
177    fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
178        let op_factor = self.mul_mat(factor);
179        crate::linalg_helpers::fast_atb(factor, &op_factor)
180    }
181
182    /// Compute the exact projected matrix `F^T B F`, reusing caller-owned
183    /// projection caches when the operator has a shared row/design factor.
184    fn projected_matrix_cached(
185        &self,
186        factor: &Array2<f64>,
187        factor_cache: &ProjectedFactorCache,
188    ) -> Array2<f64> {
189        assert!(std::mem::size_of_val(factor_cache) > 0);
190        match self.projection_design_id() {
191            Some(design_id) => {
192                let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
193                let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
194                crate::linalg_helpers::fast_atb(factor, projected.as_ref())
195            }
196            None => self.projected_matrix(factor),
197        }
198    }
199
200    /// Fill columns `[start, start + out.ncols())` of `B` into `out`.
201    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
202        let cols = out.ncols();
203        let dim = out.nrows();
204        assert!(start + cols <= dim);
205        let mut basis = Array1::<f64>::zeros(dim);
206        for local_col in 0..cols {
207            let global_col = start + local_col;
208            basis[global_col] = 1.0;
209            self.mul_vec_into(basis.view(), out.column_mut(local_col));
210            basis[global_col] = 0.0;
211        }
212    }
213
214    /// Accumulate `scale * B · v` into caller-owned storage.
215    fn scaled_add_mul_vec(
216        &self,
217        v: ArrayView1<'_, f64>,
218        scale: f64,
219        mut out: ArrayViewMut1<'_, f64>,
220    ) {
221        if scale == 0.0 {
222            return;
223        }
224        let mut work = Array1::<f64>::zeros(out.len());
225        self.mul_vec_into(v, work.view_mut());
226        out.scaled_add(scale, &work);
227    }
228
229    /// Compute v^T · B · u (bilinear form).
230    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
231        let mut bv = Array1::<f64>::zeros(v.len());
232        self.mul_vec_into(v.view(), bv.view_mut());
233        u.dot(&bv)
234    }
235
236    /// Compute v^T · B · u without requiring owned vector inputs.
237    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
238        let mut bv = Array1::<f64>::zeros(v.len());
239        self.mul_vec_into(v, bv.view_mut());
240        u.dot(&bv)
241    }
242
243    /// Whether `bilinear_view` is implemented as a direct scalar contraction.
244    fn has_fast_bilinear_view(&self) -> bool {
245        false
246    }
247
248    /// Full dense materialization.
249    fn to_dense(&self) -> Array2<f64> {
250        let p = self.dim();
251        let mut out = Array2::<f64>::zeros((p, p));
252        let mut basis = Array1::<f64>::zeros(p);
253        for j in 0..p {
254            basis[j] = 1.0;
255            self.mul_vec_into(basis.view(), out.column_mut(j));
256            basis[j] = 0.0;
257        }
258        out
259    }
260
261    /// Whether this operator uses implicit (non-materialized) storage.
262    fn is_implicit(&self) -> bool;
263
264    /// If this operator is block-local, returns the block range and local matrix.
265    fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
266        None
267    }
268}
269
270#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
271pub struct ProjectedFactorKey {
272    pub(crate) design_id: usize,
273    pub(crate) factor_ptr: usize,
274    pub(crate) rows: usize,
275    pub(crate) cols: usize,
276    pub(crate) row_stride: isize,
277    pub(crate) col_stride: isize,
278    pub(crate) value_hash: u64,
279    pub(crate) value_hash2: u64,
280}
281
282impl ProjectedFactorKey {
283    pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
284        let strides = factor.strides();
285        let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
286        Self {
287            design_id,
288            factor_ptr: factor.as_ptr() as usize,
289            rows: factor.nrows(),
290            cols: factor.ncols(),
291            row_stride: strides[0],
292            col_stride: strides[1],
293            value_hash,
294            value_hash2,
295        }
296    }
297
298    /// Construct a synthetic, unique-by-`seed` key without going through
299    /// [`Self::from_factor_view`]. Used by cache tests that need to inject
300    /// fingerprints directly (and deterministically) rather than relying on
301    /// ndarray pointer aliasing, which the real constructor keys on.
302    pub fn synthetic(seed: u64) -> Self {
303        Self {
304            design_id: 1,
305            factor_ptr: seed as usize,
306            rows: 1,
307            cols: 1,
308            row_stride: 1,
309            col_stride: 1,
310            value_hash: seed,
311            value_hash2: seed.wrapping_mul(31),
312        }
313    }
314}
315
316pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
317    let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
318    let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
319    for (idx, value) in factor.iter().enumerate() {
320        let bits = value.to_bits();
321        let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
322        h1 ^= mixed;
323        h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
324        h2 ^= bits.rotate_left((idx & 63) as u32);
325        h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
326    }
327    (h1, h2)
328}
329
330/// Memoizer for projected factor products keyed on a `(design, factor)` fingerprint.
331pub struct ProjectedFactorCache {
332    pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
333}
334
335pub(crate) struct ProjectedFactorCacheInner {
336    pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
337    pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
338    pub(crate) next_seq: u64,
339    pub(crate) total_bytes: usize,
340    pub(crate) budget_bytes: usize,
341}
342
343pub(crate) struct ProjectedFactorInProgress {
344    pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
345    pub(crate) ready: Condvar,
346    pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
347    pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
348}
349
350pub(crate) enum ProjectedFactorInProgressState {
351    Ready(Arc<Array2<f64>>),
352    Failed,
353}
354
355pub(crate) struct ProjectedFactorEntry {
356    pub(crate) value: Arc<Array2<f64>>,
357    pub(crate) bytes: usize,
358    pub(crate) last_used: u64,
359}
360
361impl Default for ProjectedFactorCache {
362    fn default() -> Self {
363        Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
364    }
365}
366
367impl ProjectedFactorCache {
368    pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
369
370    pub fn with_budget(budget_bytes: usize) -> Self {
371        Self {
372            inner: Mutex::new(ProjectedFactorCacheInner {
373                entries: HashMap::new(),
374                in_progress: HashMap::new(),
375                next_seq: 0,
376                total_bytes: 0,
377                budget_bytes,
378            }),
379        }
380    }
381
382    pub fn get_or_insert_with(
383        &self,
384        key: ProjectedFactorKey,
385        compute: impl FnOnce() -> Array2<f64>,
386    ) -> Arc<Array2<f64>> {
387        enum CacheLookup {
388            Hit(Arc<Array2<f64>>),
389            Wait(Arc<ProjectedFactorInProgress>),
390            Compute(Arc<ProjectedFactorInProgress>),
391        }
392
393        let lookup = {
394            let mut inner = self
395                .inner
396                .lock()
397                .expect("projected factor cache lock poisoned");
398            inner.next_seq += 1;
399            let now = inner.next_seq;
400            if let Some(entry) = inner.entries.get_mut(&key) {
401                entry.last_used = now;
402                CacheLookup::Hit(entry.value.clone())
403            } else if let Some(waiter) = inner.in_progress.get(&key) {
404                CacheLookup::Wait(waiter.clone())
405            } else {
406                let marker = Arc::new(ProjectedFactorInProgress {
407                    state: Mutex::new(None),
408                    ready: Condvar::new(),
409                    waiter_count: std::sync::atomic::AtomicUsize::new(0),
410                    subscriber_arrived: (Mutex::new(()), Condvar::new()),
411                });
412                inner.in_progress.insert(key, marker.clone());
413                CacheLookup::Compute(marker)
414            }
415        };
416
417        match lookup {
418            CacheLookup::Hit(value) => value,
419            CacheLookup::Wait(marker) => {
420                marker
421                    .waiter_count
422                    .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
423                let (lock, cv) = &marker.subscriber_arrived;
424                drop(
425                    lock.lock()
426                        .expect("subscriber-arrived notification lock poisoned"),
427                );
428                cv.notify_all();
429                let mut guard = marker
430                    .state
431                    .lock()
432                    .expect("projected factor in-progress lock poisoned");
433                let result = loop {
434                    match guard.as_ref() {
435                        Some(ProjectedFactorInProgressState::Ready(value)) => {
436                            break value.clone();
437                        }
438                        Some(ProjectedFactorInProgressState::Failed) => {
439                            marker
440                                .waiter_count
441                                .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
442                            reml_contract_panic("projected factor cache producer panicked")
443                        }
444                        None => {
445                            guard = marker
446                                .ready
447                                .wait(guard)
448                                .expect("projected factor in-progress wait poisoned");
449                        }
450                    }
451                };
452                marker
453                    .waiter_count
454                    .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
455                result
456            }
457            CacheLookup::Compute(marker) => {
458                let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
459                    Ok(value) => value,
460                    Err(payload) => {
461                        let mut inner = self
462                            .inner
463                            .lock()
464                            .expect("projected factor cache lock poisoned");
465                        inner.in_progress.remove(&key);
466                        drop(inner);
467
468                        let mut guard = marker
469                            .state
470                            .lock()
471                            .expect("projected factor in-progress lock poisoned");
472                        *guard = Some(ProjectedFactorInProgressState::Failed);
473                        marker.ready.notify_all();
474                        resume_unwind(payload);
475                    }
476                };
477                let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
478                let mut inner = self
479                    .inner
480                    .lock()
481                    .expect("projected factor cache lock poisoned");
482                inner.next_seq += 1;
483                let now = inner.next_seq;
484
485                if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
486                    while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
487                        && !inner.entries.is_empty()
488                    {
489                        let Some(oldest_key) = inner
490                            .entries
491                            .iter()
492                            .min_by_key(|(_, e)| e.last_used)
493                            .map(|(k, _)| *k)
494                        else {
495                            break;
496                        };
497                        if let Some(removed) = inner.entries.remove(&oldest_key) {
498                            inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
499                        }
500                    }
501                }
502
503                let value = if let Some(entry) = inner.entries.get_mut(&key) {
504                    entry.last_used = now;
505                    entry.value.clone()
506                } else {
507                    inner.entries.insert(
508                        key,
509                        ProjectedFactorEntry {
510                            value: computed.clone(),
511                            bytes,
512                            last_used: now,
513                        },
514                    );
515                    inner.total_bytes = inner.total_bytes.saturating_add(bytes);
516                    computed
517                };
518                inner.in_progress.remove(&key);
519                drop(inner);
520
521                let mut guard = marker
522                    .state
523                    .lock()
524                    .expect("projected factor in-progress lock poisoned");
525                *guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
526                marker.ready.notify_all();
527                value
528            }
529        }
530    }
531
532    pub fn len(&self) -> usize {
533        self.inner
534            .lock()
535            .map(|inner| inner.entries.len())
536            .unwrap_or(0)
537    }
538
539    pub fn total_bytes(&self) -> usize {
540        self.inner
541            .lock()
542            .map(|inner| inner.total_bytes)
543            .unwrap_or(0)
544    }
545
546    pub fn is_empty(&self) -> bool {
547        self.len() == 0
548    }
549
550    /// Test/diagnostic affordance: block until a consumer has subscribed to the
551    /// in-progress slot for `key` (i.e. is waiting on the producer), or until
552    /// `timeout` elapses. Returns `true` if a subscriber arrived, `false` if the
553    /// key has no in-progress slot or the wait timed out.
554    ///
555    /// This lives on the cache because it reaches into the per-key subscriber
556    /// condvar and waiter counter, which are private synchronization internals;
557    /// exposing it as a method keeps those fields encapsulated while still
558    /// letting downstream tests deterministically order producer/consumer
559    /// interleavings.
560    pub fn wait_for_subscriber(
561        &self,
562        key: ProjectedFactorKey,
563        timeout: std::time::Duration,
564    ) -> bool {
565        let marker = {
566            let inner = self
567                .inner
568                .lock()
569                .expect("projected factor cache lock poisoned");
570            let Some(m) = inner.in_progress.get(&key) else {
571                return false;
572            };
573            Arc::clone(m)
574        };
575        if marker
576            .waiter_count
577            .load(std::sync::atomic::Ordering::Acquire)
578            > 0
579        {
580            return true;
581        }
582        let (lock, cv) = &marker.subscriber_arrived;
583        let mut guard = lock
584            .lock()
585            .expect("subscriber-arrived notification lock poisoned");
586        let deadline = std::time::Instant::now() + timeout;
587        loop {
588            if marker
589                .waiter_count
590                .load(std::sync::atomic::Ordering::Acquire)
591                > 0
592            {
593                return true;
594            }
595            let now = std::time::Instant::now();
596            if now >= deadline {
597                return false;
598            }
599            let (next_guard, result) = cv
600                .wait_timeout(guard, deadline - now)
601                .expect("subscriber-arrived wait poisoned");
602            guard = next_guard;
603            if result.timed_out()
604                && marker
605                    .waiter_count
606                    .load(std::sync::atomic::Ordering::Acquire)
607                    == 0
608            {
609                return false;
610            }
611        }
612    }
613}
614
615#[derive(Clone)]
616pub struct DenseMatrixHyperOperator {
617    pub matrix: Array2<f64>,
618}
619
620impl HyperOperator for DenseMatrixHyperOperator {
621    fn dim(&self) -> usize {
622        self.matrix.nrows()
623    }
624
625    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
626        self.matrix.dot(v)
627    }
628
629    fn as_any(&self) -> &(dyn Any + 'static) {
630        self
631    }
632
633    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
634        self.matrix.dot(&v)
635    }
636
637    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
638        assert_eq!(self.matrix.ncols(), v.len());
639        assert_eq!(self.matrix.nrows(), out.len());
640        for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
641            *out_value = row.dot(&v);
642        }
643    }
644
645    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
646        let end = start + out.ncols();
647        assert!(end <= self.matrix.ncols());
648        out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
649    }
650
651    fn scaled_add_mul_vec(
652        &self,
653        v: ArrayView1<'_, f64>,
654        scale: f64,
655        mut out: ArrayViewMut1<'_, f64>,
656    ) {
657        assert_eq!(self.matrix.ncols(), v.len());
658        assert_eq!(self.matrix.nrows(), out.len());
659        if scale == 0.0 {
660            return;
661        }
662        for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
663            *out_value += scale * row.dot(&v);
664        }
665    }
666
667    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
668        dense_bilinear(&self.matrix, v.view(), u.view())
669    }
670
671    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
672        dense_bilinear(&self.matrix, v, u)
673    }
674
675    fn to_dense(&self) -> Array2<f64> {
676        self.matrix.clone()
677    }
678
679    fn is_implicit(&self) -> bool {
680        false
681    }
682}
683
684#[derive(Clone)]
685pub struct BlockLocalDrift {
686    pub local: Array2<f64>,
687    pub start: usize,
688    pub end: usize,
689    pub total_dim: usize,
690}
691
692impl HyperOperator for BlockLocalDrift {
693    fn dim(&self) -> usize {
694        self.total_dim
695    }
696
697    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
698        assert_eq!(v.len(), self.total_dim);
699        let mut out = Array1::zeros(self.total_dim);
700        self.mul_vec_into(v.view(), out.view_mut());
701        out
702    }
703
704    fn as_any(&self) -> &(dyn Any + 'static) {
705        self
706    }
707
708    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
709        assert_eq!(v.len(), self.total_dim);
710        assert_eq!(out.len(), self.total_dim);
711        out.fill(0.0);
712        let v_block = v.slice(ndarray::s![self.start..self.end]);
713        let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
714        dense_matvec_into(&self.local, v_block, out_block.view_mut());
715    }
716
717    fn scaled_add_mul_vec(
718        &self,
719        v: ArrayView1<'_, f64>,
720        scale: f64,
721        mut out: ArrayViewMut1<'_, f64>,
722    ) {
723        assert_eq!(v.len(), self.total_dim);
724        assert_eq!(out.len(), self.total_dim);
725        if scale == 0.0 {
726            return;
727        }
728        let v_block = v.slice(ndarray::s![self.start..self.end]);
729        let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
730        dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
731    }
732
733    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
734        self.bilinear_view(v.view(), u.view())
735    }
736
737    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
738        assert_eq!(v.len(), self.total_dim);
739        assert_eq!(u.len(), self.total_dim);
740        let v_block = v.slice(ndarray::s![self.start..self.end]);
741        let u_block = u.slice(ndarray::s![self.start..self.end]);
742        dense_bilinear(&self.local, v_block, u_block)
743    }
744
745    fn to_dense(&self) -> Array2<f64> {
746        let p = self.total_dim;
747        let mut out = Array2::zeros((p, p));
748        out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
749            .assign(&self.local);
750        out
751    }
752
753    fn is_implicit(&self) -> bool {
754        false
755    }
756
757    fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
758        Some((&self.local, self.start, self.end))
759    }
760}
761
762#[derive(Clone)]
763pub struct HyperCoordDrift {
764    pub dense: Option<Array2<f64>>,
765    pub block_local: Option<BlockLocalDrift>,
766    pub operator: Option<Arc<dyn HyperOperator>>,
767}
768
769impl HyperCoordDrift {
770    pub fn none() -> Self {
771        Self {
772            dense: None,
773            block_local: None,
774            operator: None,
775        }
776    }
777
778    pub fn from_dense(dense: Array2<f64>) -> Self {
779        Self {
780            dense: Some(dense),
781            block_local: None,
782            operator: None,
783        }
784    }
785
786    pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
787        Self {
788            dense: None,
789            block_local: None,
790            operator: Some(operator),
791        }
792    }
793
794    pub fn from_parts(
795        dense: Option<Array2<f64>>,
796        operator: Option<Arc<dyn HyperOperator>>,
797    ) -> Self {
798        let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
799        Self {
800            dense,
801            block_local: None,
802            operator,
803        }
804    }
805
806    pub fn from_block_local_and_operator(
807        local: Array2<f64>,
808        start: usize,
809        end: usize,
810        total_dim: usize,
811        operator: Option<Arc<dyn HyperOperator>>,
812    ) -> Self {
813        Self {
814            dense: None,
815            block_local: Some(BlockLocalDrift {
816                local,
817                start,
818                end,
819                total_dim,
820            }),
821            operator,
822        }
823    }
824
825    pub fn has_operator(&self) -> bool {
826        self.operator.is_some()
827    }
828
829    pub fn uses_operator_fast_path(&self) -> bool {
830        self.operator.is_some() || self.block_local.is_some()
831    }
832
833    pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
834        self.operator.as_ref().map(Arc::as_ref)
835    }
836
837    pub fn materialize(&self) -> Array2<f64> {
838        let p = self.infer_dim();
839        if p == 0 {
840            return Array2::zeros((0, 0));
841        }
842        let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
843        if let Some(bl) = &self.block_local {
844            out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
845                .scaled_add(1.0, &bl.local);
846        }
847        if let Some(op) = &self.operator {
848            out += &op.to_dense();
849        }
850        out
851    }
852
853    pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
854        let mut out = Array1::zeros(v.len());
855        self.scaled_add_apply(v.view(), 1.0, &mut out);
856        out
857    }
858
859    pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
860        assert_eq!(v.len(), out.len());
861        if scale == 0.0 {
862            return;
863        }
864        if let Some(dense) = &self.dense {
865            dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
866        }
867        if let Some(bl) = &self.block_local {
868            let v_block = v.slice(ndarray::s![bl.start..bl.end]);
869            let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
870            dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
871        }
872        if let Some(op) = &self.operator {
873            op.scaled_add_mul_vec(v, scale, out.view_mut());
874        }
875    }
876
877    pub(crate) fn infer_dim(&self) -> usize {
878        if let Some(d) = &self.dense {
879            return d.nrows();
880        }
881        if let Some(op) = &self.operator {
882            return op.dim();
883        }
884        if let Some(bl) = &self.block_local {
885            return bl.total_dim;
886        }
887        0
888    }
889}
890
891#[derive(Clone)]
892pub struct HyperCoord {
893    pub a: f64,
894    pub g: Array1<f64>,
895    pub drift: HyperCoordDrift,
896    pub ld_s: f64,
897    pub b_depends_on_beta: bool,
898    pub is_penalty_like: bool,
899    pub firth_g: Option<Array1<f64>>,
900    pub tk_eta_fixed: Option<Array1<f64>>,
901    pub tk_x_fixed: Option<Array2<f64>>,
902}
903
904pub struct HyperCoordPair {
905    pub a: f64,
906    pub g: Array1<f64>,
907    pub b_mat: Array2<f64>,
908    pub b_operator: Option<Box<dyn HyperOperator>>,
909    pub ld_s: f64,
910}
911
912impl HyperCoordPair {
913    pub fn zero() -> Self {
914        Self {
915            a: 0.0,
916            g: Array1::zeros(0),
917            b_mat: Array2::zeros((0, 0)),
918            b_operator: None,
919            ld_s: 0.0,
920        }
921    }
922}
923
924#[derive(Clone)]
925pub enum DriftDerivResult {
926    Dense(Array2<f64>),
927    Operator(Arc<dyn HyperOperator>),
928}
929
930impl std::fmt::Debug for DriftDerivResult {
931    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
932        match self {
933            Self::Dense(matrix) => f
934                .debug_tuple("Dense")
935                .field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
936                .finish(),
937            Self::Operator(_) => f
938                .debug_tuple("Operator")
939                .field(&"<hyper-operator>")
940                .finish(),
941        }
942    }
943}
944
945impl DriftDerivResult {
946    pub fn into_operator(self) -> Arc<dyn HyperOperator> {
947        match self {
948            Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
949            Self::Operator(operator) => operator,
950        }
951    }
952
953    pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
954        match self {
955            Self::Dense(matrix) => matrix.dot(v),
956            Self::Operator(operator) => operator.mul_vec(v),
957        }
958    }
959}
960
961pub type FixedDriftDerivFn =
962    Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
963
964pub struct ContractedPsiSecondOrder {
965    pub objective: Array1<f64>,
966    pub score: Array2<f64>,
967    pub hessian: Vec<DriftDerivResult>,
968    pub ld_s: Array1<f64>,
969}
970
971pub type ContractedPsiSecondOrderFn =
972    Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;