use std::any::Any;
use std::collections::HashMap;
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::sync::{Arc, Condvar, Mutex};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[cold]
fn reml_contract_panic(message: impl Into<String>) -> ! {
std::panic::panic_any(message.into())
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EvalMode {
ValueOnly,
ValueAndGradient,
ValueGradientHessian,
}
struct NonDowncastableHyperOperator;
static NON_DOWNCASTABLE_HYPER_OPERATOR: NonDowncastableHyperOperator = NonDowncastableHyperOperator;
pub trait HyperOperator: Send + Sync {
fn dim(&self) -> usize;
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
fn as_any(&self) -> &(dyn Any + 'static) {
&NON_DOWNCASTABLE_HYPER_OPERATOR
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.mul_vec(&v.to_owned())
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.assign(&self.mul_vec_view(v));
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if rayon::current_thread_index().is_some() {
for col in 0..k {
let bv = out.column_mut(col);
self.mul_vec_into(factor.column(col), bv);
}
return out;
}
let cols: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|col| {
let mut bv = Array1::<f64>::zeros(p);
self.mul_vec_into(factor.column(col), bv.view_mut());
bv
})
.collect();
for (col, bv) in cols.into_iter().enumerate() {
out.column_mut(col).assign(&bv);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
let op_factor = self.mul_mat(factor);
factor
.iter()
.zip(op_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum()
}
fn projection_design_id(&self) -> Option<usize> {
None
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
factor_cache: &ProjectedFactorCache,
) -> f64 {
assert!(std::mem::size_of_val(factor_cache) > 0);
match self.projection_design_id() {
Some(design_id) => {
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
factor
.iter()
.zip(projected.iter())
.map(|(&f, &bf)| f * bf)
.sum()
}
None => self.trace_projected_factor(factor),
}
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
let op_factor = self.mul_mat(factor);
crate::faer_ndarray::fast_atb(factor, &op_factor)
}
fn projected_matrix_cached(
&self,
factor: &Array2<f64>,
factor_cache: &ProjectedFactorCache,
) -> Array2<f64> {
assert!(std::mem::size_of_val(factor_cache) > 0);
match self.projection_design_id() {
Some(design_id) => {
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
let projected = factor_cache.get_or_insert_with(key, || self.mul_mat(factor));
crate::faer_ndarray::fast_atb(factor, projected.as_ref())
}
None => self.projected_matrix(factor),
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
let dim = out.nrows();
assert!(start + cols <= dim);
let mut basis = Array1::<f64>::zeros(dim);
for local_col in 0..cols {
let global_col = start + local_col;
basis[global_col] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(local_col));
basis[global_col] = 0.0;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let mut work = Array1::<f64>::zeros(out.len());
self.mul_vec_into(v, work.view_mut());
out.scaled_add(scale, &work);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), bv.view_mut());
u.dot(&bv)
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, bv.view_mut());
u.dot(&bv)
}
fn has_fast_bilinear_view(&self) -> bool {
false
}
fn to_dense(&self) -> Array2<f64> {
let p = self.dim();
let mut out = Array2::<f64>::zeros((p, p));
let mut basis = Array1::<f64>::zeros(p);
for j in 0..p {
basis[j] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(j));
basis[j] = 0.0;
}
out
}
fn is_implicit(&self) -> bool;
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
None
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct ProjectedFactorKey {
pub(crate) design_id: usize,
pub(crate) factor_ptr: usize,
pub(crate) rows: usize,
pub(crate) cols: usize,
pub(crate) row_stride: isize,
pub(crate) col_stride: isize,
pub(crate) value_hash: u64,
pub(crate) value_hash2: u64,
}
impl ProjectedFactorKey {
pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
let strides = factor.strides();
let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
Self {
design_id,
factor_ptr: factor.as_ptr() as usize,
rows: factor.nrows(),
cols: factor.ncols(),
row_stride: strides[0],
col_stride: strides[1],
value_hash,
value_hash2,
}
}
}
pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
for (idx, value) in factor.iter().enumerate() {
let bits = value.to_bits();
let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
h1 ^= mixed;
h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
h2 ^= bits.rotate_left((idx & 63) as u32);
h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
}
(h1, h2)
}
pub struct ProjectedFactorCache {
pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
}
pub(crate) struct ProjectedFactorCacheInner {
pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
pub(crate) next_seq: u64,
pub(crate) total_bytes: usize,
pub(crate) budget_bytes: usize,
}
pub(crate) struct ProjectedFactorInProgress {
pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
pub(crate) ready: Condvar,
pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
}
pub(crate) enum ProjectedFactorInProgressState {
Ready(Arc<Array2<f64>>),
Failed,
}
pub(crate) struct ProjectedFactorEntry {
pub(crate) value: Arc<Array2<f64>>,
pub(crate) bytes: usize,
pub(crate) last_used: u64,
}
impl Default for ProjectedFactorCache {
fn default() -> Self {
Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
}
}
impl ProjectedFactorCache {
pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
pub fn with_budget(budget_bytes: usize) -> Self {
Self {
inner: Mutex::new(ProjectedFactorCacheInner {
entries: HashMap::new(),
in_progress: HashMap::new(),
next_seq: 0,
total_bytes: 0,
budget_bytes,
}),
}
}
pub fn get_or_insert_with(
&self,
key: ProjectedFactorKey,
compute: impl FnOnce() -> Array2<f64>,
) -> Arc<Array2<f64>> {
enum CacheLookup {
Hit(Arc<Array2<f64>>),
Wait(Arc<ProjectedFactorInProgress>),
Compute(Arc<ProjectedFactorInProgress>),
}
let lookup = {
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
CacheLookup::Hit(entry.value.clone())
} else if let Some(waiter) = inner.in_progress.get(&key) {
CacheLookup::Wait(waiter.clone())
} else {
let marker = Arc::new(ProjectedFactorInProgress {
state: Mutex::new(None),
ready: Condvar::new(),
waiter_count: std::sync::atomic::AtomicUsize::new(0),
subscriber_arrived: (Mutex::new(()), Condvar::new()),
});
inner.in_progress.insert(key, marker.clone());
CacheLookup::Compute(marker)
}
};
match lookup {
CacheLookup::Hit(value) => value,
CacheLookup::Wait(marker) => {
marker
.waiter_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
let (lock, cv) = &marker.subscriber_arrived;
drop(
lock.lock()
.expect("subscriber-arrived notification lock poisoned"),
);
cv.notify_all();
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
let result = loop {
match guard.as_ref() {
Some(ProjectedFactorInProgressState::Ready(value)) => {
break value.clone();
}
Some(ProjectedFactorInProgressState::Failed) => {
marker
.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
reml_contract_panic("projected factor cache producer panicked")
}
None => {
guard = marker
.ready
.wait(guard)
.expect("projected factor in-progress wait poisoned");
}
}
};
marker
.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
result
}
CacheLookup::Compute(marker) => {
let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
Ok(value) => value,
Err(payload) => {
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.in_progress.remove(&key);
drop(inner);
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
*guard = Some(ProjectedFactorInProgressState::Failed);
marker.ready.notify_all();
resume_unwind(payload);
}
};
let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
&& !inner.entries.is_empty()
{
let Some(oldest_key) = inner
.entries
.iter()
.min_by_key(|(_, e)| e.last_used)
.map(|(k, _)| *k)
else {
break;
};
if let Some(removed) = inner.entries.remove(&oldest_key) {
inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
}
}
}
let value = if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
entry.value.clone()
} else {
inner.entries.insert(
key,
ProjectedFactorEntry {
value: computed.clone(),
bytes,
last_used: now,
},
);
inner.total_bytes = inner.total_bytes.saturating_add(bytes);
computed
};
inner.in_progress.remove(&key);
drop(inner);
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
*guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
marker.ready.notify_all();
value
}
}
}
pub fn len(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.entries.len())
.unwrap_or(0)
}
pub fn total_bytes(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.total_bytes)
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone)]
pub struct DenseMatrixHyperOperator {
pub matrix: Array2<f64>,
}
impl HyperOperator for DenseMatrixHyperOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
self.matrix.dot(v)
}
fn as_any(&self) -> &(dyn Any + 'static) {
self
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.matrix.dot(&v)
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
assert_eq!(self.matrix.ncols(), v.len());
assert_eq!(self.matrix.nrows(), out.len());
for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
*out_value = row.dot(&v);
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let end = start + out.ncols();
assert!(end <= self.matrix.ncols());
out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
assert_eq!(self.matrix.ncols(), v.len());
assert_eq!(self.matrix.nrows(), out.len());
if scale == 0.0 {
return;
}
for (row, out_value) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
*out_value += scale * row.dot(&v);
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
dense_bilinear(&self.matrix, v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
dense_bilinear(&self.matrix, v, u)
}
fn to_dense(&self) -> Array2<f64> {
self.matrix.clone()
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone)]
pub struct BlockLocalDrift {
pub local: Array2<f64>,
pub start: usize,
pub end: usize,
pub total_dim: usize,
}
impl HyperOperator for BlockLocalDrift {
fn dim(&self) -> usize {
self.total_dim
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.total_dim);
let mut out = Array1::zeros(self.total_dim);
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn as_any(&self) -> &(dyn Any + 'static) {
self
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
assert_eq!(v.len(), self.total_dim);
assert_eq!(out.len(), self.total_dim);
out.fill(0.0);
let v_block = v.slice(ndarray::s![self.start..self.end]);
let mut out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_into(&self.local, v_block, out_block.view_mut());
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
assert_eq!(v.len(), self.total_dim);
assert_eq!(out.len(), self.total_dim);
if scale == 0.0 {
return;
}
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.bilinear_view(v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
assert_eq!(v.len(), self.total_dim);
assert_eq!(u.len(), self.total_dim);
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
dense_bilinear(&self.local, v_block, u_block)
}
fn to_dense(&self) -> Array2<f64> {
let p = self.total_dim;
let mut out = Array2::zeros((p, p));
out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
.assign(&self.local);
out
}
fn is_implicit(&self) -> bool {
false
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
Some((&self.local, self.start, self.end))
}
}
#[derive(Clone)]
pub struct HyperCoordDrift {
pub dense: Option<Array2<f64>>,
pub block_local: Option<BlockLocalDrift>,
pub operator: Option<Arc<dyn HyperOperator>>,
}
impl HyperCoordDrift {
pub fn none() -> Self {
Self {
dense: None,
block_local: None,
operator: None,
}
}
pub fn from_dense(dense: Array2<f64>) -> Self {
Self {
dense: Some(dense),
block_local: None,
operator: None,
}
}
pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
Self {
dense: None,
block_local: None,
operator: Some(operator),
}
}
pub fn from_parts(
dense: Option<Array2<f64>>,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
Self {
dense,
block_local: None,
operator,
}
}
pub fn from_block_local_and_operator(
local: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
Self {
dense: None,
block_local: Some(BlockLocalDrift {
local,
start,
end,
total_dim,
}),
operator,
}
}
pub fn has_operator(&self) -> bool {
self.operator.is_some()
}
pub fn uses_operator_fast_path(&self) -> bool {
self.operator.is_some() || self.block_local.is_some()
}
pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
self.operator.as_ref().map(Arc::as_ref)
}
pub fn materialize(&self) -> Array2<f64> {
let p = self.infer_dim();
if p == 0 {
return Array2::zeros((0, 0));
}
let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
if let Some(bl) = &self.block_local {
out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
.scaled_add(1.0, &bl.local);
}
if let Some(op) = &self.operator {
out += &op.to_dense();
}
out
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.scaled_add_apply(v.view(), 1.0, &mut out);
out
}
pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
assert_eq!(v.len(), out.len());
if scale == 0.0 {
return;
}
if let Some(dense) = &self.dense {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
if let Some(bl) = &self.block_local {
let v_block = v.slice(ndarray::s![bl.start..bl.end]);
let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
}
if let Some(op) = &self.operator {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
pub(crate) fn infer_dim(&self) -> usize {
if let Some(d) = &self.dense {
return d.nrows();
}
if let Some(op) = &self.operator {
return op.dim();
}
if let Some(bl) = &self.block_local {
return bl.total_dim;
}
0
}
}
#[derive(Clone)]
pub struct HyperCoord {
pub a: f64,
pub g: Array1<f64>,
pub drift: HyperCoordDrift,
pub ld_s: f64,
pub b_depends_on_beta: bool,
pub is_penalty_like: bool,
pub firth_g: Option<Array1<f64>>,
pub tk_eta_fixed: Option<Array1<f64>>,
pub tk_x_fixed: Option<Array2<f64>>,
}
pub struct HyperCoordPair {
pub a: f64,
pub g: Array1<f64>,
pub b_mat: Array2<f64>,
pub b_operator: Option<Box<dyn HyperOperator>>,
pub ld_s: f64,
}
impl HyperCoordPair {
pub fn zero() -> Self {
Self {
a: 0.0,
g: Array1::zeros(0),
b_mat: Array2::zeros((0, 0)),
b_operator: None,
ld_s: 0.0,
}
}
}
#[derive(Clone)]
pub enum DriftDerivResult {
Dense(Array2<f64>),
Operator(Arc<dyn HyperOperator>),
}
impl std::fmt::Debug for DriftDerivResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense(matrix) => f
.debug_tuple("Dense")
.field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
.finish(),
Self::Operator(_) => f
.debug_tuple("Operator")
.field(&"<hyper-operator>")
.finish(),
}
}
}
impl DriftDerivResult {
pub fn into_operator(self) -> Arc<dyn HyperOperator> {
match self {
Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
Self::Operator(operator) => operator,
}
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.dot(v),
Self::Operator(operator) => operator.mul_vec(v),
}
}
}
pub type FixedDriftDerivFn =
Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
pub struct ContractedPsiSecondOrder {
pub objective: Array1<f64>,
pub score: Array2<f64>,
pub hessian: Vec<DriftDerivResult>,
pub ld_s: Array1<f64>,
}
pub type ContractedPsiSecondOrderFn =
Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;
use crate::solver::estimate::reml::reml_outer_engine::{
dense_bilinear, dense_matvec_into, dense_matvec_scaled_add_into,
};