use crate::families::cubic_cell_kernel::{
DenestedCubicCell, evaluate_cell_derivative_moments_uncached,
};
use crate::gpu::cubic_cell::branch::classify_cell_for_gpu;
use crate::gpu::cubic_cell::{
CubicCellDerivativeMomentHostView, CubicCellMomentStatus, GpuCellBranchTag,
};
pub(crate) struct HostMomentBatch {
pub moments: Vec<f64>,
pub status: Vec<u8>,
pub stride: usize,
}
pub(crate) fn build_host_moments(
view: &CubicCellDerivativeMomentHostView<'_>,
) -> Result<HostMomentBatch, String> {
let n_cells = view.cells.len();
let stride = view.max_degree + 1;
let mut moments = vec![0.0_f64; n_cells.saturating_mul(stride)];
let mut status = vec![CubicCellMomentStatus::Ok as u8; n_cells];
for (i, &gpu_cell) in view.cells.iter().enumerate() {
let row = &mut moments[i * stride..(i + 1) * stride];
let host_tag = match classify_cell_for_gpu(gpu_cell) {
Ok(tag) => tag,
Err(code) => {
status[i] = code as u8;
continue;
}
};
let caller_tag = view.branches[i];
if host_tag != caller_tag {
status[i] = CubicCellMomentStatus::InvalidInterval as u8;
continue;
}
let cpu_cell = DenestedCubicCell {
left: gpu_cell.left,
right: gpu_cell.right,
c0: gpu_cell.c0,
c1: gpu_cell.c1,
c2: gpu_cell.c2,
c3: gpu_cell.c3,
};
match evaluate_cell_derivative_moments_uncached(cpu_cell, view.max_degree) {
Ok(state) => {
let copy_len = state.moments.len().min(stride);
row[..copy_len].copy_from_slice(&state.moments[..copy_len]);
if row.iter().any(|x| !x.is_finite()) {
for slot in row.iter_mut() {
*slot = 0.0;
}
status[i] = CubicCellMomentStatus::NonFiniteEvaluation as u8;
}
}
Err(_) => {
status[i] = match host_tag {
GpuCellBranchTag::AffineTail => {
CubicCellMomentStatus::NonAffineInfiniteInterval as u8
}
_ => CubicCellMomentStatus::InvalidInterval as u8,
};
}
}
}
Ok(HostMomentBatch {
moments,
status,
stride,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::families::cubic_cell_kernel::{
DenestedCubicCell, evaluate_cell_derivative_moments_uncached,
};
use crate::gpu::cubic_cell::{
CubicCellDerivativeMomentHostView, CubicCellMomentResidency, GpuCellBranchTag,
GpuDenestedCubicCell,
};
fn gpu_from_cpu(cpu: DenestedCubicCell) -> GpuDenestedCubicCell {
GpuDenestedCubicCell {
left: cpu.left,
right: cpu.right,
c0: cpu.c0,
c1: cpu.c1,
c2: cpu.c2,
c3: cpu.c3,
}
}
fn assert_row_matches_cpu(
row: &[f64],
cell: DenestedCubicCell,
max_degree: usize,
ulp_rel: f64,
) {
let state =
evaluate_cell_derivative_moments_uncached(cell, max_degree).expect("cpu reference");
assert_eq!(row.len(), max_degree + 1);
for (k, (&got, &want)) in row.iter().zip(state.moments.iter()).enumerate() {
let denom = want.abs().max(1.0);
let rel = (got - want).abs() / denom;
assert!(
rel <= ulp_rel,
"moment k={k} got={got:.17e} want={want:.17e} rel={rel:.3e} tol={ulp_rel:.3e}"
);
}
}
#[test]
fn host_substrate_matches_cpu_for_quartic_finite_cell() {
let cpu = DenestedCubicCell {
left: -1.25,
right: -0.2,
c0: -0.35,
c1: 0.85,
c2: 0.4,
c3: 0.0,
};
let gpu = gpu_from_cpu(cpu);
let branches = vec![GpuCellBranchTag::NonAffineFinite];
let view = CubicCellDerivativeMomentHostView {
cells: std::slice::from_ref(&gpu),
branches: &branches,
max_degree: 9,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.status[0], CubicCellMomentStatus::Ok as u8);
assert_row_matches_cpu(&out.moments[..out.stride], cpu, 9, 0.0);
}
#[test]
fn host_substrate_matches_cpu_for_sextic_finite_cell_at_d21() {
let cpu = DenestedCubicCell {
left: -0.5,
right: 1.7,
c0: 0.2,
c1: -0.6,
c2: 0.25,
c3: 0.18,
};
let gpu = gpu_from_cpu(cpu);
let branches = vec![GpuCellBranchTag::NonAffineFinite];
let view = CubicCellDerivativeMomentHostView {
cells: std::slice::from_ref(&gpu),
branches: &branches,
max_degree: 21,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.status[0], CubicCellMomentStatus::Ok as u8);
assert_row_matches_cpu(&out.moments[..out.stride], cpu, 21, 0.0);
}
#[test]
fn host_substrate_matches_cpu_for_affine_tail_cell() {
let cpu = DenestedCubicCell {
left: f64::NEG_INFINITY,
right: -0.7,
c0: 0.1,
c1: 0.5,
c2: 0.0,
c3: 0.0,
};
let gpu = gpu_from_cpu(cpu);
let branches = vec![GpuCellBranchTag::AffineTail];
let view = CubicCellDerivativeMomentHostView {
cells: std::slice::from_ref(&gpu),
branches: &branches,
max_degree: 15,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.status[0], CubicCellMomentStatus::Ok as u8);
assert_row_matches_cpu(&out.moments[..out.stride], cpu, 15, 0.0);
}
#[test]
fn host_substrate_matches_cpu_for_whole_line_affine() {
let cpu = DenestedCubicCell {
left: f64::NEG_INFINITY,
right: f64::INFINITY,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
};
let gpu = gpu_from_cpu(cpu);
let branches = vec![GpuCellBranchTag::AffineTail];
let view = CubicCellDerivativeMomentHostView {
cells: std::slice::from_ref(&gpu),
branches: &branches,
max_degree: 9,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.status[0], CubicCellMomentStatus::Ok as u8);
assert_row_matches_cpu(&out.moments[..out.stride], cpu, 9, 0.0);
}
#[test]
fn host_substrate_zeros_invalid_cell_and_records_status() {
let gpu = GpuDenestedCubicCell {
left: 1.0,
right: -1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
};
let branches = vec![GpuCellBranchTag::NonAffineFinite];
let view = CubicCellDerivativeMomentHostView {
cells: std::slice::from_ref(&gpu),
branches: &branches,
max_degree: 9,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.status[0], CubicCellMomentStatus::InvalidInterval as u8);
assert!(out.moments.iter().all(|&x| x == 0.0));
}
#[test]
fn cubic_cell_substrate_parity_against_cpu_evaluator() {
let cells_cpu = [
DenestedCubicCell {
left: f64::NEG_INFINITY,
right: -1.5,
c0: 0.05,
c1: 0.4,
c2: 0.0,
c3: 0.0,
},
DenestedCubicCell {
left: -1.5,
right: -0.3,
c0: -0.1,
c1: 0.2,
c2: 0.0,
c3: 0.0,
},
DenestedCubicCell {
left: -0.3,
right: 0.4,
c0: 0.0,
c1: 0.5,
c2: 0.3,
c3: 0.0,
},
DenestedCubicCell {
left: 0.4,
right: 1.1,
c0: 0.15,
c1: -0.25,
c2: 0.1,
c3: 0.18,
},
DenestedCubicCell {
left: 1.1,
right: 2.0,
c0: -0.2,
c1: 0.6,
c2: 0.0,
c3: 0.0,
},
DenestedCubicCell {
left: 2.0,
right: f64::INFINITY,
c0: 0.3,
c1: -0.4,
c2: 0.0,
c3: 0.0,
},
];
let cells_gpu: Vec<GpuDenestedCubicCell> =
cells_cpu.iter().copied().map(gpu_from_cpu).collect();
let branches: Vec<GpuCellBranchTag> = cells_cpu
.iter()
.map(|c| {
if !c.left.is_finite() || !c.right.is_finite() {
GpuCellBranchTag::AffineTail
} else if c.c2 == 0.0 && c.c3 == 0.0 {
GpuCellBranchTag::Affine
} else {
GpuCellBranchTag::NonAffineFinite
}
})
.collect();
for &max_degree in &[9usize, 15, 21] {
let view = CubicCellDerivativeMomentHostView {
cells: &cells_gpu,
branches: &branches,
max_degree,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.stride, max_degree + 1);
assert_eq!(out.status.len(), cells_cpu.len());
for (i, &cell) in cells_cpu.iter().enumerate() {
assert_eq!(
out.status[i],
CubicCellMomentStatus::Ok as u8,
"cell {i} status was {} at degree={max_degree}",
out.status[i]
);
let row = &out.moments[i * out.stride..(i + 1) * out.stride];
assert_row_matches_cpu(row, cell, max_degree, 0.0);
}
}
}
#[test]
fn host_substrate_flags_caller_branch_mismatch() {
let gpu = GpuDenestedCubicCell {
left: -1.0,
right: 1.0,
c0: 0.2,
c1: 0.3,
c2: 0.4,
c3: 0.0,
};
let branches = vec![GpuCellBranchTag::AffineTail];
let view = CubicCellDerivativeMomentHostView {
cells: std::slice::from_ref(&gpu),
branches: &branches,
max_degree: 9,
residency: CubicCellMomentResidency::Host,
};
let out = build_host_moments(&view).expect("host substrate");
assert_eq!(out.status[0], CubicCellMomentStatus::InvalidInterval as u8);
assert!(out.moments.iter().all(|&x| x == 0.0));
}
}