use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const P2M_OP_ID: &str = "vyre-primitives::math::fmm_p2m_step";
pub const M2L_OP_ID: &str = "vyre-primitives::math::fmm_m2l_step";
pub const L2P_OP_ID: &str = "vyre-primitives::math::fmm_l2p_step";
pub const EXPANSION_WORDS: u32 = 8;
pub const PARTICLE_STRIDE: u32 = 4;
pub const CELL_STRIDE: u32 = 4;
#[must_use]
pub fn p2m_step(
particles: &str,
cell_assignment: &str,
cell_centers: &str,
multipoles: &str,
n_particles: u32,
n_cells: u32,
) -> Program {
if n_particles == 0 {
return crate::invalid_output_program(
P2M_OP_ID,
multipoles,
DataType::U32,
"Fix: p2m_step requires n_particles > 0, got 0.".to_string(),
);
}
if n_cells == 0 {
return crate::invalid_output_program(
P2M_OP_ID,
multipoles,
DataType::U32,
"Fix: p2m_step requires n_cells > 0, got 0.".to_string(),
);
}
let t = Expr::InvocationId { axis: 0 };
let cell = t.clone();
let body = vec![Node::if_then(
Expr::lt(cell.clone(), Expr::u32(n_cells)),
vec![
Node::let_bind("acc_real0", Expr::u32(0)),
Node::let_bind("acc_imag0", Expr::u32(0)),
Node::loop_for(
"p_idx",
Expr::u32(0),
Expr::u32(n_particles),
vec![Node::if_then(
Expr::eq(
Expr::load(cell_assignment, Expr::var("p_idx")),
cell.clone(),
),
vec![Node::assign(
"acc_real0",
Expr::add(
Expr::var("acc_real0"),
Expr::load(
particles,
Expr::add(
Expr::mul(Expr::var("p_idx"), Expr::u32(PARTICLE_STRIDE)),
Expr::u32(2), ),
),
),
)],
)],
),
Node::store(
multipoles,
Expr::mul(cell, Expr::u32(EXPANSION_WORDS)),
Expr::var("acc_real0"),
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(particles, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_particles * PARTICLE_STRIDE),
BufferDecl::storage(cell_assignment, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_particles),
BufferDecl::storage(cell_centers, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_cells * CELL_STRIDE),
BufferDecl::storage(multipoles, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(n_cells * EXPANSION_WORDS),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(P2M_OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn p2m_zeroth_moment_cpu(charges: &[f64], cell_assignment: &[u32]) -> Vec<f64> {
let mut moments = Vec::new();
p2m_zeroth_moment_cpu_into(charges, cell_assignment, &mut moments);
moments
}
pub fn p2m_zeroth_moment_cpu_into(
charges: &[f64],
cell_assignment: &[u32],
moments: &mut Vec<f64>,
) {
if charges.is_empty() {
debug_assert!(cell_assignment.is_empty());
moments.clear();
return;
}
let n_cells = cell_assignment.iter().max().copied().unwrap_or(0) as usize + 1;
moments.clear();
moments.resize(n_cells, 0.0);
for (i, &c) in cell_assignment.iter().enumerate() {
moments[c as usize] += charges[i];
}
}
#[must_use]
pub fn l2p_zeroth_eval_cpu(local_moment: f64, _target_x: f64, _target_y: f64) -> f64 {
local_moment
}
#[must_use]
pub fn m2l_zeroth_translate_cpu(source_moment: f64, distance: f64) -> f64 {
let r = distance.max(1e-12);
source_moment / r
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_p2m_total_charge_matches_sum() {
let charges = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let cells = vec![0u32, 0, 0, 0, 0];
let m = p2m_zeroth_moment_cpu(&charges, &cells);
assert_eq!(m.len(), 1);
assert!(approx_eq(m[0], 15.0));
}
#[test]
fn cpu_p2m_partitions_charges_by_cell() {
let charges = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0];
let cells = vec![0u32, 0, 0, 1, 1, 1];
let m = p2m_zeroth_moment_cpu(&charges, &cells);
assert!(approx_eq(m[0], 6.0));
assert!(approx_eq(m[1], 60.0));
}
#[test]
fn cpu_p2m_into_reuses_moment_buffer() {
let charges = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0];
let cells = vec![0u32, 0, 0, 1, 1, 1];
let mut moments = Vec::with_capacity(8);
let ptr = moments.as_ptr();
p2m_zeroth_moment_cpu_into(&charges, &cells, &mut moments);
assert!(approx_eq(moments[0], 6.0));
assert!(approx_eq(moments[1], 60.0));
assert_eq!(moments.as_ptr(), ptr);
}
#[test]
fn cpu_m2l_inverse_distance_kernel() {
assert!(approx_eq(m2l_zeroth_translate_cpu(10.0, 2.0), 5.0));
}
#[test]
fn cpu_m2l_zero_distance_clamps() {
assert!(m2l_zeroth_translate_cpu(1.0, 0.0).is_finite());
}
#[test]
fn cpu_l2p_passthrough() {
assert!(approx_eq(l2p_zeroth_eval_cpu(7.5, 0.0, 0.0), 7.5));
}
#[test]
fn ir_p2m_program_buffer_layout() {
let p = p2m_step("part", "asgn", "ccen", "mult", 100, 16);
assert_eq!(p.workgroup_size, [256, 1, 1]);
assert_eq!(p.buffers[0].count(), 100 * PARTICLE_STRIDE);
assert_eq!(p.buffers[1].count(), 100);
assert_eq!(p.buffers[2].count(), 16 * CELL_STRIDE);
assert_eq!(p.buffers[3].count(), 16 * EXPANSION_WORDS);
}
#[test]
fn zero_particles_traps() {
let p = p2m_step("p", "a", "c", "m", 0, 1);
assert!(p.stats().trap());
}
#[test]
fn zero_cells_traps() {
let p = p2m_step("p", "a", "c", "m", 1, 0);
assert!(p.stats().trap());
}
}