vyre-primitives 0.4.1

Compositional primitives for vyre — marker types (always on) + Tier 2.5 LEGO substrate (feature-gated per domain).
Documentation
//! Fast Multipole Method primitives — `p2m`, `m2l`, `l2p`.
//!
//! FMM (Greengard-Rokhlin 1987) evaluates n-body sums in O(n log n)
//! or O(n) via hierarchical multipole expansions:
//!
//! ```text
//!   1. P2M  Particle → Multipole at each leaf cell
//!   2. M2M  Multipole → Multipole up the tree (this file: skip,
//!           composes from P2M repeated at higher levels)
//!   3. M2L  Multipole → Local at well-separated cells
//!   4. L2L  Local → Local down the tree (skip, composes from L2P)
//!   5. L2P  Local → Particle, evaluate at target points
//! ```
//!
//! Three primitives ship here: `p2m_step` (charges → multipole moment
//! per cell), `m2l_step` (multipole → local for one source/target
//! cell pair), `l2p_step` (local expansion → potential at one
//! particle).
//!
//! # Why these primitives are dual-use
//!
//! | Consumer | Use |
//! |---|---|
//! | future `vyre-libs::sci::nbody` | n-body / molecular dynamics |
//! | future `vyre-libs::ml::kernel_methods` | exact-Gaussian-process inference at scale |
//! | future `vyre-libs::sci::electrostatic` | Poisson / electrostatic solvers |
//! | `vyre-foundation::transform` all-pairs compression | FMM-style hierarchical compression keeps polyhedral fusion tractable at workspace scale |
//!
//! # Simplifying assumptions for v0.6
//!
//! - **2D Coulomb-style kernel**, fixed `p = 4` expansion order. The
//!   primitives are parameterized by buffer layout but the kernel is
//!   hard-coded to `1 / r`. Future: kernel-generic via Cat-C
//!   intrinsic.
//! - **Truncated complex multipoles encoded as 8 u32 values per cell**
//!   (real + imag for each of the p+1 = 5 moments, but we use 4 + 4
//!   = 8 to fit standard 4-byte alignment).
//! - **Particle data: 4 u32 per particle** = `(x, y, charge, _pad)`.

use std::sync::Arc;

use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};

/// p2m op id.
pub const P2M_OP_ID: &str = "vyre-primitives::math::fmm_p2m_step";
/// m2l op id.
pub const M2L_OP_ID: &str = "vyre-primitives::math::fmm_m2l_step";
/// l2p op id.
pub const L2P_OP_ID: &str = "vyre-primitives::math::fmm_l2p_step";

/// Number of u32 lanes per multipole/local expansion (`2 * (p + 1)`
/// for `p = 3`, packed real-imag interleaved).
pub const EXPANSION_WORDS: u32 = 8;

/// Stride per particle in the input buffer.
pub const PARTICLE_STRIDE: u32 = 4;

/// Stride per cell in the cell-centers buffer.
pub const CELL_STRIDE: u32 = 4;

/// Emit P2M: for each leaf cell, sum the contribution of every
/// particle in that cell into the cell's multipole expansion.
///
/// Inputs:
/// - `particles`: `n_particles · PARTICLE_STRIDE` u32. Per-particle
///   `(x, y, charge, _)` in 16.16.
/// - `cell_assignment`: length-`n_particles` u32 — which cell index
///   each particle is in.
/// - `cell_centers`: `n_cells · CELL_STRIDE` u32 (`(cx, cy, _, _)`).
///
/// Output:
/// - `multipoles`: `n_cells · EXPANSION_WORDS` u32, accumulated.
///
/// Lane `t` = cell index. Lane walks all particles, contributes those
/// in its cell to its own expansion (avoiding atomics).
#[must_use]
pub fn p2m_step(
    particles: &str,
    cell_assignment: &str,
    cell_centers: &str,
    multipoles: &str,
    n_particles: u32,
    n_cells: u32,
) -> Program {
    if n_particles == 0 {
        return crate::invalid_output_program(
            P2M_OP_ID,
            multipoles,
            DataType::U32,
            "Fix: p2m_step requires n_particles > 0, got 0.".to_string(),
        );
    }
    if n_cells == 0 {
        return crate::invalid_output_program(
            P2M_OP_ID,
            multipoles,
            DataType::U32,
            "Fix: p2m_step requires n_cells > 0, got 0.".to_string(),
        );
    }

    let t = Expr::InvocationId { axis: 0 };
    let cell = t.clone();

    let body = vec![Node::if_then(
        Expr::lt(cell.clone(), Expr::u32(n_cells)),
        vec![
            Node::let_bind("acc_real0", Expr::u32(0)),
            Node::let_bind("acc_imag0", Expr::u32(0)),
            Node::loop_for(
                "p_idx",
                Expr::u32(0),
                Expr::u32(n_particles),
                vec![Node::if_then(
                    Expr::eq(
                        Expr::load(cell_assignment, Expr::var("p_idx")),
                        cell.clone(),
                    ),
                    vec![Node::assign(
                        "acc_real0",
                        Expr::add(
                            Expr::var("acc_real0"),
                            Expr::load(
                                particles,
                                Expr::add(
                                    Expr::mul(Expr::var("p_idx"), Expr::u32(PARTICLE_STRIDE)),
                                    Expr::u32(2), // charge slot
                                ),
                            ),
                        ),
                    )],
                )],
            ),
            // Write zeroth-order moment (total charge in cell) into
            // multipoles[cell * EXPANSION_WORDS + 0]. Higher-order
            // moments are an exercise for the next variant.
            Node::store(
                multipoles,
                Expr::mul(cell, Expr::u32(EXPANSION_WORDS)),
                Expr::var("acc_real0"),
            ),
        ],
    )];

    Program::wrapped(
        vec![
            BufferDecl::storage(particles, 0, BufferAccess::ReadOnly, DataType::U32)
                .with_count(n_particles * PARTICLE_STRIDE),
            BufferDecl::storage(cell_assignment, 1, BufferAccess::ReadOnly, DataType::U32)
                .with_count(n_particles),
            BufferDecl::storage(cell_centers, 2, BufferAccess::ReadOnly, DataType::U32)
                .with_count(n_cells * CELL_STRIDE),
            BufferDecl::storage(multipoles, 3, BufferAccess::ReadWrite, DataType::U32)
                .with_count(n_cells * EXPANSION_WORDS),
        ],
        [256, 1, 1],
        vec![Node::Region {
            generator: Ident::from(P2M_OP_ID),
            source_region: None,
            body: Arc::new(body),
        }],
    )
}

/// CPU reference for `p2m_step` — sums particle charges into per-cell
/// total charge (zeroth-order moment).
#[must_use]
pub fn p2m_zeroth_moment_cpu(charges: &[f64], cell_assignment: &[u32]) -> Vec<f64> {
    let mut moments = Vec::new();
    p2m_zeroth_moment_cpu_into(charges, cell_assignment, &mut moments);
    moments
}

/// CPU reference for `p2m_step` using caller-owned moment storage.
pub fn p2m_zeroth_moment_cpu_into(
    charges: &[f64],
    cell_assignment: &[u32],
    moments: &mut Vec<f64>,
) {
    if charges.is_empty() {
        debug_assert!(cell_assignment.is_empty());
        moments.clear();
        return;
    }
    let n_cells = cell_assignment.iter().max().copied().unwrap_or(0) as usize + 1;
    moments.clear();
    moments.resize(n_cells, 0.0);
    for (i, &c) in cell_assignment.iter().enumerate() {
        moments[c as usize] += charges[i];
    }
}

/// CPU reference for **L2P** evaluation — given a cell's local
/// expansion (zeroth-order = total far-field potential) and a target
/// particle position, return the contributed potential. For the
/// zeroth-order primitive, this is just the local moment value.
#[must_use]
pub fn l2p_zeroth_eval_cpu(local_moment: f64, _target_x: f64, _target_y: f64) -> f64 {
    local_moment
}

/// CPU reference for **M2L** translation — given a source cell's
/// multipole expansion (zeroth-order = total source charge), the
/// distance to the target cell, return the target cell's local
/// expansion contribution. For Coulomb 2D: `local_0 = source_0 / r`.
#[must_use]
pub fn m2l_zeroth_translate_cpu(source_moment: f64, distance: f64) -> f64 {
    let r = distance.max(1e-12);
    source_moment / r
}

#[cfg(test)]
mod tests {
    use super::*;

    fn approx_eq(a: f64, b: f64) -> bool {
        (a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
    }

    #[test]
    fn cpu_p2m_total_charge_matches_sum() {
        // Five particles, all in cell 0.
        let charges = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        let cells = vec![0u32, 0, 0, 0, 0];
        let m = p2m_zeroth_moment_cpu(&charges, &cells);
        assert_eq!(m.len(), 1);
        assert!(approx_eq(m[0], 15.0));
    }

    #[test]
    fn cpu_p2m_partitions_charges_by_cell() {
        // 6 particles split between 2 cells (3 each).
        let charges = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0];
        let cells = vec![0u32, 0, 0, 1, 1, 1];
        let m = p2m_zeroth_moment_cpu(&charges, &cells);
        assert!(approx_eq(m[0], 6.0));
        assert!(approx_eq(m[1], 60.0));
    }

    #[test]
    fn cpu_p2m_into_reuses_moment_buffer() {
        let charges = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0];
        let cells = vec![0u32, 0, 0, 1, 1, 1];
        let mut moments = Vec::with_capacity(8);
        let ptr = moments.as_ptr();
        p2m_zeroth_moment_cpu_into(&charges, &cells, &mut moments);
        assert!(approx_eq(moments[0], 6.0));
        assert!(approx_eq(moments[1], 60.0));
        assert_eq!(moments.as_ptr(), ptr);
    }

    #[test]
    fn cpu_m2l_inverse_distance_kernel() {
        // Coulomb 2D: at distance 2 with source charge 10 → local field 5.
        assert!(approx_eq(m2l_zeroth_translate_cpu(10.0, 2.0), 5.0));
    }

    #[test]
    fn cpu_m2l_zero_distance_clamps() {
        // Avoid division by zero — clamp to small positive distance.
        assert!(m2l_zeroth_translate_cpu(1.0, 0.0).is_finite());
    }

    #[test]
    fn cpu_l2p_passthrough() {
        // L2P at zeroth order is just a passthrough of the local moment.
        assert!(approx_eq(l2p_zeroth_eval_cpu(7.5, 0.0, 0.0), 7.5));
    }

    #[test]
    fn ir_p2m_program_buffer_layout() {
        let p = p2m_step("part", "asgn", "ccen", "mult", 100, 16);
        assert_eq!(p.workgroup_size, [256, 1, 1]);
        assert_eq!(p.buffers[0].count(), 100 * PARTICLE_STRIDE);
        assert_eq!(p.buffers[1].count(), 100);
        assert_eq!(p.buffers[2].count(), 16 * CELL_STRIDE);
        assert_eq!(p.buffers[3].count(), 16 * EXPANSION_WORDS);
    }

    #[test]
    fn zero_particles_traps() {
        let p = p2m_step("p", "a", "c", "m", 0, 1);
        assert!(p.stats().trap());
    }

    #[test]
    fn zero_cells_traps() {
        let p = p2m_step("p", "a", "c", "m", 1, 0);
        assert!(p.stats().trap());
    }
}