use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Error)]
pub enum MessageConfigError {
#[error("radius must be positive")]
InvalidRadius,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum MessagePhaseError {
#[error("cannot output messages after finalize")]
AlreadyFinalized,
#[error("must finalize messages before reading")]
NotFinalized,
}
pub trait Message: Clone + Send + Sync + 'static {}
impl<T: Clone + Send + Sync + 'static> Message for T {}
#[derive(Debug, Clone)]
pub struct BruteForceMessages<M> {
buffer: Vec<M>,
finalized: bool,
}
impl<M: Clone> BruteForceMessages<M> {
pub fn new() -> Self {
Self {
buffer: Vec::new(),
finalized: false,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
finalized: false,
}
}
pub fn output(&mut self, message: M) {
self.try_output(message)
.expect("cannot output after finalize");
}
pub fn try_output(&mut self, message: M) -> Result<(), MessagePhaseError> {
if self.finalized {
return Err(MessagePhaseError::AlreadyFinalized);
}
self.buffer.push(message);
Ok(())
}
pub fn finalize(&mut self) {
self.finalized = true;
}
pub fn read_all(&self) -> &[M] {
self.try_read_all()
.expect("must finalize messages before reading")
}
pub fn try_read_all(&self) -> Result<&[M], MessagePhaseError> {
if !self.finalized {
return Err(MessagePhaseError::NotFinalized);
}
Ok(&self.buffer)
}
pub fn is_finalized(&self) -> bool {
self.finalized
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn clear(&mut self) {
self.buffer.clear();
self.finalized = false;
}
}
impl<M: Clone> Default for BruteForceMessages<M> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SpatialMessages2D<M> {
messages: Vec<M>,
positions: Vec<(f32, f32)>,
radius: f32,
bin_map: HashMap<(i32, i32), (usize, usize)>,
sorted_indices: Vec<usize>,
finalized: bool,
}
impl<M: Clone> SpatialMessages2D<M> {
pub fn new(radius: f32) -> Result<Self, MessageConfigError> {
if radius <= 0.0 {
return Err(MessageConfigError::InvalidRadius);
}
Ok(Self {
messages: Vec::new(),
positions: Vec::new(),
radius,
bin_map: HashMap::new(),
sorted_indices: Vec::new(),
finalized: false,
})
}
pub fn output(&mut self, message: M, x: f32, y: f32) {
self.try_output(message, x, y)
.expect("cannot output after finalize");
}
pub fn try_output(&mut self, message: M, x: f32, y: f32) -> Result<(), MessagePhaseError> {
if self.finalized {
return Err(MessagePhaseError::AlreadyFinalized);
}
self.messages.push(message);
self.positions.push((x, y));
Ok(())
}
pub fn finalize(&mut self) {
let n = self.messages.len();
let inv_radius = 1.0 / self.radius;
let mut bin_assignments: Vec<(i32, i32)> = Vec::with_capacity(n);
for &(x, y) in &self.positions {
let bx = (x * inv_radius).floor() as i32;
let by = (y * inv_radius).floor() as i32;
bin_assignments.push((bx, by));
}
self.sorted_indices.clear();
self.sorted_indices.extend(0..n);
self.sorted_indices
.sort_unstable_by(|&a, &b| bin_assignments[a].cmp(&bin_assignments[b]));
self.bin_map.clear();
if !self.sorted_indices.is_empty() {
let mut current_bin = bin_assignments[self.sorted_indices[0]];
let mut start = 0;
for i in 1..n {
let bin = bin_assignments[self.sorted_indices[i]];
if bin != current_bin {
self.bin_map.insert(current_bin, (start, i - start));
current_bin = bin;
start = i;
}
}
self.bin_map.insert(current_bin, (start, n - start));
}
self.finalized = true;
}
pub fn read_nearby(&self, x: f32, y: f32, radius: f32) -> SpatialIter2D<'_, M> {
self.try_read_nearby(x, y, radius)
.expect("must finalize messages before reading")
}
pub fn try_read_nearby(
&self,
x: f32,
y: f32,
radius: f32,
) -> Result<SpatialIter2D<'_, M>, MessagePhaseError> {
if !self.finalized {
return Err(MessagePhaseError::NotFinalized);
}
let inv = 1.0 / self.radius;
let center_bx = (x * inv).floor() as i32;
let center_by = (y * inv).floor() as i32;
let grid_r = ((radius / self.radius).ceil() as i32).max(1);
let first_dx = -grid_r;
let first_dy = -grid_r;
let bx = center_bx + first_dx;
let by = center_by + first_dy;
let (bin_start, bin_count) = self.bin_map.get(&(bx, by)).copied().unwrap_or((0, 0));
Ok(SpatialIter2D {
messages: &self.messages,
positions: &self.positions,
sorted_indices: &self.sorted_indices,
bin_map: &self.bin_map,
query_x: x,
query_y: y,
radius_sq: radius * radius,
center_bx,
center_by,
grid_r,
cur_dx: first_dx,
cur_dy: first_dy,
bin_start,
bin_offset: 0,
bin_count,
})
}
pub fn is_finalized(&self) -> bool {
self.finalized
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn clear(&mut self) {
self.messages.clear();
self.positions.clear();
self.bin_map.clear();
self.sorted_indices.clear();
self.finalized = false;
}
}
pub struct SpatialIter2D<'a, M> {
messages: &'a [M],
positions: &'a [(f32, f32)],
sorted_indices: &'a [usize],
bin_map: &'a HashMap<(i32, i32), (usize, usize)>,
query_x: f32,
query_y: f32,
radius_sq: f32,
center_bx: i32,
center_by: i32,
grid_r: i32,
cur_dx: i32,
cur_dy: i32,
bin_start: usize,
bin_offset: usize,
bin_count: usize,
}
impl<'a, M> Iterator for SpatialIter2D<'a, M> {
type Item = (&'a M, f32);
fn next(&mut self) -> Option<Self::Item> {
loop {
while self.bin_offset < self.bin_count {
let idx = self.sorted_indices[self.bin_start + self.bin_offset];
self.bin_offset += 1;
let (px, py) = self.positions[idx];
let dx = self.query_x - px;
let dy = self.query_y - py;
let dist_sq = dx * dx + dy * dy;
if dist_sq <= self.radius_sq {
return Some((&self.messages[idx], dist_sq));
}
}
loop {
self.cur_dx += 1;
if self.cur_dx > self.grid_r {
self.cur_dx = -self.grid_r;
self.cur_dy += 1;
if self.cur_dy > self.grid_r {
return None; }
}
let bx = self.center_bx + self.cur_dx;
let by = self.center_by + self.cur_dy;
if let Some(&(start, count)) = self.bin_map.get(&(bx, by)) {
self.bin_start = start;
self.bin_offset = 0;
self.bin_count = count;
break; }
}
}
}
}
#[derive(Debug, Clone)]
pub struct SpatialMessages3D<M> {
messages: Vec<M>,
positions: Vec<(f32, f32, f32)>,
radius: f32,
bin_map: HashMap<(i32, i32, i32), (usize, usize)>,
sorted_indices: Vec<usize>,
finalized: bool,
}
impl<M: Clone> SpatialMessages3D<M> {
pub fn new(radius: f32) -> Result<Self, MessageConfigError> {
if radius <= 0.0 {
return Err(MessageConfigError::InvalidRadius);
}
Ok(Self {
messages: Vec::new(),
positions: Vec::new(),
radius,
bin_map: HashMap::new(),
sorted_indices: Vec::new(),
finalized: false,
})
}
pub fn output(&mut self, message: M, x: f32, y: f32, z: f32) {
self.try_output(message, x, y, z)
.expect("cannot output after finalize");
}
pub fn try_output(
&mut self,
message: M,
x: f32,
y: f32,
z: f32,
) -> Result<(), MessagePhaseError> {
if self.finalized {
return Err(MessagePhaseError::AlreadyFinalized);
}
self.messages.push(message);
self.positions.push((x, y, z));
Ok(())
}
pub fn finalize(&mut self) {
let n = self.messages.len();
let inv_radius = 1.0 / self.radius;
let mut bin_assignments: Vec<(i32, i32, i32)> = Vec::with_capacity(n);
for &(x, y, z) in &self.positions {
let bx = (x * inv_radius).floor() as i32;
let by = (y * inv_radius).floor() as i32;
let bz = (z * inv_radius).floor() as i32;
bin_assignments.push((bx, by, bz));
}
self.sorted_indices.clear();
self.sorted_indices.extend(0..n);
self.sorted_indices
.sort_unstable_by(|&a, &b| bin_assignments[a].cmp(&bin_assignments[b]));
self.bin_map.clear();
if !self.sorted_indices.is_empty() {
let mut current_bin = bin_assignments[self.sorted_indices[0]];
let mut start = 0;
for i in 1..n {
let bin = bin_assignments[self.sorted_indices[i]];
if bin != current_bin {
self.bin_map.insert(current_bin, (start, i - start));
current_bin = bin;
start = i;
}
}
self.bin_map.insert(current_bin, (start, n - start));
}
self.finalized = true;
}
pub fn read_nearby(&self, x: f32, y: f32, z: f32, radius: f32) -> SpatialIter3D<'_, M> {
self.try_read_nearby(x, y, z, radius)
.expect("must finalize messages before reading")
}
pub fn try_read_nearby(
&self,
x: f32,
y: f32,
z: f32,
radius: f32,
) -> Result<SpatialIter3D<'_, M>, MessagePhaseError> {
if !self.finalized {
return Err(MessagePhaseError::NotFinalized);
}
let inv = 1.0 / self.radius;
let center_bx = (x * inv).floor() as i32;
let center_by = (y * inv).floor() as i32;
let center_bz = (z * inv).floor() as i32;
let grid_r = ((radius / self.radius).ceil() as i32).max(1);
let first_dx = -grid_r;
let first_dy = -grid_r;
let first_dz = -grid_r;
let bx = center_bx + first_dx;
let by = center_by + first_dy;
let bz = center_bz + first_dz;
let (bin_start, bin_count) = self.bin_map.get(&(bx, by, bz)).copied().unwrap_or((0, 0));
Ok(SpatialIter3D {
messages: &self.messages,
positions: &self.positions,
sorted_indices: &self.sorted_indices,
bin_map: &self.bin_map,
query_x: x,
query_y: y,
query_z: z,
radius_sq: radius * radius,
center_bx,
center_by,
center_bz,
grid_r,
cur_dx: first_dx,
cur_dy: first_dy,
cur_dz: first_dz,
bin_start,
bin_offset: 0,
bin_count,
})
}
pub fn is_finalized(&self) -> bool {
self.finalized
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn clear(&mut self) {
self.messages.clear();
self.positions.clear();
self.bin_map.clear();
self.sorted_indices.clear();
self.finalized = false;
}
}
pub struct SpatialIter3D<'a, M> {
messages: &'a [M],
positions: &'a [(f32, f32, f32)],
sorted_indices: &'a [usize],
bin_map: &'a HashMap<(i32, i32, i32), (usize, usize)>,
query_x: f32,
query_y: f32,
query_z: f32,
radius_sq: f32,
center_bx: i32,
center_by: i32,
center_bz: i32,
grid_r: i32,
cur_dx: i32,
cur_dy: i32,
cur_dz: i32,
bin_start: usize,
bin_offset: usize,
bin_count: usize,
}
impl<'a, M> Iterator for SpatialIter3D<'a, M> {
type Item = (&'a M, f32);
fn next(&mut self) -> Option<Self::Item> {
loop {
while self.bin_offset < self.bin_count {
let idx = self.sorted_indices[self.bin_start + self.bin_offset];
self.bin_offset += 1;
let (px, py, pz) = self.positions[idx];
let dx = self.query_x - px;
let dy = self.query_y - py;
let dz = self.query_z - pz;
let dist_sq = dx * dx + dy * dy + dz * dz;
if dist_sq <= self.radius_sq {
return Some((&self.messages[idx], dist_sq));
}
}
loop {
self.cur_dx += 1;
if self.cur_dx > self.grid_r {
self.cur_dx = -self.grid_r;
self.cur_dy += 1;
if self.cur_dy > self.grid_r {
self.cur_dy = -self.grid_r;
self.cur_dz += 1;
if self.cur_dz > self.grid_r {
return None; }
}
}
let bx = self.center_bx + self.cur_dx;
let by = self.center_by + self.cur_dy;
let bz = self.center_bz + self.cur_dz;
if let Some(&(start, count)) = self.bin_map.get(&(bx, by, bz)) {
self.bin_start = start;
self.bin_offset = 0;
self.bin_count = count;
break; }
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn brute_force_basic() {
let mut msgs = BruteForceMessages::new();
msgs.output(42i32);
msgs.output(99);
msgs.finalize();
assert_eq!(msgs.read_all(), &[42, 99]);
msgs.clear();
assert!(msgs.is_empty());
}
#[test]
fn brute_force_phase_errors_are_typed() {
let mut msgs = BruteForceMessages::new();
assert_eq!(msgs.try_read_all(), Err(MessagePhaseError::NotFinalized));
msgs.try_output(1).unwrap();
msgs.finalize();
assert_eq!(msgs.try_output(2), Err(MessagePhaseError::AlreadyFinalized));
assert_eq!(msgs.try_read_all().unwrap(), &[1]);
}
#[test]
fn spatial_2d_basic() {
let mut msgs = SpatialMessages2D::new(1.0).unwrap();
msgs.output("a", 0.0, 0.0);
msgs.output("b", 0.5, 0.5);
msgs.output("c", 10.0, 10.0);
msgs.finalize();
let nearby: Vec<_> = msgs.read_nearby(0.0, 0.0, 1.0).collect();
assert_eq!(nearby.len(), 2);
let labels: Vec<&str> = nearby.iter().map(|(&m, _)| m).collect();
assert!(labels.contains(&"a"));
assert!(labels.contains(&"b"));
}
#[test]
fn spatial_2d_phase_errors_are_typed() {
let mut msgs = SpatialMessages2D::new(1.0).unwrap();
assert_eq!(
msgs.try_read_nearby(0.0, 0.0, 1.0).err(),
Some(MessagePhaseError::NotFinalized)
);
msgs.try_output("a", 0.0, 0.0).unwrap();
msgs.finalize();
assert_eq!(
msgs.try_output("b", 1.0, 1.0),
Err(MessagePhaseError::AlreadyFinalized)
);
}
#[test]
fn spatial_3d_basic() {
let mut msgs = SpatialMessages3D::new(1.0).unwrap();
msgs.output("a", 0.0, 0.0, 0.0);
msgs.output("b", 0.5, 0.5, 0.5);
msgs.output("c", 10.0, 10.0, 10.0);
msgs.finalize();
let nearby: Vec<_> = msgs.read_nearby(0.0, 0.0, 0.0, 1.0).collect();
assert_eq!(nearby.len(), 2);
let labels: Vec<&str> = nearby.iter().map(|(&m, _)| m).collect();
assert!(labels.contains(&"a"));
assert!(labels.contains(&"b"));
}
#[test]
fn spatial_3d_phase_errors_are_typed() {
let mut msgs = SpatialMessages3D::new(1.0).unwrap();
assert_eq!(
msgs.try_read_nearby(0.0, 0.0, 0.0, 1.0).err(),
Some(MessagePhaseError::NotFinalized)
);
msgs.try_output("a", 0.0, 0.0, 0.0).unwrap();
msgs.finalize();
assert_eq!(
msgs.try_output("b", 1.0, 1.0, 1.0),
Err(MessagePhaseError::AlreadyFinalized)
);
}
}