1use std::any::Any;
8use std::collections::HashMap;
9use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
10use std::sync::{Arc, Condvar, Mutex};
11
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14
15#[macro_use]
16mod macros;
17
18pub mod basis_error;
19pub mod block_count_error;
20pub mod block_role;
21pub mod block_spec;
22pub mod coefficient_prior_mean;
23pub mod custom_family_blockwise;
24pub mod custom_family_error;
25pub mod diagnostics;
26pub mod dispersion;
27pub mod dispersion_cov;
28pub mod estimation_error;
29pub mod execution_path;
30pub mod family_options;
31pub mod finite_validation;
32pub mod fisher_rao;
33pub mod gauge;
34pub mod identifiability_audit;
35pub mod joint_penalty;
36mod linalg_helpers;
37mod linear_constraints;
38pub mod monotone_root_error;
39pub mod outer_subsample;
40pub mod penalty_coordinate;
41pub mod penalty_matrix;
42mod pseudo_logdet;
43pub mod psi_design_contract;
44pub mod psi_terms;
45pub mod riemannian_retraction;
46pub mod rho_posterior;
50pub mod row_measure;
51pub mod row_metric;
52pub mod schedule;
53pub mod laplace_sampler_contract;
57pub mod topology_certificates;
58mod seeding;
59pub mod solver_contract;
60pub mod types;
61
62pub use riemannian_retraction::LatentRetractionRegistry;
63pub use row_measure::RowSubsampleMask;
64
65mod gpu {
66 pub(crate) mod linalg_dispatch {
67 use ndarray::{Array2, ArrayView2};
68
69 pub(crate) fn try_fast_atb(
70 a: ArrayView2<'_, f64>,
71 b: ArrayView2<'_, f64>,
72 ) -> Option<Array2<f64>> {
73 let (n_a, p) = a.dim();
74 let (n_b, q) = b.dim();
75 assert_eq!(n_a, n_b, "A and B must have same number of rows");
76 if !crate::linalg_helpers::should_use_faer_matmul(p, q, n_a) {
77 return None;
78 }
79 Some(crate::linalg_helpers::fast_atb_with_parallelism(
80 &a,
81 &b,
82 crate::linalg_helpers::matmul_parallelism(p, q, n_a),
83 ))
84 }
85 }
86}
87
88pub use basis_error::BasisError;
89pub use block_count_error::BlockCountMismatch;
90pub use block_role::BlockRole;
91pub use block_spec::{
92 AdditiveBlockJacobian, BlockEffectiveJacobian, BlockGeometryDirectionalDerivative,
93 BlockWorkingSet, FamilyChannelHessian, FamilyLinearizationState, GaugeComposedJacobian,
94 ParameterBlockSpec, ParameterBlockState, RowScaledJacobian, TensorChannelHessian,
95};
96pub use coefficient_prior_mean::{CoefficientPriorMean, PriorMeanError};
97pub use custom_family_blockwise::{
98 CUSTOM_FAMILY_RIDGE_FLOOR, CUSTOM_FAMILY_WEIGHT_FLOOR, ExactNewtonOuterCurvature,
99 validate_blockspec_consistency,
100};
101pub use custom_family_error::CustomFamilyError;
102pub use dispersion::Dispersion;
103pub use dispersion_cov::{DispersionExt, PhiScaledCovariance, UnscaledPrecision, se_from_covariance};
104pub use estimation_error::EstimationError;
105pub use execution_path::ExecutionPath;
106pub use family_options::{ExactNewtonOuterObjective, ExactOuterDerivativeOrder};
107pub use finite_validation::{
108 bail_if_cached_beta_non_finite, ensure_finite_scalar, ensure_finite_scalar_estimation,
109 validate_all_finite, validate_all_finite_estimation,
110};
111pub use fisher_rao::{
112 FisherRaoDefiniteness, normalize_fisher_rao_blocks, normalize_fisher_rao_blocks_pd,
113};
114pub use gam_linalg::faer_ndarray::{in_nested_parallel_region, with_nested_parallel};
115pub use gauge::Gauge;
116pub use identifiability_audit::{
117 AliasedPair, BlockIdentity, DroppedColumn, IdentifiabilityAudit, MapUniquenessError,
118};
119pub use joint_penalty::{JointPenaltyBundle, JointPenaltyError, JointPenaltySpec};
120use linalg_helpers::{dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into};
121pub use linear_constraints::LinearInequalityConstraints;
122pub use monotone_root_error::MonotoneRootError;
123pub use penalty_coordinate::PenaltyCoordinate;
124pub use penalty_matrix::PenaltyMatrix;
125pub use pseudo_logdet::PseudoLogdetMode;
126pub use psi_design_contract::{
127 CustomFamilyBlockPsiDerivative, CustomFamilyPsiDerivativeOperator, JointHessianSourcePreference,
128 MaterializablePsiDerivativeOperator, MaterializationIntent, SharedDerivativeBlocks,
129};
130pub use psi_terms::{
131 ExactNewtonJointPsiSecondOrderContracted, ExactNewtonJointPsiSecondOrderTerms,
132 ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace,
133};
134pub use row_metric::{MetricProvenance, RowMetric, WeightField};
135pub use schedule::{GumbelTemperatureSchedule, ScheduleKind, SearchStrategy};
136pub use seeding::{SeedConfig, SeedRiskProfile, clamp_seed_rho_to_bounds, normalize_seed_bounds};
137pub use solver_contract::{
138 DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterEval,
139 OuterHessianMaterialization, OuterHessianOperator, OuterStrategyError,
140};
141pub use types::*;
142
143#[cold]
144fn reml_contract_panic(message: impl Into<String>) -> ! {
145 std::panic::panic_any(message.into())
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Eq)]
150pub enum EvalMode {
151 ValueOnly,
153 ValueAndGradient,
155 ValueGradientHessian,
157}
158
159struct NonDowncastableHyperOperator;
162
163static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
164
165pub trait HyperOperator: Send + Sync {
166 fn dim(&self) -> usize;
169
170 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
172
173 fn as_any(&self) -> &(dyn Any + 'static) {
177 &NON_DOWNCASTABLE_HYPER_OPERATOR
178 }
179
180 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
182 self.mul_vec(&v.to_owned())
183 }
184
185 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
187 out.assign(&self.mul_vec_view(v));
188 }
189
190 fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
193 let p = factor.nrows();
194 let k = factor.ncols();
195 let mut out = Array2::<f64>::zeros((p, k));
196 if rayon::current_thread_index().is_some() {
197 for col in 0..k {
198 let bv = out.column_mut(col);
199 self.mul_vec_into(factor.column(col), bv);
200 }
201 return out;
202 }
203 let cols: Vec<Array1<f64>> = (0..k)
204 .into_par_iter()
205 .map(|col| {
206 let mut bv = Array1::<f64>::zeros(p);
207 self.mul_vec_into(factor.column(col), bv.view_mut());
208 bv
209 })
210 .collect();
211 for (col, bv) in cols.into_iter().enumerate() {
212 out.column_mut(col).assign(&bv);
213 }
214 out
215 }
216
217 fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
219 let op_factor = self.mul_mat(factor);
220 factor
221 .iter()
222 .zip(op_factor.iter())
223 .map(|(&f, &bf)| f * bf)
224 .sum()
225 }
226
227 fn projection_design_id(&self) -> Option<usize> {
236 None
237 }
238
239 fn trace_projected_factor_cached(
240 &self,
241 factor: &Array2<f64>,
242 factor_cache: &ProjectedFactorCache,
243 ) -> f64 {
244 assert!(std::mem::size_of_val(factor_cache) > 0);
248 match self.projection_design_id() {
249 Some(design_id) => {
250 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
251 let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
252 factor
253 .iter()
254 .zip(projected.iter())
255 .map(|(&f, &bf)| f * bf)
256 .sum()
257 }
258 None => self.trace_projected_factor(factor),
259 }
260 }
261
262 fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
264 let op_factor = self.mul_mat(factor);
265 crate::linalg_helpers::fast_atb(factor, &op_factor)
266 }
267
268 fn projected_matrix_cached(
271 &self,
272 factor: &Array2<f64>,
273 factor_cache: &ProjectedFactorCache,
274 ) -> Array2<f64> {
275 assert!(std::mem::size_of_val(factor_cache) > 0);
276 match self.projection_design_id() {
277 Some(design_id) => {
278 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
279 let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
280 crate::linalg_helpers::fast_atb(factor, projected.as_ref())
281 }
282 None => self.projected_matrix(factor),
283 }
284 }
285
286 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
288 let cols = out.ncols();
289 let dim = out.nrows();
290 assert!(start + cols <= dim);
291 let mut basis = Array1::<f64>::zeros(dim);
292 for local_col in 0..cols {
293 let global_col = start + local_col;
294 basis[global_col] = 1.0;
295 self.mul_vec_into(basis.view(), out.column_mut(local_col));
296 basis[global_col] = 0.0;
297 }
298 }
299
300 fn scaled_add_mul_vec(
302 &self,
303 v: ArrayView1<'_, f64>,
304 scale: f64,
305 mut out: ArrayViewMut1<'_, f64>,
306 ) {
307 if scale == 0.0 {
308 return;
309 }
310 let mut work = Array1::<f64>::zeros(out.len());
311 self.mul_vec_into(v, work.view_mut());
312 out.scaled_add(scale, &work);
313 }
314
315 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
317 let mut bv = Array1::<f64>::zeros(v.len());
318 self.mul_vec_into(v.view(), bv.view_mut());
319 u.dot(&bv)
320 }
321
322 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
324 let mut bv = Array1::<f64>::zeros(v.len());
325 self.mul_vec_into(v, bv.view_mut());
326 u.dot(&bv)
327 }
328
329 fn has_fast_bilinear_view(&self) -> bool {
331 false
332 }
333
334 fn to_dense(&self) -> Array2<f64> {
336 let p = self.dim();
337 let mut out = Array2::<f64>::zeros((p, p));
338 let mut basis = Array1::<f64>::zeros(p);
339 for j in 0..p {
340 basis[j] = 1.0;
341 self.mul_vec_into(basis.view(), out.column_mut(j));
342 basis[j] = 0.0;
343 }
344 out
345 }
346
347 fn is_implicit(&self) -> bool;
349
350 fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
352 None
353 }
354}
355
356#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
357pub struct ProjectedFactorKey {
358 pub(crate) design_id: usize,
359 pub(crate) factor_ptr: usize,
360 pub(crate) rows: usize,
361 pub(crate) cols: usize,
362 pub(crate) row_stride: isize,
363 pub(crate) col_stride: isize,
364 pub(crate) value_hash: u64,
365 pub(crate) value_hash2: u64,
366}
367
368impl ProjectedFactorKey {
369 pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
370 let strides = factor.strides();
371 let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
372 Self {
373 design_id,
374 factor_ptr: factor.as_ptr() as usize,
375 rows: factor.nrows(),
376 cols: factor.ncols(),
377 row_stride: strides[0],
378 col_stride: strides[1],
379 value_hash,
380 value_hash2,
381 }
382 }
383
384 pub fn synthetic(seed: u64) -> Self {
389 Self {
390 design_id: 1,
391 factor_ptr: seed as usize,
392 rows: 1,
393 cols: 1,
394 row_stride: 1,
395 col_stride: 1,
396 value_hash: seed,
397 value_hash2: seed.wrapping_mul(31),
398 }
399 }
400}
401
402pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
403 let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
404 let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
405 for (idx, value) in factor.iter().enumerate() {
406 let bits = value.to_bits();
407 let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
408 h1 ^= mixed;
409 h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
410 h2 ^= bits.rotate_left((idx & 63) as u32);
411 h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
412 }
413 (h1, h2)
414}
415
416pub struct ProjectedFactorCache {
418 pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
419}
420
421pub(crate) struct ProjectedFactorCacheInner {
422 pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
423 pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
424 pub(crate) next_seq: u64,
425 pub(crate) total_bytes: usize,
426 pub(crate) budget_bytes: usize,
427}
428
429pub(crate) struct ProjectedFactorInProgress {
430 pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
431 pub(crate) ready: Condvar,
432 pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
433 pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
434}
435
436pub(crate) enum ProjectedFactorInProgressState {
437 Ready(Arc<Array2<f64>>),
438 Failed,
439}
440
441pub(crate) struct ProjectedFactorEntry {
442 pub(crate) value: Arc<Array2<f64>>,
443 pub(crate) bytes: usize,
444 pub(crate) last_used: u64,
445}
446
447impl Default for ProjectedFactorCache {
448 fn default() -> Self {
449 Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
450 }
451}
452
453impl ProjectedFactorCache {
454 pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
455
456 pub fn with_budget(budget_bytes: usize) -> Self {
457 Self {
458 inner: Mutex::new(ProjectedFactorCacheInner {
459 entries: HashMap::new(),
460 in_progress: HashMap::new(),
461 next_seq: 0,
462 total_bytes: 0,
463 budget_bytes,
464 }),
465 }
466 }
467
468 pub fn get_or_insert_with(
469 &self,
470 key: ProjectedFactorKey,
471 compute: impl FnOnce() -> Array2<f64>,
472 ) -> Arc<Array2<f64>> {
473 enum CacheLookup {
474 Hit(Arc<Array2<f64>>),
475 Wait(Arc<ProjectedFactorInProgress>),
476 Compute(Arc<ProjectedFactorInProgress>),
477 }
478
479 let lookup = {
480 let mut inner = self
481 .inner
482 .lock()
483 .expect("projected factor cache lock poisoned");
484 inner.next_seq += 1;
485 let now = inner.next_seq;
486 if let Some(entry) = inner.entries.get_mut(&key) {
487 entry.last_used = now;
488 CacheLookup::Hit(entry.value.clone())
489 } else if let Some(waiter) = inner.in_progress.get(&key) {
490 CacheLookup::Wait(waiter.clone())
491 } else {
492 let marker = Arc::new(ProjectedFactorInProgress {
493 state: Mutex::new(None),
494 ready: Condvar::new(),
495 waiter_count: std::sync::atomic::AtomicUsize::new(0),
496 subscriber_arrived: (Mutex::new(()), Condvar::new()),
497 });
498 inner.in_progress.insert(key, marker.clone());
499 CacheLookup::Compute(marker)
500 }
501 };
502
503 match lookup {
504 CacheLookup::Hit(value) => value,
505 CacheLookup::Wait(marker) => {
506 marker
507 .waiter_count
508 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
509 let (lock, cv) = &marker.subscriber_arrived;
510 drop(
511 lock.lock()
512 .expect("subscriber-arrived notification lock poisoned"),
513 );
514 cv.notify_all();
515 let mut guard = marker
516 .state
517 .lock()
518 .expect("projected factor in-progress lock poisoned");
519 let result = loop {
520 match guard.as_ref() {
521 Some(ProjectedFactorInProgressState::Ready(value)) => {
522 break value.clone();
523 }
524 Some(ProjectedFactorInProgressState::Failed) => {
525 marker
526 .waiter_count
527 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
528 reml_contract_panic("projected factor cache producer panicked")
529 }
530 None => {
531 guard = marker
532 .ready
533 .wait(guard)
534 .expect("projected factor in-progress wait poisoned");
535 }
536 }
537 };
538 marker
539 .waiter_count
540 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
541 result
542 }
543 CacheLookup::Compute(marker) => {
544 let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
545 Ok(value) => value,
546 Err(payload) => {
547 let mut inner = self
548 .inner
549 .lock()
550 .expect("projected factor cache lock poisoned");
551 inner.in_progress.remove(&key);
552 drop(inner);
553
554 let mut guard = marker
555 .state
556 .lock()
557 .expect("projected factor in-progress lock poisoned");
558 *guard = Some(ProjectedFactorInProgressState::Failed);
559 marker.ready.notify_all();
560 resume_unwind(payload);
561 }
562 };
563 let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
564 let mut inner = self
565 .inner
566 .lock()
567 .expect("projected factor cache lock poisoned");
568 inner.next_seq += 1;
569 let now = inner.next_seq;
570
571 if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
572 while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
573 && !inner.entries.is_empty()
574 {
575 let Some(oldest_key) = inner
576 .entries
577 .iter()
578 .min_by_key(|(_, e)| e.last_used)
579 .map(|(k, _)| *k)
580 else {
581 break;
582 };
583 if let Some(removed) = inner.entries.remove(&oldest_key) {
584 inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
585 }
586 }
587 }
588
589 let value = if let Some(entry) = inner.entries.get_mut(&key) {
590 entry.last_used = now;
591 entry.value.clone()
592 } else {
593 inner.entries.insert(
594 key,
595 ProjectedFactorEntry {
596 value: computed.clone(),
597 bytes,
598 last_used: now,
599 },
600 );
601 inner.total_bytes = inner.total_bytes.saturating_add(bytes);
602 computed
603 };
604 inner.in_progress.remove(&key);
605 drop(inner);
606
607 let mut guard = marker
608 .state
609 .lock()
610 .expect("projected factor in-progress lock poisoned");
611 *guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
612 marker.ready.notify_all();
613 value
614 }
615 }
616 }
617
618 pub fn len(&self) -> usize {
619 self.inner
620 .lock()
621 .map(|inner| inner.entries.len())
622 .unwrap_or(0)
623 }
624
625 pub fn total_bytes(&self) -> usize {
626 self.inner
627 .lock()
628 .map(|inner| inner.total_bytes)
629 .unwrap_or(0)
630 }
631
632 pub fn is_empty(&self) -> bool {
633 self.len() == 0
634 }
635
636 pub fn wait_for_subscriber(
647 &self,
648 key: ProjectedFactorKey,
649 timeout: std::time::Duration,
650 ) -> bool {
651 let marker = {
652 let inner = self
653 .inner
654 .lock()
655 .expect("projected factor cache lock poisoned");
656 let Some(m) = inner.in_progress.get(&key) else {
657 return false;
658 };
659 Arc::clone(m)
660 };
661 if marker
662 .waiter_count
663 .load(std::sync::atomic::Ordering::Acquire)
664 > 0
665 {
666 return true;
667 }
668 let (lock, cv) = &marker.subscriber_arrived;
669 let mut guard = lock
670 .lock()
671 .expect("subscriber-arrived notification lock poisoned");
672 let deadline = std::time::Instant::now() + timeout;
673 loop {
674 if marker
675 .waiter_count
676 .load(std::sync::atomic::Ordering::Acquire)
677 > 0
678 {
679 return true;
680 }
681 let now = std::time::Instant::now();
682 if now >= deadline {
683 return false;
684 }
685 let (next_guard, result) = cv
686 .wait_timeout(guard, deadline - now)
687 .expect("subscriber-arrived wait poisoned");
688 guard = next_guard;
689 if result.timed_out()
690 && marker
691 .waiter_count
692 .load(std::sync::atomic::Ordering::Acquire)
693 == 0
694 {
695 return false;
696 }
697 }
698 }
699}
700
701#[derive(Clone)]
702pub struct DenseMatrixHyperOperator {
703 pub matrix: Array2<f64>,
704}
705
706impl HyperOperator for DenseMatrixHyperOperator {
707 fn dim(&self) -> usize {
708 self.matrix.nrows()
709 }
710
711 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
712 self.matrix.dot(v)
713 }
714
715 fn as_any(&self) -> &(dyn Any + 'static) {
716 self
717 }
718
719 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
720 self.matrix.dot(&v)
721 }
722
723 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
724 assert_eq!(self.matrix.ncols(), v.len());
725 assert_eq!(self.matrix.nrows(), out.len());
726 for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
727 *out_value = row.dot(&v);
728 }
729 }
730
731 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
732 let end = start + out.ncols();
733 assert!(end <= self.matrix.ncols());
734 out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
735 }
736
737 fn scaled_add_mul_vec(
738 &self,
739 v: ArrayView1<'_, f64>,
740 scale: f64,
741 mut out: ArrayViewMut1<'_, f64>,
742 ) {
743 assert_eq!(self.matrix.ncols(), v.len());
744 assert_eq!(self.matrix.nrows(), out.len());
745 if scale == 0.0 {
746 return;
747 }
748 for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
749 *out_value += scale * row.dot(&v);
750 }
751 }
752
753 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
754 dense_bilinear(&self.matrix, v.view(), u.view())
755 }
756
757 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
758 dense_bilinear(&self.matrix, v, u)
759 }
760
761 fn to_dense(&self) -> Array2<f64> {
762 self.matrix.clone()
763 }
764
765 fn is_implicit(&self) -> bool {
766 false
767 }
768}
769
770#[derive(Clone)]
771pub struct BlockLocalDrift {
772 pub local: Array2<f64>,
773 pub start: usize,
774 pub end: usize,
775 pub total_dim: usize,
776}
777
778impl HyperOperator for BlockLocalDrift {
779 fn dim(&self) -> usize {
780 self.total_dim
781 }
782
783 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
784 assert_eq!(v.len(), self.total_dim);
785 let mut out = Array1::zeros(self.total_dim);
786 self.mul_vec_into(v.view(), out.view_mut());
787 out
788 }
789
790 fn as_any(&self) -> &(dyn Any + 'static) {
791 self
792 }
793
794 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
795 assert_eq!(v.len(), self.total_dim);
796 assert_eq!(out.len(), self.total_dim);
797 out.fill(0.0);
798 let v_block = v.slice(ndarray::s![self.start..self.end]);
799 let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
800 dense_matvec_into(&self.local, v_block, out_block.view_mut());
801 }
802
803 fn scaled_add_mul_vec(
804 &self,
805 v: ArrayView1<'_, f64>,
806 scale: f64,
807 mut out: ArrayViewMut1<'_, f64>,
808 ) {
809 assert_eq!(v.len(), self.total_dim);
810 assert_eq!(out.len(), self.total_dim);
811 if scale == 0.0 {
812 return;
813 }
814 let v_block = v.slice(ndarray::s![self.start..self.end]);
815 let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
816 dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
817 }
818
819 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
820 self.bilinear_view(v.view(), u.view())
821 }
822
823 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
824 assert_eq!(v.len(), self.total_dim);
825 assert_eq!(u.len(), self.total_dim);
826 let v_block = v.slice(ndarray::s![self.start..self.end]);
827 let u_block = u.slice(ndarray::s![self.start..self.end]);
828 dense_bilinear(&self.local, v_block, u_block)
829 }
830
831 fn to_dense(&self) -> Array2<f64> {
832 let p = self.total_dim;
833 let mut out = Array2::zeros((p, p));
834 out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
835 .assign(&self.local);
836 out
837 }
838
839 fn is_implicit(&self) -> bool {
840 false
841 }
842
843 fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
844 Some((&self.local, self.start, self.end))
845 }
846}
847
848#[derive(Clone)]
849pub struct HyperCoordDrift {
850 pub dense: Option<Array2<f64>>,
851 pub block_local: Option<BlockLocalDrift>,
852 pub operator: Option<Arc<dyn HyperOperator>>,
853}
854
855impl HyperCoordDrift {
856 pub fn none() -> Self {
857 Self {
858 dense: None,
859 block_local: None,
860 operator: None,
861 }
862 }
863
864 pub fn from_dense(dense: Array2<f64>) -> Self {
865 Self {
866 dense: Some(dense),
867 block_local: None,
868 operator: None,
869 }
870 }
871
872 pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
873 Self {
874 dense: None,
875 block_local: None,
876 operator: Some(operator),
877 }
878 }
879
880 pub fn from_parts(
881 dense: Option<Array2<f64>>,
882 operator: Option<Arc<dyn HyperOperator>>,
883 ) -> Self {
884 let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
885 Self {
886 dense,
887 block_local: None,
888 operator,
889 }
890 }
891
892 pub fn from_block_local_and_operator(
893 local: Array2<f64>,
894 start: usize,
895 end: usize,
896 total_dim: usize,
897 operator: Option<Arc<dyn HyperOperator>>,
898 ) -> Self {
899 Self {
900 dense: None,
901 block_local: Some(BlockLocalDrift {
902 local,
903 start,
904 end,
905 total_dim,
906 }),
907 operator,
908 }
909 }
910
911 pub fn has_operator(&self) -> bool {
912 self.operator.is_some()
913 }
914
915 pub fn uses_operator_fast_path(&self) -> bool {
916 self.operator.is_some() || self.block_local.is_some()
917 }
918
919 pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
920 self.operator.as_ref().map(Arc::as_ref)
921 }
922
923 pub fn materialize(&self) -> Array2<f64> {
924 let p = self.infer_dim();
925 if p == 0 {
926 return Array2::zeros((0, 0));
927 }
928 let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
929 if let Some(bl) = &self.block_local {
930 out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
931 .scaled_add(1.0, &bl.local);
932 }
933 if let Some(op) = &self.operator {
934 out += &op.to_dense();
935 }
936 out
937 }
938
939 pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
940 let mut out = Array1::zeros(v.len());
941 self.scaled_add_apply(v.view(), 1.0, &mut out);
942 out
943 }
944
945 pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
946 assert_eq!(v.len(), out.len());
947 if scale == 0.0 {
948 return;
949 }
950 if let Some(dense) = &self.dense {
951 dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
952 }
953 if let Some(bl) = &self.block_local {
954 let v_block = v.slice(ndarray::s![bl.start..bl.end]);
955 let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
956 dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
957 }
958 if let Some(op) = &self.operator {
959 op.scaled_add_mul_vec(v, scale, out.view_mut());
960 }
961 }
962
963 pub(crate) fn infer_dim(&self) -> usize {
964 if let Some(d) = &self.dense {
965 return d.nrows();
966 }
967 if let Some(op) = &self.operator {
968 return op.dim();
969 }
970 if let Some(bl) = &self.block_local {
971 return bl.total_dim;
972 }
973 0
974 }
975}
976
977#[derive(Clone)]
978pub struct HyperCoord {
979 pub a: f64,
980 pub g: Array1<f64>,
981 pub drift: HyperCoordDrift,
982 pub ld_s: f64,
983 pub b_depends_on_beta: bool,
984 pub is_penalty_like: bool,
985 pub firth_g: Option<Array1<f64>>,
986 pub tk_eta_fixed: Option<Array1<f64>>,
987 pub tk_x_fixed: Option<Array2<f64>>,
988}
989
990pub struct HyperCoordPair {
991 pub a: f64,
992 pub g: Array1<f64>,
993 pub b_mat: Array2<f64>,
994 pub b_operator: Option<Box<dyn HyperOperator>>,
995 pub ld_s: f64,
996}
997
998impl HyperCoordPair {
999 pub fn zero() -> Self {
1000 Self {
1001 a: 0.0,
1002 g: Array1::zeros(0),
1003 b_mat: Array2::zeros((0, 0)),
1004 b_operator: None,
1005 ld_s: 0.0,
1006 }
1007 }
1008}
1009
1010#[derive(Clone)]
1011pub enum DriftDerivResult {
1012 Dense(Array2<f64>),
1013 Operator(Arc<dyn HyperOperator>),
1014}
1015
1016impl std::fmt::Debug for DriftDerivResult {
1017 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1018 match self {
1019 Self::Dense(matrix) => f
1020 .debug_tuple("Dense")
1021 .field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
1022 .finish(),
1023 Self::Operator(_) => f
1024 .debug_tuple("Operator")
1025 .field(&"<hyper-operator>")
1026 .finish(),
1027 }
1028 }
1029}
1030
1031impl DriftDerivResult {
1032 pub fn into_operator(self) -> Arc<dyn HyperOperator> {
1033 match self {
1034 Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
1035 Self::Operator(operator) => operator,
1036 }
1037 }
1038
1039 pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
1040 match self {
1041 Self::Dense(matrix) => matrix.dot(v),
1042 Self::Operator(operator) => operator.mul_vec(v),
1043 }
1044 }
1045}
1046
1047pub type FixedDriftDerivFn =
1048 Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
1049
1050pub struct ContractedPsiSecondOrder {
1051 pub objective: Array1<f64>,
1052 pub score: Array2<f64>,
1053 pub hessian: Vec<DriftDerivResult>,
1054 pub ld_s: Array1<f64>,
1055}
1056
1057pub type ContractedPsiSecondOrderFn =
1058 Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;