use crate::device_store::{DeviceSoaStore, DeviceSoaStoreF64};
use rustsim_core::types::AgentId;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Error)]
pub enum DeviceSpatialError {
#[error("spatial radius must be finite and positive, got {0}")]
InvalidRadius(f64),
#[error("spatial cell size must be finite and positive, got {0}")]
InvalidCellSize(f64),
#[error(
"spatial coordinate column {column} does not exist; runtime has {num_columns} columns"
)]
MissingColumn { column: usize, num_columns: usize },
#[error("spatial coordinate column {column} has {len} rows, expected {agent_count}")]
ColumnLengthMismatch {
column: usize,
len: usize,
agent_count: usize,
},
#[error("spatial row {row} is out of bounds for {agent_count} active agents")]
RowOutOfBounds { row: usize, agent_count: usize },
#[error("agent id {0} does not exist in the spatial index")]
MissingAgentId(AgentId),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DeviceSpatialConfig2D {
pub x_column: usize,
pub y_column: usize,
pub radius: f64,
pub cell_size: f64,
}
impl DeviceSpatialConfig2D {
pub fn new(x_column: usize, y_column: usize, radius: f64) -> Self {
Self {
x_column,
y_column,
radius,
cell_size: radius,
}
}
pub fn with_cell_size(mut self, cell_size: f64) -> Self {
self.cell_size = cell_size;
self
}
fn validate(self) -> Result<(), DeviceSpatialError> {
if !(self.radius.is_finite() && self.radius > 0.0) {
return Err(DeviceSpatialError::InvalidRadius(self.radius));
}
if !(self.cell_size.is_finite() && self.cell_size > 0.0) {
return Err(DeviceSpatialError::InvalidCellSize(self.cell_size));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DeviceSpatialCell2D {
pub cell: [i32; 2],
pub start: usize,
pub len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DeviceNeighbor2D {
pub source_id: AgentId,
pub target_id: AgentId,
pub source_row: usize,
pub target_row: usize,
pub distance_squared: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DeviceSpatialIndex2D {
config: DeviceSpatialConfig2D,
ids: Vec<AgentId>,
x: Vec<f64>,
y: Vec<f64>,
row_cells: Vec<[i32; 2]>,
cells: Vec<DeviceSpatialCell2D>,
sorted_rows: Vec<usize>,
}
impl DeviceSpatialIndex2D {
pub fn build_f32(
store: &DeviceSoaStore,
config: DeviceSpatialConfig2D,
) -> Result<Self, DeviceSpatialError> {
validate_config_and_columns(config, store.num_columns())?;
validate_store_columns(
config,
store.num_columns(),
store.column(config.x_column).len(),
store.column(config.y_column).len(),
store.agent_count(),
)?;
let x: Vec<f64> = store
.column(config.x_column)
.iter()
.map(|v| *v as f64)
.collect();
let y: Vec<f64> = store
.column(config.y_column)
.iter()
.map(|v| *v as f64)
.collect();
Self::build_from_parts(store.ids(), x, y, config)
}
pub fn build_f64(
store: &DeviceSoaStoreF64,
config: DeviceSpatialConfig2D,
) -> Result<Self, DeviceSpatialError> {
validate_config_and_columns(config, store.num_columns())?;
validate_store_columns(
config,
store.num_columns(),
store.column(config.x_column).len(),
store.column(config.y_column).len(),
store.agent_count(),
)?;
Self::build_from_parts(
store.ids(),
store.column(config.x_column).to_vec(),
store.column(config.y_column).to_vec(),
config,
)
}
fn build_from_parts(
ids: &[AgentId],
x: Vec<f64>,
y: Vec<f64>,
config: DeviceSpatialConfig2D,
) -> Result<Self, DeviceSpatialError> {
config.validate()?;
let mut row_cells = Vec::with_capacity(ids.len());
for row in 0..ids.len() {
row_cells.push(cell_of(x[row], y[row], config.cell_size));
}
let mut sorted_rows: Vec<usize> = (0..ids.len()).collect();
sorted_rows.sort_unstable_by_key(|row| (row_cells[*row][0], row_cells[*row][1], *row));
let mut cells = Vec::new();
let mut offset = 0;
while offset < sorted_rows.len() {
let cell = row_cells[sorted_rows[offset]];
let start = offset;
while offset < sorted_rows.len() && row_cells[sorted_rows[offset]] == cell {
offset += 1;
}
cells.push(DeviceSpatialCell2D {
cell,
start,
len: offset - start,
});
}
Ok(Self {
config,
ids: ids.to_vec(),
x,
y,
row_cells,
cells,
sorted_rows,
})
}
pub fn config(&self) -> DeviceSpatialConfig2D {
self.config
}
pub fn agent_count(&self) -> usize {
self.ids.len()
}
pub fn ids(&self) -> &[AgentId] {
&self.ids
}
pub fn cells(&self) -> &[DeviceSpatialCell2D] {
&self.cells
}
pub fn sorted_rows(&self) -> &[usize] {
&self.sorted_rows
}
pub fn row_cells(&self) -> &[[i32; 2]] {
&self.row_cells
}
pub fn row_of(&self, id: AgentId) -> Option<usize> {
self.ids.iter().position(|existing| *existing == id)
}
pub fn neighbors_for_row(
&self,
source_row: usize,
) -> Result<Vec<DeviceNeighbor2D>, DeviceSpatialError> {
if source_row >= self.ids.len() {
return Err(DeviceSpatialError::RowOutOfBounds {
row: source_row,
agent_count: self.ids.len(),
});
}
let radius_cells = (self.config.radius / self.config.cell_size).ceil() as i32;
let source_cell = self.row_cells[source_row];
let radius_sq = self.config.radius * self.config.radius;
let mut out = Vec::new();
for dx in -radius_cells..=radius_cells {
for dy in -radius_cells..=radius_cells {
let cell = [source_cell[0] + dx, source_cell[1] + dy];
let Some(bucket) = self.cell_bucket(cell) else {
continue;
};
for sorted_offset in bucket.start..bucket.start + bucket.len {
let target_row = self.sorted_rows[sorted_offset];
if target_row == source_row {
continue;
}
let ddx = self.x[target_row] - self.x[source_row];
let ddy = self.y[target_row] - self.y[source_row];
let distance_squared = ddx * ddx + ddy * ddy;
if distance_squared <= radius_sq {
out.push(DeviceNeighbor2D {
source_id: self.ids[source_row],
target_id: self.ids[target_row],
source_row,
target_row,
distance_squared,
});
}
}
}
}
out.sort_unstable_by_key(|neighbor| (neighbor.target_row, neighbor.target_id));
Ok(out)
}
pub fn neighbors_for_id(
&self,
id: AgentId,
) -> Result<Vec<DeviceNeighbor2D>, DeviceSpatialError> {
let row = self
.row_of(id)
.ok_or(DeviceSpatialError::MissingAgentId(id))?;
self.neighbors_for_row(row)
}
pub fn directed_pairs(&self) -> Vec<DeviceNeighbor2D> {
let mut out = Vec::new();
for row in 0..self.ids.len() {
if let Ok(mut neighbors) = self.neighbors_for_row(row) {
out.append(&mut neighbors);
}
}
out
}
fn cell_bucket(&self, cell: [i32; 2]) -> Option<DeviceSpatialCell2D> {
self.cells
.binary_search_by_key(&(cell[0], cell[1]), |entry| (entry.cell[0], entry.cell[1]))
.ok()
.map(|index| self.cells[index])
}
}
fn validate_store_columns(
config: DeviceSpatialConfig2D,
num_columns: usize,
x_len: usize,
y_len: usize,
agent_count: usize,
) -> Result<(), DeviceSpatialError> {
validate_config_and_columns(config, num_columns)?;
if x_len != agent_count {
return Err(DeviceSpatialError::ColumnLengthMismatch {
column: config.x_column,
len: x_len,
agent_count,
});
}
if y_len != agent_count {
return Err(DeviceSpatialError::ColumnLengthMismatch {
column: config.y_column,
len: y_len,
agent_count,
});
}
Ok(())
}
fn validate_config_and_columns(
config: DeviceSpatialConfig2D,
num_columns: usize,
) -> Result<(), DeviceSpatialError> {
config.validate()?;
for column in [config.x_column, config.y_column] {
if column >= num_columns {
return Err(DeviceSpatialError::MissingColumn {
column,
num_columns,
});
}
}
Ok(())
}
fn cell_of(x: f64, y: f64, cell_size: f64) -> [i32; 2] {
[
(x / cell_size).floor() as i32,
(y / cell_size).floor() as i32,
]
}