1use ndarray::Array1;
11
12use crate::gpu_kernels::arrow_schur::{
13 ArrowSchurGpuFailure, solve_arrow_newton_step, solve_arrow_newton_step_dense_reference,
14};
15use gam_problem::ExecutionPath;
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
35pub enum InnerSolveMode {
36 DeviceResident,
37 DeviceReupload,
38 CpuReference,
39}
40
41impl InnerSolveMode {
42 #[inline]
47 const fn execution_path(self) -> ExecutionPath {
48 match self {
49 Self::DeviceResident => ExecutionPath::GpuResidentFull,
50 Self::DeviceReupload => ExecutionPath::GpuReupload,
51 Self::CpuReference => ExecutionPath::Cpu,
52 }
53 }
54}
55use crate::arrow_schur::{ArrowSchurError, ArrowSchurSystem};
56
57#[derive(Clone, Copy, Debug, Eq, PartialEq)]
64pub struct DeviceResidentArrowShape {
65 pub n: usize,
66 pub p: usize,
67 pub basis_cols: usize,
68 pub d: usize,
69}
70
71impl DeviceResidentArrowShape {
72 #[inline]
73 pub const fn qwen_non_gating() -> Self {
74 Self {
75 n: 2_000,
76 p: 2_048,
77 basis_cols: 8,
78 d: 2,
79 }
80 }
81
82 #[inline]
86 pub const fn color_arm() -> Self {
87 Self {
88 n: 180,
89 p: 5_120,
90 basis_cols: 9,
91 d: 2,
92 }
93 }
94
95 #[inline]
96 pub const fn target_len(self) -> usize {
97 self.n * self.p
98 }
99
100 #[inline]
101 pub const fn basis_len(self) -> usize {
102 self.n * self.basis_cols
103 }
104
105 #[inline]
106 pub const fn row_hessian_len(self) -> usize {
107 self.n * self.d * self.d
108 }
109
110 #[inline]
111 pub const fn row_cross_len(self) -> usize {
112 self.n * self.d * self.p
113 }
114
115 #[inline]
116 pub const fn row_gradient_len(self) -> usize {
117 self.n * self.d
118 }
119
120 #[inline]
121 pub const fn border_hessian_len(self) -> usize {
122 self.p * self.p
123 }
124}
125
126#[derive(Clone, Debug)]
133pub struct DeviceResidentArrowSlabs {
134 pub row_hessian_slabs: Vec<f64>,
135 pub row_cross_slabs: Vec<f64>,
136 pub row_gradient_slabs: Vec<f64>,
137 pub border_hessian: Vec<f64>,
138 pub border_gradient: Vec<f64>,
139}
140
141#[derive(Clone, Debug)]
143pub struct DeviceResidentArrowStep {
144 pub delta_t: Array1<f64>,
145 pub delta_beta: Array1<f64>,
146 pub objective: f64,
147 pub gradient_norm: f64,
148 pub log_det_hessian: f64,
149 pub execution_path: ExecutionPath,
150}
151
152#[derive(Debug, Clone)]
153pub enum DeviceResidentArrowError {
154 Shape { reason: String },
155 Unavailable { reason: String },
156 Solve { reason: String },
157}
158
159impl std::fmt::Display for DeviceResidentArrowError {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 Self::Shape { reason } | Self::Unavailable { reason } | Self::Solve { reason } => {
163 f.write_str(reason)
164 }
165 }
166 }
167}
168
169impl std::error::Error for DeviceResidentArrowError {}
170
171#[cfg(target_os = "linux")]
172pub struct DeviceResidentArrowBuffers {
173 pub stream: std::sync::Arc<cudarc::driver::CudaStream>,
174 pub target_x_dev: cudarc::driver::CudaSlice<f64>,
175 pub basis_values_dev: cudarc::driver::CudaSlice<f64>,
176 pub gate_activations_dev: cudarc::driver::CudaSlice<f64>,
177 pub row_hessian_dev: cudarc::driver::CudaSlice<f64>,
178 pub row_cross_dev: cudarc::driver::CudaSlice<f64>,
179 pub row_gradient_dev: cudarc::driver::CudaSlice<f64>,
180 pub border_hessian_dev: cudarc::driver::CudaSlice<f64>,
181 pub border_gradient_dev: cudarc::driver::CudaSlice<f64>,
182 pub bytes: usize,
183}
184
185pub struct DeviceResidentArrowWorkspace {
187 shape: DeviceResidentArrowShape,
188 target_x: Vec<f64>,
189 basis_values: Vec<f64>,
190 gate_activations: Vec<f64>,
191 slabs: DeviceResidentArrowSlabs,
192 #[cfg(target_os = "linux")]
193 device: Option<DeviceResidentArrowBuffers>,
194}
195
196impl DeviceResidentArrowWorkspace {
197 pub fn new(
198 shape: DeviceResidentArrowShape,
199 target_x: Vec<f64>,
200 basis_values: Vec<f64>,
201 gate_activations: Vec<f64>,
202 slabs: DeviceResidentArrowSlabs,
203 ) -> Result<Self, DeviceResidentArrowError> {
204 validate_shape(shape, &target_x, &basis_values, &gate_activations, &slabs)?;
205 #[cfg(target_os = "linux")]
206 let device =
207 upload_resident_buffers(shape, &target_x, &basis_values, &gate_activations, &slabs);
208 Ok(Self {
209 shape,
210 target_x,
211 basis_values,
212 gate_activations,
213 slabs,
214 #[cfg(target_os = "linux")]
215 device,
216 })
217 }
218
219 #[inline]
220 pub const fn shape(&self) -> DeviceResidentArrowShape {
221 self.shape
222 }
223
224 #[must_use]
225 pub fn device_resident(&self) -> bool {
226 #[cfg(target_os = "linux")]
227 {
228 self.device.is_some()
229 }
230 #[cfg(not(target_os = "linux"))]
231 {
232 false
233 }
234 }
235
236 #[must_use]
237 pub fn resident_device_bytes(&self) -> usize {
238 #[cfg(target_os = "linux")]
239 {
240 self.device.as_ref().map_or(0, |device| device.bytes)
241 }
242 #[cfg(not(target_os = "linux"))]
243 {
244 0
245 }
246 }
247
248 #[must_use]
253 fn context_id(&self) -> usize {
254 usize::from(self.device_resident())
255 }
256
257 #[must_use]
260 fn frame_upload_bytes(&self) -> usize {
261 [
262 self.slabs.row_hessian_slabs.len(),
263 self.slabs.row_cross_slabs.len(),
264 self.slabs.row_gradient_slabs.len(),
265 self.slabs.border_hessian.len(),
266 self.slabs.border_gradient.len(),
267 ]
268 .into_iter()
269 .sum::<usize>()
270 * std::mem::size_of::<f64>()
271 }
272
273 #[must_use]
274 pub fn host_shadow_bytes(&self) -> usize {
275 [
276 self.target_x.len(),
277 self.basis_values.len(),
278 self.gate_activations.len(),
279 self.slabs.row_hessian_slabs.len(),
280 self.slabs.row_cross_slabs.len(),
281 self.slabs.row_gradient_slabs.len(),
282 self.slabs.border_hessian.len(),
283 self.slabs.border_gradient.len(),
284 ]
285 .into_iter()
286 .sum::<usize>()
287 * std::mem::size_of::<f64>()
288 }
289
290 pub fn one_inner_iteration(
293 &self,
294 ridge_t: f64,
295 ridge_beta: f64,
296 ) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
297 if !self.device_resident() {
298 return Err(DeviceResidentArrowError::Unavailable {
299 reason: "SAE resident inner iteration unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
300 });
301 }
302 let sys = self.to_arrow_system();
303 solve_arrow_newton_step(&sys, ridge_t, ridge_beta)
304 .map(|solution| self.finish_step(solution, ExecutionPath::GpuResidentLinearization))
305 .map_err(map_gpu_error)
306 }
307
308 pub fn cpu_reference_step(
311 &self,
312 ridge_t: f64,
313 ridge_beta: f64,
314 ) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
315 let sys = self.to_arrow_system();
316 solve_arrow_newton_step_dense_reference(&sys, ridge_t, ridge_beta)
317 .map(|solution| self.finish_step(solution, ExecutionPath::Cpu))
318 .map_err(|reason| DeviceResidentArrowError::Solve { reason })
319 }
320
321 pub fn to_arrow_system(&self) -> ArrowSchurSystem {
322 let shape = self.shape;
323 let mut sys = ArrowSchurSystem::new(shape.n, shape.d, shape.p);
324 for i in 0..shape.n {
325 let h_base = i * shape.d * shape.d;
326 let b_base = i * shape.d * shape.p;
327 let g_base = i * shape.d;
328 for r in 0..shape.d {
329 for c in 0..shape.d {
330 sys.rows[i].htt[[r, c]] =
331 self.slabs.row_hessian_slabs[h_base + r * shape.d + c];
332 }
333 sys.rows[i].gt[r] = self.slabs.row_gradient_slabs[g_base + r];
334 for c in 0..shape.p {
335 sys.rows[i].htbeta[[r, c]] =
336 self.slabs.row_cross_slabs[b_base + r * shape.p + c];
337 }
338 }
339 }
340 for r in 0..shape.p {
341 sys.gb[r] = self.slabs.border_gradient[r];
342 for c in 0..shape.p {
343 sys.hbb[[r, c]] = self.slabs.border_hessian[r * shape.p + c];
344 }
345 }
346 sys.refresh_row_hessian_fingerprint();
347 sys
348 }
349
350 fn finish_step(
351 &self,
352 solution: crate::gpu_kernels::arrow_schur::ArrowSchurGpuSolution,
353 execution_path: ExecutionPath,
354 ) -> DeviceResidentArrowStep {
355 DeviceResidentArrowStep {
356 delta_t: solution.delta_t,
357 delta_beta: solution.delta_beta,
358 objective: 0.5 * squared_norm(&self.target_x),
359 gradient_norm: self.gradient_norm(),
360 log_det_hessian: solution.log_det_hessian,
361 execution_path,
362 }
363 }
364
365 fn gradient_norm(&self) -> f64 {
366 let row = squared_norm(&self.slabs.row_gradient_slabs);
367 let border = squared_norm(&self.slabs.border_gradient);
368 (row + border).sqrt()
369 }
370
371 pub fn device_fit(
400 &self,
401 opts: &DeviceResidentInnerOptions,
402 ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
403 if !self.device_resident() {
404 return Err(DeviceResidentArrowError::Unavailable {
405 reason: "SAE resident inner loop unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
406 });
407 }
408 self.run_inner_loop(opts, InnerSolveMode::DeviceResident)
409 }
410
411 pub fn device_reupload_fit(
419 &self,
420 opts: &DeviceResidentInnerOptions,
421 ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
422 if !self.device_resident() {
423 return Err(DeviceResidentArrowError::Unavailable {
424 reason: "SAE re-uploading inner loop unavailable: CUDA runtime did not admit the row-block workload".to_string(),
425 });
426 }
427 self.run_inner_loop(opts, InnerSolveMode::DeviceReupload)
428 }
429
430 pub fn cpu_reference_fit(
434 &self,
435 opts: &DeviceResidentInnerOptions,
436 ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
437 self.run_inner_loop(opts, InnerSolveMode::CpuReference)
438 }
439
440 fn run_inner_loop(
441 &self,
442 opts: &DeviceResidentInnerOptions,
443 mode: InnerSolveMode,
444 ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
445 let execution_path = mode.execution_path();
446 let n = self.shape.n;
447 let d = self.shape.d;
448 let p = self.shape.p;
449 let t_len = n * d;
450
451 let mut t = vec![0.0_f64; t_len];
455 let mut beta = vec![0.0_f64; p];
456
457 let base = self.to_arrow_system();
458 let half_target_energy = 0.5 * squared_norm(&self.target_x);
459
460 let mut ridge_t = opts.initial_ridge_t.max(0.0);
461 let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
462 let mut resident_frame: Option<(
471 f64,
472 f64,
473 crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle,
474 )> = None;
475 let mut current_objective = self.objective_at(&base, half_target_energy, &t, &beta);
476 let mut accepted_iters = 0_usize;
477 let mut total_iters = 0_usize;
478 let mut converged = false;
479 let mut last_step = DeviceResidentArrowStep {
480 delta_t: Array1::zeros(t_len),
481 delta_beta: Array1::zeros(p),
482 objective: current_objective,
483 gradient_norm: 0.0,
484 log_det_hessian: 0.0,
485 execution_path,
486 };
487
488 while total_iters < opts.max_iterations {
489 let residual = self.residual_system(&base, &t, &beta);
491 let g_norm = arrow_system_gradient_norm(&residual);
492 let scale = 1.0 + iterate_norm(&t, &beta);
493 if g_norm / scale < opts.convergence_tolerance {
494 converged = true;
495 break;
496 }
497
498 let solution = match mode {
499 InnerSolveMode::DeviceResident => {
500 let frame_matches = resident_frame
505 .as_ref()
506 .is_some_and(|(rt, rb, _)| *rt == ridge_t && *rb == ridge_beta);
507 let mut frame_build_error: Option<DeviceResidentArrowError> = None;
508 if !frame_matches {
509 resident_frame = None;
510 match crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
511 &residual, ridge_t, ridge_beta,
512 ) {
513 Ok(frame) => {
514 gam_gpu::profile::telemetry_record_handle_creation(
520 self.context_id(),
521 );
522 gam_gpu::profile::telemetry_record_factorization();
523 gam_gpu::profile::telemetry_record_h2d(
524 self.frame_upload_bytes(),
525 );
526 resident_frame = Some((ridge_t, ridge_beta, frame));
527 }
528 Err(err) => frame_build_error = Some(map_gpu_error(err)),
529 }
530 }
531 match resident_frame.as_ref() {
532 Some((_, _, frame)) => {
533 let mut g_t = Vec::with_capacity(n * d);
536 for row in &residual.rows {
537 for &v in row.gt.iter() {
538 g_t.push(v);
539 }
540 }
541 let g_beta: Vec<f64> = residual.gb.iter().copied().collect();
542 let grad_bytes =
546 (g_t.len() + g_beta.len()) * std::mem::size_of::<f64>();
547 gam_gpu::profile::telemetry_record_h2d(grad_bytes);
548 gam_gpu::profile::telemetry_record_kernel_launch();
549 gam_gpu::profile::telemetry_record_d2h(
550 (n * d + p) * std::mem::size_of::<f64>(),
551 );
552 frame.solve_gradient(&g_t, &g_beta).map_err(map_gpu_error)
553 }
554 None => Err(frame_build_error.unwrap_or_else(|| {
555 DeviceResidentArrowError::Solve {
556 reason: "SAE resident frame build declined".to_string(),
557 }
558 })),
559 }
560 }
561 InnerSolveMode::DeviceReupload => {
562 gam_gpu::profile::telemetry_record_handle_creation(self.context_id());
568 gam_gpu::profile::telemetry_record_factorization();
569 gam_gpu::profile::telemetry_record_h2d(self.frame_upload_bytes());
570 gam_gpu::profile::telemetry_record_kernel_launch();
571 gam_gpu::profile::telemetry_record_d2h(
572 (n * d + p) * std::mem::size_of::<f64>(),
573 );
574 solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
575 }
576 InnerSolveMode::CpuReference => {
577 solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
578 .map_err(|reason| DeviceResidentArrowError::Solve { reason })
579 }
580 };
581
582 let solution = match solution {
583 Ok(sol) => sol,
584 Err(DeviceResidentArrowError::Solve { .. })
585 | Err(DeviceResidentArrowError::Unavailable { .. }) => {
586 ridge_t = grow_ridge(ridge_t, opts.lm_grow);
590 ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
591 if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
592 return Err(DeviceResidentArrowError::Solve {
593 reason: format!(
594 "SAE resident inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
595 opts.max_ridge
596 ),
597 });
598 }
599 total_iters += 1;
600 continue;
601 }
602 Err(other) => return Err(other),
603 };
604
605 let predicted_reduction =
608 crate::arrow_schur::arrow_bare_quadratic_model_reduction(
609 &residual,
610 solution.delta_t.view(),
611 solution.delta_beta.view(),
612 ridge_t,
613 ridge_beta,
614 )
615 .map_err(|err| DeviceResidentArrowError::Solve {
616 reason: format!("SAE resident inner loop predicted-reduction failed: {err}"),
617 })?;
618
619 let mut trial_t = t.clone();
621 let mut trial_beta = beta.clone();
622 for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
623 *slot += *dv;
624 }
625 for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
626 *slot += *dv;
627 }
628 let trial_objective =
629 self.objective_at(&base, half_target_energy, &trial_t, &trial_beta);
630
631 let objective_scale = current_objective.abs();
643 let noise_floor = objective_scale * 1e-14;
644 let actual_reduction = current_objective - trial_objective;
645 let rho = if predicted_reduction > noise_floor {
646 actual_reduction / predicted_reduction
647 } else if actual_reduction >= -noise_floor {
648 1.0
649 } else {
650 -1.0
651 };
652
653 if rho > 0.0 && trial_objective.is_finite() {
654 t = trial_t;
655 beta = trial_beta;
656 current_objective = trial_objective;
657 ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
658 ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
659 last_step = DeviceResidentArrowStep {
660 delta_t: solution.delta_t,
661 delta_beta: solution.delta_beta,
662 objective: current_objective,
663 gradient_norm: g_norm,
664 log_det_hessian: solution.log_det_hessian,
665 execution_path,
666 };
667 accepted_iters += 1;
668 total_iters += 1;
669 } else {
670 ridge_t = grow_ridge(ridge_t, opts.lm_grow);
671 ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
672 if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
673 return Err(DeviceResidentArrowError::Solve {
674 reason: format!(
675 "SAE resident inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
676 opts.max_ridge
677 ),
678 });
679 }
680 total_iters += 1;
681 }
682 }
683
684 Ok(DeviceResidentInnerOutcome {
685 t: Array1::from_vec(t),
686 beta: Array1::from_vec(beta),
687 objective: current_objective,
688 gradient_norm: last_step.gradient_norm,
689 log_det_hessian: last_step.log_det_hessian,
690 iterations: total_iters,
691 accepted_iterations: accepted_iters,
692 converged,
693 execution_path,
694 })
695 }
696
697 pub fn device_fit_outer_sequence(
733 &self,
734 base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
735 opts: &DeviceResidentInnerOptions,
736 ) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
737 if !self.device_resident() {
738 return Err(DeviceResidentArrowError::Unavailable {
739 reason: "SAE outer-sequence residency unavailable: CUDA runtime did not admit the row-block workload".to_string(),
740 });
741 }
742 self.run_outer_sequence(
743 base_gradient_overrides,
744 opts,
745 InnerSolveMode::DeviceResident,
746 )
747 }
748
749 pub fn cpu_reference_outer_sequence(
754 &self,
755 base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
756 opts: &DeviceResidentInnerOptions,
757 ) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
758 self.run_outer_sequence(base_gradient_overrides, opts, InnerSolveMode::CpuReference)
759 }
760
761 fn run_outer_sequence(
762 &self,
763 base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
764 opts: &DeviceResidentInnerOptions,
765 mode: InnerSolveMode,
766 ) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
767 let n = self.shape.n;
768 let d = self.shape.d;
769 let p = self.shape.p;
770 let t_len = n * d;
771 let half_target_energy = 0.5 * squared_norm(&self.target_x);
772
773 let mut shared = SharedFrameState::default();
780 let mut outcomes = Vec::with_capacity(base_gradient_overrides.len());
781
782 for (g_t_override, g_beta_override) in base_gradient_overrides {
783 if g_t_override.len() != t_len || g_beta_override.len() != p {
784 return Err(DeviceResidentArrowError::Shape {
785 reason: format!(
786 "outer-sequence gradient shape mismatch: g_t={} (want {t_len}), g_beta={} (want {p})",
787 g_t_override.len(),
788 g_beta_override.len()
789 ),
790 });
791 }
792 let mut base = self.to_arrow_system();
795 for (i, row) in base.rows.iter_mut().enumerate() {
796 for r in 0..d {
797 row.gt[r] = g_t_override[i * d + r];
798 }
799 }
800 for (j, gb) in base.gb.iter_mut().enumerate() {
801 *gb = g_beta_override[j];
802 }
803 base.refresh_row_hessian_fingerprint();
804
805 let outcome = self.run_one_outer(&base, half_target_energy, opts, mode, &mut shared)?;
806 outcomes.push(outcome);
807 }
808
809 Ok(OuterSequenceOutcome {
810 outers: outcomes,
811 frame_builds: shared.frame_builds,
812 })
813 }
814
815 fn run_one_outer(
822 &self,
823 base: &ArrowSchurSystem,
824 half_target_energy: f64,
825 opts: &DeviceResidentInnerOptions,
826 mode: InnerSolveMode,
827 shared: &mut SharedFrameState,
828 ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
829 let execution_path = mode.execution_path();
830 let n = self.shape.n;
831 let d = self.shape.d;
832 let p = self.shape.p;
833 let t_len = n * d;
834
835 let mut t = vec![0.0_f64; t_len];
836 let mut beta = vec![0.0_f64; p];
837 let mut ridge_t = opts.initial_ridge_t.max(0.0);
838 let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
839 let mut current_objective = self.objective_at(base, half_target_energy, &t, &beta);
840 let mut accepted_iters = 0_usize;
841 let mut total_iters = 0_usize;
842 let mut converged = false;
843 let mut last_gradient_norm = 0.0_f64;
844 let mut last_log_det = 0.0_f64;
845
846 while total_iters < opts.max_iterations {
847 let residual = self.residual_system(base, &t, &beta);
848 let g_norm = arrow_system_gradient_norm(&residual);
849 let scale = 1.0 + iterate_norm(&t, &beta);
850 if g_norm / scale < opts.convergence_tolerance {
851 converged = true;
852 break;
853 }
854
855 let solution = match mode {
856 InnerSolveMode::DeviceResident => {
857 let frame_matches = shared
858 .frame
859 .as_ref()
860 .is_some_and(|(rt, rb, _)| *rt == ridge_t && *rb == ridge_beta);
861 let mut frame_build_error: Option<DeviceResidentArrowError> = None;
862 if !frame_matches {
863 shared.frame = None;
864 match crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
865 &residual, ridge_t, ridge_beta,
866 ) {
867 Ok(frame) => {
868 shared.frame_builds += 1;
869 gam_gpu::profile::telemetry_record_handle_creation(
870 self.context_id(),
871 );
872 gam_gpu::profile::telemetry_record_factorization();
873 gam_gpu::profile::telemetry_record_h2d(
874 self.frame_upload_bytes(),
875 );
876 shared.frame = Some((ridge_t, ridge_beta, frame));
877 }
878 Err(err) => frame_build_error = Some(map_gpu_error(err)),
879 }
880 }
881 match shared.frame.as_ref() {
882 Some((_, _, frame)) => {
883 let mut g_t = Vec::with_capacity(n * d);
884 for row in &residual.rows {
885 for &v in row.gt.iter() {
886 g_t.push(v);
887 }
888 }
889 let g_beta: Vec<f64> = residual.gb.iter().copied().collect();
890 let grad_bytes =
891 (g_t.len() + g_beta.len()) * std::mem::size_of::<f64>();
892 gam_gpu::profile::telemetry_record_h2d(grad_bytes);
893 gam_gpu::profile::telemetry_record_kernel_launch();
894 gam_gpu::profile::telemetry_record_d2h(
895 (n * d + p) * std::mem::size_of::<f64>(),
896 );
897 frame.solve_gradient(&g_t, &g_beta).map_err(map_gpu_error)
898 }
899 None => Err(frame_build_error.unwrap_or_else(|| {
900 DeviceResidentArrowError::Solve {
901 reason: "SAE resident frame build declined".to_string(),
902 }
903 })),
904 }
905 }
906 InnerSolveMode::DeviceReupload => {
907 solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
908 }
909 InnerSolveMode::CpuReference => {
910 solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
911 .map_err(|reason| DeviceResidentArrowError::Solve { reason })
912 }
913 };
914
915 let solution = match solution {
916 Ok(sol) => sol,
917 Err(DeviceResidentArrowError::Solve { .. })
918 | Err(DeviceResidentArrowError::Unavailable { .. }) => {
919 ridge_t = grow_ridge(ridge_t, opts.lm_grow);
920 ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
921 if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
922 return Err(DeviceResidentArrowError::Solve {
923 reason: format!(
924 "SAE outer-sequence inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
925 opts.max_ridge
926 ),
927 });
928 }
929 total_iters += 1;
930 continue;
931 }
932 Err(other) => return Err(other),
933 };
934
935 let predicted_reduction =
936 crate::arrow_schur::arrow_bare_quadratic_model_reduction(
937 &residual,
938 solution.delta_t.view(),
939 solution.delta_beta.view(),
940 ridge_t,
941 ridge_beta,
942 )
943 .map_err(|err| DeviceResidentArrowError::Solve {
944 reason: format!("SAE outer-sequence predicted-reduction failed: {err}"),
945 })?;
946
947 let mut trial_t = t.clone();
948 let mut trial_beta = beta.clone();
949 for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
950 *slot += *dv;
951 }
952 for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
953 *slot += *dv;
954 }
955 let trial_objective =
956 self.objective_at(base, half_target_energy, &trial_t, &trial_beta);
957
958 let objective_scale = current_objective.abs();
959 let noise_floor = objective_scale * 1e-14;
960 let actual_reduction = current_objective - trial_objective;
961 let rho = if predicted_reduction > noise_floor {
962 actual_reduction / predicted_reduction
963 } else if actual_reduction >= -noise_floor {
964 1.0
965 } else {
966 -1.0
967 };
968
969 if rho > 0.0 && trial_objective.is_finite() {
970 t = trial_t;
971 beta = trial_beta;
972 current_objective = trial_objective;
973 ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
974 ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
975 last_gradient_norm = g_norm;
976 last_log_det = solution.log_det_hessian;
977 accepted_iters += 1;
978 total_iters += 1;
979 } else {
980 ridge_t = grow_ridge(ridge_t, opts.lm_grow);
981 ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
982 if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
983 return Err(DeviceResidentArrowError::Solve {
984 reason: format!(
985 "SAE outer-sequence inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
986 opts.max_ridge
987 ),
988 });
989 }
990 total_iters += 1;
991 }
992 }
993
994 Ok(DeviceResidentInnerOutcome {
995 t: Array1::from_vec(t),
996 beta: Array1::from_vec(beta),
997 objective: current_objective,
998 gradient_norm: last_gradient_norm,
999 log_det_hessian: last_log_det,
1000 iterations: total_iters,
1001 accepted_iterations: accepted_iters,
1002 converged,
1003 execution_path,
1004 })
1005 }
1006
1007 fn objective_at(
1014 &self,
1015 base: &ArrowSchurSystem,
1016 half_target_energy: f64,
1017 t: &[f64],
1018 beta: &[f64],
1019 ) -> f64 {
1020 let n = self.shape.n;
1021 let d = self.shape.d;
1022 let p = self.shape.p;
1023 let mut quad = 0.0_f64;
1025 let mut lin = 0.0_f64;
1026 for i in 0..n {
1029 let t_base = i * d;
1030 for r in 0..d {
1031 let mut htt_t = 0.0_f64;
1033 for c in 0..d {
1034 htt_t += base.rows[i].htt[[r, c]] * t[t_base + c];
1035 }
1036 let mut htb_b = 0.0_f64;
1038 for c in 0..p {
1039 htb_b += base.rows[i].htbeta[[r, c]] * beta[c];
1040 }
1041 quad += t[t_base + r] * (htt_t + 2.0 * htb_b);
1042 lin += base.rows[i].gt[r] * t[t_base + r];
1043 }
1044 }
1045 for r in 0..p {
1047 let mut hbb_b = 0.0_f64;
1048 for c in 0..p {
1049 hbb_b += base.hbb[[r, c]] * beta[c];
1050 }
1051 quad += beta[r] * hbb_b;
1052 lin += base.gb[r] * beta[r];
1053 }
1054 half_target_energy + 0.5 * quad - lin
1055 }
1056
1057 fn residual_system(
1062 &self,
1063 base: &ArrowSchurSystem,
1064 t: &[f64],
1065 beta: &[f64],
1066 ) -> ArrowSchurSystem {
1067 let n = self.shape.n;
1068 let d = self.shape.d;
1069 let p = self.shape.p;
1070 let mut sys = self.to_arrow_system();
1078 for i in 0..n {
1079 let t_base = i * d;
1080 for r in 0..d {
1081 let mut hz = 0.0_f64;
1082 for c in 0..d {
1083 hz += base.rows[i].htt[[r, c]] * t[t_base + c];
1084 }
1085 for c in 0..p {
1086 hz += base.rows[i].htbeta[[r, c]] * beta[c];
1087 }
1088 sys.rows[i].gt[r] = hz - base.rows[i].gt[r];
1089 }
1090 }
1091 for r in 0..p {
1092 let mut hz = 0.0_f64;
1093 for c in 0..p {
1095 hz += base.hbb[[r, c]] * beta[c];
1096 }
1097 for i in 0..n {
1099 let t_base = i * d;
1100 for rr in 0..d {
1101 hz += base.rows[i].htbeta[[rr, r]] * t[t_base + rr];
1102 }
1103 }
1104 sys.gb[r] = hz - base.gb[r];
1105 }
1106 sys.refresh_row_hessian_fingerprint();
1107 sys
1108 }
1109}
1110
1111#[derive(Clone, Copy, Debug)]
1115pub struct DeviceResidentInnerOptions {
1116 pub max_iterations: usize,
1117 pub convergence_tolerance: f64,
1118 pub initial_ridge_t: f64,
1119 pub initial_ridge_beta: f64,
1120 pub lm_grow: f64,
1121 pub lm_shrink: f64,
1122 pub max_ridge: f64,
1123}
1124
1125impl Default for DeviceResidentInnerOptions {
1126 fn default() -> Self {
1127 Self {
1128 max_iterations: 16,
1129 convergence_tolerance: 1e-9,
1130 initial_ridge_t: 0.0,
1131 initial_ridge_beta: 0.0,
1132 lm_grow: 4.0,
1133 lm_shrink: 0.5,
1134 max_ridge: 1e9,
1135 }
1136 }
1137}
1138
1139#[derive(Clone, Debug)]
1141pub struct DeviceResidentInnerOutcome {
1142 pub t: Array1<f64>,
1143 pub beta: Array1<f64>,
1144 pub objective: f64,
1145 pub gradient_norm: f64,
1146 pub log_det_hessian: f64,
1147 pub iterations: usize,
1148 pub accepted_iterations: usize,
1149 pub converged: bool,
1150 pub execution_path: ExecutionPath,
1151}
1152
1153#[derive(Clone, Debug)]
1164pub struct OuterSequenceOutcome {
1165 pub outers: Vec<DeviceResidentInnerOutcome>,
1166 pub frame_builds: usize,
1167}
1168
1169#[derive(Default)]
1174struct SharedFrameState {
1175 frame: Option<(
1176 f64,
1177 f64,
1178 crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle,
1179 )>,
1180 frame_builds: usize,
1181}
1182
1183fn grow_ridge(current: f64, grow: f64) -> f64 {
1184 if current == 0.0 { 1e-6 } else { current * grow }
1185}
1186
1187fn arrow_system_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
1188 let mut acc = 0.0_f64;
1189 for row in &sys.rows {
1190 for &v in row.gt.iter() {
1191 acc += v * v;
1192 }
1193 }
1194 for &v in sys.gb.iter() {
1195 acc += v * v;
1196 }
1197 acc.sqrt()
1198}
1199
1200fn iterate_norm(t: &[f64], beta: &[f64]) -> f64 {
1201 (squared_norm(t) + squared_norm(beta)).sqrt()
1202}
1203
1204fn validate_shape(
1205 shape: DeviceResidentArrowShape,
1206 target_x: &[f64],
1207 basis_values: &[f64],
1208 gate_activations: &[f64],
1209 slabs: &DeviceResidentArrowSlabs,
1210) -> Result<(), DeviceResidentArrowError> {
1211 let checks = [
1212 ("target_x", target_x.len(), shape.target_len()),
1213 ("basis_values", basis_values.len(), shape.basis_len()),
1214 (
1215 "gate_activations",
1216 gate_activations.len(),
1217 shape.basis_len(),
1218 ),
1219 (
1220 "row_hessian_slabs",
1221 slabs.row_hessian_slabs.len(),
1222 shape.row_hessian_len(),
1223 ),
1224 (
1225 "row_cross_slabs",
1226 slabs.row_cross_slabs.len(),
1227 shape.row_cross_len(),
1228 ),
1229 (
1230 "row_gradient_slabs",
1231 slabs.row_gradient_slabs.len(),
1232 shape.row_gradient_len(),
1233 ),
1234 (
1235 "border_hessian",
1236 slabs.border_hessian.len(),
1237 shape.border_hessian_len(),
1238 ),
1239 ("border_gradient", slabs.border_gradient.len(), shape.p),
1240 ];
1241 for (label, got, want) in checks {
1242 if got != want {
1243 return Err(DeviceResidentArrowError::Shape {
1244 reason: format!(
1245 "SAE resident workspace shape mismatch for {label}: got {got}, expected {want}"
1246 ),
1247 });
1248 }
1249 }
1250 if shape.n == 0 || shape.p == 0 || shape.d == 0 || shape.basis_cols == 0 {
1251 return Err(DeviceResidentArrowError::Shape {
1252 reason: "SAE resident workspace requires nonzero n, p, basis_cols, and d".to_string(),
1253 });
1254 }
1255 Ok(())
1256}
1257
1258#[cfg(target_os = "linux")]
1259fn upload_resident_buffers(
1260 shape: DeviceResidentArrowShape,
1261 target_x: &[f64],
1262 basis_values: &[f64],
1263 gate_activations: &[f64],
1264 slabs: &DeviceResidentArrowSlabs,
1265) -> Option<DeviceResidentArrowBuffers> {
1266 use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1267
1268 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf {
1269 p: shape.d,
1270 batch: shape.n,
1271 })
1272 .or_else(|| {
1273 route_through_gpu(DispatchOp::Gemm {
1274 m: shape.p,
1275 n: shape.p,
1276 k: shape.n * shape.basis_cols,
1277 })
1278 })?;
1279 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)?;
1280 let stream = ctx.new_stream().ok()?;
1281 let target_x_dev = stream.clone_htod(target_x).ok()?;
1282 let basis_values_dev = stream.clone_htod(basis_values).ok()?;
1283 let gate_activations_dev = stream.clone_htod(gate_activations).ok()?;
1284 let row_hessian_dev = stream.clone_htod(&slabs.row_hessian_slabs).ok()?;
1285 let row_cross_dev = stream.clone_htod(&slabs.row_cross_slabs).ok()?;
1286 let row_gradient_dev = stream.clone_htod(&slabs.row_gradient_slabs).ok()?;
1287 let border_hessian_dev = stream.clone_htod(&slabs.border_hessian).ok()?;
1288 let border_gradient_dev = stream.clone_htod(&slabs.border_gradient).ok()?;
1289 let bytes = [
1290 target_x.len(),
1291 basis_values.len(),
1292 gate_activations.len(),
1293 slabs.row_hessian_slabs.len(),
1294 slabs.row_cross_slabs.len(),
1295 slabs.row_gradient_slabs.len(),
1296 slabs.border_hessian.len(),
1297 slabs.border_gradient.len(),
1298 ]
1299 .into_iter()
1300 .sum::<usize>()
1301 * std::mem::size_of::<f64>();
1302 Some(DeviceResidentArrowBuffers {
1303 stream,
1304 target_x_dev,
1305 basis_values_dev,
1306 gate_activations_dev,
1307 row_hessian_dev,
1308 row_cross_dev,
1309 row_gradient_dev,
1310 border_hessian_dev,
1311 border_gradient_dev,
1312 bytes,
1313 })
1314}
1315
1316fn map_gpu_error(err: ArrowSchurGpuFailure) -> DeviceResidentArrowError {
1317 match err {
1318 ArrowSchurGpuFailure::Unavailable => DeviceResidentArrowError::Unavailable {
1319 reason: "SAE resident inner iteration unavailable after GPU admission".to_string(),
1320 },
1321 ArrowSchurGpuFailure::RidgeBumpRequired { row, bump } => DeviceResidentArrowError::Solve {
1322 reason: format!("SAE resident inner iteration row {row} requires ridge bump {bump:e}"),
1323 },
1324 ArrowSchurGpuFailure::SchurFactorFailed { reason } => {
1325 DeviceResidentArrowError::Solve { reason }
1326 }
1327 ArrowSchurGpuFailure::GpuRequiresDenseSystem {
1328 had_hbb_matvec,
1329 had_htbeta_matvec,
1330 } => DeviceResidentArrowError::Solve {
1331 reason: format!(
1332 "SAE resident inner iteration requires dense slabs; hbb_matvec={had_hbb_matvec} htbeta_matvec={had_htbeta_matvec}"
1333 ),
1334 },
1335 }
1336}
1337
1338fn squared_norm(values: &[f64]) -> f64 {
1339 values.iter().map(|v| v * v).sum()
1340}
1341
1342impl From<ArrowSchurError> for DeviceResidentArrowError {
1343 fn from(err: ArrowSchurError) -> Self {
1344 Self::Solve {
1345 reason: err.to_string(),
1346 }
1347 }
1348}
1349
1350pub fn qwen_non_gating_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1352 qwen_non_gating_fixture_seeded(0x1017_0003_D3A1_5EED)
1353}
1354
1355pub fn qwen_non_gating_fixture_seeded(
1359 seed: u64,
1360) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1361 fixture_for_shape_seeded(DeviceResidentArrowShape::qwen_non_gating(), seed)
1362}
1363
1364pub fn color_arm_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1369 fixture_for_shape_seeded(DeviceResidentArrowShape::color_arm(), 0x1017_C010_2A12_5EED)
1370}
1371
1372fn fixture_for_shape_seeded(
1377 shape: DeviceResidentArrowShape,
1378 seed: u64,
1379) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1380 if shape.d == 0 {
1381 return Err(DeviceResidentArrowError::Shape {
1382 reason: "fixture_for_shape_seeded requires d >= 1".to_string(),
1383 });
1384 }
1385 let d = shape.d;
1386 let mut rng = SplitMix64::new(seed);
1387 let mut target_x = vec![0.0_f64; shape.target_len()];
1388 for i in 0..shape.n {
1389 for j in 0..shape.p {
1390 let phase = ((i % 97) as f64) * 0.013 + ((j % 131) as f64) * 0.007;
1391 target_x[i * shape.p + j] = 0.02 * phase.sin() + 0.001 * rng.sample_signed();
1392 }
1393 }
1394 let mut basis_values = vec![0.0_f64; shape.basis_len()];
1395 let mut gate_activations = vec![1.0_f64; shape.basis_len()];
1396 for i in 0..shape.n {
1397 for a in 0..shape.basis_cols {
1398 let phase = ((i + 1) as f64) * ((a + 1) as f64) * 0.003;
1399 basis_values[i * shape.basis_cols + a] = phase.cos();
1400 gate_activations[i * shape.basis_cols + a] = 1.0;
1401 }
1402 }
1403 let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
1404 let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
1405 let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
1406 for i in 0..shape.n {
1407 let mut basis_sum = 0.0_f64;
1408 for a in 0..shape.basis_cols {
1409 basis_sum +=
1410 basis_values[i * shape.basis_cols + a] * gate_activations[i * shape.basis_cols + a];
1411 }
1412 let h_base = i * d * d;
1415 for r in 0..d {
1416 for c in 0..d {
1417 let v = if r == c {
1418 3.0 + 0.01 * basis_sum.abs() + 0.1 * (r as f64)
1419 } else {
1420 0.02 * (basis_sum + (r + c) as f64).sin() / (d as f64)
1421 };
1422 row_hessian_slabs[h_base + r * d + c] = v;
1423 }
1424 }
1425 for r in 0..d {
1427 for c in 0..r {
1428 let avg = 0.5
1429 * (row_hessian_slabs[h_base + r * d + c]
1430 + row_hessian_slabs[h_base + c * d + r]);
1431 row_hessian_slabs[h_base + r * d + c] = avg;
1432 row_hessian_slabs[h_base + c * d + r] = avg;
1433 }
1434 }
1435 let b_base = i * d * shape.p;
1437 let g_base = i * d;
1438 for r in 0..d {
1439 for j in 0..shape.p {
1440 let feature = ((j % 257) as f64) * 0.011;
1441 row_cross_slabs[b_base + r * shape.p + j] =
1442 1.0e-4 * (basis_sum + r as f64).sin() * feature.cos();
1443 }
1444 row_gradient_slabs[g_base + r] = 0.01 * (basis_sum + r as f64).sin();
1445 }
1446 }
1447 let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
1448 for r in 0..shape.p {
1449 border_hessian[r * shape.p + r] = 4.0;
1450 if r + 1 < shape.p {
1451 border_hessian[r * shape.p + r + 1] = 0.01;
1452 border_hessian[(r + 1) * shape.p + r] = 0.01;
1453 }
1454 }
1455 let mut border_gradient = vec![0.0_f64; shape.p];
1456 for j in 0..shape.p {
1457 border_gradient[j] = 0.001 * ((j % 193) as f64 * 0.017).sin();
1458 }
1459 DeviceResidentArrowWorkspace::new(
1460 shape,
1461 target_x,
1462 basis_values,
1463 gate_activations,
1464 DeviceResidentArrowSlabs {
1465 row_hessian_slabs,
1466 row_cross_slabs,
1467 row_gradient_slabs,
1468 border_hessian,
1469 border_gradient,
1470 },
1471 )
1472}
1473
1474pub struct MultiplexedFit {
1476 pub outcome: DeviceResidentInnerOutcome,
1477}
1478
1479pub fn run_resident_fits_multiplexed(
1509 workspaces: Vec<DeviceResidentArrowWorkspace>,
1510 opts: DeviceResidentInnerOptions,
1511) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String> {
1512 run_resident_fits_multiplexed_with(workspaces, opts, |workspace, opts| {
1513 workspace.device_fit(opts)
1514 })
1515}
1516
1517fn run_resident_fits_multiplexed_with<Run>(
1521 workspaces: Vec<DeviceResidentArrowWorkspace>,
1522 opts: DeviceResidentInnerOptions,
1523 run_one: Run,
1524) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String>
1525where
1526 Run: Fn(
1527 &DeviceResidentArrowWorkspace,
1528 &DeviceResidentInnerOptions,
1529 ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError>
1530 + Sync,
1531{
1532 let rows = crate::topology_selector::run_topology_race_parallel(
1533 workspaces,
1534 move |workspace: DeviceResidentArrowWorkspace| {
1535 run_one(&workspace, &opts).map(|outcome| MultiplexedFit { outcome })
1536 },
1537 )?;
1538 Ok(rows.into_iter().map(|row| row.result).collect())
1539}
1540
1541pub fn run_resident_fits_sequential(
1546 workspaces: &[DeviceResidentArrowWorkspace],
1547 opts: &DeviceResidentInnerOptions,
1548) -> Vec<Result<MultiplexedFit, DeviceResidentArrowError>> {
1549 workspaces
1550 .iter()
1551 .map(|workspace| {
1552 workspace
1553 .device_fit(opts)
1554 .map(|outcome| MultiplexedFit { outcome })
1555 })
1556 .collect()
1557}
1558
1559#[derive(Clone, Copy, Debug)]
1578pub struct SweepVariant {
1579 pub dim: DeviceResidentArrowShape,
1581 pub seed: u64,
1583}
1584
1585#[derive(Clone, Copy, Debug)]
1587pub struct SweepThroughput {
1588 pub fits: usize,
1589 pub succeeded: usize,
1590 pub wall_seconds: f64,
1591 pub fits_per_second: f64,
1593}
1594
1595pub fn build_sweep_workspaces(
1600 variants: &[SweepVariant],
1601) -> Result<Vec<DeviceResidentArrowWorkspace>, DeviceResidentArrowError> {
1602 variants
1603 .iter()
1604 .map(|v| fixture_for_shape_seeded(v.dim, v.seed))
1605 .collect()
1606}
1607
1608pub fn run_variant_sweep_multiplexed(
1613 variants: &[SweepVariant],
1614 opts: DeviceResidentInnerOptions,
1615) -> Result<
1616 (
1617 Vec<Result<MultiplexedFit, DeviceResidentArrowError>>,
1618 SweepThroughput,
1619 ),
1620 String,
1621> {
1622 let workspaces = build_sweep_workspaces(variants).map_err(|e| e.to_string())?;
1623 run_battery_sweep_multiplexed(workspaces, opts)
1624}
1625
1626pub fn run_battery_sweep_multiplexed(
1637 workspaces: Vec<DeviceResidentArrowWorkspace>,
1638 opts: DeviceResidentInnerOptions,
1639) -> Result<
1640 (
1641 Vec<Result<MultiplexedFit, DeviceResidentArrowError>>,
1642 SweepThroughput,
1643 ),
1644 String,
1645> {
1646 let fits = workspaces.len();
1647 let start = std::time::Instant::now();
1648 let results = run_resident_fits_multiplexed(workspaces, opts)?;
1649 let wall_seconds = start.elapsed().as_secs_f64();
1650 let succeeded = results.iter().filter(|r| r.is_ok()).count();
1651 let throughput = SweepThroughput {
1652 fits,
1653 succeeded,
1654 wall_seconds,
1655 fits_per_second: (fits as f64) / wall_seconds.max(1e-9),
1656 };
1657 Ok((results, throughput))
1658}
1659
1660#[must_use]
1667pub fn color_arm_variant_matrix() -> Vec<SweepVariant> {
1668 let topologies = ["euclidean", "circle", "torus", "sphere"];
1669 let mut variants = Vec::with_capacity(4 * topologies.len() * 2);
1670 for k in 1..=4u64 {
1671 for (t_idx, _topology) in topologies.iter().enumerate() {
1672 for &(d, basis_cols, basis_tag) in &[(2usize, 8usize, 0u64), (1usize, 2usize, 1u64)] {
1674 let mut dim = DeviceResidentArrowShape::color_arm();
1675 dim.d = d;
1676 dim.basis_cols = basis_cols;
1677 let seed = 0x1017_C010_0000_0000 ^ (k << 16) ^ ((t_idx as u64) << 8) ^ basis_tag;
1678 variants.push(SweepVariant { dim, seed });
1679 }
1680 }
1681 }
1682 variants
1683}
1684
1685pub fn assert_sweep_parity_vs_sequential(
1692 variants: &[SweepVariant],
1693 opts: &DeviceResidentInnerOptions,
1694 multiplexed: &[Result<MultiplexedFit, DeviceResidentArrowError>],
1695) -> Result<SweepThroughput, String> {
1696 let workspaces = build_sweep_workspaces(variants).map_err(|e| e.to_string())?;
1697 let start = std::time::Instant::now();
1698 let sequential = run_resident_fits_sequential(&workspaces, opts);
1699 let wall_seconds = start.elapsed().as_secs_f64();
1700 if sequential.len() != multiplexed.len() {
1701 return Err(format!(
1702 "sweep parity: length mismatch seq={} mux={}",
1703 sequential.len(),
1704 multiplexed.len()
1705 ));
1706 }
1707 for (idx, (seq, mux)) in sequential.iter().zip(multiplexed.iter()).enumerate() {
1708 match (seq, mux) {
1709 (Ok(s), Ok(m)) => {
1710 if s.outcome.t.as_slice() != m.outcome.t.as_slice()
1711 || s.outcome.beta.as_slice() != m.outcome.beta.as_slice()
1712 || s.outcome.objective.to_bits() != m.outcome.objective.to_bits()
1713 {
1714 return Err(format!(
1715 "sweep parity: fit {idx} multiplexed result differs from sequential"
1716 ));
1717 }
1718 }
1719 (Err(_), Err(_)) => {}
1720 _ => {
1721 return Err(format!(
1722 "sweep parity: fit {idx} success/failure disagrees seq-vs-mux"
1723 ));
1724 }
1725 }
1726 }
1727 let fits = variants.len();
1728 let succeeded = sequential.iter().filter(|r| r.is_ok()).count();
1729 Ok(SweepThroughput {
1730 fits,
1731 succeeded,
1732 wall_seconds,
1733 fits_per_second: (fits as f64) / wall_seconds.max(1e-9),
1734 })
1735}
1736
1737struct SplitMix64 {
1738 state: u64,
1739}
1740
1741impl SplitMix64 {
1742 const fn new(seed: u64) -> Self {
1743 Self { state: seed }
1744 }
1745
1746 fn next_u64(&mut self) -> u64 {
1747 gam_linalg::utils::splitmix64(&mut self.state)
1748 }
1749
1750 fn sample_signed(&mut self) -> f64 {
1751 let unit = (self.next_u64() >> 11) as f64 / ((1_u64 << 53) as f64);
1752 2.0 * unit - 1.0
1753 }
1754}
1755
1756#[cfg(test)]
1757mod tests {
1758 use super::*;
1759 use ndarray::Array2;
1760
1761 fn small_fixture(seed: u64) -> DeviceResidentArrowWorkspace {
1765 let shape = DeviceResidentArrowShape {
1773 n: 8,
1774 p: 4,
1775 basis_cols: 2,
1776 d: 2,
1777 };
1778 let mut rng = SplitMix64::new(seed);
1779 let target_x = vec![0.0_f64; shape.target_len()];
1780 let basis_values = vec![0.5_f64; shape.basis_len()];
1781 let gate_activations = vec![1.0_f64; shape.basis_len()];
1782
1783 let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
1784 let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
1785 let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
1786 for i in 0..shape.n {
1787 let h = i * shape.d * shape.d;
1788 row_hessian_slabs[h] = 5.0 + 0.1 * rng.sample_signed();
1789 row_hessian_slabs[h + 1] = 0.05 * rng.sample_signed();
1790 row_hessian_slabs[h + 2] = row_hessian_slabs[h + 1];
1791 row_hessian_slabs[h + 3] = 4.0 + 0.1 * rng.sample_signed();
1792 let b = i * shape.d * shape.p;
1793 for j in 0..shape.p {
1794 row_cross_slabs[b + j] = 0.01 * rng.sample_signed();
1795 row_cross_slabs[b + shape.p + j] = 0.01 * rng.sample_signed();
1796 }
1797 let g = i * shape.d;
1798 row_gradient_slabs[g] = rng.sample_signed();
1799 row_gradient_slabs[g + 1] = rng.sample_signed();
1800 }
1801 let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
1802 for r in 0..shape.p {
1803 border_hessian[r * shape.p + r] = 6.0 + 0.1 * rng.sample_signed();
1804 }
1805 let border_gradient: Vec<f64> = (0..shape.p).map(|_| rng.sample_signed()).collect();
1806
1807 DeviceResidentArrowWorkspace::new(
1808 shape,
1809 target_x,
1810 basis_values,
1811 gate_activations,
1812 DeviceResidentArrowSlabs {
1813 row_hessian_slabs,
1814 row_cross_slabs,
1815 row_gradient_slabs,
1816 border_hessian,
1817 border_gradient,
1818 },
1819 )
1820 .expect("small resident fixture must validate")
1821 }
1822
1823 fn dense_hz(
1826 ws: &DeviceResidentArrowWorkspace,
1827 sys: &ArrowSchurSystem,
1828 ) -> (Array2<f64>, Array1<f64>) {
1829 let shape = ws.shape;
1830 let total = shape.n * shape.d + shape.p;
1831 let mut h = Array2::<f64>::zeros((total, total));
1832 let mut g0 = Array1::<f64>::zeros(total);
1833 for i in 0..shape.n {
1834 let base = i * shape.d;
1835 for r in 0..shape.d {
1836 for c in 0..shape.d {
1837 h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
1838 }
1839 for c in 0..shape.p {
1840 let v = sys.rows[i].htbeta[[r, c]];
1841 h[[base + r, shape.n * shape.d + c]] = v;
1842 h[[shape.n * shape.d + c, base + r]] = v;
1843 }
1844 g0[base + r] = sys.rows[i].gt[r];
1845 }
1846 }
1847 for r in 0..shape.p {
1848 for c in 0..shape.p {
1849 h[[shape.n * shape.d + r, shape.n * shape.d + c]] = sys.hbb[[r, c]];
1850 }
1851 g0[shape.n * shape.d + r] = sys.gb[r];
1852 }
1853 (h, g0)
1854 }
1855
1856 #[test]
1857 fn cpu_inner_loop_reaches_quadratic_minimiser() {
1858 let ws = small_fixture(0xABCD_0001);
1859 let opts = DeviceResidentInnerOptions::default();
1860 let outcome = ws.cpu_reference_fit(&opts).expect("cpu fit");
1861 assert!(
1862 outcome.converged,
1863 "inner loop must converge on a PD quadratic"
1864 );
1865
1866 let base = ws.to_arrow_system();
1868 let (h, g0) = dense_hz(&ws, &base);
1869 let total = ws.shape.n * ws.shape.d + ws.shape.p;
1870 let mut z = Array1::<f64>::zeros(total);
1871 for r in 0..ws.shape.n * ws.shape.d {
1872 z[r] = outcome.t[r];
1873 }
1874 for c in 0..ws.shape.p {
1875 z[ws.shape.n * ws.shape.d + c] = outcome.beta[c];
1876 }
1877 let hz = h.dot(&z);
1878 let mut max_resid = 0.0_f64;
1879 for r in 0..total {
1880 max_resid = max_resid.max((hz[r] - g0[r]).abs());
1881 }
1882 assert!(
1883 max_resid < 1e-9,
1884 "inner loop fixed point must solve H z = g0; residual {max_resid:e}"
1885 );
1886 }
1887
1888 #[test]
1889 fn cpu_multiplex_matches_sequential_bit_identical() {
1890 let seeds = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
1891 let opts = DeviceResidentInnerOptions::default();
1892
1893 let seq_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
1894 let sequential: Vec<_> = seq_workspaces
1895 .iter()
1896 .map(|ws| ws.cpu_reference_fit(&opts).expect("seq cpu fit"))
1897 .collect();
1898
1899 let mux_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
1900 let multiplexed = run_resident_fits_multiplexed_with(mux_workspaces, opts, |ws, opts| {
1901 ws.cpu_reference_fit(opts)
1902 })
1903 .expect("multiplexed cpu fits");
1904
1905 assert_eq!(sequential.len(), multiplexed.len());
1906 for (seq, mux) in sequential.iter().zip(multiplexed.iter()) {
1907 let mux = mux.as_ref().expect("mux fit ok");
1908 assert_eq!(seq.t.as_slice(), mux.outcome.t.as_slice());
1911 assert_eq!(seq.beta.as_slice(), mux.outcome.beta.as_slice());
1912 assert_eq!(seq.objective.to_bits(), mux.outcome.objective.to_bits());
1913 }
1914 }
1915
1916 #[test]
1925 fn device_resident_fit_matches_cpu_reference() {
1926 let ws = small_fixture(0x5AE_1017);
1927 let opts = DeviceResidentInnerOptions::default();
1928
1929 let cpu = ws.cpu_reference_fit(&opts).expect("cpu reference fit");
1931 assert!(cpu.converged, "cpu reference must converge on PD quadratic");
1932
1933 let base = ws.to_arrow_system();
1934
1935 if ws.device_resident() {
1936 let dev = ws.device_fit(&opts).expect("device resident fit");
1938 assert_eq!(
1939 dev.execution_path,
1940 ExecutionPath::GpuResidentFull,
1941 "device_fit must report the full device-resident execution path"
1942 );
1943 assert!(dev.converged, "device resident loop must converge");
1944
1945 let t_scale = cpu.t.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
1951 let b_scale = cpu.beta.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
1952 let mut max_rel = 0.0_f64;
1953 for (a, b) in dev.t.iter().zip(cpu.t.iter()) {
1954 max_rel = max_rel.max((a - b).abs() / t_scale);
1955 }
1956 for (a, b) in dev.beta.iter().zip(cpu.beta.iter()) {
1957 max_rel = max_rel.max((a - b).abs() / b_scale);
1958 }
1959 assert!(
1960 max_rel < 1e-9,
1961 "resident device fit must match CPU reference (rel {max_rel:e})"
1962 );
1963
1964 let frame = crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
1967 &base,
1968 opts.initial_ridge_t,
1969 opts.initial_ridge_beta,
1970 )
1971 .expect("resident frame must build on CUDA host");
1972 let g_t: Vec<f64> = base
1973 .rows
1974 .iter()
1975 .flat_map(|r| r.gt.iter().copied())
1976 .collect();
1977 let g_beta: Vec<f64> = base.gb.iter().copied().collect();
1978 let resident_sol = frame
1979 .solve_gradient(&g_t, &g_beta)
1980 .expect("resident single-gradient solve");
1981 let full = crate::gpu_kernels::arrow_schur::solve_arrow_newton_step_dense_reference(
1982 &base,
1983 opts.initial_ridge_t,
1984 opts.initial_ridge_beta,
1985 )
1986 .expect("dense reference single solve");
1987 let mut max_step_rel = 0.0_f64;
1988 let step_scale = full
1989 .delta_t
1990 .iter()
1991 .chain(full.delta_beta.iter())
1992 .fold(1.0_f64, |m, &v| m.max(v.abs()));
1993 for (a, b) in resident_sol.delta_t.iter().zip(full.delta_t.iter()) {
1994 max_step_rel = max_step_rel.max((a - b).abs() / step_scale);
1995 }
1996 for (a, b) in resident_sol.delta_beta.iter().zip(full.delta_beta.iter()) {
1997 max_step_rel = max_step_rel.max((a - b).abs() / step_scale);
1998 }
1999 assert!(
2000 max_step_rel < 1e-9,
2001 "resident solve_gradient must match full dense reference step (rel {max_step_rel:e})"
2002 );
2003
2004 let reup = ws
2007 .device_reupload_fit(&opts)
2008 .expect("device re-uploading fit");
2009 assert_eq!(
2010 reup.execution_path,
2011 ExecutionPath::GpuReupload,
2012 "device_reupload_fit must report the re-uploading device path"
2013 );
2014 assert!(reup.converged, "re-uploading loop must converge");
2015 let mut max_reup_rel = 0.0_f64;
2016 for (a, b) in reup.t.iter().zip(cpu.t.iter()) {
2017 max_reup_rel = max_reup_rel.max((a - b).abs() / t_scale);
2018 }
2019 for (a, b) in reup.beta.iter().zip(cpu.beta.iter()) {
2020 max_reup_rel = max_reup_rel.max((a - b).abs() / b_scale);
2021 }
2022 assert!(
2023 max_reup_rel < 1e-9,
2024 "re-uploading GPU fit must match CPU reference (rel {max_reup_rel:e})"
2025 );
2026 } else {
2027 assert!(
2035 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
2036 "device_resident() is false on a host WITH a CUDA runtime present, \
2037 despite a floor-clearing fixture (batch=8): the resident device \
2038 buffers failed to bind — a real device fault, not a CPU-only skip."
2039 );
2040 let dev = ws.device_fit(&opts);
2042 assert!(
2043 matches!(dev, Err(DeviceResidentArrowError::Unavailable { .. })),
2044 "device_fit must report Unavailable on a CPU-only host, got {dev:?}"
2045 );
2046 let reup = ws.device_reupload_fit(&opts);
2047 assert!(
2048 matches!(reup, Err(DeviceResidentArrowError::Unavailable { .. })),
2049 "device_reupload_fit must report Unavailable on a CPU-only host, got {reup:?}"
2050 );
2051 let frame = crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
2052 &base,
2053 opts.initial_ridge_t,
2054 opts.initial_ridge_beta,
2055 );
2056 assert!(
2057 frame.is_err(),
2058 "resident frame construction must decline on a CPU-only host"
2059 );
2060 }
2061 }
2062
2063 #[test]
2081 fn resident_inner_solve_matches_production_arrow_core() {
2082 use crate::arrow_schur::{ArrowSolveOptions, solve_arrow_newton_step_core};
2083
2084 let ws = small_fixture(0x1017_F17);
2085 let opts = DeviceResidentInnerOptions::default();
2086
2087 let resident = ws.cpu_reference_fit(&opts).expect("resident cpu fit");
2089 assert!(
2090 resident.converged,
2091 "resident reference must converge on the PD quadratic"
2092 );
2093
2094 let sys = ws.to_arrow_system();
2098 let (delta_t, delta_beta, _diag) = solve_arrow_newton_step_core(
2099 &sys,
2100 opts.initial_ridge_t,
2101 opts.initial_ridge_beta,
2102 &ArrowSolveOptions::direct(),
2103 )
2104 .expect("production arrow-core solve");
2105
2106 let t_scale = resident.t.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
2111 let b_scale = resident.beta.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
2112 let mut max_rel_t = 0.0_f64;
2118 let mut worst_t: Option<(usize, f64, f64)> = None;
2119 for (i, (prod, res)) in delta_t.iter().zip(resident.t.iter()).enumerate() {
2120 let rel = (prod + res).abs() / t_scale;
2121 if rel > max_rel_t {
2122 max_rel_t = rel;
2123 worst_t = Some((i, *prod, *res));
2124 }
2125 }
2126 let mut max_rel_b = 0.0_f64;
2127 let mut worst_b: Option<(usize, f64, f64)> = None;
2128 for (i, (prod, res)) in delta_beta.iter().zip(resident.beta.iter()).enumerate() {
2129 let rel = (prod + res).abs() / b_scale;
2130 if rel > max_rel_b {
2131 max_rel_b = rel;
2132 worst_b = Some((i, *prod, *res));
2133 }
2134 }
2135 let max_rel = max_rel_t.max(max_rel_b);
2136 assert!(
2137 max_rel < 1e-9,
2138 "production arrow-core Newton step must be −(resident converged fit) on \
2139 the same quadratic; wiring the device seam into the SAE inner loop must \
2140 not change the system being solved. rel_t={max_rel_t:e} (worst {worst_t:?}: \
2141 Δt+t* must be 0), rel_beta={max_rel_b:e} (worst {worst_b:?}: Δβ+β* must \
2142 be 0). A t-only gap implicates the per-row factor / row-gradient \
2143 assembly; a β-only gap the border Schur path."
2144 );
2145 }
2146
2147 #[test]
2155 fn outer_sequence_reuses_frame_and_matches_independent() {
2156 let ws = super::color_arm_fixture().expect("color_arm fixture");
2157 let opts = DeviceResidentInnerOptions::default();
2158 let n = ws.shape.n;
2159 let d = ws.shape.d;
2160 let p = ws.shape.p;
2161
2162 let outers: Vec<(Vec<f64>, Vec<f64>)> = (0..3)
2166 .map(|s| {
2167 let g_t: Vec<f64> = (0..n * d)
2168 .map(|i| 0.01 * (((i + 3 * s) as f64) * 0.002).sin())
2169 .collect();
2170 let g_beta: Vec<f64> = (0..p)
2171 .map(|j| 0.001 * (((j + 11 * s) as f64) * 0.0009).cos())
2172 .collect();
2173 (g_t, g_beta)
2174 })
2175 .collect();
2176
2177 let independent = ws
2180 .cpu_reference_outer_sequence(&outers, &opts)
2181 .expect("cpu reference outer sequence");
2182 assert_eq!(independent.outers.len(), outers.len());
2183
2184 if ws.device_resident() {
2185 let shared = ws
2187 .device_fit_outer_sequence(&outers, &opts)
2188 .expect("device outer sequence");
2189 assert_eq!(
2190 shared.frame_builds,
2191 1,
2192 "across-outer residency must build the resident frame exactly once \
2193 for an unchanged operator (got {} builds over {} outers) — a count \
2194 > 1 means the frame was needlessly re-factored per outer",
2195 shared.frame_builds,
2196 outers.len()
2197 );
2198 for (idx, (sh, ind)) in shared
2201 .outers
2202 .iter()
2203 .zip(independent.outers.iter())
2204 .enumerate()
2205 {
2206 let scale = ind
2207 .t
2208 .iter()
2209 .chain(ind.beta.iter())
2210 .fold(1.0_f64, |m, &v| m.max(v.abs()));
2211 let mut max_rel = 0.0_f64;
2212 for (a, b) in sh.t.iter().zip(ind.t.iter()) {
2213 max_rel = max_rel.max((a - b).abs() / scale);
2214 }
2215 for (a, b) in sh.beta.iter().zip(ind.beta.iter()) {
2216 max_rel = max_rel.max((a - b).abs() / scale);
2217 }
2218 assert!(
2219 max_rel < 1e-9,
2220 "outer {idx}: across-outer-shared frame must match independent fit \
2221 (rel {max_rel:e})"
2222 );
2223 }
2224 println!(
2225 "[#1017 outer-seq color_arm] outers={} frame_builds={} (across-outer factor \
2226 amortized) parity<1e-9 OK",
2227 outers.len(),
2228 shared.frame_builds
2229 );
2230 } else {
2231 println!(
2232 "[#1017 outer-seq color_arm] no CUDA device — across-outer residency skipped; \
2233 run on the GPU node to assert frame_builds==1 + device parity"
2234 );
2235 }
2236 }
2237
2238 #[test]
2259 fn gpu_residency_per_solve_bench() {
2260 use std::time::Instant;
2261 const N_SOLVES: usize = 24;
2262 for (label, ws) in [
2263 ("color_arm", super::color_arm_fixture()),
2264 ("qwen_non_gating", super::qwen_non_gating_fixture()),
2265 ] {
2266 let ws = ws.expect("bench fixture must validate");
2267 let base = ws.to_arrow_system();
2268 let n = ws.shape.n;
2273 let d = ws.shape.d;
2274 let p = ws.shape.p;
2275 let gradients: Vec<(Vec<f64>, Vec<f64>)> = (0..N_SOLVES)
2276 .map(|s| {
2277 let g_t: Vec<f64> =
2278 (0..n * d).map(|i| ((i + s) as f64 * 0.001).sin()).collect();
2279 let g_beta: Vec<f64> = (0..p)
2280 .map(|j| ((j + 7 * s) as f64 * 0.0007).cos())
2281 .collect();
2282 (g_t, g_beta)
2283 })
2284 .collect();
2285
2286 if !ws.device_resident() {
2287 println!(
2288 "[#1017 per-solve {label}] no CUDA device — {N_SOLVES} solves skipped; \
2289 run on the GPU node for the across-iteration residency speedup"
2290 );
2291 continue;
2292 }
2293
2294 let t_build = Instant::now();
2298 let frame =
2299 crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(&base, 0.0, 0.0)
2300 .expect("resident frame must build on CUDA host");
2301 let frame_build_ms = t_build.elapsed().as_secs_f64() * 1e3;
2302
2303 let _ = frame
2310 .solve_gradient(&gradients[0].0, &gradients[0].1)
2311 .expect("resident warm-up solve");
2312 {
2313 let mut sys = ws.to_arrow_system();
2314 for (i, row) in sys.rows.iter_mut().enumerate() {
2315 for r in 0..d {
2316 row.gt[r] = gradients[0].0[i * d + r];
2317 }
2318 }
2319 for (j, gb) in sys.gb.iter_mut().enumerate() {
2320 *gb = gradients[0].1[j];
2321 }
2322 sys.refresh_row_hessian_fingerprint();
2323 let _ = crate::gpu_kernels::arrow_schur::solve_arrow_newton_step(&sys, 0.0, 0.0)
2324 .expect("reupload warm-up solve");
2325 }
2326
2327 let t_res = Instant::now();
2332 let mut resident_steps = Vec::with_capacity(N_SOLVES);
2333 for (g_t, g_beta) in &gradients {
2334 resident_steps.push(
2335 frame
2336 .solve_gradient(g_t, g_beta)
2337 .expect("resident solve_gradient"),
2338 );
2339 }
2340 let resident_ms = t_res.elapsed().as_secs_f64() * 1e3;
2341
2342 let t_reup = Instant::now();
2344 let mut reupload_steps = Vec::with_capacity(N_SOLVES);
2345 for (g_t, g_beta) in &gradients {
2346 let mut sys = ws.to_arrow_system();
2347 for (i, row) in sys.rows.iter_mut().enumerate() {
2348 for r in 0..d {
2349 row.gt[r] = g_t[i * d + r];
2350 }
2351 }
2352 for (j, gb) in sys.gb.iter_mut().enumerate() {
2353 *gb = g_beta[j];
2354 }
2355 sys.refresh_row_hessian_fingerprint();
2356 reupload_steps.push(
2357 crate::gpu_kernels::arrow_schur::solve_arrow_newton_step(&sys, 0.0, 0.0)
2358 .expect("reupload solve_arrow_newton_step"),
2359 );
2360 }
2361 let reupload_ms = t_reup.elapsed().as_secs_f64() * 1e3;
2362
2363 let mut max_rel = 0.0_f64;
2366 for (rs, us) in resident_steps.iter().zip(reupload_steps.iter()) {
2367 let scale = us
2368 .delta_t
2369 .iter()
2370 .chain(us.delta_beta.iter())
2371 .fold(1.0_f64, |m, &v| m.max(v.abs()));
2372 for (a, b) in rs.delta_t.iter().zip(us.delta_t.iter()) {
2373 max_rel = max_rel.max((a - b).abs() / scale);
2374 }
2375 for (a, b) in rs.delta_beta.iter().zip(us.delta_beta.iter()) {
2376 max_rel = max_rel.max((a - b).abs() / scale);
2377 }
2378 }
2379
2380 let resident_per_solve = resident_ms / N_SOLVES as f64;
2381 let reupload_per_solve = reupload_ms / N_SOLVES as f64;
2382 let residency_speedup = reupload_ms / resident_ms.max(1e-9);
2383 println!(
2384 "[#1017 per-solve {label}] N={N_SOLVES} frame_build={frame_build_ms:.2}ms \
2385 resident={resident_ms:.2}ms ({resident_per_solve:.3}ms/solve, \
2386 grad-upload + warm factors) reupload={reupload_ms:.2}ms \
2387 ({reupload_per_solve:.3}ms/solve, N factors + N D/B uploads) \
2388 residency_speedup={residency_speedup:.2}x parity_rel={max_rel:e}"
2389 );
2390 assert!(
2391 max_rel < 1e-9,
2392 "{label}: resident per-solve steps must match reupload (rel {max_rel:e})"
2393 );
2394
2395 let min_speedup = if label == "color_arm" { 1.5 } else { 1.0 };
2411 assert!(
2412 residency_speedup > min_speedup,
2413 "{label}: across-iteration residency must beat per-solve re-upload \
2414 (residency_speedup={residency_speedup:.3}x, required >{min_speedup}x; \
2415 resident {resident_per_solve:.3}ms/solve vs reupload \
2416 {reupload_per_solve:.3}ms/solve over N={N_SOLVES} solves) — the resident \
2417 frame either silently re-uploaded D/B or the dispatch dropped the \
2418 amortized factor path"
2419 );
2420 }
2421 }
2422
2423 fn battery_variant_matrix() -> Vec<super::SweepVariant> {
2428 let mut variants = Vec::new();
2429 for k in 1..=4u64 {
2432 for basis_cols in [4usize, 8, 12] {
2433 let mut dim = DeviceResidentArrowShape::color_arm();
2434 dim.basis_cols = basis_cols;
2435 variants.push(super::SweepVariant {
2436 dim,
2437 seed: 0x1017_0040_0000_0000 ^ (k << 8) ^ (basis_cols as u64),
2438 });
2439 }
2440 }
2441 variants
2442 }
2443
2444 #[test]
2448 fn variant_sweep_multiplex_matches_sequential() {
2449 let variants = battery_variant_matrix();
2450 let opts = DeviceResidentInnerOptions::default();
2451
2452 let workspaces =
2455 super::build_sweep_workspaces(&variants).expect("sweep workspaces must build");
2456 let multiplexed =
2457 super::run_resident_fits_multiplexed_with(workspaces, opts, |ws, opts| {
2458 ws.cpu_reference_fit(opts)
2459 })
2460 .expect("multiplexed cpu sweep");
2461
2462 let seq_workspaces =
2463 super::build_sweep_workspaces(&variants).expect("sweep workspaces must build");
2464 let sequential: Vec<_> = seq_workspaces
2465 .iter()
2466 .map(|ws| ws.cpu_reference_fit(&opts))
2467 .collect();
2468
2469 assert_eq!(multiplexed.len(), sequential.len());
2470 for (idx, (mux, seq)) in multiplexed.iter().zip(sequential.iter()).enumerate() {
2471 let mux = &mux.as_ref().unwrap().outcome;
2472 let seq = seq.as_ref().unwrap();
2473 assert_eq!(
2474 mux.t.as_slice(),
2475 seq.t.as_slice(),
2476 "variant {idx}: multiplexed t differs from sequential"
2477 );
2478 assert_eq!(
2479 mux.beta.as_slice(),
2480 seq.beta.as_slice(),
2481 "variant {idx}: multiplexed beta differs from sequential"
2482 );
2483 assert_eq!(
2484 mux.objective.to_bits(),
2485 seq.objective.to_bits(),
2486 "variant {idx}: multiplexed objective differs from sequential"
2487 );
2488 }
2489 }
2490
2491 #[test]
2497 fn gpu_multiplex_throughput_bench() {
2498 let variants = battery_variant_matrix();
2499 let opts = DeviceResidentInnerOptions::default();
2500
2501 let probe = super::build_sweep_workspaces(&variants).expect("sweep workspaces");
2502 let any_device = probe.iter().any(|w| w.device_resident());
2503 if !any_device {
2504 println!(
2505 "[#1017 mux-bench] no CUDA device — {} variants (K1..4 x 3 basis) \
2506 skipped; run on the GPU node for cross-fit throughput",
2507 variants.len()
2508 );
2509 return;
2510 }
2511
2512 let (results, mux_tp) =
2513 super::run_variant_sweep_multiplexed(&variants, opts).expect("multiplexed sweep");
2514 let seq_tp = super::assert_sweep_parity_vs_sequential(&variants, &opts, &results)
2515 .expect("sweep parity vs sequential must hold");
2516 println!(
2517 "[#1017 mux-bench] fits={} succeeded={} multiplexed={:.3}s ({:.1} fits/s) \
2518 sequential={:.3}s ({:.1} fits/s) cross-fit-speedup={:.2}x",
2519 mux_tp.fits,
2520 mux_tp.succeeded,
2521 mux_tp.wall_seconds,
2522 mux_tp.fits_per_second,
2523 seq_tp.wall_seconds,
2524 seq_tp.fits_per_second,
2525 mux_tp.fits_per_second / seq_tp.fits_per_second.max(1e-9),
2526 );
2527 assert_eq!(
2528 mux_tp.succeeded, mux_tp.fits,
2529 "all battery variants must fit successfully on device"
2530 );
2531 }
2532}