rustsim 0.0.1

High-performance agent-based modelling engine - top-level orchestration crate
Documentation
//! Device-ready spatial interaction structures for columnar runtimes.
//!
//! The types in this module provide the production interaction boundary for
//! [`ColumnarRuntime`](crate::columnar_runtime::ColumnarRuntime): spatial
//! neighborhoods are built directly from authoritative SoA columns and stored
//! as contiguous row-index buffers. This mirrors FlameGPU2's spatial message
//! PBM shape while retaining a CPU fallback path on hosts without CUDA.

use crate::device_store::{DeviceSoaStore, DeviceSoaStoreF64};
use rustsim_core::types::AgentId;
use thiserror::Error;

/// Errors returned by columnar spatial interaction APIs.
#[derive(Debug, Clone, PartialEq, Error)]
pub enum DeviceSpatialError {
    /// Spatial query radius must be finite and positive.
    #[error("spatial radius must be finite and positive, got {0}")]
    InvalidRadius(f64),
    /// Spatial hash cell size must be finite and positive.
    #[error("spatial cell size must be finite and positive, got {0}")]
    InvalidCellSize(f64),
    /// A configured coordinate column does not exist in the runtime schema.
    #[error(
        "spatial coordinate column {column} does not exist; runtime has {num_columns} columns"
    )]
    MissingColumn { column: usize, num_columns: usize },
    /// A coordinate column length disagrees with the active agent count.
    #[error("spatial coordinate column {column} has {len} rows, expected {agent_count}")]
    ColumnLengthMismatch {
        /// Column index.
        column: usize,
        /// Observed row count.
        len: usize,
        /// Expected active agent count.
        agent_count: usize,
    },
    /// A row index outside the active population was requested.
    #[error("spatial row {row} is out of bounds for {agent_count} active agents")]
    RowOutOfBounds { row: usize, agent_count: usize },
    /// An agent ID was not found in the authoritative runtime state.
    #[error("agent id {0} does not exist in the spatial index")]
    MissingAgentId(AgentId),
}

/// Configuration for a 2-D columnar spatial index.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DeviceSpatialConfig2D {
    /// SoA column containing X coordinates.
    pub x_column: usize,
    /// SoA column containing Y coordinates.
    pub y_column: usize,
    /// Euclidean query radius.
    pub radius: f64,
    /// Spatial hash cell size. Values near `radius` usually minimize scans.
    pub cell_size: f64,
}

impl DeviceSpatialConfig2D {
    /// Create a 2-D spatial config using `radius` as the cell size.
    pub fn new(x_column: usize, y_column: usize, radius: f64) -> Self {
        Self {
            x_column,
            y_column,
            radius,
            cell_size: radius,
        }
    }

    /// Override the spatial hash cell size.
    pub fn with_cell_size(mut self, cell_size: f64) -> Self {
        self.cell_size = cell_size;
        self
    }

    fn validate(self) -> Result<(), DeviceSpatialError> {
        if !(self.radius.is_finite() && self.radius > 0.0) {
            return Err(DeviceSpatialError::InvalidRadius(self.radius));
        }
        if !(self.cell_size.is_finite() && self.cell_size > 0.0) {
            return Err(DeviceSpatialError::InvalidCellSize(self.cell_size));
        }
        Ok(())
    }
}

/// One occupied spatial hash cell inside a [`DeviceSpatialIndex2D`].
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DeviceSpatialCell2D {
    /// Integer cell coordinate `(x, y)`.
    pub cell: [i32; 2],
    /// Start offset into [`DeviceSpatialIndex2D::sorted_rows`].
    pub start: usize,
    /// Number of rows in this cell.
    pub len: usize,
}

/// One directed neighbor relation returned by a spatial query.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DeviceNeighbor2D {
    /// Source agent ID.
    pub source_id: AgentId,
    /// Neighbor agent ID.
    pub target_id: AgentId,
    /// Source row in the authoritative SoA store.
    pub source_row: usize,
    /// Neighbor row in the authoritative SoA store.
    pub target_row: usize,
    /// Squared Euclidean distance between source and target.
    pub distance_squared: f64,
}

/// Contiguous 2-D spatial index over authoritative SoA rows.
#[derive(Debug, Clone, PartialEq)]
pub struct DeviceSpatialIndex2D {
    config: DeviceSpatialConfig2D,
    ids: Vec<AgentId>,
    x: Vec<f64>,
    y: Vec<f64>,
    row_cells: Vec<[i32; 2]>,
    cells: Vec<DeviceSpatialCell2D>,
    sorted_rows: Vec<usize>,
}

impl DeviceSpatialIndex2D {
    /// Build an index from an authoritative `f32` SoA store.
    pub fn build_f32(
        store: &DeviceSoaStore,
        config: DeviceSpatialConfig2D,
    ) -> Result<Self, DeviceSpatialError> {
        validate_config_and_columns(config, store.num_columns())?;
        validate_store_columns(
            config,
            store.num_columns(),
            store.column(config.x_column).len(),
            store.column(config.y_column).len(),
            store.agent_count(),
        )?;
        let x: Vec<f64> = store
            .column(config.x_column)
            .iter()
            .map(|v| *v as f64)
            .collect();
        let y: Vec<f64> = store
            .column(config.y_column)
            .iter()
            .map(|v| *v as f64)
            .collect();
        Self::build_from_parts(store.ids(), x, y, config)
    }

    /// Build an index from an authoritative `f64` SoA store.
    pub fn build_f64(
        store: &DeviceSoaStoreF64,
        config: DeviceSpatialConfig2D,
    ) -> Result<Self, DeviceSpatialError> {
        validate_config_and_columns(config, store.num_columns())?;
        validate_store_columns(
            config,
            store.num_columns(),
            store.column(config.x_column).len(),
            store.column(config.y_column).len(),
            store.agent_count(),
        )?;
        Self::build_from_parts(
            store.ids(),
            store.column(config.x_column).to_vec(),
            store.column(config.y_column).to_vec(),
            config,
        )
    }

    fn build_from_parts(
        ids: &[AgentId],
        x: Vec<f64>,
        y: Vec<f64>,
        config: DeviceSpatialConfig2D,
    ) -> Result<Self, DeviceSpatialError> {
        config.validate()?;
        let mut row_cells = Vec::with_capacity(ids.len());
        for row in 0..ids.len() {
            row_cells.push(cell_of(x[row], y[row], config.cell_size));
        }

        let mut sorted_rows: Vec<usize> = (0..ids.len()).collect();
        sorted_rows.sort_unstable_by_key(|row| (row_cells[*row][0], row_cells[*row][1], *row));

        let mut cells = Vec::new();
        let mut offset = 0;
        while offset < sorted_rows.len() {
            let cell = row_cells[sorted_rows[offset]];
            let start = offset;
            while offset < sorted_rows.len() && row_cells[sorted_rows[offset]] == cell {
                offset += 1;
            }
            cells.push(DeviceSpatialCell2D {
                cell,
                start,
                len: offset - start,
            });
        }

        Ok(Self {
            config,
            ids: ids.to_vec(),
            x,
            y,
            row_cells,
            cells,
            sorted_rows,
        })
    }

    /// Config used to build this index.
    pub fn config(&self) -> DeviceSpatialConfig2D {
        self.config
    }

    /// Number of indexed agents.
    pub fn agent_count(&self) -> usize {
        self.ids.len()
    }

    /// Agent IDs in authoritative row order.
    pub fn ids(&self) -> &[AgentId] {
        &self.ids
    }

    /// Occupied cells in sorted order.
    pub fn cells(&self) -> &[DeviceSpatialCell2D] {
        &self.cells
    }

    /// Row indices sorted by spatial cell, analogous to a PBM message order.
    pub fn sorted_rows(&self) -> &[usize] {
        &self.sorted_rows
    }

    /// Cell coordinate for each authoritative row.
    pub fn row_cells(&self) -> &[[i32; 2]] {
        &self.row_cells
    }

    /// Return the authoritative row for `id`, if present.
    pub fn row_of(&self, id: AgentId) -> Option<usize> {
        self.ids.iter().position(|existing| *existing == id)
    }

    /// Query directed neighbors for an authoritative row.
    pub fn neighbors_for_row(
        &self,
        source_row: usize,
    ) -> Result<Vec<DeviceNeighbor2D>, DeviceSpatialError> {
        if source_row >= self.ids.len() {
            return Err(DeviceSpatialError::RowOutOfBounds {
                row: source_row,
                agent_count: self.ids.len(),
            });
        }

        let radius_cells = (self.config.radius / self.config.cell_size).ceil() as i32;
        let source_cell = self.row_cells[source_row];
        let radius_sq = self.config.radius * self.config.radius;
        let mut out = Vec::new();

        for dx in -radius_cells..=radius_cells {
            for dy in -radius_cells..=radius_cells {
                let cell = [source_cell[0] + dx, source_cell[1] + dy];
                let Some(bucket) = self.cell_bucket(cell) else {
                    continue;
                };
                for sorted_offset in bucket.start..bucket.start + bucket.len {
                    let target_row = self.sorted_rows[sorted_offset];
                    if target_row == source_row {
                        continue;
                    }
                    let ddx = self.x[target_row] - self.x[source_row];
                    let ddy = self.y[target_row] - self.y[source_row];
                    let distance_squared = ddx * ddx + ddy * ddy;
                    if distance_squared <= radius_sq {
                        out.push(DeviceNeighbor2D {
                            source_id: self.ids[source_row],
                            target_id: self.ids[target_row],
                            source_row,
                            target_row,
                            distance_squared,
                        });
                    }
                }
            }
        }

        out.sort_unstable_by_key(|neighbor| (neighbor.target_row, neighbor.target_id));
        Ok(out)
    }

    /// Query directed neighbors for an agent ID.
    pub fn neighbors_for_id(
        &self,
        id: AgentId,
    ) -> Result<Vec<DeviceNeighbor2D>, DeviceSpatialError> {
        let row = self
            .row_of(id)
            .ok_or(DeviceSpatialError::MissingAgentId(id))?;
        self.neighbors_for_row(row)
    }

    /// Materialize all directed neighbor relations in row order.
    pub fn directed_pairs(&self) -> Vec<DeviceNeighbor2D> {
        let mut out = Vec::new();
        for row in 0..self.ids.len() {
            if let Ok(mut neighbors) = self.neighbors_for_row(row) {
                out.append(&mut neighbors);
            }
        }
        out
    }

    fn cell_bucket(&self, cell: [i32; 2]) -> Option<DeviceSpatialCell2D> {
        self.cells
            .binary_search_by_key(&(cell[0], cell[1]), |entry| (entry.cell[0], entry.cell[1]))
            .ok()
            .map(|index| self.cells[index])
    }
}

fn validate_store_columns(
    config: DeviceSpatialConfig2D,
    num_columns: usize,
    x_len: usize,
    y_len: usize,
    agent_count: usize,
) -> Result<(), DeviceSpatialError> {
    validate_config_and_columns(config, num_columns)?;
    if x_len != agent_count {
        return Err(DeviceSpatialError::ColumnLengthMismatch {
            column: config.x_column,
            len: x_len,
            agent_count,
        });
    }
    if y_len != agent_count {
        return Err(DeviceSpatialError::ColumnLengthMismatch {
            column: config.y_column,
            len: y_len,
            agent_count,
        });
    }
    Ok(())
}

fn validate_config_and_columns(
    config: DeviceSpatialConfig2D,
    num_columns: usize,
) -> Result<(), DeviceSpatialError> {
    config.validate()?;
    for column in [config.x_column, config.y_column] {
        if column >= num_columns {
            return Err(DeviceSpatialError::MissingColumn {
                column,
                num_columns,
            });
        }
    }
    Ok(())
}

fn cell_of(x: f64, y: f64, cell_size: f64) -> [i32; 2] {
    [
        (x / cell_size).floor() as i32,
        (y / cell_size).floor() as i32,
    ]
}