use slang_hal::backend::Backend;
use slang_hal::function::GpuFunction;
use slang_hal::Shader;
#[derive(Shader)]
#[shader(module = "stensor::geometry::svd_glam::test_svd3")]
pub struct Svd3GlamShader<B: Backend> {
pub test_svd3: GpuFunction<B>,
}
#[cfg(test)]
mod test {
use crate::tensor::GpuTensor;
use minislang::SlangCompiler;
use nalgebra::{Matrix3, Vector3};
use slang_hal::backend::WebGpu;
use slang_hal::backend::{Backend, Encoder};
use slang_hal::{BufferUsages, Shader, ShaderArgs};
#[derive(Copy, Clone, Debug)]
struct GpuSvd3 {
u: Matrix3<f32>,
s: Vector3<f32>,
vt: Matrix3<f32>,
}
#[derive(ShaderArgs)]
struct Svd3Args<'a, B: Backend> {
inputs: &'a B::Buffer<f32>,
outputs: &'a B::Buffer<f32>,
}
fn approx_eq_mat3_rel(a: &Matrix3<f32>, b: &Matrix3<f32>, eps: f32) -> bool {
let scale_a = a.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale_b = b.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = scale_a.max(scale_b).max(1e-10);
a.iter()
.zip(b.iter())
.all(|(x, y)| (x - y).abs() < eps * scale)
}
fn assert_valid_svd(m: &Matrix3<f32>, svd: &GpuSvd3, recon_rel_eps: f32, ortho_eps: f32, label: &str) {
assert!(svd.s.x >= -1e-6, "{label}: s0 negative: {}", svd.s.x);
assert!(svd.s.y >= -1e-6, "{label}: s1 negative: {}", svd.s.y);
assert!(svd.s.z >= -1e-6, "{label}: s2 negative: {}", svd.s.z);
assert!(svd.s.x + 1e-6 >= svd.s.y, "{label}: s0 < s1: {} < {}", svd.s.x, svd.s.y);
assert!(svd.s.y + 1e-6 >= svd.s.z, "{label}: s1 < s2: {} < {}", svd.s.y, svd.s.z);
let reconstructed = svd.u * Matrix3::from_diagonal(&svd.s) * svd.vt;
assert!(
approx_eq_mat3_rel(m, &reconstructed, recon_rel_eps),
"{label}: reconstruction failed\n original: {m:?}\n reconstructed: {reconstructed:?}"
);
let utu = svd.u.transpose() * svd.u;
assert!(
approx_eq_mat3_rel(&utu, &Matrix3::identity(), ortho_eps),
"{label}: U not orthogonal\n U^T*U: {utu:?}"
);
let vtv = svd.vt * svd.vt.transpose();
assert!(
approx_eq_mat3_rel(&vtv, &Matrix3::identity(), ortho_eps),
"{label}: V not orthogonal\n V*V^T: {vtv:?}"
);
}
fn assert_svd_invariants(svd: &GpuSvd3, ortho_eps: f32, label: &str) {
assert!(svd.s.x >= -1e-6, "{label}: s0 negative: {}", svd.s.x);
assert!(svd.s.y >= -1e-6, "{label}: s1 negative: {}", svd.s.y);
assert!(svd.s.z >= -1e-6, "{label}: s2 negative: {}", svd.s.z);
assert!(svd.s.x + 1e-6 >= svd.s.y, "{label}: s0 < s1: {} < {}", svd.s.x, svd.s.y);
assert!(svd.s.y + 1e-6 >= svd.s.z, "{label}: s1 < s2: {} < {}", svd.s.y, svd.s.z);
let utu = svd.u.transpose() * svd.u;
assert!(
approx_eq_mat3_rel(&utu, &Matrix3::identity(), ortho_eps),
"{label}: U not orthogonal\n U^T*U: {utu:?}"
);
let vtv = svd.vt * svd.vt.transpose();
assert!(
approx_eq_mat3_rel(&vtv, &Matrix3::identity(), ortho_eps),
"{label}: V not orthogonal\n V*V^T: {vtv:?}"
);
}
fn assert_matches_nalgebra(m: &Matrix3<f32>, svd: &GpuSvd3, rel_eps: f32, label: &str) {
let na_svd = m.svd(true, true);
let na_sv = na_svd.singular_values;
let scale = na_sv[0].max(1e-10);
assert!((svd.s.x - na_sv[0]).abs() < rel_eps * scale, "{label}: s0 {} vs {}", svd.s.x, na_sv[0]);
assert!((svd.s.y - na_sv[1]).abs() < rel_eps * scale, "{label}: s1 {} vs {}", svd.s.y, na_sv[1]);
assert!((svd.s.z - na_sv[2]).abs() < rel_eps * scale, "{label}: s2 {} vs {}", svd.s.z, na_sv[2]);
}
struct Rng(u32);
impl Rng {
fn new(seed: u32) -> Self { Self(seed) }
fn next_u32(&mut self) -> u32 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 17;
self.0 ^= self.0 << 5;
self.0
}
fn next_f32_range(&mut self, lo: f32, hi: f32) -> f32 {
let t = (self.next_u32() as f64) / (u32::MAX as f64);
lo + (hi - lo) * t as f32
}
fn next_mat3(&mut self, lo: f32, hi: f32) -> Matrix3<f32> {
let mut arr = [0.0f32; 9];
for v in &mut arr { *v = self.next_f32_range(lo, hi); }
Matrix3::from_column_slice(&arr)
}
}
fn pack_matrices(matrices: &[Matrix3<f32>]) -> Vec<f32> {
let mut data = Vec::with_capacity(matrices.len() * 9);
for m in matrices {
for &v in m.as_slice() { data.push(v); }
}
data
}
fn unpack_svd_results(data: &[f32], count: usize) -> Vec<GpuSvd3> {
let mut results = Vec::with_capacity(count);
for i in 0..count {
let base = i * 21;
let u = Matrix3::from_column_slice(&data[base..base + 9]);
let s = Vector3::new(data[base + 9], data[base + 10], data[base + 11]);
let vt = Matrix3::from_column_slice(&data[base + 12..base + 21]);
results.push(GpuSvd3 { u, s, vt });
}
results
}
async fn run_gpu_svd3(backend: &impl Backend, matrices: &[Matrix3<f32>]) -> Vec<GpuSvd3> {
let mut compiler = SlangCompiler::new(vec![]);
crate::register_shaders(&mut compiler);
let shader = super::Svd3GlamShader::from_backend(backend, &compiler).unwrap();
let input_data = pack_matrices(matrices);
let count = matrices.len();
let gpu_inputs = GpuTensor::vector(backend, &input_data, BufferUsages::STORAGE).unwrap();
let output_len = count * 21;
let gpu_outputs = GpuTensor::<f32, _>::vector(
backend,
&vec![0.0f32; output_len],
BufferUsages::STORAGE | BufferUsages::COPY_SRC,
)
.unwrap();
let mut encoder = backend.begin_encoding();
let mut pass = encoder.begin_pass("test_svd3_glam", None);
let args = Svd3Args { inputs: gpu_inputs.buffer(), outputs: gpu_outputs.buffer() };
shader
.test_svd3
.launch(backend, &mut pass, &args, [count as u32, 1, 1])
.unwrap();
drop(pass);
backend.submit(encoder).unwrap();
backend.synchronize().unwrap();
let mut output_data = vec![0.0f32; output_len];
backend.slow_read_buffer(gpu_outputs.buffer(), &mut output_data).await.unwrap();
unpack_svd_results(&output_data, count)
}
#[futures_test::test]
#[serial_test::serial]
async fn gpu_svd3_glam_webgpu() {
let backend = WebGpu::default().await.unwrap();
let identity = Matrix3::identity();
let diagonal = Matrix3::from_columns(&[
Vector3::new(3.0, 0.0, 0.0),
Vector3::new(0.0, 2.0, 0.0),
Vector3::new(0.0, 0.0, 1.0),
]);
let symmetric = Matrix3::from_columns(&[
Vector3::new(2.0, 1.0, 0.0),
Vector3::new(1.0, 3.0, 1.0),
Vector3::new(0.0, 1.0, 2.0),
]);
let general = Matrix3::from_columns(&[
Vector3::new(1.0, 4.0, 7.0),
Vector3::new(2.0, 5.0, 8.0),
Vector3::new(3.0, 6.0, 10.0),
]);
let mixed_sign = Matrix3::from_columns(&[
Vector3::new(0.5, -1.2, 3.7),
Vector3::new(2.1, 0.3, -0.8),
Vector3::new(-1.0, 4.5, 2.2),
]);
let rank2 = Matrix3::from_columns(&[
Vector3::new(1.0, 4.0, 7.0),
Vector3::new(2.0, 5.0, 8.0),
Vector3::new(3.0, 6.0, 9.0),
]);
let mut neg_det = general;
neg_det.set_column(0, &(-general.column(0)));
let matrices = vec![identity, diagonal, symmetric, general, mixed_sign, neg_det];
let labels = ["identity", "diagonal", "symmetric", "general", "mixed_sign", "neg_det"];
let results = run_gpu_svd3(&backend, &matrices).await;
assert!((results[0].s.x - 1.0).abs() < 1e-6);
assert!((results[1].s.x - 3.0).abs() < 1e-6);
assert!((results[1].s.z - 1.0).abs() < 1e-6);
for (i, (m, svd)) in matrices.iter().zip(results.iter()).enumerate() {
assert_valid_svd(m, svd, 1e-3, 1e-4, labels[i]);
assert_matches_nalgebra(m, svd, 3e-3, labels[i]);
}
let rank2_res = run_gpu_svd3(&backend, &[rank2]).await;
assert!(rank2_res[0].s.z < 0.1, "rank2 s2 should be small: {}", rank2_res[0].s.z);
assert_valid_svd(&rank2, &rank2_res[0], 0.02, 1e-4, "rank2");
for (seed, lo, hi) in [
(0xDEAD_BEEFu32, -1.0f32, 1.0f32),
(0xCAFE_1234, -100.0, 100.0),
] {
let mut rng = Rng::new(seed);
let matrices: Vec<_> = (0..500).map(|_| rng.next_mat3(lo, hi)).collect();
let results = run_gpu_svd3(&backend, &matrices).await;
for (i, (m, svd)) in matrices.iter().zip(results.iter()).enumerate() {
let label = format!("random_{seed:x}_{i}");
assert_svd_invariants(svd, 1e-3, &label);
let na_sv = m.svd(false, false).singular_values;
if na_sv[2] > 1e-2 * na_sv[0] {
let recon = svd.u * Matrix3::from_diagonal(&svd.s) * svd.vt;
assert!(
approx_eq_mat3_rel(m, &recon, 1e-2),
"{label}: reconstruction failed (cond {})\n original: {m:?}\n recon: {recon:?}",
na_sv[0] / na_sv[2]
);
}
}
}
}
}