use crate::custom_family::{
ParameterBlockSpec, joint_coupled_coefficient_hessian_cost, use_joint_matrix_free_path,
};
pub fn operator_aware_hessian_cost(
p_total: u64,
n: u64,
matrix_free_cost: u64,
dense_cost: u64,
) -> u64 {
if use_joint_matrix_free_path(p_total as usize, n as usize) {
matrix_free_cost
} else {
dense_cost
}
}
pub fn joint_coupled_operator_aware_hessian_cost(n: u64, specs: &[ParameterBlockSpec]) -> u64 {
let p_total: u64 = specs
.iter()
.map(|s| s.design.ncols() as u64)
.fold(0u64, |acc, p| acc.saturating_add(p));
operator_aware_hessian_cost(
p_total,
n,
n.saturating_mul(p_total),
joint_coupled_coefficient_hessian_cost(n, specs),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_family::ParameterBlockSpec;
use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix};
use ndarray::Array2;
const MIN_DIM: u64 = 512;
const MIN_ROWS: u64 = 50_000;
const MIN_DIM_AT_LARGE_N: u64 = 128;
fn spec_with_ncols(p: usize) -> ParameterBlockSpec {
ParameterBlockSpec {
design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::zeros((1, p)))),
..ParameterBlockSpec::defaults()
}
}
#[test]
fn operator_aware_picks_dense_for_small_problem() {
let p = 4;
let n = 100;
let mf = 11;
let dense = 22;
assert!(!use_joint_matrix_free_path(p as usize, n as usize));
assert_eq!(operator_aware_hessian_cost(p, n, mf, dense), dense);
}
#[test]
fn operator_aware_picks_matrix_free_for_wide_problem() {
let p = MIN_DIM;
let n = 1;
let mf = 11;
let dense = 22;
assert!(use_joint_matrix_free_path(p as usize, n as usize));
assert_eq!(operator_aware_hessian_cost(p, n, mf, dense), mf);
}
#[test]
fn operator_aware_branch_is_exactly_the_gate() {
let mf = 7;
let dense = 99;
for &p in &[1u64, 64, 128, 256, 511, 512, 1024] {
for &n in &[1u64, 100, 49_999, 50_000, 100_000] {
let expected = if use_joint_matrix_free_path(p as usize, n as usize) {
mf
} else {
dense
};
assert_eq!(
operator_aware_hessian_cost(p, n, mf, dense),
expected,
"p={p} n={n}"
);
}
}
}
#[test]
fn gate_large_n_branch_boundary() {
let mf = 1;
let dense = 2;
assert_eq!(
operator_aware_hessian_cost(MIN_DIM_AT_LARGE_N - 1, MIN_ROWS, mf, dense),
dense
);
assert!(MIN_DIM_AT_LARGE_N * 30_000 < 4_000_000);
assert_eq!(
operator_aware_hessian_cost(MIN_DIM_AT_LARGE_N, 30_000, mf, dense),
dense
);
assert_eq!(
operator_aware_hessian_cost(MIN_DIM_AT_LARGE_N, MIN_ROWS, mf, dense),
mf
);
}
#[test]
fn gate_linear_work_branch() {
let p = 200u64;
let n = 20_000u64;
assert!(p * n >= 4_000_000);
assert!(n < MIN_ROWS);
assert!(p < MIN_DIM);
assert!(use_joint_matrix_free_path(p as usize, n as usize));
assert_eq!(operator_aware_hessian_cost(p, n, 5, 6), 5);
}
#[test]
fn joint_coupled_small_returns_dense_n_p_squared() {
let specs = [spec_with_ncols(3), spec_with_ncols(5)];
let n = 10u64;
let p_total = 8u64; assert!(!use_joint_matrix_free_path(p_total as usize, n as usize));
assert_eq!(
joint_coupled_operator_aware_hessian_cost(n, &specs),
n * p_total * p_total
);
}
#[test]
fn joint_coupled_wide_returns_matrix_free_n_times_p() {
let specs = [spec_with_ncols(300), spec_with_ncols(300)];
let n = 4u64;
let p_total = 600u64;
assert!(use_joint_matrix_free_path(p_total as usize, n as usize));
assert_eq!(
joint_coupled_operator_aware_hessian_cost(n, &specs),
n * p_total
);
}
#[test]
fn joint_coupled_empty_specs_is_zero() {
let specs: [ParameterBlockSpec; 0] = [];
assert_eq!(joint_coupled_operator_aware_hessian_cost(1234, &specs), 0);
}
#[test]
fn joint_coupled_p_total_sums_block_ncols() {
let specs = [spec_with_ncols(2), spec_with_ncols(2), spec_with_ncols(2)];
let n = 7u64;
let p_total = 6u64;
assert!(!use_joint_matrix_free_path(p_total as usize, n as usize));
assert_eq!(
joint_coupled_operator_aware_hessian_cost(n, &specs),
n * p_total * p_total
);
}
}