1use std::any::Any;
8use std::collections::HashMap;
9use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
10use std::sync::{Arc, Condvar, Mutex};
11
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14
15#[macro_use]
16mod macros;
17
18pub mod basis_error;
19pub mod block_count_error;
20pub mod block_role;
21pub mod block_spec;
22pub mod coefficient_prior_mean;
23pub mod custom_family_blockwise;
24pub mod custom_family_error;
25pub mod diagnostics;
26pub mod dispersion;
27pub mod dispersion_cov;
28pub mod estimation_error;
29pub mod execution_path;
30pub mod family_options;
31pub mod finite_validation;
32pub mod fisher_rao;
33pub mod gauge;
34pub mod identifiability_audit;
35pub mod joint_penalty;
36mod linalg_helpers;
37mod linear_constraints;
38pub mod monotone_root_error;
39pub mod outer_subsample;
40pub mod penalty_coordinate;
41pub mod penalty_matrix;
42mod pseudo_logdet;
43pub mod psi_design_contract;
44pub mod psi_terms;
45pub mod riemannian_retraction;
46pub mod rho_posterior;
50pub mod row_measure;
51pub mod row_metric;
52pub mod schedule;
53pub mod laplace_sampler_contract;
57mod seeding;
58pub mod solver_contract;
59pub mod topology_certificates;
60pub mod types;
61
62pub use riemannian_retraction::LatentRetractionRegistry;
63pub use row_measure::RowSubsampleMask;
64
65mod gpu {
66 pub(crate) mod linalg_dispatch {
67 use ndarray::{Array2, ArrayView2};
68
69 pub(crate) fn try_fast_atb(
70 a: ArrayView2<'_, f64>,
71 b: ArrayView2<'_, f64>,
72 ) -> Option<Array2<f64>> {
73 let (n_a, p) = a.dim();
74 let (n_b, q) = b.dim();
75 assert_eq!(n_a, n_b, "A and B must have same number of rows");
76 if !crate::linalg_helpers::should_use_faer_matmul(p, q, n_a) {
77 return None;
78 }
79 Some(crate::linalg_helpers::fast_atb_with_parallelism(
80 &a,
81 &b,
82 crate::linalg_helpers::matmul_parallelism(p, q, n_a),
83 ))
84 }
85 }
86}
87
88pub use basis_error::BasisError;
89pub use block_count_error::BlockCountMismatch;
90pub use block_role::BlockRole;
91pub use block_spec::{
92 AdditiveBlockJacobian, BlockEffectiveJacobian, BlockGeometryDirectionalDerivative,
93 BlockWorkingSet, FamilyChannelHessian, FamilyLinearizationState, GaugeComposedJacobian,
94 ParameterBlockSpec, ParameterBlockState, RowScaledJacobian, TensorChannelHessian,
95};
96pub use coefficient_prior_mean::{CoefficientPriorMean, PriorMeanError};
97pub use custom_family_blockwise::{
98 CUSTOM_FAMILY_RIDGE_FLOOR, CUSTOM_FAMILY_WEIGHT_FLOOR, ExactNewtonOuterCurvature,
99 validate_blockspec_consistency,
100};
101pub use custom_family_error::CustomFamilyError;
102pub use dispersion::Dispersion;
103pub use dispersion_cov::{
104 DispersionExt, PhiScaledCovariance, UnscaledPrecision, se_from_covariance,
105};
106pub use estimation_error::EstimationError;
107pub use execution_path::ExecutionPath;
108pub use family_options::{ExactNewtonOuterObjective, ExactOuterDerivativeOrder};
109pub use finite_validation::{
110 bail_if_cached_beta_non_finite, ensure_finite_scalar, ensure_finite_scalar_estimation,
111 validate_all_finite, validate_all_finite_estimation,
112};
113pub use fisher_rao::{
114 FisherRaoDefiniteness, normalize_fisher_rao_blocks, normalize_fisher_rao_blocks_pd,
115};
116pub use gam_linalg::faer_ndarray::{in_nested_parallel_region, with_nested_parallel};
117pub use gauge::Gauge;
118pub use identifiability_audit::{
119 AliasedPair, BlockIdentity, DroppedColumn, IdentifiabilityAudit, MapUniquenessError,
120};
121pub use joint_penalty::{JointPenaltyBundle, JointPenaltyError, JointPenaltySpec};
122use linalg_helpers::{dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into};
123pub use linear_constraints::LinearInequalityConstraints;
124pub use monotone_root_error::MonotoneRootError;
125pub use penalty_coordinate::PenaltyCoordinate;
126pub use penalty_matrix::PenaltyMatrix;
127pub use pseudo_logdet::PseudoLogdetMode;
128pub use psi_design_contract::{
129 CustomFamilyBlockPsiDerivative, CustomFamilyPsiDerivativeOperator,
130 JointHessianSourcePreference, MaterializablePsiDerivativeOperator, MaterializationIntent,
131 SharedDerivativeBlocks,
132};
133pub use psi_terms::{
134 ExactNewtonJointPsiSecondOrderContracted, ExactNewtonJointPsiSecondOrderTerms,
135 ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace,
136};
137pub use row_metric::{MetricProvenance, RowMetric, WeightField, pack_probe_factors};
138pub use schedule::{GumbelTemperatureSchedule, ScheduleKind, SearchStrategy};
139pub use seeding::{SeedConfig, SeedRiskProfile, clamp_seed_rho_to_bounds, normalize_seed_bounds};
140pub use solver_contract::{
141 DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterEval,
142 OuterHessianMaterialization, OuterHessianOperator, OuterStrategyError,
143};
144pub use types::*;
145
146#[cold]
147fn reml_contract_panic(message: impl Into<String>) -> ! {
148 std::panic::panic_any(message.into())
149}
150
151#[derive(Clone, Copy, Debug, PartialEq, Eq)]
153pub enum EvalMode {
154 ValueOnly,
156 ValueAndGradient,
158 ValueGradientHessian,
160}
161
162struct NonDowncastableHyperOperator;
165
166static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
167
168pub trait HyperOperator: Send + Sync {
169 fn dim(&self) -> usize;
172
173 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
175
176 fn as_any(&self) -> &(dyn Any + 'static) {
180 &NON_DOWNCASTABLE_HYPER_OPERATOR
181 }
182
183 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
185 self.mul_vec(&v.to_owned())
186 }
187
188 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
190 out.assign(&self.mul_vec_view(v));
191 }
192
193 fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
196 let p = factor.nrows();
197 let k = factor.ncols();
198 let mut out = Array2::<f64>::zeros((p, k));
199 if rayon::current_thread_index().is_some() {
200 for col in 0..k {
201 let bv = out.column_mut(col);
202 self.mul_vec_into(factor.column(col), bv);
203 }
204 return out;
205 }
206 let cols: Vec<Array1<f64>> = (0..k)
207 .into_par_iter()
208 .map(|col| {
209 let mut bv = Array1::<f64>::zeros(p);
210 self.mul_vec_into(factor.column(col), bv.view_mut());
211 bv
212 })
213 .collect();
214 for (col, bv) in cols.into_iter().enumerate() {
215 out.column_mut(col).assign(&bv);
216 }
217 out
218 }
219
220 fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
222 let op_factor = self.mul_mat(factor);
223 factor
224 .iter()
225 .zip(op_factor.iter())
226 .map(|(&f, &bf)| f * bf)
227 .sum()
228 }
229
230 fn projection_design_id(&self) -> Option<usize> {
239 None
240 }
241
242 fn trace_projected_factor_cached(
243 &self,
244 factor: &Array2<f64>,
245 factor_cache: &ProjectedFactorCache,
246 ) -> f64 {
247 assert!(std::mem::size_of_val(factor_cache) > 0);
251 match self.projection_design_id() {
252 Some(design_id) => {
253 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
254 let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
255 factor
256 .iter()
257 .zip(projected.iter())
258 .map(|(&f, &bf)| f * bf)
259 .sum()
260 }
261 None => self.trace_projected_factor(factor),
262 }
263 }
264
265 fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
267 let op_factor = self.mul_mat(factor);
268 crate::linalg_helpers::fast_atb(factor, &op_factor)
269 }
270
271 fn projected_matrix_cached(
274 &self,
275 factor: &Array2<f64>,
276 factor_cache: &ProjectedFactorCache,
277 ) -> Array2<f64> {
278 assert!(std::mem::size_of_val(factor_cache) > 0);
279 match self.projection_design_id() {
280 Some(design_id) => {
281 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
282 let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
283 crate::linalg_helpers::fast_atb(factor, projected.as_ref())
284 }
285 None => self.projected_matrix(factor),
286 }
287 }
288
289 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
291 let cols = out.ncols();
292 let dim = out.nrows();
293 assert!(start + cols <= dim);
294 let mut basis = Array1::<f64>::zeros(dim);
295 for local_col in 0..cols {
296 let global_col = start + local_col;
297 basis[global_col] = 1.0;
298 self.mul_vec_into(basis.view(), out.column_mut(local_col));
299 basis[global_col] = 0.0;
300 }
301 }
302
303 fn scaled_add_mul_vec(
305 &self,
306 v: ArrayView1<'_, f64>,
307 scale: f64,
308 mut out: ArrayViewMut1<'_, f64>,
309 ) {
310 if scale == 0.0 {
311 return;
312 }
313 let mut work = Array1::<f64>::zeros(out.len());
314 self.mul_vec_into(v, work.view_mut());
315 out.scaled_add(scale, &work);
316 }
317
318 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
320 let mut bv = Array1::<f64>::zeros(v.len());
321 self.mul_vec_into(v.view(), bv.view_mut());
322 u.dot(&bv)
323 }
324
325 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
327 let mut bv = Array1::<f64>::zeros(v.len());
328 self.mul_vec_into(v, bv.view_mut());
329 u.dot(&bv)
330 }
331
332 fn has_fast_bilinear_view(&self) -> bool {
334 false
335 }
336
337 fn to_dense(&self) -> Array2<f64> {
339 let p = self.dim();
340 let mut out = Array2::<f64>::zeros((p, p));
341 let mut basis = Array1::<f64>::zeros(p);
342 for j in 0..p {
343 basis[j] = 1.0;
344 self.mul_vec_into(basis.view(), out.column_mut(j));
345 basis[j] = 0.0;
346 }
347 out
348 }
349
350 fn is_implicit(&self) -> bool;
352
353 fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
355 None
356 }
357}
358
359#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
360pub struct ProjectedFactorKey {
361 pub(crate) design_id: usize,
362 pub(crate) factor_ptr: usize,
363 pub(crate) rows: usize,
364 pub(crate) cols: usize,
365 pub(crate) row_stride: isize,
366 pub(crate) col_stride: isize,
367 pub(crate) value_hash: u64,
368 pub(crate) value_hash2: u64,
369}
370
371impl ProjectedFactorKey {
372 pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
373 let strides = factor.strides();
374 let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
375 Self {
376 design_id,
377 factor_ptr: factor.as_ptr() as usize,
378 rows: factor.nrows(),
379 cols: factor.ncols(),
380 row_stride: strides[0],
381 col_stride: strides[1],
382 value_hash,
383 value_hash2,
384 }
385 }
386
387 pub fn synthetic(seed: u64) -> Self {
392 Self {
393 design_id: 1,
394 factor_ptr: seed as usize,
395 rows: 1,
396 cols: 1,
397 row_stride: 1,
398 col_stride: 1,
399 value_hash: seed,
400 value_hash2: seed.wrapping_mul(31),
401 }
402 }
403}
404
405pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
406 let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
407 let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
408 for (idx, value) in factor.iter().enumerate() {
409 let bits = value.to_bits();
410 let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
411 h1 ^= mixed;
412 h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
413 h2 ^= bits.rotate_left((idx & 63) as u32);
414 h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
415 }
416 (h1, h2)
417}
418
419pub struct ProjectedFactorCache {
421 pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
422}
423
424pub(crate) struct ProjectedFactorCacheInner {
425 pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
426 pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
427 pub(crate) next_seq: u64,
428 pub(crate) total_bytes: usize,
429 pub(crate) budget_bytes: usize,
430}
431
432pub(crate) struct ProjectedFactorInProgress {
433 pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
434 pub(crate) ready: Condvar,
435 pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
436 pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
437}
438
439pub(crate) enum ProjectedFactorInProgressState {
440 Ready(Arc<Array2<f64>>),
441 Failed,
442}
443
444pub(crate) struct ProjectedFactorEntry {
445 pub(crate) value: Arc<Array2<f64>>,
446 pub(crate) bytes: usize,
447 pub(crate) last_used: u64,
448}
449
450impl Default for ProjectedFactorCache {
451 fn default() -> Self {
452 Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
453 }
454}
455
456impl ProjectedFactorCache {
457 pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
458
459 pub fn with_budget(budget_bytes: usize) -> Self {
460 Self {
461 inner: Mutex::new(ProjectedFactorCacheInner {
462 entries: HashMap::new(),
463 in_progress: HashMap::new(),
464 next_seq: 0,
465 total_bytes: 0,
466 budget_bytes,
467 }),
468 }
469 }
470
471 pub fn get_or_insert_with(
472 &self,
473 key: ProjectedFactorKey,
474 compute: impl FnOnce() -> Array2<f64>,
475 ) -> Arc<Array2<f64>> {
476 enum CacheLookup {
477 Hit(Arc<Array2<f64>>),
478 Wait(Arc<ProjectedFactorInProgress>),
479 Compute(Arc<ProjectedFactorInProgress>),
480 }
481
482 let lookup = {
483 let mut inner = self
484 .inner
485 .lock()
486 .expect("projected factor cache lock poisoned");
487 inner.next_seq += 1;
488 let now = inner.next_seq;
489 if let Some(entry) = inner.entries.get_mut(&key) {
490 entry.last_used = now;
491 CacheLookup::Hit(entry.value.clone())
492 } else if let Some(waiter) = inner.in_progress.get(&key) {
493 CacheLookup::Wait(waiter.clone())
494 } else {
495 let marker = Arc::new(ProjectedFactorInProgress {
496 state: Mutex::new(None),
497 ready: Condvar::new(),
498 waiter_count: std::sync::atomic::AtomicUsize::new(0),
499 subscriber_arrived: (Mutex::new(()), Condvar::new()),
500 });
501 inner.in_progress.insert(key, marker.clone());
502 CacheLookup::Compute(marker)
503 }
504 };
505
506 match lookup {
507 CacheLookup::Hit(value) => value,
508 CacheLookup::Wait(marker) => {
509 marker
510 .waiter_count
511 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
512 let (lock, cv) = &marker.subscriber_arrived;
513 drop(
514 lock.lock()
515 .expect("subscriber-arrived notification lock poisoned"),
516 );
517 cv.notify_all();
518 let mut guard = marker
519 .state
520 .lock()
521 .expect("projected factor in-progress lock poisoned");
522 let result = loop {
523 match guard.as_ref() {
524 Some(ProjectedFactorInProgressState::Ready(value)) => {
525 break value.clone();
526 }
527 Some(ProjectedFactorInProgressState::Failed) => {
528 marker
529 .waiter_count
530 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
531 reml_contract_panic("projected factor cache producer panicked")
532 }
533 None => {
534 guard = marker
535 .ready
536 .wait(guard)
537 .expect("projected factor in-progress wait poisoned");
538 }
539 }
540 };
541 marker
542 .waiter_count
543 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
544 result
545 }
546 CacheLookup::Compute(marker) => {
547 let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
548 Ok(value) => value,
549 Err(payload) => {
550 let mut inner = self
551 .inner
552 .lock()
553 .expect("projected factor cache lock poisoned");
554 inner.in_progress.remove(&key);
555 drop(inner);
556
557 let mut guard = marker
558 .state
559 .lock()
560 .expect("projected factor in-progress lock poisoned");
561 *guard = Some(ProjectedFactorInProgressState::Failed);
562 marker.ready.notify_all();
563 resume_unwind(payload);
564 }
565 };
566 let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
567 let mut inner = self
568 .inner
569 .lock()
570 .expect("projected factor cache lock poisoned");
571 inner.next_seq += 1;
572 let now = inner.next_seq;
573
574 if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
575 while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
576 && !inner.entries.is_empty()
577 {
578 let Some(oldest_key) = inner
579 .entries
580 .iter()
581 .min_by_key(|(_, e)| e.last_used)
582 .map(|(k, _)| *k)
583 else {
584 break;
585 };
586 if let Some(removed) = inner.entries.remove(&oldest_key) {
587 inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
588 }
589 }
590 }
591
592 let value = if let Some(entry) = inner.entries.get_mut(&key) {
593 entry.last_used = now;
594 entry.value.clone()
595 } else {
596 inner.entries.insert(
597 key,
598 ProjectedFactorEntry {
599 value: computed.clone(),
600 bytes,
601 last_used: now,
602 },
603 );
604 inner.total_bytes = inner.total_bytes.saturating_add(bytes);
605 computed
606 };
607 inner.in_progress.remove(&key);
608 drop(inner);
609
610 let mut guard = marker
611 .state
612 .lock()
613 .expect("projected factor in-progress lock poisoned");
614 *guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
615 marker.ready.notify_all();
616 value
617 }
618 }
619 }
620
621 pub fn len(&self) -> usize {
622 self.inner
623 .lock()
624 .map(|inner| inner.entries.len())
625 .unwrap_or(0)
626 }
627
628 pub fn total_bytes(&self) -> usize {
629 self.inner
630 .lock()
631 .map(|inner| inner.total_bytes)
632 .unwrap_or(0)
633 }
634
635 pub fn is_empty(&self) -> bool {
636 self.len() == 0
637 }
638
639 pub fn wait_for_subscriber(
650 &self,
651 key: ProjectedFactorKey,
652 timeout: std::time::Duration,
653 ) -> bool {
654 let marker = {
655 let inner = self
656 .inner
657 .lock()
658 .expect("projected factor cache lock poisoned");
659 let Some(m) = inner.in_progress.get(&key) else {
660 return false;
661 };
662 Arc::clone(m)
663 };
664 if marker
665 .waiter_count
666 .load(std::sync::atomic::Ordering::Acquire)
667 > 0
668 {
669 return true;
670 }
671 let (lock, cv) = &marker.subscriber_arrived;
672 let mut guard = lock
673 .lock()
674 .expect("subscriber-arrived notification lock poisoned");
675 let deadline = std::time::Instant::now() + timeout;
676 loop {
677 if marker
678 .waiter_count
679 .load(std::sync::atomic::Ordering::Acquire)
680 > 0
681 {
682 return true;
683 }
684 let now = std::time::Instant::now();
685 if now >= deadline {
686 return false;
687 }
688 let (next_guard, result) = cv
689 .wait_timeout(guard, deadline - now)
690 .expect("subscriber-arrived wait poisoned");
691 guard = next_guard;
692 if result.timed_out()
693 && marker
694 .waiter_count
695 .load(std::sync::atomic::Ordering::Acquire)
696 == 0
697 {
698 return false;
699 }
700 }
701 }
702}
703
704#[derive(Clone)]
705pub struct DenseMatrixHyperOperator {
706 pub matrix: Array2<f64>,
707}
708
709impl HyperOperator for DenseMatrixHyperOperator {
710 fn dim(&self) -> usize {
711 self.matrix.nrows()
712 }
713
714 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
715 self.matrix.dot(v)
716 }
717
718 fn as_any(&self) -> &(dyn Any + 'static) {
719 self
720 }
721
722 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
723 self.matrix.dot(&v)
724 }
725
726 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
727 assert_eq!(self.matrix.ncols(), v.len());
728 assert_eq!(self.matrix.nrows(), out.len());
729 for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
730 *out_value = row.dot(&v);
731 }
732 }
733
734 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
735 let end = start + out.ncols();
736 assert!(end <= self.matrix.ncols());
737 out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
738 }
739
740 fn scaled_add_mul_vec(
741 &self,
742 v: ArrayView1<'_, f64>,
743 scale: f64,
744 mut out: ArrayViewMut1<'_, f64>,
745 ) {
746 assert_eq!(self.matrix.ncols(), v.len());
747 assert_eq!(self.matrix.nrows(), out.len());
748 if scale == 0.0 {
749 return;
750 }
751 for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
752 *out_value += scale * row.dot(&v);
753 }
754 }
755
756 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
757 dense_bilinear(&self.matrix, v.view(), u.view())
758 }
759
760 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
761 dense_bilinear(&self.matrix, v, u)
762 }
763
764 fn to_dense(&self) -> Array2<f64> {
765 self.matrix.clone()
766 }
767
768 fn is_implicit(&self) -> bool {
769 false
770 }
771}
772
773#[derive(Clone)]
774pub struct BlockLocalDrift {
775 pub local: Array2<f64>,
776 pub start: usize,
777 pub end: usize,
778 pub total_dim: usize,
779}
780
781impl HyperOperator for BlockLocalDrift {
782 fn dim(&self) -> usize {
783 self.total_dim
784 }
785
786 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
787 assert_eq!(v.len(), self.total_dim);
788 let mut out = Array1::zeros(self.total_dim);
789 self.mul_vec_into(v.view(), out.view_mut());
790 out
791 }
792
793 fn as_any(&self) -> &(dyn Any + 'static) {
794 self
795 }
796
797 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
798 assert_eq!(v.len(), self.total_dim);
799 assert_eq!(out.len(), self.total_dim);
800 out.fill(0.0);
801 let v_block = v.slice(ndarray::s![self.start..self.end]);
802 let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
803 dense_matvec_into(&self.local, v_block, out_block.view_mut());
804 }
805
806 fn scaled_add_mul_vec(
807 &self,
808 v: ArrayView1<'_, f64>,
809 scale: f64,
810 mut out: ArrayViewMut1<'_, f64>,
811 ) {
812 assert_eq!(v.len(), self.total_dim);
813 assert_eq!(out.len(), self.total_dim);
814 if scale == 0.0 {
815 return;
816 }
817 let v_block = v.slice(ndarray::s![self.start..self.end]);
818 let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
819 dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
820 }
821
822 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
823 self.bilinear_view(v.view(), u.view())
824 }
825
826 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
827 assert_eq!(v.len(), self.total_dim);
828 assert_eq!(u.len(), self.total_dim);
829 let v_block = v.slice(ndarray::s![self.start..self.end]);
830 let u_block = u.slice(ndarray::s![self.start..self.end]);
831 dense_bilinear(&self.local, v_block, u_block)
832 }
833
834 fn to_dense(&self) -> Array2<f64> {
835 let p = self.total_dim;
836 let mut out = Array2::zeros((p, p));
837 out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
838 .assign(&self.local);
839 out
840 }
841
842 fn is_implicit(&self) -> bool {
843 false
844 }
845
846 fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
847 Some((&self.local, self.start, self.end))
848 }
849}
850
851#[derive(Clone)]
852pub struct HyperCoordDrift {
853 pub dense: Option<Array2<f64>>,
854 pub block_local: Option<BlockLocalDrift>,
855 pub operator: Option<Arc<dyn HyperOperator>>,
856}
857
858impl HyperCoordDrift {
859 pub fn none() -> Self {
860 Self {
861 dense: None,
862 block_local: None,
863 operator: None,
864 }
865 }
866
867 pub fn from_dense(dense: Array2<f64>) -> Self {
868 Self {
869 dense: Some(dense),
870 block_local: None,
871 operator: None,
872 }
873 }
874
875 pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
876 Self {
877 dense: None,
878 block_local: None,
879 operator: Some(operator),
880 }
881 }
882
883 pub fn from_parts(
884 dense: Option<Array2<f64>>,
885 operator: Option<Arc<dyn HyperOperator>>,
886 ) -> Self {
887 let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
888 Self {
889 dense,
890 block_local: None,
891 operator,
892 }
893 }
894
895 pub fn from_block_local_and_operator(
896 local: Array2<f64>,
897 start: usize,
898 end: usize,
899 total_dim: usize,
900 operator: Option<Arc<dyn HyperOperator>>,
901 ) -> Self {
902 Self {
903 dense: None,
904 block_local: Some(BlockLocalDrift {
905 local,
906 start,
907 end,
908 total_dim,
909 }),
910 operator,
911 }
912 }
913
914 pub fn has_operator(&self) -> bool {
915 self.operator.is_some()
916 }
917
918 pub fn uses_operator_fast_path(&self) -> bool {
919 self.operator.is_some() || self.block_local.is_some()
920 }
921
922 pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
923 self.operator.as_ref().map(Arc::as_ref)
924 }
925
926 pub fn materialize(&self) -> Array2<f64> {
927 let p = self.infer_dim();
928 if p == 0 {
929 return Array2::zeros((0, 0));
930 }
931 let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
932 if let Some(bl) = &self.block_local {
933 out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
934 .scaled_add(1.0, &bl.local);
935 }
936 if let Some(op) = &self.operator {
937 out += &op.to_dense();
938 }
939 out
940 }
941
942 pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
943 let mut out = Array1::zeros(v.len());
944 self.scaled_add_apply(v.view(), 1.0, &mut out);
945 out
946 }
947
948 pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
949 assert_eq!(v.len(), out.len());
950 if scale == 0.0 {
951 return;
952 }
953 if let Some(dense) = &self.dense {
954 dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
955 }
956 if let Some(bl) = &self.block_local {
957 let v_block = v.slice(ndarray::s![bl.start..bl.end]);
958 let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
959 dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
960 }
961 if let Some(op) = &self.operator {
962 op.scaled_add_mul_vec(v, scale, out.view_mut());
963 }
964 }
965
966 pub(crate) fn infer_dim(&self) -> usize {
967 if let Some(d) = &self.dense {
968 return d.nrows();
969 }
970 if let Some(op) = &self.operator {
971 return op.dim();
972 }
973 if let Some(bl) = &self.block_local {
974 return bl.total_dim;
975 }
976 0
977 }
978}
979
980#[derive(Clone)]
981pub struct HyperCoord {
982 pub a: f64,
983 pub g: Array1<f64>,
984 pub drift: HyperCoordDrift,
985 pub ld_s: f64,
986 pub b_depends_on_beta: bool,
987 pub is_penalty_like: bool,
988 pub firth_g: Option<Array1<f64>>,
989 pub tk_eta_fixed: Option<Array1<f64>>,
990 pub tk_x_fixed: Option<Array2<f64>>,
991}
992
993pub struct HyperCoordPair {
994 pub a: f64,
995 pub g: Array1<f64>,
996 pub b_mat: Array2<f64>,
997 pub b_operator: Option<Box<dyn HyperOperator>>,
998 pub ld_s: f64,
999}
1000
1001pub type HyperCoordPairFn = Arc<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>;
1012
1013impl HyperCoordPair {
1014 pub fn zero() -> Self {
1015 Self {
1016 a: 0.0,
1017 g: Array1::zeros(0),
1018 b_mat: Array2::zeros((0, 0)),
1019 b_operator: None,
1020 ld_s: 0.0,
1021 }
1022 }
1023}
1024
1025#[derive(Clone)]
1026pub enum DriftDerivResult {
1027 Dense(Array2<f64>),
1028 Operator(Arc<dyn HyperOperator>),
1029}
1030
1031impl std::fmt::Debug for DriftDerivResult {
1032 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1033 match self {
1034 Self::Dense(matrix) => f
1035 .debug_tuple("Dense")
1036 .field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
1037 .finish(),
1038 Self::Operator(_) => f
1039 .debug_tuple("Operator")
1040 .field(&"<hyper-operator>")
1041 .finish(),
1042 }
1043 }
1044}
1045
1046impl DriftDerivResult {
1047 pub fn into_operator(self) -> Arc<dyn HyperOperator> {
1048 match self {
1049 Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
1050 Self::Operator(operator) => operator,
1051 }
1052 }
1053
1054 pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
1055 match self {
1056 Self::Dense(matrix) => matrix.dot(v),
1057 Self::Operator(operator) => operator.mul_vec(v),
1058 }
1059 }
1060}
1061
1062pub type FixedDriftDerivFn =
1063 Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
1064
1065pub type SharedFixedDriftDerivFn =
1073 Arc<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
1074
1075pub struct ContractedPsiSecondOrder {
1076 pub objective: Array1<f64>,
1077 pub score: Array2<f64>,
1078 pub hessian: Vec<DriftDerivResult>,
1079 pub ld_s: Array1<f64>,
1080}
1081
1082pub type ContractedPsiSecondOrderFn =
1083 Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;