use ndarray::Array1;
use crate::gpu::kernels::arrow_schur::{
ArrowSchurGpuFailure, solve_arrow_newton_step, solve_arrow_newton_step_dense_reference,
};
use crate::model_types::ExecutionPath;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum InnerSolveMode {
DeviceResident,
DeviceReupload,
CpuReference,
}
impl InnerSolveMode {
#[inline]
const fn execution_path(self) -> ExecutionPath {
match self {
Self::DeviceResident => ExecutionPath::GpuResidentFull,
Self::DeviceReupload => ExecutionPath::GpuReupload,
Self::CpuReference => ExecutionPath::Cpu,
}
}
}
use crate::solver::arrow_schur::{ArrowSchurError, ArrowSchurSystem};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DeviceResidentArrowShape {
pub n: usize,
pub p: usize,
pub basis_cols: usize,
pub d: usize,
}
impl DeviceResidentArrowShape {
#[inline]
pub const fn qwen_non_gating() -> Self {
Self {
n: 2_000,
p: 2_048,
basis_cols: 8,
d: 2,
}
}
#[inline]
pub const fn color_arm() -> Self {
Self {
n: 180,
p: 5_120,
basis_cols: 9,
d: 2,
}
}
#[inline]
pub const fn target_len(self) -> usize {
self.n * self.p
}
#[inline]
pub const fn basis_len(self) -> usize {
self.n * self.basis_cols
}
#[inline]
pub const fn row_hessian_len(self) -> usize {
self.n * self.d * self.d
}
#[inline]
pub const fn row_cross_len(self) -> usize {
self.n * self.d * self.p
}
#[inline]
pub const fn row_gradient_len(self) -> usize {
self.n * self.d
}
#[inline]
pub const fn border_hessian_len(self) -> usize {
self.p * self.p
}
}
#[derive(Clone, Debug)]
pub struct DeviceResidentArrowSlabs {
pub row_hessian_slabs: Vec<f64>,
pub row_cross_slabs: Vec<f64>,
pub row_gradient_slabs: Vec<f64>,
pub border_hessian: Vec<f64>,
pub border_gradient: Vec<f64>,
}
#[derive(Clone, Debug)]
pub struct DeviceResidentArrowStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub objective: f64,
pub gradient_norm: f64,
pub log_det_hessian: f64,
pub execution_path: ExecutionPath,
}
#[derive(Debug, Clone)]
pub enum DeviceResidentArrowError {
Shape { reason: String },
Unavailable { reason: String },
Solve { reason: String },
}
impl std::fmt::Display for DeviceResidentArrowError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Shape { reason } | Self::Unavailable { reason } | Self::Solve { reason } => {
f.write_str(reason)
}
}
}
}
impl std::error::Error for DeviceResidentArrowError {}
#[cfg(target_os = "linux")]
pub struct DeviceResidentArrowBuffers {
pub stream: std::sync::Arc<cudarc::driver::CudaStream>,
pub target_x_dev: cudarc::driver::CudaSlice<f64>,
pub basis_values_dev: cudarc::driver::CudaSlice<f64>,
pub gate_activations_dev: cudarc::driver::CudaSlice<f64>,
pub row_hessian_dev: cudarc::driver::CudaSlice<f64>,
pub row_cross_dev: cudarc::driver::CudaSlice<f64>,
pub row_gradient_dev: cudarc::driver::CudaSlice<f64>,
pub border_hessian_dev: cudarc::driver::CudaSlice<f64>,
pub border_gradient_dev: cudarc::driver::CudaSlice<f64>,
pub bytes: usize,
}
pub struct DeviceResidentArrowWorkspace {
shape: DeviceResidentArrowShape,
target_x: Vec<f64>,
basis_values: Vec<f64>,
gate_activations: Vec<f64>,
slabs: DeviceResidentArrowSlabs,
#[cfg(target_os = "linux")]
device: Option<DeviceResidentArrowBuffers>,
}
impl DeviceResidentArrowWorkspace {
pub fn new(
shape: DeviceResidentArrowShape,
target_x: Vec<f64>,
basis_values: Vec<f64>,
gate_activations: Vec<f64>,
slabs: DeviceResidentArrowSlabs,
) -> Result<Self, DeviceResidentArrowError> {
validate_shape(shape, &target_x, &basis_values, &gate_activations, &slabs)?;
#[cfg(target_os = "linux")]
let device =
upload_resident_buffers(shape, &target_x, &basis_values, &gate_activations, &slabs);
Ok(Self {
shape,
target_x,
basis_values,
gate_activations,
slabs,
#[cfg(target_os = "linux")]
device,
})
}
#[inline]
pub const fn shape(&self) -> DeviceResidentArrowShape {
self.shape
}
#[must_use]
pub fn device_resident(&self) -> bool {
#[cfg(target_os = "linux")]
{
self.device.is_some()
}
#[cfg(not(target_os = "linux"))]
{
false
}
}
#[must_use]
pub fn resident_device_bytes(&self) -> usize {
#[cfg(target_os = "linux")]
{
self.device.as_ref().map_or(0, |device| device.bytes)
}
#[cfg(not(target_os = "linux"))]
{
0
}
}
#[must_use]
fn context_id(&self) -> usize {
usize::from(self.device_resident())
}
#[must_use]
fn frame_upload_bytes(&self) -> usize {
[
self.slabs.row_hessian_slabs.len(),
self.slabs.row_cross_slabs.len(),
self.slabs.row_gradient_slabs.len(),
self.slabs.border_hessian.len(),
self.slabs.border_gradient.len(),
]
.into_iter()
.sum::<usize>()
* std::mem::size_of::<f64>()
}
#[must_use]
pub fn host_shadow_bytes(&self) -> usize {
[
self.target_x.len(),
self.basis_values.len(),
self.gate_activations.len(),
self.slabs.row_hessian_slabs.len(),
self.slabs.row_cross_slabs.len(),
self.slabs.row_gradient_slabs.len(),
self.slabs.border_hessian.len(),
self.slabs.border_gradient.len(),
]
.into_iter()
.sum::<usize>()
* std::mem::size_of::<f64>()
}
pub fn one_inner_iteration(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
if !self.device_resident() {
return Err(DeviceResidentArrowError::Unavailable {
reason: "SAE resident inner iteration unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
});
}
let sys = self.to_arrow_system();
solve_arrow_newton_step(&sys, ridge_t, ridge_beta)
.map(|solution| self.finish_step(solution, ExecutionPath::GpuResidentLinearization))
.map_err(map_gpu_error)
}
pub fn cpu_reference_step(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
let sys = self.to_arrow_system();
solve_arrow_newton_step_dense_reference(&sys, ridge_t, ridge_beta)
.map(|solution| self.finish_step(solution, ExecutionPath::Cpu))
.map_err(|reason| DeviceResidentArrowError::Solve { reason })
}
pub fn to_arrow_system(&self) -> ArrowSchurSystem {
let shape = self.shape;
let mut sys = ArrowSchurSystem::new(shape.n, shape.d, shape.p);
for i in 0..shape.n {
let h_base = i * shape.d * shape.d;
let b_base = i * shape.d * shape.p;
let g_base = i * shape.d;
for r in 0..shape.d {
for c in 0..shape.d {
sys.rows[i].htt[[r, c]] =
self.slabs.row_hessian_slabs[h_base + r * shape.d + c];
}
sys.rows[i].gt[r] = self.slabs.row_gradient_slabs[g_base + r];
for c in 0..shape.p {
sys.rows[i].htbeta[[r, c]] =
self.slabs.row_cross_slabs[b_base + r * shape.p + c];
}
}
}
for r in 0..shape.p {
sys.gb[r] = self.slabs.border_gradient[r];
for c in 0..shape.p {
sys.hbb[[r, c]] = self.slabs.border_hessian[r * shape.p + c];
}
}
sys.refresh_row_hessian_fingerprint();
sys
}
fn finish_step(
&self,
solution: crate::gpu::kernels::arrow_schur::ArrowSchurGpuSolution,
execution_path: ExecutionPath,
) -> DeviceResidentArrowStep {
DeviceResidentArrowStep {
delta_t: solution.delta_t,
delta_beta: solution.delta_beta,
objective: 0.5 * squared_norm(&self.target_x),
gradient_norm: self.gradient_norm(),
log_det_hessian: solution.log_det_hessian,
execution_path,
}
}
fn gradient_norm(&self) -> f64 {
let row = squared_norm(&self.slabs.row_gradient_slabs);
let border = squared_norm(&self.slabs.border_gradient);
(row + border).sqrt()
}
pub fn device_fit(
&self,
opts: &DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
if !self.device_resident() {
return Err(DeviceResidentArrowError::Unavailable {
reason: "SAE resident inner loop unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
});
}
self.run_inner_loop(opts, InnerSolveMode::DeviceResident)
}
pub fn device_reupload_fit(
&self,
opts: &DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
if !self.device_resident() {
return Err(DeviceResidentArrowError::Unavailable {
reason: "SAE re-uploading inner loop unavailable: CUDA runtime did not admit the row-block workload".to_string(),
});
}
self.run_inner_loop(opts, InnerSolveMode::DeviceReupload)
}
pub fn cpu_reference_fit(
&self,
opts: &DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
self.run_inner_loop(opts, InnerSolveMode::CpuReference)
}
fn run_inner_loop(
&self,
opts: &DeviceResidentInnerOptions,
mode: InnerSolveMode,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
let execution_path = mode.execution_path();
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let t_len = n * d;
let mut t = vec![0.0_f64; t_len];
let mut beta = vec![0.0_f64; p];
let base = self.to_arrow_system();
let half_target_energy = 0.5 * squared_norm(&self.target_x);
let mut ridge_t = opts.initial_ridge_t.max(0.0);
let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
let mut resident_frame: Option<(
f64,
f64,
crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle,
)> = None;
let mut current_objective = self.objective_at(&base, half_target_energy, &t, &beta);
let mut accepted_iters = 0_usize;
let mut total_iters = 0_usize;
let mut converged = false;
let mut last_step = DeviceResidentArrowStep {
delta_t: Array1::zeros(t_len),
delta_beta: Array1::zeros(p),
objective: current_objective,
gradient_norm: 0.0,
log_det_hessian: 0.0,
execution_path,
};
while total_iters < opts.max_iterations {
let residual = self.residual_system(&base, &t, &beta);
let g_norm = arrow_system_gradient_norm(&residual);
let scale = 1.0 + iterate_norm(&t, &beta);
if g_norm / scale < opts.convergence_tolerance {
converged = true;
break;
}
let solution = match mode {
InnerSolveMode::DeviceResident => {
let frame_matches = resident_frame
.as_ref()
.is_some_and(|(rt, rb, _)| *rt == ridge_t && *rb == ridge_beta);
let mut frame_build_error: Option<DeviceResidentArrowError> = None;
if !frame_matches {
resident_frame = None;
match crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle::new(
&residual, ridge_t, ridge_beta,
) {
Ok(frame) => {
crate::gpu::profile::telemetry_record_handle_creation(
self.context_id(),
);
crate::gpu::profile::telemetry_record_factorization();
crate::gpu::profile::telemetry_record_h2d(
self.frame_upload_bytes(),
);
resident_frame = Some((ridge_t, ridge_beta, frame));
}
Err(err) => frame_build_error = Some(map_gpu_error(err)),
}
}
match resident_frame.as_ref() {
Some((_, _, frame)) => {
let mut g_t = Vec::with_capacity(n * d);
for row in &residual.rows {
for &v in row.gt.iter() {
g_t.push(v);
}
}
let g_beta: Vec<f64> = residual.gb.iter().copied().collect();
let grad_bytes =
(g_t.len() + g_beta.len()) * std::mem::size_of::<f64>();
crate::gpu::profile::telemetry_record_h2d(grad_bytes);
crate::gpu::profile::telemetry_record_kernel_launch();
crate::gpu::profile::telemetry_record_d2h(
(n * d + p) * std::mem::size_of::<f64>(),
);
frame.solve_gradient(&g_t, &g_beta).map_err(map_gpu_error)
}
None => Err(frame_build_error.unwrap_or_else(|| {
DeviceResidentArrowError::Solve {
reason: "SAE resident frame build declined".to_string(),
}
})),
}
}
InnerSolveMode::DeviceReupload => {
crate::gpu::profile::telemetry_record_handle_creation(self.context_id());
crate::gpu::profile::telemetry_record_factorization();
crate::gpu::profile::telemetry_record_h2d(self.frame_upload_bytes());
crate::gpu::profile::telemetry_record_kernel_launch();
crate::gpu::profile::telemetry_record_d2h(
(n * d + p) * std::mem::size_of::<f64>(),
);
solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
}
InnerSolveMode::CpuReference => {
solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
.map_err(|reason| DeviceResidentArrowError::Solve { reason })
}
};
let solution = match solution {
Ok(sol) => sol,
Err(DeviceResidentArrowError::Solve { .. })
| Err(DeviceResidentArrowError::Unavailable { .. }) => {
ridge_t = grow_ridge(ridge_t, opts.lm_grow);
ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
return Err(DeviceResidentArrowError::Solve {
reason: format!(
"SAE resident inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
opts.max_ridge
),
});
}
total_iters += 1;
continue;
}
Err(other) => return Err(other),
};
let predicted_reduction =
crate::solver::arrow_schur::arrow_bare_quadratic_model_reduction(
&residual,
solution.delta_t.view(),
solution.delta_beta.view(),
ridge_t,
ridge_beta,
)
.map_err(|err| DeviceResidentArrowError::Solve {
reason: format!("SAE resident inner loop predicted-reduction failed: {err}"),
})?;
let mut trial_t = t.clone();
let mut trial_beta = beta.clone();
for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
*slot += *dv;
}
for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
*slot += *dv;
}
let trial_objective =
self.objective_at(&base, half_target_energy, &trial_t, &trial_beta);
let objective_scale = current_objective.abs();
let noise_floor = objective_scale * 1e-14;
let actual_reduction = current_objective - trial_objective;
let rho = if predicted_reduction > noise_floor {
actual_reduction / predicted_reduction
} else if actual_reduction >= -noise_floor {
1.0
} else {
-1.0
};
if rho > 0.0 && trial_objective.is_finite() {
t = trial_t;
beta = trial_beta;
current_objective = trial_objective;
ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
last_step = DeviceResidentArrowStep {
delta_t: solution.delta_t,
delta_beta: solution.delta_beta,
objective: current_objective,
gradient_norm: g_norm,
log_det_hessian: solution.log_det_hessian,
execution_path,
};
accepted_iters += 1;
total_iters += 1;
} else {
ridge_t = grow_ridge(ridge_t, opts.lm_grow);
ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
return Err(DeviceResidentArrowError::Solve {
reason: format!(
"SAE resident inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
opts.max_ridge
),
});
}
total_iters += 1;
}
}
Ok(DeviceResidentInnerOutcome {
t: Array1::from_vec(t),
beta: Array1::from_vec(beta),
objective: current_objective,
gradient_norm: last_step.gradient_norm,
log_det_hessian: last_step.log_det_hessian,
iterations: total_iters,
accepted_iterations: accepted_iters,
converged,
execution_path,
})
}
pub fn device_fit_outer_sequence(
&self,
base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
opts: &DeviceResidentInnerOptions,
) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
if !self.device_resident() {
return Err(DeviceResidentArrowError::Unavailable {
reason: "SAE outer-sequence residency unavailable: CUDA runtime did not admit the row-block workload".to_string(),
});
}
self.run_outer_sequence(
base_gradient_overrides,
opts,
InnerSolveMode::DeviceResident,
)
}
pub fn cpu_reference_outer_sequence(
&self,
base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
opts: &DeviceResidentInnerOptions,
) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
self.run_outer_sequence(base_gradient_overrides, opts, InnerSolveMode::CpuReference)
}
fn run_outer_sequence(
&self,
base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
opts: &DeviceResidentInnerOptions,
mode: InnerSolveMode,
) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let t_len = n * d;
let half_target_energy = 0.5 * squared_norm(&self.target_x);
let mut shared = SharedFrameState::default();
let mut outcomes = Vec::with_capacity(base_gradient_overrides.len());
for (g_t_override, g_beta_override) in base_gradient_overrides {
if g_t_override.len() != t_len || g_beta_override.len() != p {
return Err(DeviceResidentArrowError::Shape {
reason: format!(
"outer-sequence gradient shape mismatch: g_t={} (want {t_len}), g_beta={} (want {p})",
g_t_override.len(),
g_beta_override.len()
),
});
}
let mut base = self.to_arrow_system();
for (i, row) in base.rows.iter_mut().enumerate() {
for r in 0..d {
row.gt[r] = g_t_override[i * d + r];
}
}
for (j, gb) in base.gb.iter_mut().enumerate() {
*gb = g_beta_override[j];
}
base.refresh_row_hessian_fingerprint();
let outcome = self.run_one_outer(&base, half_target_energy, opts, mode, &mut shared)?;
outcomes.push(outcome);
}
Ok(OuterSequenceOutcome {
outers: outcomes,
frame_builds: shared.frame_builds,
})
}
fn run_one_outer(
&self,
base: &ArrowSchurSystem,
half_target_energy: f64,
opts: &DeviceResidentInnerOptions,
mode: InnerSolveMode,
shared: &mut SharedFrameState,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
let execution_path = mode.execution_path();
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let t_len = n * d;
let mut t = vec![0.0_f64; t_len];
let mut beta = vec![0.0_f64; p];
let mut ridge_t = opts.initial_ridge_t.max(0.0);
let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
let mut current_objective = self.objective_at(base, half_target_energy, &t, &beta);
let mut accepted_iters = 0_usize;
let mut total_iters = 0_usize;
let mut converged = false;
let mut last_gradient_norm = 0.0_f64;
let mut last_log_det = 0.0_f64;
while total_iters < opts.max_iterations {
let residual = self.residual_system(base, &t, &beta);
let g_norm = arrow_system_gradient_norm(&residual);
let scale = 1.0 + iterate_norm(&t, &beta);
if g_norm / scale < opts.convergence_tolerance {
converged = true;
break;
}
let solution = match mode {
InnerSolveMode::DeviceResident => {
let frame_matches = shared
.frame
.as_ref()
.is_some_and(|(rt, rb, _)| *rt == ridge_t && *rb == ridge_beta);
let mut frame_build_error: Option<DeviceResidentArrowError> = None;
if !frame_matches {
shared.frame = None;
match crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle::new(
&residual, ridge_t, ridge_beta,
) {
Ok(frame) => {
shared.frame_builds += 1;
crate::gpu::profile::telemetry_record_handle_creation(
self.context_id(),
);
crate::gpu::profile::telemetry_record_factorization();
crate::gpu::profile::telemetry_record_h2d(
self.frame_upload_bytes(),
);
shared.frame = Some((ridge_t, ridge_beta, frame));
}
Err(err) => frame_build_error = Some(map_gpu_error(err)),
}
}
match shared.frame.as_ref() {
Some((_, _, frame)) => {
let mut g_t = Vec::with_capacity(n * d);
for row in &residual.rows {
for &v in row.gt.iter() {
g_t.push(v);
}
}
let g_beta: Vec<f64> = residual.gb.iter().copied().collect();
let grad_bytes =
(g_t.len() + g_beta.len()) * std::mem::size_of::<f64>();
crate::gpu::profile::telemetry_record_h2d(grad_bytes);
crate::gpu::profile::telemetry_record_kernel_launch();
crate::gpu::profile::telemetry_record_d2h(
(n * d + p) * std::mem::size_of::<f64>(),
);
frame.solve_gradient(&g_t, &g_beta).map_err(map_gpu_error)
}
None => Err(frame_build_error.unwrap_or_else(|| {
DeviceResidentArrowError::Solve {
reason: "SAE resident frame build declined".to_string(),
}
})),
}
}
InnerSolveMode::DeviceReupload => {
solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
}
InnerSolveMode::CpuReference => {
solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
.map_err(|reason| DeviceResidentArrowError::Solve { reason })
}
};
let solution = match solution {
Ok(sol) => sol,
Err(DeviceResidentArrowError::Solve { .. })
| Err(DeviceResidentArrowError::Unavailable { .. }) => {
ridge_t = grow_ridge(ridge_t, opts.lm_grow);
ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
return Err(DeviceResidentArrowError::Solve {
reason: format!(
"SAE outer-sequence inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
opts.max_ridge
),
});
}
total_iters += 1;
continue;
}
Err(other) => return Err(other),
};
let predicted_reduction =
crate::solver::arrow_schur::arrow_bare_quadratic_model_reduction(
&residual,
solution.delta_t.view(),
solution.delta_beta.view(),
ridge_t,
ridge_beta,
)
.map_err(|err| DeviceResidentArrowError::Solve {
reason: format!("SAE outer-sequence predicted-reduction failed: {err}"),
})?;
let mut trial_t = t.clone();
let mut trial_beta = beta.clone();
for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
*slot += *dv;
}
for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
*slot += *dv;
}
let trial_objective =
self.objective_at(base, half_target_energy, &trial_t, &trial_beta);
let objective_scale = current_objective.abs();
let noise_floor = objective_scale * 1e-14;
let actual_reduction = current_objective - trial_objective;
let rho = if predicted_reduction > noise_floor {
actual_reduction / predicted_reduction
} else if actual_reduction >= -noise_floor {
1.0
} else {
-1.0
};
if rho > 0.0 && trial_objective.is_finite() {
t = trial_t;
beta = trial_beta;
current_objective = trial_objective;
ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
last_gradient_norm = g_norm;
last_log_det = solution.log_det_hessian;
accepted_iters += 1;
total_iters += 1;
} else {
ridge_t = grow_ridge(ridge_t, opts.lm_grow);
ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
return Err(DeviceResidentArrowError::Solve {
reason: format!(
"SAE outer-sequence inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
opts.max_ridge
),
});
}
total_iters += 1;
}
}
Ok(DeviceResidentInnerOutcome {
t: Array1::from_vec(t),
beta: Array1::from_vec(beta),
objective: current_objective,
gradient_norm: last_gradient_norm,
log_det_hessian: last_log_det,
iterations: total_iters,
accepted_iterations: accepted_iters,
converged,
execution_path,
})
}
fn objective_at(
&self,
base: &ArrowSchurSystem,
half_target_energy: f64,
t: &[f64],
beta: &[f64],
) -> f64 {
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let mut quad = 0.0_f64;
let mut lin = 0.0_f64;
for i in 0..n {
let t_base = i * d;
for r in 0..d {
let mut htt_t = 0.0_f64;
for c in 0..d {
htt_t += base.rows[i].htt[[r, c]] * t[t_base + c];
}
let mut htb_b = 0.0_f64;
for c in 0..p {
htb_b += base.rows[i].htbeta[[r, c]] * beta[c];
}
quad += t[t_base + r] * (htt_t + 2.0 * htb_b);
lin += base.rows[i].gt[r] * t[t_base + r];
}
}
for r in 0..p {
let mut hbb_b = 0.0_f64;
for c in 0..p {
hbb_b += base.hbb[[r, c]] * beta[c];
}
quad += beta[r] * hbb_b;
lin += base.gb[r] * beta[r];
}
half_target_energy + 0.5 * quad - lin
}
fn residual_system(
&self,
base: &ArrowSchurSystem,
t: &[f64],
beta: &[f64],
) -> ArrowSchurSystem {
let n = self.shape.n;
let d = self.shape.d;
let p = self.shape.p;
let mut sys = self.to_arrow_system();
for i in 0..n {
let t_base = i * d;
for r in 0..d {
let mut hz = 0.0_f64;
for c in 0..d {
hz += base.rows[i].htt[[r, c]] * t[t_base + c];
}
for c in 0..p {
hz += base.rows[i].htbeta[[r, c]] * beta[c];
}
sys.rows[i].gt[r] = hz - base.rows[i].gt[r];
}
}
for r in 0..p {
let mut hz = 0.0_f64;
for c in 0..p {
hz += base.hbb[[r, c]] * beta[c];
}
for i in 0..n {
let t_base = i * d;
for rr in 0..d {
hz += base.rows[i].htbeta[[rr, r]] * t[t_base + rr];
}
}
sys.gb[r] = hz - base.gb[r];
}
sys.refresh_row_hessian_fingerprint();
sys
}
}
#[derive(Clone, Copy, Debug)]
pub struct DeviceResidentInnerOptions {
pub max_iterations: usize,
pub convergence_tolerance: f64,
pub initial_ridge_t: f64,
pub initial_ridge_beta: f64,
pub lm_grow: f64,
pub lm_shrink: f64,
pub max_ridge: f64,
}
impl Default for DeviceResidentInnerOptions {
fn default() -> Self {
Self {
max_iterations: 16,
convergence_tolerance: 1e-9,
initial_ridge_t: 0.0,
initial_ridge_beta: 0.0,
lm_grow: 4.0,
lm_shrink: 0.5,
max_ridge: 1e9,
}
}
}
#[derive(Clone, Debug)]
pub struct DeviceResidentInnerOutcome {
pub t: Array1<f64>,
pub beta: Array1<f64>,
pub objective: f64,
pub gradient_norm: f64,
pub log_det_hessian: f64,
pub iterations: usize,
pub accepted_iterations: usize,
pub converged: bool,
pub execution_path: ExecutionPath,
}
#[derive(Clone, Debug)]
pub struct OuterSequenceOutcome {
pub outers: Vec<DeviceResidentInnerOutcome>,
pub frame_builds: usize,
}
#[derive(Default)]
struct SharedFrameState {
frame: Option<(
f64,
f64,
crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle,
)>,
frame_builds: usize,
}
fn grow_ridge(current: f64, grow: f64) -> f64 {
if current == 0.0 { 1e-6 } else { current * grow }
}
fn arrow_system_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut acc = 0.0_f64;
for row in &sys.rows {
for &v in row.gt.iter() {
acc += v * v;
}
}
for &v in sys.gb.iter() {
acc += v * v;
}
acc.sqrt()
}
fn iterate_norm(t: &[f64], beta: &[f64]) -> f64 {
(squared_norm(t) + squared_norm(beta)).sqrt()
}
fn validate_shape(
shape: DeviceResidentArrowShape,
target_x: &[f64],
basis_values: &[f64],
gate_activations: &[f64],
slabs: &DeviceResidentArrowSlabs,
) -> Result<(), DeviceResidentArrowError> {
let checks = [
("target_x", target_x.len(), shape.target_len()),
("basis_values", basis_values.len(), shape.basis_len()),
(
"gate_activations",
gate_activations.len(),
shape.basis_len(),
),
(
"row_hessian_slabs",
slabs.row_hessian_slabs.len(),
shape.row_hessian_len(),
),
(
"row_cross_slabs",
slabs.row_cross_slabs.len(),
shape.row_cross_len(),
),
(
"row_gradient_slabs",
slabs.row_gradient_slabs.len(),
shape.row_gradient_len(),
),
(
"border_hessian",
slabs.border_hessian.len(),
shape.border_hessian_len(),
),
("border_gradient", slabs.border_gradient.len(), shape.p),
];
for (label, got, want) in checks {
if got != want {
return Err(DeviceResidentArrowError::Shape {
reason: format!(
"SAE resident workspace shape mismatch for {label}: got {got}, expected {want}"
),
});
}
}
if shape.n == 0 || shape.p == 0 || shape.d == 0 || shape.basis_cols == 0 {
return Err(DeviceResidentArrowError::Shape {
reason: "SAE resident workspace requires nonzero n, p, basis_cols, and d".to_string(),
});
}
Ok(())
}
#[cfg(target_os = "linux")]
fn upload_resident_buffers(
shape: DeviceResidentArrowShape,
target_x: &[f64],
basis_values: &[f64],
gate_activations: &[f64],
slabs: &DeviceResidentArrowSlabs,
) -> Option<DeviceResidentArrowBuffers> {
use crate::gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf {
p: shape.d,
batch: shape.n,
})
.or_else(|| {
route_through_gpu(DispatchOp::Gemm {
m: shape.p,
n: shape.p,
k: shape.n * shape.basis_cols,
})
})?;
let ctx = crate::gpu::device_runtime::cuda_context_for(runtime.device.ordinal)?;
let stream = ctx.new_stream().ok()?;
let target_x_dev = stream.clone_htod(target_x).ok()?;
let basis_values_dev = stream.clone_htod(basis_values).ok()?;
let gate_activations_dev = stream.clone_htod(gate_activations).ok()?;
let row_hessian_dev = stream.clone_htod(&slabs.row_hessian_slabs).ok()?;
let row_cross_dev = stream.clone_htod(&slabs.row_cross_slabs).ok()?;
let row_gradient_dev = stream.clone_htod(&slabs.row_gradient_slabs).ok()?;
let border_hessian_dev = stream.clone_htod(&slabs.border_hessian).ok()?;
let border_gradient_dev = stream.clone_htod(&slabs.border_gradient).ok()?;
let bytes = [
target_x.len(),
basis_values.len(),
gate_activations.len(),
slabs.row_hessian_slabs.len(),
slabs.row_cross_slabs.len(),
slabs.row_gradient_slabs.len(),
slabs.border_hessian.len(),
slabs.border_gradient.len(),
]
.into_iter()
.sum::<usize>()
* std::mem::size_of::<f64>();
Some(DeviceResidentArrowBuffers {
stream,
target_x_dev,
basis_values_dev,
gate_activations_dev,
row_hessian_dev,
row_cross_dev,
row_gradient_dev,
border_hessian_dev,
border_gradient_dev,
bytes,
})
}
fn map_gpu_error(err: ArrowSchurGpuFailure) -> DeviceResidentArrowError {
match err {
ArrowSchurGpuFailure::Unavailable => DeviceResidentArrowError::Unavailable {
reason: "SAE resident inner iteration unavailable after GPU admission".to_string(),
},
ArrowSchurGpuFailure::RidgeBumpRequired { row, bump } => DeviceResidentArrowError::Solve {
reason: format!("SAE resident inner iteration row {row} requires ridge bump {bump:e}"),
},
ArrowSchurGpuFailure::SchurFactorFailed { reason } => {
DeviceResidentArrowError::Solve { reason }
}
ArrowSchurGpuFailure::GpuRequiresDenseSystem {
had_hbb_matvec,
had_htbeta_matvec,
} => DeviceResidentArrowError::Solve {
reason: format!(
"SAE resident inner iteration requires dense slabs; hbb_matvec={had_hbb_matvec} htbeta_matvec={had_htbeta_matvec}"
),
},
}
}
fn squared_norm(values: &[f64]) -> f64 {
values.iter().map(|v| v * v).sum()
}
impl From<ArrowSchurError> for DeviceResidentArrowError {
fn from(err: ArrowSchurError) -> Self {
Self::Solve {
reason: err.to_string(),
}
}
}
pub fn qwen_non_gating_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
qwen_non_gating_fixture_seeded(0x1017_0003_D3A1_5EED)
}
pub fn qwen_non_gating_fixture_seeded(
seed: u64,
) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
fixture_for_shape_seeded(DeviceResidentArrowShape::qwen_non_gating(), seed)
}
pub fn color_arm_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
fixture_for_shape_seeded(DeviceResidentArrowShape::color_arm(), 0x1017_C010_2A12_5EED)
}
fn fixture_for_shape_seeded(
shape: DeviceResidentArrowShape,
seed: u64,
) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
if shape.d == 0 {
return Err(DeviceResidentArrowError::Shape {
reason: "fixture_for_shape_seeded requires d >= 1".to_string(),
});
}
let d = shape.d;
let mut rng = SplitMix64::new(seed);
let mut target_x = vec![0.0_f64; shape.target_len()];
for i in 0..shape.n {
for j in 0..shape.p {
let phase = ((i % 97) as f64) * 0.013 + ((j % 131) as f64) * 0.007;
target_x[i * shape.p + j] = 0.02 * phase.sin() + 0.001 * rng.sample_signed();
}
}
let mut basis_values = vec![0.0_f64; shape.basis_len()];
let mut gate_activations = vec![1.0_f64; shape.basis_len()];
for i in 0..shape.n {
for a in 0..shape.basis_cols {
let phase = ((i + 1) as f64) * ((a + 1) as f64) * 0.003;
basis_values[i * shape.basis_cols + a] = phase.cos();
gate_activations[i * shape.basis_cols + a] = 1.0;
}
}
let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
for i in 0..shape.n {
let mut basis_sum = 0.0_f64;
for a in 0..shape.basis_cols {
basis_sum +=
basis_values[i * shape.basis_cols + a] * gate_activations[i * shape.basis_cols + a];
}
let h_base = i * d * d;
for r in 0..d {
for c in 0..d {
let v = if r == c {
3.0 + 0.01 * basis_sum.abs() + 0.1 * (r as f64)
} else {
0.02 * (basis_sum + (r + c) as f64).sin() / (d as f64)
};
row_hessian_slabs[h_base + r * d + c] = v;
}
}
for r in 0..d {
for c in 0..r {
let avg = 0.5
* (row_hessian_slabs[h_base + r * d + c]
+ row_hessian_slabs[h_base + c * d + r]);
row_hessian_slabs[h_base + r * d + c] = avg;
row_hessian_slabs[h_base + c * d + r] = avg;
}
}
let b_base = i * d * shape.p;
let g_base = i * d;
for r in 0..d {
for j in 0..shape.p {
let feature = ((j % 257) as f64) * 0.011;
row_cross_slabs[b_base + r * shape.p + j] =
1.0e-4 * (basis_sum + r as f64).sin() * feature.cos();
}
row_gradient_slabs[g_base + r] = 0.01 * (basis_sum + r as f64).sin();
}
}
let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
for r in 0..shape.p {
border_hessian[r * shape.p + r] = 4.0;
if r + 1 < shape.p {
border_hessian[r * shape.p + r + 1] = 0.01;
border_hessian[(r + 1) * shape.p + r] = 0.01;
}
}
let mut border_gradient = vec![0.0_f64; shape.p];
for j in 0..shape.p {
border_gradient[j] = 0.001 * ((j % 193) as f64 * 0.017).sin();
}
DeviceResidentArrowWorkspace::new(
shape,
target_x,
basis_values,
gate_activations,
DeviceResidentArrowSlabs {
row_hessian_slabs,
row_cross_slabs,
row_gradient_slabs,
border_hessian,
border_gradient,
},
)
}
pub struct MultiplexedFit {
pub outcome: DeviceResidentInnerOutcome,
}
pub fn run_resident_fits_multiplexed(
workspaces: Vec<DeviceResidentArrowWorkspace>,
opts: DeviceResidentInnerOptions,
) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String> {
run_resident_fits_multiplexed_with(workspaces, opts, |workspace, opts| {
workspace.device_fit(opts)
})
}
fn run_resident_fits_multiplexed_with<Run>(
workspaces: Vec<DeviceResidentArrowWorkspace>,
opts: DeviceResidentInnerOptions,
run_one: Run,
) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String>
where
Run: Fn(
&DeviceResidentArrowWorkspace,
&DeviceResidentInnerOptions,
) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError>
+ Sync,
{
let rows = crate::solver::topology_selector::run_topology_race_parallel(
workspaces,
move |workspace: DeviceResidentArrowWorkspace| {
run_one(&workspace, &opts).map(|outcome| MultiplexedFit { outcome })
},
)?;
Ok(rows.into_iter().map(|row| row.result).collect())
}
pub fn run_resident_fits_sequential(
workspaces: &[DeviceResidentArrowWorkspace],
opts: &DeviceResidentInnerOptions,
) -> Vec<Result<MultiplexedFit, DeviceResidentArrowError>> {
workspaces
.iter()
.map(|workspace| {
workspace
.device_fit(opts)
.map(|outcome| MultiplexedFit { outcome })
})
.collect()
}
#[derive(Clone, Copy, Debug)]
pub struct SweepVariant {
pub dim: DeviceResidentArrowShape,
pub seed: u64,
}
#[derive(Clone, Copy, Debug)]
pub struct SweepThroughput {
pub fits: usize,
pub succeeded: usize,
pub wall_seconds: f64,
pub fits_per_second: f64,
}
pub fn build_sweep_workspaces(
variants: &[SweepVariant],
) -> Result<Vec<DeviceResidentArrowWorkspace>, DeviceResidentArrowError> {
variants
.iter()
.map(|v| fixture_for_shape_seeded(v.dim, v.seed))
.collect()
}
pub fn run_variant_sweep_multiplexed(
variants: &[SweepVariant],
opts: DeviceResidentInnerOptions,
) -> Result<
(
Vec<Result<MultiplexedFit, DeviceResidentArrowError>>,
SweepThroughput,
),
String,
> {
let workspaces = build_sweep_workspaces(variants).map_err(|e| e.to_string())?;
run_battery_sweep_multiplexed(workspaces, opts)
}
pub fn run_battery_sweep_multiplexed(
workspaces: Vec<DeviceResidentArrowWorkspace>,
opts: DeviceResidentInnerOptions,
) -> Result<
(
Vec<Result<MultiplexedFit, DeviceResidentArrowError>>,
SweepThroughput,
),
String,
> {
let fits = workspaces.len();
let start = std::time::Instant::now();
let results = run_resident_fits_multiplexed(workspaces, opts)?;
let wall_seconds = start.elapsed().as_secs_f64();
let succeeded = results.iter().filter(|r| r.is_ok()).count();
let throughput = SweepThroughput {
fits,
succeeded,
wall_seconds,
fits_per_second: (fits as f64) / wall_seconds.max(1e-9),
};
Ok((results, throughput))
}
#[must_use]
pub fn color_arm_variant_matrix() -> Vec<SweepVariant> {
let topologies = ["euclidean", "circle", "torus", "sphere"];
let mut variants = Vec::with_capacity(4 * topologies.len() * 2);
for k in 1..=4u64 {
for (t_idx, _topology) in topologies.iter().enumerate() {
for &(d, basis_cols, basis_tag) in &[(2usize, 8usize, 0u64), (1usize, 2usize, 1u64)] {
let mut dim = DeviceResidentArrowShape::color_arm();
dim.d = d;
dim.basis_cols = basis_cols;
let seed = 0x1017_C010_0000_0000 ^ (k << 16) ^ ((t_idx as u64) << 8) ^ basis_tag;
variants.push(SweepVariant { dim, seed });
}
}
}
variants
}
pub fn assert_sweep_parity_vs_sequential(
variants: &[SweepVariant],
opts: &DeviceResidentInnerOptions,
multiplexed: &[Result<MultiplexedFit, DeviceResidentArrowError>],
) -> Result<SweepThroughput, String> {
let workspaces = build_sweep_workspaces(variants).map_err(|e| e.to_string())?;
let start = std::time::Instant::now();
let sequential = run_resident_fits_sequential(&workspaces, opts);
let wall_seconds = start.elapsed().as_secs_f64();
if sequential.len() != multiplexed.len() {
return Err(format!(
"sweep parity: length mismatch seq={} mux={}",
sequential.len(),
multiplexed.len()
));
}
for (idx, (seq, mux)) in sequential.iter().zip(multiplexed.iter()).enumerate() {
match (seq, mux) {
(Ok(s), Ok(m)) => {
if s.outcome.t.as_slice() != m.outcome.t.as_slice()
|| s.outcome.beta.as_slice() != m.outcome.beta.as_slice()
|| s.outcome.objective.to_bits() != m.outcome.objective.to_bits()
{
return Err(format!(
"sweep parity: fit {idx} multiplexed result differs from sequential"
));
}
}
(Err(_), Err(_)) => {}
_ => {
return Err(format!(
"sweep parity: fit {idx} success/failure disagrees seq-vs-mux"
));
}
}
}
let fits = variants.len();
let succeeded = sequential.iter().filter(|r| r.is_ok()).count();
Ok(SweepThroughput {
fits,
succeeded,
wall_seconds,
fits_per_second: (fits as f64) / wall_seconds.max(1e-9),
})
}
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
const fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
crate::linalg::utils::splitmix64(&mut self.state)
}
fn sample_signed(&mut self) -> f64 {
let unit = (self.next_u64() >> 11) as f64 / ((1_u64 << 53) as f64);
2.0 * unit - 1.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn small_fixture(seed: u64) -> DeviceResidentArrowWorkspace {
let shape = DeviceResidentArrowShape {
n: 8,
p: 4,
basis_cols: 2,
d: 2,
};
let mut rng = SplitMix64::new(seed);
let target_x = vec![0.0_f64; shape.target_len()];
let basis_values = vec![0.5_f64; shape.basis_len()];
let gate_activations = vec![1.0_f64; shape.basis_len()];
let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
for i in 0..shape.n {
let h = i * shape.d * shape.d;
row_hessian_slabs[h] = 5.0 + 0.1 * rng.sample_signed();
row_hessian_slabs[h + 1] = 0.05 * rng.sample_signed();
row_hessian_slabs[h + 2] = row_hessian_slabs[h + 1];
row_hessian_slabs[h + 3] = 4.0 + 0.1 * rng.sample_signed();
let b = i * shape.d * shape.p;
for j in 0..shape.p {
row_cross_slabs[b + j] = 0.01 * rng.sample_signed();
row_cross_slabs[b + shape.p + j] = 0.01 * rng.sample_signed();
}
let g = i * shape.d;
row_gradient_slabs[g] = rng.sample_signed();
row_gradient_slabs[g + 1] = rng.sample_signed();
}
let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
for r in 0..shape.p {
border_hessian[r * shape.p + r] = 6.0 + 0.1 * rng.sample_signed();
}
let border_gradient: Vec<f64> = (0..shape.p).map(|_| rng.sample_signed()).collect();
DeviceResidentArrowWorkspace::new(
shape,
target_x,
basis_values,
gate_activations,
DeviceResidentArrowSlabs {
row_hessian_slabs,
row_cross_slabs,
row_gradient_slabs,
border_hessian,
border_gradient,
},
)
.expect("small resident fixture must validate")
}
fn dense_hz(
ws: &DeviceResidentArrowWorkspace,
sys: &ArrowSchurSystem,
) -> (Array2<f64>, Array1<f64>) {
let shape = ws.shape;
let total = shape.n * shape.d + shape.p;
let mut h = Array2::<f64>::zeros((total, total));
let mut g0 = Array1::<f64>::zeros(total);
for i in 0..shape.n {
let base = i * shape.d;
for r in 0..shape.d {
for c in 0..shape.d {
h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
}
for c in 0..shape.p {
let v = sys.rows[i].htbeta[[r, c]];
h[[base + r, shape.n * shape.d + c]] = v;
h[[shape.n * shape.d + c, base + r]] = v;
}
g0[base + r] = sys.rows[i].gt[r];
}
}
for r in 0..shape.p {
for c in 0..shape.p {
h[[shape.n * shape.d + r, shape.n * shape.d + c]] = sys.hbb[[r, c]];
}
g0[shape.n * shape.d + r] = sys.gb[r];
}
(h, g0)
}
#[test]
fn cpu_inner_loop_reaches_quadratic_minimiser() {
let ws = small_fixture(0xABCD_0001);
let opts = DeviceResidentInnerOptions::default();
let outcome = ws.cpu_reference_fit(&opts).expect("cpu fit");
assert!(
outcome.converged,
"inner loop must converge on a PD quadratic"
);
let base = ws.to_arrow_system();
let (h, g0) = dense_hz(&ws, &base);
let total = ws.shape.n * ws.shape.d + ws.shape.p;
let mut z = Array1::<f64>::zeros(total);
for r in 0..ws.shape.n * ws.shape.d {
z[r] = outcome.t[r];
}
for c in 0..ws.shape.p {
z[ws.shape.n * ws.shape.d + c] = outcome.beta[c];
}
let hz = h.dot(&z);
let mut max_resid = 0.0_f64;
for r in 0..total {
max_resid = max_resid.max((hz[r] - g0[r]).abs());
}
assert!(
max_resid < 1e-9,
"inner loop fixed point must solve H z = g0; residual {max_resid:e}"
);
}
#[test]
fn cpu_multiplex_matches_sequential_bit_identical() {
let seeds = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
let opts = DeviceResidentInnerOptions::default();
let seq_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
let sequential: Vec<_> = seq_workspaces
.iter()
.map(|ws| ws.cpu_reference_fit(&opts).expect("seq cpu fit"))
.collect();
let mux_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
let multiplexed = run_resident_fits_multiplexed_with(mux_workspaces, opts, |ws, opts| {
ws.cpu_reference_fit(opts)
})
.expect("multiplexed cpu fits");
assert_eq!(sequential.len(), multiplexed.len());
for (seq, mux) in sequential.iter().zip(multiplexed.iter()) {
let mux = mux.as_ref().expect("mux fit ok");
assert_eq!(seq.t.as_slice(), mux.outcome.t.as_slice());
assert_eq!(seq.beta.as_slice(), mux.outcome.beta.as_slice());
assert_eq!(seq.objective.to_bits(), mux.outcome.objective.to_bits());
}
}
#[test]
fn device_resident_fit_matches_cpu_reference() {
let ws = small_fixture(0x5AE_1017);
let opts = DeviceResidentInnerOptions::default();
let cpu = ws.cpu_reference_fit(&opts).expect("cpu reference fit");
assert!(cpu.converged, "cpu reference must converge on PD quadratic");
let base = ws.to_arrow_system();
if ws.device_resident() {
let dev = ws.device_fit(&opts).expect("device resident fit");
assert_eq!(
dev.execution_path,
ExecutionPath::GpuResidentFull,
"device_fit must report the full device-resident execution path"
);
assert!(dev.converged, "device resident loop must converge");
let t_scale = cpu.t.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
let b_scale = cpu.beta.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
let mut max_rel = 0.0_f64;
for (a, b) in dev.t.iter().zip(cpu.t.iter()) {
max_rel = max_rel.max((a - b).abs() / t_scale);
}
for (a, b) in dev.beta.iter().zip(cpu.beta.iter()) {
max_rel = max_rel.max((a - b).abs() / b_scale);
}
assert!(
max_rel < 1e-9,
"resident device fit must match CPU reference (rel {max_rel:e})"
);
let frame = crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle::new(
&base,
opts.initial_ridge_t,
opts.initial_ridge_beta,
)
.expect("resident frame must build on CUDA host");
let g_t: Vec<f64> = base
.rows
.iter()
.flat_map(|r| r.gt.iter().copied())
.collect();
let g_beta: Vec<f64> = base.gb.iter().copied().collect();
let resident_sol = frame
.solve_gradient(&g_t, &g_beta)
.expect("resident single-gradient solve");
let full = crate::gpu::kernels::arrow_schur::solve_arrow_newton_step_dense_reference(
&base,
opts.initial_ridge_t,
opts.initial_ridge_beta,
)
.expect("dense reference single solve");
let mut max_step_rel = 0.0_f64;
let step_scale = full
.delta_t
.iter()
.chain(full.delta_beta.iter())
.fold(1.0_f64, |m, &v| m.max(v.abs()));
for (a, b) in resident_sol.delta_t.iter().zip(full.delta_t.iter()) {
max_step_rel = max_step_rel.max((a - b).abs() / step_scale);
}
for (a, b) in resident_sol.delta_beta.iter().zip(full.delta_beta.iter()) {
max_step_rel = max_step_rel.max((a - b).abs() / step_scale);
}
assert!(
max_step_rel < 1e-9,
"resident solve_gradient must match full dense reference step (rel {max_step_rel:e})"
);
let reup = ws
.device_reupload_fit(&opts)
.expect("device re-uploading fit");
assert_eq!(
reup.execution_path,
ExecutionPath::GpuReupload,
"device_reupload_fit must report the re-uploading device path"
);
assert!(reup.converged, "re-uploading loop must converge");
let mut max_reup_rel = 0.0_f64;
for (a, b) in reup.t.iter().zip(cpu.t.iter()) {
max_reup_rel = max_reup_rel.max((a - b).abs() / t_scale);
}
for (a, b) in reup.beta.iter().zip(cpu.beta.iter()) {
max_reup_rel = max_reup_rel.max((a - b).abs() / b_scale);
}
assert!(
max_reup_rel < 1e-9,
"re-uploading GPU fit must match CPU reference (rel {max_reup_rel:e})"
);
} else {
assert!(
crate::gpu::device_runtime::GpuRuntime::global().is_none(),
"device_resident() is false on a host WITH a CUDA runtime present, \
despite a floor-clearing fixture (batch=8): the resident device \
buffers failed to bind — a real device fault, not a CPU-only skip."
);
let dev = ws.device_fit(&opts);
assert!(
matches!(dev, Err(DeviceResidentArrowError::Unavailable { .. })),
"device_fit must report Unavailable on a CPU-only host, got {dev:?}"
);
let reup = ws.device_reupload_fit(&opts);
assert!(
matches!(reup, Err(DeviceResidentArrowError::Unavailable { .. })),
"device_reupload_fit must report Unavailable on a CPU-only host, got {reup:?}"
);
let frame = crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle::new(
&base,
opts.initial_ridge_t,
opts.initial_ridge_beta,
);
assert!(
frame.is_err(),
"resident frame construction must decline on a CPU-only host"
);
}
}
#[test]
fn resident_inner_solve_matches_production_arrow_core() {
use crate::solver::arrow_schur::{ArrowSolveOptions, solve_arrow_newton_step_core};
let ws = small_fixture(0x1017_F17);
let opts = DeviceResidentInnerOptions::default();
let resident = ws.cpu_reference_fit(&opts).expect("resident cpu fit");
assert!(
resident.converged,
"resident reference must converge on the PD quadratic"
);
let sys = ws.to_arrow_system();
let (delta_t, delta_beta, _diag) = solve_arrow_newton_step_core(
&sys,
opts.initial_ridge_t,
opts.initial_ridge_beta,
&ArrowSolveOptions::direct(),
)
.expect("production arrow-core solve");
let t_scale = resident.t.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
let b_scale = resident.beta.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
let mut max_rel_t = 0.0_f64;
let mut worst_t: Option<(usize, f64, f64)> = None;
for (i, (prod, res)) in delta_t.iter().zip(resident.t.iter()).enumerate() {
let rel = (prod + res).abs() / t_scale;
if rel > max_rel_t {
max_rel_t = rel;
worst_t = Some((i, *prod, *res));
}
}
let mut max_rel_b = 0.0_f64;
let mut worst_b: Option<(usize, f64, f64)> = None;
for (i, (prod, res)) in delta_beta.iter().zip(resident.beta.iter()).enumerate() {
let rel = (prod + res).abs() / b_scale;
if rel > max_rel_b {
max_rel_b = rel;
worst_b = Some((i, *prod, *res));
}
}
let max_rel = max_rel_t.max(max_rel_b);
assert!(
max_rel < 1e-9,
"production arrow-core Newton step must be −(resident converged fit) on \
the same quadratic; wiring the device seam into the SAE inner loop must \
not change the system being solved. rel_t={max_rel_t:e} (worst {worst_t:?}: \
Δt+t* must be 0), rel_beta={max_rel_b:e} (worst {worst_b:?}: Δβ+β* must \
be 0). A t-only gap implicates the per-row factor / row-gradient \
assembly; a β-only gap the border Schur path."
);
}
#[test]
fn outer_sequence_reuses_frame_and_matches_independent() {
let ws = super::color_arm_fixture().expect("color_arm fixture");
let opts = DeviceResidentInnerOptions::default();
let n = ws.shape.n;
let d = ws.shape.d;
let p = ws.shape.p;
let outers: Vec<(Vec<f64>, Vec<f64>)> = (0..3)
.map(|s| {
let g_t: Vec<f64> = (0..n * d)
.map(|i| 0.01 * (((i + 3 * s) as f64) * 0.002).sin())
.collect();
let g_beta: Vec<f64> = (0..p)
.map(|j| 0.001 * (((j + 11 * s) as f64) * 0.0009).cos())
.collect();
(g_t, g_beta)
})
.collect();
let independent = ws
.cpu_reference_outer_sequence(&outers, &opts)
.expect("cpu reference outer sequence");
assert_eq!(independent.outers.len(), outers.len());
if ws.device_resident() {
let shared = ws
.device_fit_outer_sequence(&outers, &opts)
.expect("device outer sequence");
assert_eq!(
shared.frame_builds,
1,
"across-outer residency must build the resident frame exactly once \
for an unchanged operator (got {} builds over {} outers) — a count \
> 1 means the frame was needlessly re-factored per outer",
shared.frame_builds,
outers.len()
);
for (idx, (sh, ind)) in shared
.outers
.iter()
.zip(independent.outers.iter())
.enumerate()
{
let scale = ind
.t
.iter()
.chain(ind.beta.iter())
.fold(1.0_f64, |m, &v| m.max(v.abs()));
let mut max_rel = 0.0_f64;
for (a, b) in sh.t.iter().zip(ind.t.iter()) {
max_rel = max_rel.max((a - b).abs() / scale);
}
for (a, b) in sh.beta.iter().zip(ind.beta.iter()) {
max_rel = max_rel.max((a - b).abs() / scale);
}
assert!(
max_rel < 1e-9,
"outer {idx}: across-outer-shared frame must match independent fit \
(rel {max_rel:e})"
);
}
println!(
"[#1017 outer-seq color_arm] outers={} frame_builds={} (across-outer factor \
amortized) parity<1e-9 OK",
outers.len(),
shared.frame_builds
);
} else {
println!(
"[#1017 outer-seq color_arm] no CUDA device — across-outer residency skipped; \
run on the GPU node to assert frame_builds==1 + device parity"
);
}
}
#[test]
fn gpu_residency_per_solve_bench() {
use std::time::Instant;
const N_SOLVES: usize = 24;
for (label, ws) in [
("color_arm", super::color_arm_fixture()),
("qwen_non_gating", super::qwen_non_gating_fixture()),
] {
let ws = ws.expect("bench fixture must validate");
let base = ws.to_arrow_system();
let n = ws.shape.n;
let d = ws.shape.d;
let p = ws.shape.p;
let gradients: Vec<(Vec<f64>, Vec<f64>)> = (0..N_SOLVES)
.map(|s| {
let g_t: Vec<f64> =
(0..n * d).map(|i| ((i + s) as f64 * 0.001).sin()).collect();
let g_beta: Vec<f64> = (0..p)
.map(|j| ((j + 7 * s) as f64 * 0.0007).cos())
.collect();
(g_t, g_beta)
})
.collect();
if !ws.device_resident() {
println!(
"[#1017 per-solve {label}] no CUDA device — {N_SOLVES} solves skipped; \
run on the GPU node for the across-iteration residency speedup"
);
continue;
}
let t_build = Instant::now();
let frame =
crate::gpu::kernels::arrow_schur::ResidentArrowFrameHandle::new(&base, 0.0, 0.0)
.expect("resident frame must build on CUDA host");
let frame_build_ms = t_build.elapsed().as_secs_f64() * 1e3;
let _ = frame
.solve_gradient(&gradients[0].0, &gradients[0].1)
.expect("resident warm-up solve");
{
let mut sys = ws.to_arrow_system();
for (i, row) in sys.rows.iter_mut().enumerate() {
for r in 0..d {
row.gt[r] = gradients[0].0[i * d + r];
}
}
for (j, gb) in sys.gb.iter_mut().enumerate() {
*gb = gradients[0].1[j];
}
sys.refresh_row_hessian_fingerprint();
let _ = crate::gpu::kernels::arrow_schur::solve_arrow_newton_step(&sys, 0.0, 0.0)
.expect("reupload warm-up solve");
}
let t_res = Instant::now();
let mut resident_steps = Vec::with_capacity(N_SOLVES);
for (g_t, g_beta) in &gradients {
resident_steps.push(
frame
.solve_gradient(g_t, g_beta)
.expect("resident solve_gradient"),
);
}
let resident_ms = t_res.elapsed().as_secs_f64() * 1e3;
let t_reup = Instant::now();
let mut reupload_steps = Vec::with_capacity(N_SOLVES);
for (g_t, g_beta) in &gradients {
let mut sys = ws.to_arrow_system();
for (i, row) in sys.rows.iter_mut().enumerate() {
for r in 0..d {
row.gt[r] = g_t[i * d + r];
}
}
for (j, gb) in sys.gb.iter_mut().enumerate() {
*gb = g_beta[j];
}
sys.refresh_row_hessian_fingerprint();
reupload_steps.push(
crate::gpu::kernels::arrow_schur::solve_arrow_newton_step(&sys, 0.0, 0.0)
.expect("reupload solve_arrow_newton_step"),
);
}
let reupload_ms = t_reup.elapsed().as_secs_f64() * 1e3;
let mut max_rel = 0.0_f64;
for (rs, us) in resident_steps.iter().zip(reupload_steps.iter()) {
let scale = us
.delta_t
.iter()
.chain(us.delta_beta.iter())
.fold(1.0_f64, |m, &v| m.max(v.abs()));
for (a, b) in rs.delta_t.iter().zip(us.delta_t.iter()) {
max_rel = max_rel.max((a - b).abs() / scale);
}
for (a, b) in rs.delta_beta.iter().zip(us.delta_beta.iter()) {
max_rel = max_rel.max((a - b).abs() / scale);
}
}
let resident_per_solve = resident_ms / N_SOLVES as f64;
let reupload_per_solve = reupload_ms / N_SOLVES as f64;
let residency_speedup = reupload_ms / resident_ms.max(1e-9);
println!(
"[#1017 per-solve {label}] N={N_SOLVES} frame_build={frame_build_ms:.2}ms \
resident={resident_ms:.2}ms ({resident_per_solve:.3}ms/solve, \
grad-upload + warm factors) reupload={reupload_ms:.2}ms \
({reupload_per_solve:.3}ms/solve, N factors + N D/B uploads) \
residency_speedup={residency_speedup:.2}x parity_rel={max_rel:e}"
);
assert!(
max_rel < 1e-9,
"{label}: resident per-solve steps must match reupload (rel {max_rel:e})"
);
let min_speedup = if label == "color_arm" { 1.5 } else { 1.0 };
assert!(
residency_speedup > min_speedup,
"{label}: across-iteration residency must beat per-solve re-upload \
(residency_speedup={residency_speedup:.3}x, required >{min_speedup}x; \
resident {resident_per_solve:.3}ms/solve vs reupload \
{reupload_per_solve:.3}ms/solve over N={N_SOLVES} solves) — the resident \
frame either silently re-uploaded D/B or the dispatch dropped the \
amortized factor path"
);
}
}
fn battery_variant_matrix() -> Vec<super::SweepVariant> {
let mut variants = Vec::new();
for k in 1..=4u64 {
for basis_cols in [4usize, 8, 12] {
let mut dim = DeviceResidentArrowShape::color_arm();
dim.basis_cols = basis_cols;
variants.push(super::SweepVariant {
dim,
seed: 0x1017_0040_0000_0000 ^ (k << 8) ^ (basis_cols as u64),
});
}
}
variants
}
#[test]
fn variant_sweep_multiplex_matches_sequential() {
let variants = battery_variant_matrix();
let opts = DeviceResidentInnerOptions::default();
let workspaces =
super::build_sweep_workspaces(&variants).expect("sweep workspaces must build");
let multiplexed =
super::run_resident_fits_multiplexed_with(workspaces, opts, |ws, opts| {
ws.cpu_reference_fit(opts)
})
.expect("multiplexed cpu sweep");
let seq_workspaces =
super::build_sweep_workspaces(&variants).expect("sweep workspaces must build");
let sequential: Vec<_> = seq_workspaces
.iter()
.map(|ws| ws.cpu_reference_fit(&opts))
.collect();
assert_eq!(multiplexed.len(), sequential.len());
for (idx, (mux, seq)) in multiplexed.iter().zip(sequential.iter()).enumerate() {
let mux = &mux.as_ref().unwrap().outcome;
let seq = seq.as_ref().unwrap();
assert_eq!(
mux.t.as_slice(),
seq.t.as_slice(),
"variant {idx}: multiplexed t differs from sequential"
);
assert_eq!(
mux.beta.as_slice(),
seq.beta.as_slice(),
"variant {idx}: multiplexed beta differs from sequential"
);
assert_eq!(
mux.objective.to_bits(),
seq.objective.to_bits(),
"variant {idx}: multiplexed objective differs from sequential"
);
}
}
#[test]
fn gpu_multiplex_throughput_bench() {
let variants = battery_variant_matrix();
let opts = DeviceResidentInnerOptions::default();
let probe = super::build_sweep_workspaces(&variants).expect("sweep workspaces");
let any_device = probe.iter().any(|w| w.device_resident());
if !any_device {
println!(
"[#1017 mux-bench] no CUDA device — {} variants (K1..4 x 3 basis) \
skipped; run on the GPU node for cross-fit throughput",
variants.len()
);
return;
}
let (results, mux_tp) =
super::run_variant_sweep_multiplexed(&variants, opts).expect("multiplexed sweep");
let seq_tp = super::assert_sweep_parity_vs_sequential(&variants, &opts, &results)
.expect("sweep parity vs sequential must hold");
println!(
"[#1017 mux-bench] fits={} succeeded={} multiplexed={:.3}s ({:.1} fits/s) \
sequential={:.3}s ({:.1} fits/s) cross-fit-speedup={:.2}x",
mux_tp.fits,
mux_tp.succeeded,
mux_tp.wall_seconds,
mux_tp.fits_per_second,
seq_tp.wall_seconds,
seq_tp.fits_per_second,
mux_tp.fits_per_second / seq_tp.fits_per_second.max(1e-9),
);
assert_eq!(
mux_tp.succeeded, mux_tp.fits,
"all battery variants must fit successfully on device"
);
}
}