use super::context::{
ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
};
const GPU_DISPATCH_THRESHOLD: usize = 100;
pub fn compute_stda_j_matrix_gpu(
ctx: &GpuContext,
q_matrix: &[f64],
gamma: &[f64],
n_atoms: usize,
n_singles: usize,
) -> Result<Vec<f64>, String> {
if n_singles < GPU_DISPATCH_THRESHOLD || !ctx.capabilities.gpu_available {
return compute_stda_j_matrix_cpu(q_matrix, gamma, n_atoms, n_singles);
}
let q_f32: Vec<f32> = q_matrix.iter().map(|&x| x as f32).collect();
let gamma_f32: Vec<f32> = gamma.iter().map(|&x| x as f32).collect();
let gamma_q_bytes = vec![0u8; n_atoms * n_singles * 4];
let dispatch = ComputeDispatchDescriptor {
label: "stda_gamma_q".to_string(),
shader_source: MATMUL_SHADER.to_string(),
entry_point: "main".to_string(),
workgroup_count: [
n_atoms.div_ceil(16) as u32,
n_singles.div_ceil(16) as u32,
1,
],
bindings: vec![
ComputeBindingDescriptor {
label: "gamma".to_string(),
kind: ComputeBindingKind::StorageReadOnly,
bytes: bytemuck_cast_f32(&gamma_f32),
},
ComputeBindingDescriptor {
label: "q".to_string(),
kind: ComputeBindingKind::StorageReadOnly,
bytes: bytemuck_cast_f32(&q_f32),
},
ComputeBindingDescriptor {
label: "result".to_string(),
kind: ComputeBindingKind::StorageReadWrite,
bytes: gamma_q_bytes,
},
ComputeBindingDescriptor {
label: "dims".to_string(),
kind: ComputeBindingKind::Uniform,
bytes: pack_dims(n_atoms as u32, n_atoms as u32, n_singles as u32),
},
],
};
let gamma_q_result = ctx
.run_compute(&dispatch)?
.outputs
.into_iter()
.last()
.unwrap_or_default();
let result_bytes = vec![0u8; n_singles * n_singles * 4];
let dispatch2 = ComputeDispatchDescriptor {
label: "stda_qt_gamma_q".to_string(),
shader_source: MATMUL_TRANSPOSE_SHADER.to_string(),
entry_point: "main".to_string(),
workgroup_count: [
n_singles.div_ceil(16) as u32,
n_singles.div_ceil(16) as u32,
1,
],
bindings: vec![
ComputeBindingDescriptor {
label: "q".to_string(),
kind: ComputeBindingKind::StorageReadOnly,
bytes: bytemuck_cast_f32(&q_f32),
},
ComputeBindingDescriptor {
label: "gamma_q".to_string(),
kind: ComputeBindingKind::StorageReadOnly,
bytes: gamma_q_result,
},
ComputeBindingDescriptor {
label: "result".to_string(),
kind: ComputeBindingKind::StorageReadWrite,
bytes: result_bytes,
},
ComputeBindingDescriptor {
label: "dims".to_string(),
kind: ComputeBindingKind::Uniform,
bytes: pack_dims(n_atoms as u32, n_singles as u32, n_singles as u32),
},
],
};
let a_off_bytes = ctx
.run_compute(&dispatch2)?
.outputs
.into_iter()
.last()
.unwrap_or_default();
let a_off_f32: &[f32] = bytemuck_cast_from_u8(&a_off_bytes);
Ok(a_off_f32.iter().map(|&x| 2.0 * x as f64).collect())
}
fn compute_stda_j_matrix_cpu(
q_matrix: &[f64],
gamma: &[f64],
n_atoms: usize,
n_singles: usize,
) -> Result<Vec<f64>, String> {
let mut result = vec![0.0; n_singles * n_singles];
for ia in 0..n_singles {
for jb in 0..=ia {
let mut val = 0.0;
for a in 0..n_atoms {
let q_a_ia = q_matrix[a * n_singles + ia];
if q_a_ia.abs() < 1e-12 {
continue;
}
for b in 0..n_atoms {
val += q_a_ia * gamma[a * n_atoms + b] * q_matrix[b * n_singles + jb];
}
}
result[ia * n_singles + jb] = 2.0 * val;
result[jb * n_singles + ia] = 2.0 * val;
}
}
Ok(result)
}
fn bytemuck_cast_f32(data: &[f32]) -> Vec<u8> {
data.iter().flat_map(|x| x.to_ne_bytes()).collect()
}
fn bytemuck_cast_from_u8(data: &[u8]) -> &[f32] {
let (prefix, result, suffix) = unsafe { data.align_to::<f32>() };
if prefix.is_empty() && suffix.is_empty() {
result
} else {
&[]
}
}
fn pack_dims(m: u32, k: u32, n: u32) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
bytes.extend_from_slice(&m.to_ne_bytes());
bytes.extend_from_slice(&k.to_ne_bytes());
bytes.extend_from_slice(&n.to_ne_bytes());
bytes.extend_from_slice(&0u32.to_ne_bytes()); bytes
}
const MATMUL_SHADER: &str = r#"
struct Dims { M: u32, K: u32, N: u32, _pad: u32 }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;
@group(0) @binding(3) var<uniform> dims: Dims;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let row = gid.x;
let col = gid.y;
if row >= dims.M || col >= dims.N { return; }
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
sum = sum + a[row * dims.K + k] * b[k * dims.N + col];
}
c[row * dims.N + col] = sum;
}
"#;
const MATMUL_TRANSPOSE_SHADER: &str = r#"
struct Dims { K: u32, M: u32, N: u32, _pad: u32 }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;
@group(0) @binding(3) var<uniform> dims: Dims;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let row = gid.x;
let col = gid.y;
if row >= dims.M || col >= dims.N { return; }
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
sum = sum + a[k * dims.M + row] * b[k * dims.N + col];
}
c[row * dims.N + col] = sum;
}
"#;