use crate::arrow_schur::ArrowSchurSystem;
use crate::gpu_kernels::arrow_schur::{
ArrowSchurGpuFailure, ArrowSchurGpuSolution, solve_arrow_newton_step_dense_reference,
};
#[must_use]
pub fn solve_batched_k1_border(
systems: &[ArrowSchurSystem],
ridge_t: f64,
ridge_beta: f64,
) -> Vec<Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure>> {
if let Some(batched) = try_device_batched_k1(systems) {
return batched;
}
systems
.iter()
.map(|sys| cpu_reference_k1(sys, ridge_t, ridge_beta))
.collect()
}
fn cpu_reference_k1(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
solve_arrow_newton_step_dense_reference(sys, ridge_t, ridge_beta)
.map_err(|reason| ArrowSchurGpuFailure::SchurFactorFailed { reason })
}
fn try_device_batched_k1(
systems: &[ArrowSchurSystem],
) -> Option<Vec<Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure>>> {
if systems.is_empty() {
return None;
}
let runtime = gam_gpu::device_runtime::GpuRuntime::global()?;
let total_rows: usize = systems.iter().map(|s| s.rows.len()).sum();
let mean_k = systems.iter().map(|s| s.k).sum::<usize>() / systems.len();
let max_d = systems.iter().map(|s| s.d).max().unwrap_or(0);
if !runtime
.policy()
.reduced_schur_matvec_should_offload(total_rows, mean_k, max_d, 1)
{
return None;
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn pd_k1_system(n: usize, d: usize, k: usize, seed: f64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(n, d, k);
for (i, row) in sys.rows.iter_mut().enumerate() {
for r in 0..d {
row.htt[[r, r]] = 2.0 + seed;
row.gt[r] = 0.1 * (i as f64 + 1.0) + seed;
for c in 0..k {
row.htbeta[[r, c]] = 0.05 * ((r + c + i) as f64 + 1.0);
}
}
}
for r in 0..k {
sys.hbb[[r, r]] = 2.0 + seed;
sys.gb[r] = 0.2 * (r as f64 + 1.0) + seed;
}
sys
}
fn indefinite_k1_system(n: usize, d: usize, k: usize) -> ArrowSchurSystem {
let mut sys = pd_k1_system(n, d, k, 0.0);
for row in sys.rows.iter_mut() {
for r in 0..d {
row.htt[[r, r]] = -2.0;
}
}
sys
}
fn assert_solution_eq(a: &ArrowSchurGpuSolution, b: &ArrowSchurGpuSolution) {
assert_eq!(a.delta_t.len(), b.delta_t.len());
assert_eq!(a.delta_beta.len(), b.delta_beta.len());
for (x, y) in a.delta_t.iter().zip(b.delta_t.iter()) {
assert!((x - y).abs() < 1e-12, "delta_t mismatch: {x} vs {y}");
}
for (x, y) in a.delta_beta.iter().zip(b.delta_beta.iter()) {
assert!((x - y).abs() < 1e-12, "delta_beta mismatch: {x} vs {y}");
}
}
#[test]
fn empty_class_returns_empty() {
let out = solve_batched_k1_border(&[], 1e-6, 1e-6);
assert!(out.is_empty(), "an empty color class must return no results");
}
#[test]
fn single_atom_matches_dense_reference() {
let sys = pd_k1_system(5, 2, 3, 0.0);
let batched = solve_batched_k1_border(std::slice::from_ref(&sys), 1e-6, 1e-6);
assert_eq!(batched.len(), 1);
let reference = solve_arrow_newton_step_dense_reference(&sys, 1e-6, 1e-6)
.expect("PD reference atom must solve");
let got = batched[0].as_ref().expect("batched single atom must solve");
assert_solution_eq(got, &reference);
}
#[test]
fn class_results_are_positional_and_independent() {
let systems = [
pd_k1_system(4, 2, 2, 0.0),
pd_k1_system(6, 2, 3, 0.5),
pd_k1_system(3, 1, 2, 1.0),
];
let batched = solve_batched_k1_border(&systems, 1e-6, 1e-6);
assert_eq!(batched.len(), systems.len());
for (idx, sys) in systems.iter().enumerate() {
let alone = solve_arrow_newton_step_dense_reference(sys, 1e-6, 1e-6)
.expect("each PD atom must solve on its own");
let in_class = batched[idx].as_ref().expect("each atom must solve in-class");
assert_solution_eq(in_class, &alone);
}
}
#[test]
fn per_atom_decline_is_isolated_never_fatal() {
let systems = [pd_k1_system(4, 2, 2, 0.0), indefinite_k1_system(4, 2, 2)];
let batched = solve_batched_k1_border(&systems, 0.0, 0.0);
assert_eq!(batched.len(), 2);
assert!(batched[0].is_ok(), "the PD atom must still solve in a mixed class");
match &batched[1] {
Err(ArrowSchurGpuFailure::SchurFactorFailed { .. }) => {}
Err(other) => panic!("the indefinite atom must decline as SchurFactorFailed; got {other:?}"),
Ok(_) => panic!("the indefinite atom must decline per-atom, but it solved"),
}
}
}