use std::{collections::HashMap, error, str::FromStr};
use coordinate_system::octant;
use ordered_float::OrderedFloat;
use permutation::Permutation;
use rstar::{PointDistance, RTree, RTreeObject, AABB};
use ultraviolet::DVec3;
use crate::{
geometry::{aabb::Aabb, ellipsoid::Ellipsoid, support::Support},
group_operators::ConditioningParams,
};
pub mod coordinate_system;
pub mod group_provider;
pub mod zero_mean;
#[derive(Copy, Clone)]
pub struct SpatialData {
pub(crate) support: Support,
pub(crate) data_idx: u32,
}
impl SpatialData {
pub fn support(&self) -> Support {
self.support
}
pub fn data_idx(&self) -> u32 {
self.data_idx
}
}
impl RTreeObject for SpatialData {
type Envelope = AABB<[f64; 3]>;
fn envelope(&self) -> Self::Envelope {
let (mins, maxs) = match &self.support {
Support::Point(p) => (*p, *p),
Support::Aabb { aabb, disc: _ } => {
let mins = aabb.mins();
let maxs = aabb.maxs();
(mins, maxs)
}
};
AABB::from_corners(*mins.as_array(), *maxs.as_array())
}
}
impl PointDistance for SpatialData {
fn distance_2(
&self,
point: &<Self::Envelope as rstar::Envelope>::Point,
) -> <<Self::Envelope as rstar::Envelope>::Point as rstar::Point>::Scalar {
self.support.sq_dist_to_point(DVec3::from(*point))
}
}
pub struct NeighboringElement<'a, T> {
pub idx: u32,
pub data: &'a T,
pub support: Support,
pub sq_dist: f64,
}
#[derive(Clone)]
pub struct SpatialAcceleratedDB<T> {
pub tree: RTree<SpatialData>,
pub supports: Vec<Support>,
pub data: Vec<T>,
}
impl<T> SpatialAcceleratedDB<T> {
pub fn new(supports: Vec<Support>, data: Vec<T>) -> Self {
let spatial_data = supports
.clone()
.into_iter()
.enumerate()
.map(|(data_idx, support)| SpatialData {
support,
data_idx: data_idx as u32,
})
.collect::<Vec<_>>();
let tree = RTree::bulk_load(spatial_data);
Self {
tree,
supports,
data,
}
}
fn _iter_nearest(&self, location: DVec3) -> impl Iterator<Item = NeighboringElement<'_, T>> {
self.tree
.nearest_neighbor_iter_with_distance_2(&[location.x, location.y, location.z])
.map(|(sd, sq_dist)| NeighboringElement {
idx: sd.data_idx,
data: &self.data[sd.data_idx as usize],
support: sd.support,
sq_dist,
})
}
}
impl<T> SpatialAcceleratedDB<T>
where
T: FromStr,
<T as FromStr>::Err: std::error::Error + 'static,
{
pub fn from_csv_index(
csv_path: &str,
x_col: &str,
y_col: &str,
z_col: &str,
value_col: &str,
) -> Result<Self, Box<dyn error::Error>> {
let mut point_vec = Vec::new();
let mut value_vec = Vec::new();
let mut rdr = csv::Reader::from_path(csv_path)?;
for result in rdr.deserialize() {
let record: HashMap<String, String> = result?;
let x = record[x_col].parse::<f64>()?;
let y = record[y_col].parse::<f64>()?;
let z = record[z_col].parse::<f64>()?;
let value = record[value_col].parse::<T>()?;
point_vec.push(Support::Point(DVec3::new(x, y, z)));
value_vec.push(value);
}
Ok(Self::new(point_vec, value_vec))
}
}
impl<T: Copy> IterNearest for SpatialAcceleratedDB<T> {
type Data = T;
fn iter_nearest(
&self,
location: &DVec3,
) -> impl Iterator<Item = IterNearestElem<Self::Data>> + '_ {
self._iter_nearest(*location).map(|elem| IterNearestElem {
shape: elem.support,
sq_dist: elem.sq_dist,
data: *elem.data,
tag: 0,
idx: elem.idx,
})
}
}
pub struct ConditioningDataCollector<'b, T> {
pub cond_params: &'b ConditioningParams,
pub max_accepted_dist: f64,
pub ellipsoid: &'b Ellipsoid,
pub octant_shapes: Vec<Vec<Support>>,
pub octant_norm_dists: Vec<Vec<f64>>,
pub octant_values: Vec<Vec<T>>,
pub octant_inds: Vec<Vec<u32>>,
pub octant_tags: Vec<Vec<u32>>,
pub octant_counts: Vec<u32>,
pub full_octants: u8,
pub conditioned_octants: u8,
pub source_tag: Vec<u32>,
pub source_count: Vec<u32>,
pub stop: bool,
}
impl<'b, T> ConditioningDataCollector<'b, T> {
pub fn new(ellipsoid: &'b Ellipsoid, cond_params: &'b ConditioningParams) -> Self {
let octant_max = cond_params.max_octant;
Self {
cond_params,
max_accepted_dist: f64::MAX,
ellipsoid,
octant_shapes: (0..8)
.map(|_| Vec::with_capacity(octant_max))
.collect::<Vec<_>>(),
octant_norm_dists: (0..8)
.map(|_| Vec::with_capacity(octant_max))
.collect::<Vec<_>>(),
octant_values: (0..8)
.map(|_| Vec::with_capacity(octant_max))
.collect::<Vec<_>>(),
octant_inds: (0..8)
.map(|_| Vec::with_capacity(octant_max))
.collect::<Vec<_>>(),
octant_tags: (0..8)
.map(|_| Vec::with_capacity(octant_max))
.collect::<Vec<_>>(),
octant_counts: vec![0; 8],
full_octants: 0,
conditioned_octants: 0,
source_tag: Vec::new(),
source_count: Vec::new(),
stop: false,
}
}
#[inline(always)]
pub fn all_octants_full(&self) -> bool {
self.full_octants == 8
}
#[inline(always)]
pub fn increment_or_insert_tag(&mut self, tag: u32) -> bool {
if let Some(ind) = self.source_tag.iter().position(|&x| x == tag) {
if self.source_count[ind] < self.cond_params.same_source_group_limit as u32 {
self.source_count[ind] += 1;
true
} else {
false
}
} else {
self.source_tag.push(tag);
self.source_count.push(1);
true
}
}
#[inline(always)]
pub fn decrement_tag(&mut self, tag: u32) {
if let Some(ind) = self.source_tag.iter().position(|&x| x == tag) {
self.source_count[ind] -= 1;
}
}
#[inline(always)]
pub fn max_octant_dist(&self, octant: usize) -> Option<(usize, f64)> {
self.octant_norm_dists[octant]
.iter()
.copied()
.enumerate()
.max_by_key(|(_, dist)| OrderedFloat(*dist))
}
#[inline(always)]
pub fn insert_shape(
&mut self,
octant: usize,
shape: Support,
value: T,
dist: f64,
ind: u32,
tag: u32,
) {
if !self.increment_or_insert_tag(tag) {
return;
}
self.octant_shapes[octant].push(shape);
self.octant_inds[octant].push(ind);
self.octant_norm_dists[octant].push(dist);
self.octant_values[octant].push(value);
self.octant_tags[octant].push(tag);
self.octant_counts[octant] += 1;
if self.octant_shapes[octant].len() == 1 {
self.conditioned_octants += 1;
}
}
#[inline(always)]
pub fn remove_shape(&mut self, octant: usize, ind: usize) {
let tag = self.octant_tags[octant][ind];
self.octant_shapes[octant].swap_remove(ind);
self.octant_inds[octant].swap_remove(ind);
self.octant_norm_dists[octant].swap_remove(ind);
self.octant_values[octant].swap_remove(ind);
self.octant_counts[octant] -= 1;
self.decrement_tag(tag);
}
#[inline(always)]
pub fn can_swap_insert(&self, octant: usize, ind: usize, tag: u32) -> bool {
let old_tag = self.octant_tags[octant][ind];
if old_tag == tag {
return true;
}
let Some(tag_ind) = self.source_tag.iter().position(|&x| x == tag) else {
return true;
};
self.source_count[tag_ind] < self.cond_params.same_source_group_limit as u32
}
#[inline(always)]
pub fn try_insert_shape(&mut self, shape: Support, value: T, sq_dist: f64, ind: u32, tag: u32) {
let point = shape.center();
if !self.ellipsoid.may_contain_local_point_at_sq_dist(sq_dist) {
self.stop = true;
return;
}
let local_point = self
.ellipsoid
.coordinate_system
.into_local()
.transform_vec(point);
let h = self.ellipsoid.normalized_local_distance_sq(&local_point);
if h > 1.0 {
return;
}
let octant = octant(&local_point);
if self.octant_shapes[octant as usize].len() < self.cond_params.max_octant {
self.insert_shape(octant as usize, shape, value, h, ind, tag);
return;
}
if let Some((_ind, max_dist)) = self.max_octant_dist(octant as usize) {
if h < max_dist && self.can_swap_insert(octant as usize, _ind, tag) {
self.remove_shape(octant as usize, _ind);
self.insert_shape(octant as usize, shape, value, h, ind, tag);
return;
}
let h_major = local_point.mag();
if h_major > max_dist {
self.full_octants += 1;
if self.all_octants_full() {
self.stop = true;
}
}
}
}
}
pub struct IterNearestElem<T> {
pub shape: Support,
pub sq_dist: f64,
pub data: T,
pub tag: u32,
pub idx: u32,
}
pub trait IterNearest {
type Data;
fn iter_nearest(
&self,
location: &DVec3,
) -> impl Iterator<Item = IterNearestElem<Self::Data>> + '_;
}
pub struct FilteredIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(&IterNearestElem<IN::Data>) -> bool,
{
pub iter_nearest: &'a IN,
pub filter: F,
}
impl<'a, IN, F> FilteredIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(&IterNearestElem<IN::Data>) -> bool,
{
pub fn new(iter_nearest: &'a IN, filter: F) -> Self {
Self {
iter_nearest,
filter,
}
}
}
impl<IN, F> IterNearest for FilteredIterNearest<'_, IN, F>
where
IN: IterNearest,
F: Fn(&IterNearestElem<IN::Data>) -> bool,
{
type Data = IN::Data;
fn iter_nearest(
&self,
location: &DVec3,
) -> impl Iterator<Item = IterNearestElem<Self::Data>> + '_ {
self.iter_nearest
.iter_nearest(location)
.filter(|elem| (self.filter)(elem))
}
}
pub struct MappedIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(IterNearestElem<IN::Data>) -> IterNearestElem<IN::Data>,
{
pub iter_nearest: &'a IN,
pub map: F,
}
impl<'a, IN, F> MappedIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(IterNearestElem<IN::Data>) -> IterNearestElem<IN::Data>,
{
pub fn new(iter_nearest: &'a IN, map: F) -> Self {
Self { iter_nearest, map }
}
}
impl<'a, IN, F> IterNearest for MappedIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(IterNearestElem<IN::Data>) -> IterNearestElem<IN::Data>,
{
type Data = IN::Data;
fn iter_nearest(
&self,
location: &DVec3,
) -> impl Iterator<Item = IterNearestElem<Self::Data>> + '_ {
self.iter_nearest
.iter_nearest(location)
.map(|elem| (self.map)(elem))
}
}
pub trait ConditioningProvider: IterNearest + Sync + Send {
fn query(
&self,
point: &DVec3,
ellipsoid: &Ellipsoid,
params: &ConditioningParams,
) -> (Vec<usize>, Vec<Self::Data>, Vec<Support>, bool);
}
pub enum FilterMapResult<T> {
Mapped(T),
Ignore,
ExitEarly,
}
pub struct FilterMappedIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(IterNearestElem<IN::Data>) -> FilterMapResult<IterNearestElem<IN::Data>>,
{
pub iter_nearest: &'a IN,
pub map: F,
}
impl<'a, IN, F> FilterMappedIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(IterNearestElem<IN::Data>) -> FilterMapResult<IterNearestElem<IN::Data>>,
{
pub fn new(iter_nearest: &'a IN, map: F) -> Self {
Self { iter_nearest, map }
}
}
impl<'a, IN, F> IterNearest for FilterMappedIterNearest<'a, IN, F>
where
IN: IterNearest,
F: Fn(IterNearestElem<IN::Data>) -> FilterMapResult<IterNearestElem<IN::Data>>,
{
type Data = IN::Data;
fn iter_nearest(
&self,
location: &DVec3,
) -> impl Iterator<Item = IterNearestElem<Self::Data>> + '_ {
self.iter_nearest
.iter_nearest(location)
.map(|elem| (self.map)(elem))
.take_while(|res| !matches!(res, FilterMapResult::ExitEarly))
.filter_map(|res| {
if let FilterMapResult::Mapped(val) = res {
Some(val)
} else {
None
}
})
}
}
impl<IN> ConditioningProvider for IN
where
IN: IterNearest + Sync + Send,
{
fn query(
&self,
point: &DVec3,
ellipsoid: &Ellipsoid,
params: &ConditioningParams,
) -> (Vec<usize>, Vec<IN::Data>, Vec<Support>, bool) {
let mut cond_points = ConditioningDataCollector::new(ellipsoid, params);
for IterNearestElem {
shape,
sq_dist,
data,
tag,
idx,
} in self.iter_nearest(point)
{
cond_points.try_insert_shape(shape, data, sq_dist, idx, tag);
if cond_points.stop {
break;
}
}
let mut inds: Vec<usize> = cond_points
.octant_inds
.into_iter()
.flatten()
.map(|i| i as usize)
.collect();
let mut points: Vec<_> = cond_points.octant_shapes.into_iter().flatten().collect();
let mut data = cond_points
.octant_values
.into_iter()
.flatten()
.collect::<Vec<IN::Data>>();
if data.len() > params.max_n_cond {
let mut octant_counts = cond_points.octant_counts;
let mut can_remove_flag =
if cond_points.conditioned_octants > params.min_conditioned_octants as u8 {
vec![true; 8]
} else {
octant_counts.iter().map(|&count| count > 1).collect()
};
let mut octant_inds = cond_points
.octant_norm_dists
.iter()
.enumerate()
.flat_map(|(i, d)| vec![i; d.len()])
.collect::<Vec<_>>();
let mut dists: Vec<f64> = cond_points
.octant_norm_dists
.into_iter()
.flatten()
.collect();
let mut sorted_inds = (0..inds.len()).collect::<Vec<_>>();
sorted_inds.sort_by_key(|i| OrderedFloat(dists[*i]));
let mut permutation = Permutation::oneline(sorted_inds).inverse();
permutation.apply_slice_in_place(&mut inds);
permutation.apply_slice_in_place(&mut points);
permutation.apply_slice_in_place(&mut dists);
permutation.apply_slice_in_place(&mut data);
permutation.apply_slice_in_place(&mut octant_inds);
let mut end = octant_inds.len();
while data.len() > params.max_n_cond {
let Some(r_ind) = octant_inds[0..end]
.iter()
.rev()
.position(|oct| can_remove_flag[*oct])
else {
break;
};
let ind = end - r_ind - 1;
end = ind;
let octant = octant_inds[ind];
inds.swap_remove(ind);
points.swap_remove(ind);
dists.swap_remove(ind);
data.swap_remove(ind);
octant_inds.swap_remove(ind);
octant_counts[octant] -= 1;
if octant_counts[octant] == 0 {
cond_points.conditioned_octants -= 1;
}
if cond_points.conditioned_octants < params.min_conditioned_octants as u8 {
can_remove_flag = octant_counts.iter().map(|&count| count > 1).collect();
}
}
}
let res = cond_points.conditioned_octants >= params.min_conditioned_octants as u8
&& data.len() >= params.min_n_cond;
(inds, data, points, res)
}
}
pub trait DiscretiveVolume {
fn discretize(&self, dx: f64, dy: f64, dz: f64) -> Vec<DVec3>;
}
impl DiscretiveVolume for Aabb {
fn discretize(&self, dx: f64, dy: f64, dz: f64) -> Vec<DVec3> {
let nx = ((self.maxs().x - self.mins().x) / dx).ceil() as usize;
let ny = ((self.maxs().y - self.mins().y) / dy).ceil() as usize;
let nz = ((self.maxs().z - self.mins().z) / dz).ceil() as usize;
let step_x = (self.maxs().x - self.mins().x) / (nx as f64);
let step_y = (self.maxs().y - self.mins().y) / (ny as f64);
let step_z = (self.maxs().z - self.mins().z) / (nz as f64);
let mut points = Vec::new();
let mut x = self.mins().x + step_x / 2.0;
while x <= self.maxs().x {
let mut y = self.mins().y + step_y / 2.0;
while y <= self.maxs().y {
let mut z = self.mins().z + step_z / 2.0;
while z <= self.maxs().z {
points.push(DVec3::new(x, y, z));
z += step_z;
}
y += step_y;
}
x += step_x;
}
points
}
}