1use std::any::Any;
8use std::collections::HashMap;
9use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
10use std::sync::{Arc, Condvar, Mutex};
11
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14
15#[macro_use]
16mod macros;
17
18mod linalg_helpers;
19mod linear_constraints;
20mod pseudo_logdet;
21mod seeding;
22pub mod solver_contract;
23
24mod gpu {
25 pub(crate) mod linalg_dispatch {
26 use ndarray::{Array2, ArrayView2};
27
28 pub(crate) fn try_fast_atb(
29 a: ArrayView2<'_, f64>,
30 b: ArrayView2<'_, f64>,
31 ) -> Option<Array2<f64>> {
32 let (n_a, p) = a.dim();
33 let (n_b, q) = b.dim();
34 assert_eq!(n_a, n_b, "A and B must have same number of rows");
35 if !crate::linalg_helpers::should_use_faer_matmul(p, q, n_a) {
36 return None;
37 }
38 Some(crate::linalg_helpers::fast_atb_with_parallelism(
39 &a,
40 &b,
41 crate::linalg_helpers::matmul_parallelism(p, q, n_a),
42 ))
43 }
44 }
45}
46
47pub use gam_linalg::faer_ndarray::{in_nested_parallel_region, with_nested_parallel};
48use linalg_helpers::{dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into};
49pub use linear_constraints::LinearInequalityConstraints;
50pub use pseudo_logdet::PseudoLogdetMode;
51pub use seeding::{SeedConfig, SeedRiskProfile, clamp_seed_rho_to_bounds, normalize_seed_bounds};
52pub use solver_contract::{
53 DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterEval,
54 OuterHessianMaterialization, OuterHessianOperator, OuterStrategyError,
55};
56
57#[cold]
58fn reml_contract_panic(message: impl Into<String>) -> ! {
59 std::panic::panic_any(message.into())
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum EvalMode {
65 ValueOnly,
67 ValueAndGradient,
69 ValueGradientHessian,
71}
72
73struct NonDowncastableHyperOperator;
76
77static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
78
79pub trait HyperOperator: Send + Sync {
80 fn dim(&self) -> usize;
83
84 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
86
87 fn as_any(&self) -> &(dyn Any + 'static) {
91 &NON_DOWNCASTABLE_HYPER_OPERATOR
92 }
93
94 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
96 self.mul_vec(&v.to_owned())
97 }
98
99 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
101 out.assign(&self.mul_vec_view(v));
102 }
103
104 fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
107 let p = factor.nrows();
108 let k = factor.ncols();
109 let mut out = Array2::<f64>::zeros((p, k));
110 if rayon::current_thread_index().is_some() {
111 for col in 0..k {
112 let bv = out.column_mut(col);
113 self.mul_vec_into(factor.column(col), bv);
114 }
115 return out;
116 }
117 let cols: Vec<Array1<f64>> = (0..k)
118 .into_par_iter()
119 .map(|col| {
120 let mut bv = Array1::<f64>::zeros(p);
121 self.mul_vec_into(factor.column(col), bv.view_mut());
122 bv
123 })
124 .collect();
125 for (col, bv) in cols.into_iter().enumerate() {
126 out.column_mut(col).assign(&bv);
127 }
128 out
129 }
130
131 fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
133 let op_factor = self.mul_mat(factor);
134 factor
135 .iter()
136 .zip(op_factor.iter())
137 .map(|(&f, &bf)| f * bf)
138 .sum()
139 }
140
141 fn projection_design_id(&self) -> Option<usize> {
150 None
151 }
152
153 fn trace_projected_factor_cached(
154 &self,
155 factor: &Array2<f64>,
156 factor_cache: &ProjectedFactorCache,
157 ) -> f64 {
158 assert!(std::mem::size_of_val(factor_cache) > 0);
162 match self.projection_design_id() {
163 Some(design_id) => {
164 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
165 let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
166 factor
167 .iter()
168 .zip(projected.iter())
169 .map(|(&f, &bf)| f * bf)
170 .sum()
171 }
172 None => self.trace_projected_factor(factor),
173 }
174 }
175
176 fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
178 let op_factor = self.mul_mat(factor);
179 crate::linalg_helpers::fast_atb(factor, &op_factor)
180 }
181
182 fn projected_matrix_cached(
185 &self,
186 factor: &Array2<f64>,
187 factor_cache: &ProjectedFactorCache,
188 ) -> Array2<f64> {
189 assert!(std::mem::size_of_val(factor_cache) > 0);
190 match self.projection_design_id() {
191 Some(design_id) => {
192 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
193 let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
194 crate::linalg_helpers::fast_atb(factor, projected.as_ref())
195 }
196 None => self.projected_matrix(factor),
197 }
198 }
199
200 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
202 let cols = out.ncols();
203 let dim = out.nrows();
204 assert!(start + cols <= dim);
205 let mut basis = Array1::<f64>::zeros(dim);
206 for local_col in 0..cols {
207 let global_col = start + local_col;
208 basis[global_col] = 1.0;
209 self.mul_vec_into(basis.view(), out.column_mut(local_col));
210 basis[global_col] = 0.0;
211 }
212 }
213
214 fn scaled_add_mul_vec(
216 &self,
217 v: ArrayView1<'_, f64>,
218 scale: f64,
219 mut out: ArrayViewMut1<'_, f64>,
220 ) {
221 if scale == 0.0 {
222 return;
223 }
224 let mut work = Array1::<f64>::zeros(out.len());
225 self.mul_vec_into(v, work.view_mut());
226 out.scaled_add(scale, &work);
227 }
228
229 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
231 let mut bv = Array1::<f64>::zeros(v.len());
232 self.mul_vec_into(v.view(), bv.view_mut());
233 u.dot(&bv)
234 }
235
236 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
238 let mut bv = Array1::<f64>::zeros(v.len());
239 self.mul_vec_into(v, bv.view_mut());
240 u.dot(&bv)
241 }
242
243 fn has_fast_bilinear_view(&self) -> bool {
245 false
246 }
247
248 fn to_dense(&self) -> Array2<f64> {
250 let p = self.dim();
251 let mut out = Array2::<f64>::zeros((p, p));
252 let mut basis = Array1::<f64>::zeros(p);
253 for j in 0..p {
254 basis[j] = 1.0;
255 self.mul_vec_into(basis.view(), out.column_mut(j));
256 basis[j] = 0.0;
257 }
258 out
259 }
260
261 fn is_implicit(&self) -> bool;
263
264 fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
266 None
267 }
268}
269
270#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
271pub struct ProjectedFactorKey {
272 pub(crate) design_id: usize,
273 pub(crate) factor_ptr: usize,
274 pub(crate) rows: usize,
275 pub(crate) cols: usize,
276 pub(crate) row_stride: isize,
277 pub(crate) col_stride: isize,
278 pub(crate) value_hash: u64,
279 pub(crate) value_hash2: u64,
280}
281
282impl ProjectedFactorKey {
283 pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
284 let strides = factor.strides();
285 let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
286 Self {
287 design_id,
288 factor_ptr: factor.as_ptr() as usize,
289 rows: factor.nrows(),
290 cols: factor.ncols(),
291 row_stride: strides[0],
292 col_stride: strides[1],
293 value_hash,
294 value_hash2,
295 }
296 }
297
298 pub fn synthetic(seed: u64) -> Self {
303 Self {
304 design_id: 1,
305 factor_ptr: seed as usize,
306 rows: 1,
307 cols: 1,
308 row_stride: 1,
309 col_stride: 1,
310 value_hash: seed,
311 value_hash2: seed.wrapping_mul(31),
312 }
313 }
314}
315
316pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
317 let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
318 let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
319 for (idx, value) in factor.iter().enumerate() {
320 let bits = value.to_bits();
321 let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
322 h1 ^= mixed;
323 h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
324 h2 ^= bits.rotate_left((idx & 63) as u32);
325 h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
326 }
327 (h1, h2)
328}
329
330pub struct ProjectedFactorCache {
332 pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
333}
334
335pub(crate) struct ProjectedFactorCacheInner {
336 pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
337 pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
338 pub(crate) next_seq: u64,
339 pub(crate) total_bytes: usize,
340 pub(crate) budget_bytes: usize,
341}
342
343pub(crate) struct ProjectedFactorInProgress {
344 pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
345 pub(crate) ready: Condvar,
346 pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
347 pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
348}
349
350pub(crate) enum ProjectedFactorInProgressState {
351 Ready(Arc<Array2<f64>>),
352 Failed,
353}
354
355pub(crate) struct ProjectedFactorEntry {
356 pub(crate) value: Arc<Array2<f64>>,
357 pub(crate) bytes: usize,
358 pub(crate) last_used: u64,
359}
360
361impl Default for ProjectedFactorCache {
362 fn default() -> Self {
363 Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
364 }
365}
366
367impl ProjectedFactorCache {
368 pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
369
370 pub fn with_budget(budget_bytes: usize) -> Self {
371 Self {
372 inner: Mutex::new(ProjectedFactorCacheInner {
373 entries: HashMap::new(),
374 in_progress: HashMap::new(),
375 next_seq: 0,
376 total_bytes: 0,
377 budget_bytes,
378 }),
379 }
380 }
381
382 pub fn get_or_insert_with(
383 &self,
384 key: ProjectedFactorKey,
385 compute: impl FnOnce() -> Array2<f64>,
386 ) -> Arc<Array2<f64>> {
387 enum CacheLookup {
388 Hit(Arc<Array2<f64>>),
389 Wait(Arc<ProjectedFactorInProgress>),
390 Compute(Arc<ProjectedFactorInProgress>),
391 }
392
393 let lookup = {
394 let mut inner = self
395 .inner
396 .lock()
397 .expect("projected factor cache lock poisoned");
398 inner.next_seq += 1;
399 let now = inner.next_seq;
400 if let Some(entry) = inner.entries.get_mut(&key) {
401 entry.last_used = now;
402 CacheLookup::Hit(entry.value.clone())
403 } else if let Some(waiter) = inner.in_progress.get(&key) {
404 CacheLookup::Wait(waiter.clone())
405 } else {
406 let marker = Arc::new(ProjectedFactorInProgress {
407 state: Mutex::new(None),
408 ready: Condvar::new(),
409 waiter_count: std::sync::atomic::AtomicUsize::new(0),
410 subscriber_arrived: (Mutex::new(()), Condvar::new()),
411 });
412 inner.in_progress.insert(key, marker.clone());
413 CacheLookup::Compute(marker)
414 }
415 };
416
417 match lookup {
418 CacheLookup::Hit(value) => value,
419 CacheLookup::Wait(marker) => {
420 marker
421 .waiter_count
422 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
423 let (lock, cv) = &marker.subscriber_arrived;
424 drop(
425 lock.lock()
426 .expect("subscriber-arrived notification lock poisoned"),
427 );
428 cv.notify_all();
429 let mut guard = marker
430 .state
431 .lock()
432 .expect("projected factor in-progress lock poisoned");
433 let result = loop {
434 match guard.as_ref() {
435 Some(ProjectedFactorInProgressState::Ready(value)) => {
436 break value.clone();
437 }
438 Some(ProjectedFactorInProgressState::Failed) => {
439 marker
440 .waiter_count
441 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
442 reml_contract_panic("projected factor cache producer panicked")
443 }
444 None => {
445 guard = marker
446 .ready
447 .wait(guard)
448 .expect("projected factor in-progress wait poisoned");
449 }
450 }
451 };
452 marker
453 .waiter_count
454 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
455 result
456 }
457 CacheLookup::Compute(marker) => {
458 let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
459 Ok(value) => value,
460 Err(payload) => {
461 let mut inner = self
462 .inner
463 .lock()
464 .expect("projected factor cache lock poisoned");
465 inner.in_progress.remove(&key);
466 drop(inner);
467
468 let mut guard = marker
469 .state
470 .lock()
471 .expect("projected factor in-progress lock poisoned");
472 *guard = Some(ProjectedFactorInProgressState::Failed);
473 marker.ready.notify_all();
474 resume_unwind(payload);
475 }
476 };
477 let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
478 let mut inner = self
479 .inner
480 .lock()
481 .expect("projected factor cache lock poisoned");
482 inner.next_seq += 1;
483 let now = inner.next_seq;
484
485 if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
486 while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
487 && !inner.entries.is_empty()
488 {
489 let Some(oldest_key) = inner
490 .entries
491 .iter()
492 .min_by_key(|(_, e)| e.last_used)
493 .map(|(k, _)| *k)
494 else {
495 break;
496 };
497 if let Some(removed) = inner.entries.remove(&oldest_key) {
498 inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
499 }
500 }
501 }
502
503 let value = if let Some(entry) = inner.entries.get_mut(&key) {
504 entry.last_used = now;
505 entry.value.clone()
506 } else {
507 inner.entries.insert(
508 key,
509 ProjectedFactorEntry {
510 value: computed.clone(),
511 bytes,
512 last_used: now,
513 },
514 );
515 inner.total_bytes = inner.total_bytes.saturating_add(bytes);
516 computed
517 };
518 inner.in_progress.remove(&key);
519 drop(inner);
520
521 let mut guard = marker
522 .state
523 .lock()
524 .expect("projected factor in-progress lock poisoned");
525 *guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
526 marker.ready.notify_all();
527 value
528 }
529 }
530 }
531
532 pub fn len(&self) -> usize {
533 self.inner
534 .lock()
535 .map(|inner| inner.entries.len())
536 .unwrap_or(0)
537 }
538
539 pub fn total_bytes(&self) -> usize {
540 self.inner
541 .lock()
542 .map(|inner| inner.total_bytes)
543 .unwrap_or(0)
544 }
545
546 pub fn is_empty(&self) -> bool {
547 self.len() == 0
548 }
549
550 pub fn wait_for_subscriber(
561 &self,
562 key: ProjectedFactorKey,
563 timeout: std::time::Duration,
564 ) -> bool {
565 let marker = {
566 let inner = self
567 .inner
568 .lock()
569 .expect("projected factor cache lock poisoned");
570 let Some(m) = inner.in_progress.get(&key) else {
571 return false;
572 };
573 Arc::clone(m)
574 };
575 if marker
576 .waiter_count
577 .load(std::sync::atomic::Ordering::Acquire)
578 > 0
579 {
580 return true;
581 }
582 let (lock, cv) = &marker.subscriber_arrived;
583 let mut guard = lock
584 .lock()
585 .expect("subscriber-arrived notification lock poisoned");
586 let deadline = std::time::Instant::now() + timeout;
587 loop {
588 if marker
589 .waiter_count
590 .load(std::sync::atomic::Ordering::Acquire)
591 > 0
592 {
593 return true;
594 }
595 let now = std::time::Instant::now();
596 if now >= deadline {
597 return false;
598 }
599 let (next_guard, result) = cv
600 .wait_timeout(guard, deadline - now)
601 .expect("subscriber-arrived wait poisoned");
602 guard = next_guard;
603 if result.timed_out()
604 && marker
605 .waiter_count
606 .load(std::sync::atomic::Ordering::Acquire)
607 == 0
608 {
609 return false;
610 }
611 }
612 }
613}
614
615#[derive(Clone)]
616pub struct DenseMatrixHyperOperator {
617 pub matrix: Array2<f64>,
618}
619
620impl HyperOperator for DenseMatrixHyperOperator {
621 fn dim(&self) -> usize {
622 self.matrix.nrows()
623 }
624
625 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
626 self.matrix.dot(v)
627 }
628
629 fn as_any(&self) -> &(dyn Any + 'static) {
630 self
631 }
632
633 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
634 self.matrix.dot(&v)
635 }
636
637 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
638 assert_eq!(self.matrix.ncols(), v.len());
639 assert_eq!(self.matrix.nrows(), out.len());
640 for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
641 *out_value = row.dot(&v);
642 }
643 }
644
645 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
646 let end = start + out.ncols();
647 assert!(end <= self.matrix.ncols());
648 out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
649 }
650
651 fn scaled_add_mul_vec(
652 &self,
653 v: ArrayView1<'_, f64>,
654 scale: f64,
655 mut out: ArrayViewMut1<'_, f64>,
656 ) {
657 assert_eq!(self.matrix.ncols(), v.len());
658 assert_eq!(self.matrix.nrows(), out.len());
659 if scale == 0.0 {
660 return;
661 }
662 for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
663 *out_value += scale * row.dot(&v);
664 }
665 }
666
667 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
668 dense_bilinear(&self.matrix, v.view(), u.view())
669 }
670
671 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
672 dense_bilinear(&self.matrix, v, u)
673 }
674
675 fn to_dense(&self) -> Array2<f64> {
676 self.matrix.clone()
677 }
678
679 fn is_implicit(&self) -> bool {
680 false
681 }
682}
683
684#[derive(Clone)]
685pub struct BlockLocalDrift {
686 pub local: Array2<f64>,
687 pub start: usize,
688 pub end: usize,
689 pub total_dim: usize,
690}
691
692impl HyperOperator for BlockLocalDrift {
693 fn dim(&self) -> usize {
694 self.total_dim
695 }
696
697 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
698 assert_eq!(v.len(), self.total_dim);
699 let mut out = Array1::zeros(self.total_dim);
700 self.mul_vec_into(v.view(), out.view_mut());
701 out
702 }
703
704 fn as_any(&self) -> &(dyn Any + 'static) {
705 self
706 }
707
708 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
709 assert_eq!(v.len(), self.total_dim);
710 assert_eq!(out.len(), self.total_dim);
711 out.fill(0.0);
712 let v_block = v.slice(ndarray::s![self.start..self.end]);
713 let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
714 dense_matvec_into(&self.local, v_block, out_block.view_mut());
715 }
716
717 fn scaled_add_mul_vec(
718 &self,
719 v: ArrayView1<'_, f64>,
720 scale: f64,
721 mut out: ArrayViewMut1<'_, f64>,
722 ) {
723 assert_eq!(v.len(), self.total_dim);
724 assert_eq!(out.len(), self.total_dim);
725 if scale == 0.0 {
726 return;
727 }
728 let v_block = v.slice(ndarray::s![self.start..self.end]);
729 let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
730 dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
731 }
732
733 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
734 self.bilinear_view(v.view(), u.view())
735 }
736
737 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
738 assert_eq!(v.len(), self.total_dim);
739 assert_eq!(u.len(), self.total_dim);
740 let v_block = v.slice(ndarray::s![self.start..self.end]);
741 let u_block = u.slice(ndarray::s![self.start..self.end]);
742 dense_bilinear(&self.local, v_block, u_block)
743 }
744
745 fn to_dense(&self) -> Array2<f64> {
746 let p = self.total_dim;
747 let mut out = Array2::zeros((p, p));
748 out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
749 .assign(&self.local);
750 out
751 }
752
753 fn is_implicit(&self) -> bool {
754 false
755 }
756
757 fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
758 Some((&self.local, self.start, self.end))
759 }
760}
761
762#[derive(Clone)]
763pub struct HyperCoordDrift {
764 pub dense: Option<Array2<f64>>,
765 pub block_local: Option<BlockLocalDrift>,
766 pub operator: Option<Arc<dyn HyperOperator>>,
767}
768
769impl HyperCoordDrift {
770 pub fn none() -> Self {
771 Self {
772 dense: None,
773 block_local: None,
774 operator: None,
775 }
776 }
777
778 pub fn from_dense(dense: Array2<f64>) -> Self {
779 Self {
780 dense: Some(dense),
781 block_local: None,
782 operator: None,
783 }
784 }
785
786 pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
787 Self {
788 dense: None,
789 block_local: None,
790 operator: Some(operator),
791 }
792 }
793
794 pub fn from_parts(
795 dense: Option<Array2<f64>>,
796 operator: Option<Arc<dyn HyperOperator>>,
797 ) -> Self {
798 let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
799 Self {
800 dense,
801 block_local: None,
802 operator,
803 }
804 }
805
806 pub fn from_block_local_and_operator(
807 local: Array2<f64>,
808 start: usize,
809 end: usize,
810 total_dim: usize,
811 operator: Option<Arc<dyn HyperOperator>>,
812 ) -> Self {
813 Self {
814 dense: None,
815 block_local: Some(BlockLocalDrift {
816 local,
817 start,
818 end,
819 total_dim,
820 }),
821 operator,
822 }
823 }
824
825 pub fn has_operator(&self) -> bool {
826 self.operator.is_some()
827 }
828
829 pub fn uses_operator_fast_path(&self) -> bool {
830 self.operator.is_some() || self.block_local.is_some()
831 }
832
833 pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
834 self.operator.as_ref().map(Arc::as_ref)
835 }
836
837 pub fn materialize(&self) -> Array2<f64> {
838 let p = self.infer_dim();
839 if p == 0 {
840 return Array2::zeros((0, 0));
841 }
842 let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
843 if let Some(bl) = &self.block_local {
844 out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
845 .scaled_add(1.0, &bl.local);
846 }
847 if let Some(op) = &self.operator {
848 out += &op.to_dense();
849 }
850 out
851 }
852
853 pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
854 let mut out = Array1::zeros(v.len());
855 self.scaled_add_apply(v.view(), 1.0, &mut out);
856 out
857 }
858
859 pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
860 assert_eq!(v.len(), out.len());
861 if scale == 0.0 {
862 return;
863 }
864 if let Some(dense) = &self.dense {
865 dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
866 }
867 if let Some(bl) = &self.block_local {
868 let v_block = v.slice(ndarray::s![bl.start..bl.end]);
869 let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
870 dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
871 }
872 if let Some(op) = &self.operator {
873 op.scaled_add_mul_vec(v, scale, out.view_mut());
874 }
875 }
876
877 pub(crate) fn infer_dim(&self) -> usize {
878 if let Some(d) = &self.dense {
879 return d.nrows();
880 }
881 if let Some(op) = &self.operator {
882 return op.dim();
883 }
884 if let Some(bl) = &self.block_local {
885 return bl.total_dim;
886 }
887 0
888 }
889}
890
891#[derive(Clone)]
892pub struct HyperCoord {
893 pub a: f64,
894 pub g: Array1<f64>,
895 pub drift: HyperCoordDrift,
896 pub ld_s: f64,
897 pub b_depends_on_beta: bool,
898 pub is_penalty_like: bool,
899 pub firth_g: Option<Array1<f64>>,
900 pub tk_eta_fixed: Option<Array1<f64>>,
901 pub tk_x_fixed: Option<Array2<f64>>,
902}
903
904pub struct HyperCoordPair {
905 pub a: f64,
906 pub g: Array1<f64>,
907 pub b_mat: Array2<f64>,
908 pub b_operator: Option<Box<dyn HyperOperator>>,
909 pub ld_s: f64,
910}
911
912impl HyperCoordPair {
913 pub fn zero() -> Self {
914 Self {
915 a: 0.0,
916 g: Array1::zeros(0),
917 b_mat: Array2::zeros((0, 0)),
918 b_operator: None,
919 ld_s: 0.0,
920 }
921 }
922}
923
924#[derive(Clone)]
925pub enum DriftDerivResult {
926 Dense(Array2<f64>),
927 Operator(Arc<dyn HyperOperator>),
928}
929
930impl std::fmt::Debug for DriftDerivResult {
931 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
932 match self {
933 Self::Dense(matrix) => f
934 .debug_tuple("Dense")
935 .field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
936 .finish(),
937 Self::Operator(_) => f
938 .debug_tuple("Operator")
939 .field(&"<hyper-operator>")
940 .finish(),
941 }
942 }
943}
944
945impl DriftDerivResult {
946 pub fn into_operator(self) -> Arc<dyn HyperOperator> {
947 match self {
948 Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
949 Self::Operator(operator) => operator,
950 }
951 }
952
953 pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
954 match self {
955 Self::Dense(matrix) => matrix.dot(v),
956 Self::Operator(operator) => operator.mul_vec(v),
957 }
958 }
959}
960
961pub type FixedDriftDerivFn =
962 Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
963
964pub struct ContractedPsiSecondOrder {
965 pub objective: Array1<f64>,
966 pub score: Array2<f64>,
967 pub hessian: Vec<DriftDerivResult>,
968 pub ld_s: Array1<f64>,
969}
970
971pub type ContractedPsiSecondOrderFn =
972 Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;