use ndarray::Array1;
use crate::gpu::arrow_schur::{
ArrowSchurGpuFailure, solve_arrow_newton_step, solve_arrow_newton_step_dense_reference,
};
use crate::solver::arrow_schur::{ArrowSchurError, ArrowSchurSystem};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DeviceResidentArrowShape {
pub n: usize,
pub p: usize,
pub basis_cols: usize,
pub d: usize,
}
impl DeviceResidentArrowShape {
#[inline]
pub const fn qwen_non_gating() -> Self {
Self {
n: 2_000,
p: 2_048,
basis_cols: 8,
d: 2,
}
}
#[inline]
pub const fn target_len(self) -> usize {
self.n * self.p
}
#[inline]
pub const fn basis_len(self) -> usize {
self.n * self.basis_cols
}
#[inline]
pub const fn row_hessian_len(self) -> usize {
self.n * self.d * self.d
}
#[inline]
pub const fn row_cross_len(self) -> usize {
self.n * self.d * self.p
}
#[inline]
pub const fn row_gradient_len(self) -> usize {
self.n * self.d
}
#[inline]
pub const fn border_hessian_len(self) -> usize {
self.p * self.p
}
}
#[derive(Clone, Debug)]
pub struct DeviceResidentArrowSlabs {
pub row_hessian_slabs: Vec<f64>,
pub row_cross_slabs: Vec<f64>,
pub row_gradient_slabs: Vec<f64>,
pub border_hessian: Vec<f64>,
pub border_gradient: Vec<f64>,
}
#[derive(Clone, Debug)]
pub struct DeviceResidentArrowStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub objective: f64,
pub gradient_norm: f64,
pub log_det_hessian: f64,
pub used_device: bool,
}
#[derive(Debug, Clone)]
pub enum DeviceResidentArrowError {
Shape { reason: String },
Unavailable { reason: String },
Solve { reason: String },
}
impl std::fmt::Display for DeviceResidentArrowError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Shape { reason } | Self::Unavailable { reason } | Self::Solve { reason } => {
f.write_str(reason)
}
}
}
}
impl std::error::Error for DeviceResidentArrowError {}
#[cfg(target_os = "linux")]
pub struct DeviceResidentArrowBuffers {
pub stream: std::sync::Arc<cudarc::driver::CudaStream>,
pub target_x_dev: cudarc::driver::CudaSlice<f64>,
pub basis_values_dev: cudarc::driver::CudaSlice<f64>,
pub gate_activations_dev: cudarc::driver::CudaSlice<f64>,
pub row_hessian_dev: cudarc::driver::CudaSlice<f64>,
pub row_cross_dev: cudarc::driver::CudaSlice<f64>,
pub row_gradient_dev: cudarc::driver::CudaSlice<f64>,
pub border_hessian_dev: cudarc::driver::CudaSlice<f64>,
pub border_gradient_dev: cudarc::driver::CudaSlice<f64>,
pub bytes: usize,
}
pub struct DeviceResidentArrowWorkspace {
shape: DeviceResidentArrowShape,
target_x: Vec<f64>,
basis_values: Vec<f64>,
gate_activations: Vec<f64>,
slabs: DeviceResidentArrowSlabs,
#[cfg(target_os = "linux")]
device: Option<DeviceResidentArrowBuffers>,
}
impl DeviceResidentArrowWorkspace {
pub fn new(
shape: DeviceResidentArrowShape,
target_x: Vec<f64>,
basis_values: Vec<f64>,
gate_activations: Vec<f64>,
slabs: DeviceResidentArrowSlabs,
) -> Result<Self, DeviceResidentArrowError> {
validate_shape(shape, &target_x, &basis_values, &gate_activations, &slabs)?;
#[cfg(target_os = "linux")]
let device =
upload_resident_buffers(shape, &target_x, &basis_values, &gate_activations, &slabs);
Ok(Self {
shape,
target_x,
basis_values,
gate_activations,
slabs,
#[cfg(target_os = "linux")]
device,
})
}
#[inline]
pub const fn shape(&self) -> DeviceResidentArrowShape {
self.shape
}
#[must_use]
pub fn device_resident(&self) -> bool {
#[cfg(target_os = "linux")]
{
self.device.is_some()
}
#[cfg(not(target_os = "linux"))]
{
false
}
}
#[must_use]
pub fn resident_device_bytes(&self) -> usize {
#[cfg(target_os = "linux")]
{
self.device.as_ref().map_or(0, |device| device.bytes)
}
#[cfg(not(target_os = "linux"))]
{
0
}
}
#[must_use]
pub fn host_shadow_bytes(&self) -> usize {
[
self.target_x.len(),
self.basis_values.len(),
self.gate_activations.len(),
self.slabs.row_hessian_slabs.len(),
self.slabs.row_cross_slabs.len(),
self.slabs.row_gradient_slabs.len(),
self.slabs.border_hessian.len(),
self.slabs.border_gradient.len(),
]
.into_iter()
.sum::<usize>()
* std::mem::size_of::<f64>()
}
pub fn one_inner_iteration(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
if !self.device_resident() {
return Err(DeviceResidentArrowError::Unavailable {
reason: "SAE resident inner iteration unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
});
}
let sys = self.to_arrow_system();
solve_arrow_newton_step(&sys, ridge_t, ridge_beta)
.map(|solution| self.finish_step(solution, true))
.map_err(map_gpu_error)
}
pub fn cpu_reference_step(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
let sys = self.to_arrow_system();
solve_arrow_newton_step_dense_reference(&sys, ridge_t, ridge_beta)
.map(|solution| self.finish_step(solution, false))
.map_err(|reason| DeviceResidentArrowError::Solve { reason })
}
pub fn to_arrow_system(&self) -> ArrowSchurSystem {
let shape = self.shape;
let mut sys = ArrowSchurSystem::new(shape.n, shape.d, shape.p);
for i in 0..shape.n {
let h_base = i * shape.d * shape.d;
let b_base = i * shape.d * shape.p;
let g_base = i * shape.d;
for r in 0..shape.d {
for c in 0..shape.d {
sys.rows[i].htt[[r, c]] =
self.slabs.row_hessian_slabs[h_base + r * shape.d + c];
}
sys.rows[i].gt[r] = self.slabs.row_gradient_slabs[g_base + r];
for c in 0..shape.p {
sys.rows[i].htbeta[[r, c]] =
self.slabs.row_cross_slabs[b_base + r * shape.p + c];
}
}
}
for r in 0..shape.p {
sys.gb[r] = self.slabs.border_gradient[r];
for c in 0..shape.p {
sys.hbb[[r, c]] = self.slabs.border_hessian[r * shape.p + c];
}
}
sys.refresh_row_hessian_fingerprint();
sys
}
fn finish_step(
&self,
solution: crate::gpu::arrow_schur::ArrowSchurGpuSolution,
used_device: bool,
) -> DeviceResidentArrowStep {
DeviceResidentArrowStep {
delta_t: solution.delta_t,
delta_beta: solution.delta_beta,
objective: 0.5 * squared_norm(&self.target_x),
gradient_norm: self.gradient_norm(),
log_det_hessian: solution.log_det_hessian,
used_device,
}
}
fn gradient_norm(&self) -> f64 {
let row = squared_norm(&self.slabs.row_gradient_slabs);
let border = squared_norm(&self.slabs.border_gradient);
(row + border).sqrt()
}
pub fn device_fit(
&self,
opts: &DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
if !self.device_resident() {
return Err(DeviceResidentArrowError::Unavailable {
reason: "SAE resident inner loop unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
});
}
self.run_inner_loop(opts, true)
}
pub fn cpu_reference_fit(
&self,
opts: &DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
self.run_inner_loop(opts, false)
}
fn run_inner_loop(
&self,
opts: &DeviceResidentInnerOptions,
on_device: bool,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let t_len = n * d;
let mut t = vec![0.0_f64; t_len];
let mut beta = vec![0.0_f64; p];
let base = self.to_arrow_system();
let half_target_energy = 0.5 * squared_norm(&self.target_x);
let mut ridge_t = opts.initial_ridge_t.max(0.0);
let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
let mut current_objective = self.objective_at(&base, half_target_energy, &t, &beta);
let mut accepted_iters = 0_usize;
let mut total_iters = 0_usize;
let mut converged = false;
let mut last_step = DeviceResidentArrowStep {
delta_t: Array1::zeros(t_len),
delta_beta: Array1::zeros(p),
objective: current_objective,
gradient_norm: 0.0,
log_det_hessian: 0.0,
used_device: on_device,
};
while total_iters < opts.max_iterations {
let residual = self.residual_system(&base, &t, &beta);
let g_norm = arrow_system_gradient_norm(&residual);
let scale = 1.0 + iterate_norm(&t, &beta);
if g_norm / scale < opts.convergence_tolerance {
converged = true;
break;
}
let solution = if on_device {
solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
} else {
solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
.map_err(|reason| DeviceResidentArrowError::Solve { reason })
};
let solution = match solution {
Ok(sol) => sol,
Err(DeviceResidentArrowError::Solve { .. })
| Err(DeviceResidentArrowError::Unavailable { .. }) => {
ridge_t = grow_ridge(ridge_t, opts.lm_grow);
ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
return Err(DeviceResidentArrowError::Solve {
reason: format!(
"SAE resident inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
opts.max_ridge
),
});
}
total_iters += 1;
continue;
}
Err(other) => return Err(other),
};
let predicted_reduction =
crate::solver::arrow_schur::arrow_bare_quadratic_model_reduction(
&residual,
solution.delta_t.view(),
solution.delta_beta.view(),
ridge_t,
ridge_beta,
)
.map_err(|err| DeviceResidentArrowError::Solve {
reason: format!("SAE resident inner loop predicted-reduction failed: {err}"),
})?;
let mut trial_t = t.clone();
let mut trial_beta = beta.clone();
for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
*slot += *dv;
}
for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
*slot += *dv;
}
let trial_objective =
self.objective_at(&base, half_target_energy, &trial_t, &trial_beta);
let objective_scale = current_objective.abs().max(1.0);
let noise_floor = objective_scale * 1e-14;
let actual_reduction = current_objective - trial_objective;
let rho = if predicted_reduction > noise_floor {
actual_reduction / predicted_reduction
} else if actual_reduction >= -noise_floor {
1.0
} else {
-1.0
};
if rho > 0.0 && trial_objective.is_finite() {
t = trial_t;
beta = trial_beta;
current_objective = trial_objective;
ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
last_step = DeviceResidentArrowStep {
delta_t: solution.delta_t,
delta_beta: solution.delta_beta,
objective: current_objective,
gradient_norm: g_norm,
log_det_hessian: solution.log_det_hessian,
used_device: on_device,
};
accepted_iters += 1;
total_iters += 1;
} else {
ridge_t = grow_ridge(ridge_t, opts.lm_grow);
ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
return Err(DeviceResidentArrowError::Solve {
reason: format!(
"SAE resident inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
opts.max_ridge
),
});
}
total_iters += 1;
}
}
Ok(DeviceResidentInnerOutcome {
t: Array1::from_vec(t),
beta: Array1::from_vec(beta),
objective: current_objective,
gradient_norm: last_step.gradient_norm,
log_det_hessian: last_step.log_det_hessian,
iterations: total_iters,
accepted_iterations: accepted_iters,
converged,
used_device: on_device,
})
}
fn objective_at(
&self,
base: &ArrowSchurSystem,
half_target_energy: f64,
t: &[f64],
beta: &[f64],
) -> f64 {
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let mut quad = 0.0_f64;
let mut lin = 0.0_f64;
for i in 0..n {
let t_base = i * d;
for r in 0..d {
let mut htt_t = 0.0_f64;
for c in 0..d {
htt_t += base.rows[i].htt[[r, c]] * t[t_base + c];
}
let mut htb_b = 0.0_f64;
for c in 0..p {
htb_b += base.rows[i].htbeta[[r, c]] * beta[c];
}
quad += t[t_base + r] * (htt_t + 2.0 * htb_b);
lin += base.rows[i].gt[r] * t[t_base + r];
}
}
for r in 0..p {
let mut hbb_b = 0.0_f64;
for c in 0..p {
hbb_b += base.hbb[[r, c]] * beta[c];
}
quad += beta[r] * hbb_b;
lin += base.gb[r] * beta[r];
}
half_target_energy + 0.5 * quad - lin
}
fn residual_system(
&self,
base: &ArrowSchurSystem,
t: &[f64],
beta: &[f64],
) -> ArrowSchurSystem {
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let mut sys = self.to_arrow_system();
for i in 0..n {
let t_base = i * d;
for r in 0..d {
let mut hz = 0.0_f64;
for c in 0..d {
hz += base.rows[i].htt[[r, c]] * t[t_base + c];
}
for c in 0..p {
hz += base.rows[i].htbeta[[r, c]] * beta[c];
}
sys.rows[i].gt[r] = hz - base.rows[i].gt[r];
}
}
for r in 0..p {
let mut hz = 0.0_f64;
for c in 0..p {
hz += base.hbb[[r, c]] * beta[c];
}
for i in 0..n {
let t_base = i * d;
for rr in 0..d {
hz += base.rows[i].htbeta[[rr, r]] * t[t_base + rr];
}
}
sys.gb[r] = hz - base.gb[r];
}
sys.refresh_row_hessian_fingerprint();
sys
}
}
#[derive(Clone, Copy, Debug)]
pub struct DeviceResidentInnerOptions {
pub max_iterations: usize,
pub convergence_tolerance: f64,
pub initial_ridge_t: f64,
pub initial_ridge_beta: f64,
pub lm_grow: f64,
pub lm_shrink: f64,
pub max_ridge: f64,
}
impl Default for DeviceResidentInnerOptions {
fn default() -> Self {
Self {
max_iterations: 16,
convergence_tolerance: 1e-9,
initial_ridge_t: 0.0,
initial_ridge_beta: 0.0,
lm_grow: 4.0,
lm_shrink: 0.5,
max_ridge: 1e9,
}
}
}
#[derive(Clone, Debug)]
pub struct DeviceResidentInnerOutcome {
pub t: Array1<f64>,
pub beta: Array1<f64>,
pub objective: f64,
pub gradient_norm: f64,
pub log_det_hessian: f64,
pub iterations: usize,
pub accepted_iterations: usize,
pub converged: bool,
pub used_device: bool,
}
fn grow_ridge(current: f64, grow: f64) -> f64 {
if current == 0.0 { 1e-6 } else { current * grow }
}
fn arrow_system_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut acc = 0.0_f64;
for row in &sys.rows {
for &v in row.gt.iter() {
acc += v * v;
}
}
for &v in sys.gb.iter() {
acc += v * v;
}
acc.sqrt()
}
fn iterate_norm(t: &[f64], beta: &[f64]) -> f64 {
(squared_norm(t) + squared_norm(beta)).sqrt()
}
fn validate_shape(
shape: DeviceResidentArrowShape,
target_x: &[f64],
basis_values: &[f64],
gate_activations: &[f64],
slabs: &DeviceResidentArrowSlabs,
) -> Result<(), DeviceResidentArrowError> {
let checks = [
("target_x", target_x.len(), shape.target_len()),
("basis_values", basis_values.len(), shape.basis_len()),
(
"gate_activations",
gate_activations.len(),
shape.basis_len(),
),
(
"row_hessian_slabs",
slabs.row_hessian_slabs.len(),
shape.row_hessian_len(),
),
(
"row_cross_slabs",
slabs.row_cross_slabs.len(),
shape.row_cross_len(),
),
(
"row_gradient_slabs",
slabs.row_gradient_slabs.len(),
shape.row_gradient_len(),
),
(
"border_hessian",
slabs.border_hessian.len(),
shape.border_hessian_len(),
),
("border_gradient", slabs.border_gradient.len(), shape.p),
];
for (label, got, want) in checks {
if got != want {
return Err(DeviceResidentArrowError::Shape {
reason: format!(
"SAE resident workspace shape mismatch for {label}: got {got}, expected {want}"
),
});
}
}
if shape.n == 0 || shape.p == 0 || shape.d == 0 || shape.basis_cols == 0 {
return Err(DeviceResidentArrowError::Shape {
reason: "SAE resident workspace requires nonzero n, p, basis_cols, and d".to_string(),
});
}
Ok(())
}
#[cfg(target_os = "linux")]
fn upload_resident_buffers(
shape: DeviceResidentArrowShape,
target_x: &[f64],
basis_values: &[f64],
gate_activations: &[f64],
slabs: &DeviceResidentArrowSlabs,
) -> Option<DeviceResidentArrowBuffers> {
use crate::gpu::linalg::{DispatchOp, route_through_gpu};
let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf {
p: shape.d,
batch: shape.n,
})
.or_else(|| {
route_through_gpu(DispatchOp::Gemm {
m: shape.p,
n: shape.p,
k: shape.n * shape.basis_cols,
})
})?;
let ctx = crate::gpu::runtime::cuda_context_for(runtime.device.ordinal)?;
let stream = ctx.new_stream().ok()?;
let target_x_dev = stream.clone_htod(target_x).ok()?;
let basis_values_dev = stream.clone_htod(basis_values).ok()?;
let gate_activations_dev = stream.clone_htod(gate_activations).ok()?;
let row_hessian_dev = stream.clone_htod(&slabs.row_hessian_slabs).ok()?;
let row_cross_dev = stream.clone_htod(&slabs.row_cross_slabs).ok()?;
let row_gradient_dev = stream.clone_htod(&slabs.row_gradient_slabs).ok()?;
let border_hessian_dev = stream.clone_htod(&slabs.border_hessian).ok()?;
let border_gradient_dev = stream.clone_htod(&slabs.border_gradient).ok()?;
let bytes = [
target_x.len(),
basis_values.len(),
gate_activations.len(),
slabs.row_hessian_slabs.len(),
slabs.row_cross_slabs.len(),
slabs.row_gradient_slabs.len(),
slabs.border_hessian.len(),
slabs.border_gradient.len(),
]
.into_iter()
.sum::<usize>()
* std::mem::size_of::<f64>();
Some(DeviceResidentArrowBuffers {
stream,
target_x_dev,
basis_values_dev,
gate_activations_dev,
row_hessian_dev,
row_cross_dev,
row_gradient_dev,
border_hessian_dev,
border_gradient_dev,
bytes,
})
}
fn map_gpu_error(err: ArrowSchurGpuFailure) -> DeviceResidentArrowError {
match err {
ArrowSchurGpuFailure::Unavailable => DeviceResidentArrowError::Unavailable {
reason: "SAE resident inner iteration unavailable after GPU admission".to_string(),
},
ArrowSchurGpuFailure::RidgeBumpRequired { row, bump } => DeviceResidentArrowError::Solve {
reason: format!("SAE resident inner iteration row {row} requires ridge bump {bump:e}"),
},
ArrowSchurGpuFailure::SchurFactorFailed { reason } => {
DeviceResidentArrowError::Solve { reason }
}
ArrowSchurGpuFailure::GpuRequiresDenseSystem {
had_hbb_matvec,
had_htbeta_matvec,
} => DeviceResidentArrowError::Solve {
reason: format!(
"SAE resident inner iteration requires dense slabs; hbb_matvec={had_hbb_matvec} htbeta_matvec={had_htbeta_matvec}"
),
},
}
}
fn squared_norm(values: &[f64]) -> f64 {
values.iter().map(|v| v * v).sum()
}
impl From<ArrowSchurError> for DeviceResidentArrowError {
fn from(err: ArrowSchurError) -> Self {
Self::Solve {
reason: err.to_string(),
}
}
}
pub fn qwen_non_gating_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
qwen_non_gating_fixture_seeded(0x1017_0003_D3A1_5EED)
}
pub fn qwen_non_gating_fixture_seeded(
seed: u64,
) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
let shape = DeviceResidentArrowShape::qwen_non_gating();
let mut rng = SplitMix64::new(seed);
let mut target_x = vec![0.0_f64; shape.target_len()];
for i in 0..shape.n {
for j in 0..shape.p {
let phase = ((i % 97) as f64) * 0.013 + ((j % 131) as f64) * 0.007;
target_x[i * shape.p + j] = 0.02 * phase.sin() + 0.001 * rng.sample_signed();
}
}
let mut basis_values = vec![0.0_f64; shape.basis_len()];
let mut gate_activations = vec![1.0_f64; shape.basis_len()];
for i in 0..shape.n {
for a in 0..shape.basis_cols {
let phase = ((i + 1) as f64) * ((a + 1) as f64) * 0.003;
basis_values[i * shape.basis_cols + a] = phase.cos();
gate_activations[i * shape.basis_cols + a] = 1.0;
}
}
let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
for i in 0..shape.n {
let mut basis_sum = 0.0_f64;
for a in 0..shape.basis_cols {
basis_sum +=
basis_values[i * shape.basis_cols + a] * gate_activations[i * shape.basis_cols + a];
}
let h_base = i * shape.d * shape.d;
row_hessian_slabs[h_base] = 3.0 + 0.01 * basis_sum.abs();
row_hessian_slabs[h_base + 1] = 0.02 * basis_sum.sin();
row_hessian_slabs[h_base + 2] = row_hessian_slabs[h_base + 1];
row_hessian_slabs[h_base + 3] = 2.5 + 0.01 * basis_sum.abs();
let b_base = i * shape.d * shape.p;
for j in 0..shape.p {
let feature = ((j % 257) as f64) * 0.011;
row_cross_slabs[b_base + j] = 1.0e-4 * basis_sum.sin() * feature.cos();
row_cross_slabs[b_base + shape.p + j] = 1.0e-4 * basis_sum.cos() * feature.sin();
}
let g_base = i * shape.d;
row_gradient_slabs[g_base] = 0.01 * basis_sum.sin();
row_gradient_slabs[g_base + 1] = 0.01 * basis_sum.cos();
}
let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
for r in 0..shape.p {
border_hessian[r * shape.p + r] = 4.0;
if r + 1 < shape.p {
border_hessian[r * shape.p + r + 1] = 0.01;
border_hessian[(r + 1) * shape.p + r] = 0.01;
}
}
let mut border_gradient = vec![0.0_f64; shape.p];
for j in 0..shape.p {
border_gradient[j] = 0.001 * ((j % 193) as f64 * 0.017).sin();
}
DeviceResidentArrowWorkspace::new(
shape,
target_x,
basis_values,
gate_activations,
DeviceResidentArrowSlabs {
row_hessian_slabs,
row_cross_slabs,
row_gradient_slabs,
border_hessian,
border_gradient,
},
)
}
pub struct MultiplexedFit {
pub outcome: DeviceResidentInnerOutcome,
}
pub fn run_resident_fits_multiplexed(
workspaces: Vec<DeviceResidentArrowWorkspace>,
opts: DeviceResidentInnerOptions,
) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String> {
run_resident_fits_multiplexed_with(workspaces, opts, |workspace, opts| {
workspace.device_fit(opts)
})
}
fn run_resident_fits_multiplexed_with<Run>(
workspaces: Vec<DeviceResidentArrowWorkspace>,
opts: DeviceResidentInnerOptions,
run_one: Run,
) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String>
where
Run: Fn(
&DeviceResidentArrowWorkspace,
&DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError>
+ Sync,
{
let rows = crate::solver::topology_selector::run_topology_race_parallel(
workspaces,
move |workspace: DeviceResidentArrowWorkspace| {
run_one(&workspace, &opts).map(|outcome| MultiplexedFit { outcome })
},
)?;
Ok(rows.into_iter().map(|row| row.result).collect())
}
pub fn run_resident_fits_sequential(
workspaces: &[DeviceResidentArrowWorkspace],
opts: &DeviceResidentInnerOptions,
) -> Vec<Result<MultiplexedFit, DeviceResidentArrowError>> {
workspaces
.iter()
.map(|workspace| {
workspace
.device_fit(opts)
.map(|outcome| MultiplexedFit { outcome })
})
.collect()
}
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
const fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn sample_signed(&mut self) -> f64 {
let unit = (self.next_u64() >> 11) as f64 / ((1_u64 << 53) as f64);
2.0 * unit - 1.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn small_fixture(seed: u64) -> DeviceResidentArrowWorkspace {
let shape = DeviceResidentArrowShape {
n: 3,
p: 4,
basis_cols: 2,
d: 2,
};
let mut rng = SplitMix64::new(seed);
let target_x = vec![0.0_f64; shape.target_len()];
let basis_values = vec![0.5_f64; shape.basis_len()];
let gate_activations = vec![1.0_f64; shape.basis_len()];
let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
for i in 0..shape.n {
let h = i * shape.d * shape.d;
row_hessian_slabs[h] = 5.0 + 0.1 * rng.sample_signed();
row_hessian_slabs[h + 1] = 0.05 * rng.sample_signed();
row_hessian_slabs[h + 2] = row_hessian_slabs[h + 1];
row_hessian_slabs[h + 3] = 4.0 + 0.1 * rng.sample_signed();
let b = i * shape.d * shape.p;
for j in 0..shape.p {
row_cross_slabs[b + j] = 0.01 * rng.sample_signed();
row_cross_slabs[b + shape.p + j] = 0.01 * rng.sample_signed();
}
let g = i * shape.d;
row_gradient_slabs[g] = rng.sample_signed();
row_gradient_slabs[g + 1] = rng.sample_signed();
}
let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
for r in 0..shape.p {
border_hessian[r * shape.p + r] = 6.0 + 0.1 * rng.sample_signed();
}
let border_gradient: Vec<f64> = (0..shape.p).map(|_| rng.sample_signed()).collect();
DeviceResidentArrowWorkspace::new(
shape,
target_x,
basis_values,
gate_activations,
DeviceResidentArrowSlabs {
row_hessian_slabs,
row_cross_slabs,
row_gradient_slabs,
border_hessian,
border_gradient,
},
)
.expect("small resident fixture must validate")
}
fn dense_hz(
ws: &DeviceResidentArrowWorkspace,
sys: &ArrowSchurSystem,
) -> (Array2<f64>, Array1<f64>) {
let shape = ws.shape;
let total = shape.n * shape.d + shape.p;
let mut h = Array2::<f64>::zeros((total, total));
let mut g0 = Array1::<f64>::zeros(total);
for i in 0..shape.n {
let base = i * shape.d;
for r in 0..shape.d {
for c in 0..shape.d {
h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
}
for c in 0..shape.p {
let v = sys.rows[i].htbeta[[r, c]];
h[[base + r, shape.n * shape.d + c]] = v;
h[[shape.n * shape.d + c, base + r]] = v;
}
g0[base + r] = sys.rows[i].gt[r];
}
}
for r in 0..shape.p {
for c in 0..shape.p {
h[[shape.n * shape.d + r, shape.n * shape.d + c]] = sys.hbb[[r, c]];
}
g0[shape.n * shape.d + r] = sys.gb[r];
}
(h, g0)
}
#[test]
fn cpu_inner_loop_reaches_quadratic_minimiser() {
let ws = small_fixture(0xABCD_0001);
let opts = DeviceResidentInnerOptions::default();
let outcome = ws.cpu_reference_fit(&opts).expect("cpu fit");
assert!(
outcome.converged,
"inner loop must converge on a PD quadratic"
);
let base = ws.to_arrow_system();
let (h, g0) = dense_hz(&ws, &base);
let total = ws.shape.n * ws.shape.d + ws.shape.p;
let mut z = Array1::<f64>::zeros(total);
for r in 0..ws.shape.n * ws.shape.d {
z[r] = outcome.t[r];
}
for c in 0..ws.shape.p {
z[ws.shape.n * ws.shape.d + c] = outcome.beta[c];
}
let hz = h.dot(&z);
let mut max_resid = 0.0_f64;
for r in 0..total {
max_resid = max_resid.max((hz[r] - g0[r]).abs());
}
assert!(
max_resid < 1e-9,
"inner loop fixed point must solve H z = g0; residual {max_resid:e}"
);
}
#[test]
fn cpu_multiplex_matches_sequential_bit_identical() {
let seeds = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
let opts = DeviceResidentInnerOptions::default();
let seq_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
let sequential: Vec<_> = seq_workspaces
.iter()
.map(|ws| ws.cpu_reference_fit(&opts).expect("seq cpu fit"))
.collect();
let mux_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
let multiplexed = run_resident_fits_multiplexed_with(mux_workspaces, opts, |ws, opts| {
ws.cpu_reference_fit(opts)
})
.expect("multiplexed cpu fits");
assert_eq!(sequential.len(), multiplexed.len());
for (seq, mux) in sequential.iter().zip(multiplexed.iter()) {
let mux = mux.as_ref().expect("mux fit ok");
assert_eq!(seq.t.as_slice(), mux.outcome.t.as_slice());
assert_eq!(seq.beta.as_slice(), mux.outcome.beta.as_slice());
assert_eq!(seq.objective.to_bits(), mux.outcome.objective.to_bits());
}
}
}