use super::num_divisions;
use crate::geometry::{point_circle_distance, point_cylinder_distance, point_line_distance, point_point_distance};
use crate::StrError;
use plotpy::{Canvas, Curve, Plot, Text};
use russell_lab::math::{SQRT_2, SQRT_3};
use std::collections::{HashMap, HashSet};
use std::fmt;
pub fn any_x(_: &[f64]) -> bool {
true
}
pub const GS_DEFAULT_NDIV: usize = 100;
pub const GS_DEFAULT_TOLERANCE: f64 = 1e-4;
pub const GS_DEFAULT_BORDER_TOL: f64 = 1e-2;
type ContainerKey = usize;
type ItemId = usize;
type Container = HashMap<ItemId, Vec<f64>>;
type Containers = HashMap<ContainerKey, Container>;
pub struct GridSearch {
ndim: usize, ndiv: Vec<usize>, xmin: Vec<f64>, xmax: Vec<f64>, coefficient: Vec<usize>, side_length: f64, tolerance: f64, tol_dist: f64, radius: f64, halo: Vec<Vec<f64>>, halo_ncorner: usize, containers: Containers, }
impl GridSearch {
pub fn new(
xmin: &[f64],
xmax: &[f64],
ndiv: Option<usize>,
tolerance: Option<f64>,
border_tol: Option<f64>,
) -> Result<Self, StrError> {
let ndim = xmin.len();
if ndim < 2 || ndim > 3 {
return Err("xmin.len() = ndim must be 2 or 3");
}
if xmax.len() != ndim {
return Err("xmax.len() must equal ndim = xmin.len()");
}
let ndiv_long = match ndiv {
Some(v) => v,
None => GS_DEFAULT_NDIV,
};
if ndiv_long < 1 {
return Err("ndiv must be ≥ 1");
}
let ndiv = num_divisions(1, ndiv_long, xmin, xmax)?; let coefficient = vec![1, ndiv[0], ndiv[0] * ndiv[1]];
let tolerance = match tolerance {
Some(v) => v,
None => GS_DEFAULT_TOLERANCE,
};
if tolerance <= 0.0 {
return Err("tolerance must be > 0.0");
}
let border_tol = match border_tol {
Some(v) => v,
None => GS_DEFAULT_BORDER_TOL,
};
if border_tol < 0.0 {
return Err("border_tol must be ≥ 0.0");
}
let mut xmin = xmin.to_vec();
let mut xmax = xmax.to_vec();
if border_tol > 0.0 {
for i in 0..ndim {
xmin[i] -= border_tol;
xmax[i] += border_tol;
}
}
let mut side_length = (xmax[0] - xmin[0]) / (ndiv[0] as f64);
for i in 1..ndim {
let sl = (xmax[i] - xmin[i]) / (ndiv[i] as f64);
side_length = f64::max(side_length, sl);
}
if side_length <= 2.0 * tolerance {
return Err("(xmax-xmin)/ndiv must be > 2·tolerance; reduce the tolerance (or ndiv)");
}
for i in 0..ndim {
xmax[i] = xmin[i] + side_length * (ndiv[i] as f64);
}
let tol_dist = if ndim == 2 {
SQRT_2 * tolerance
} else {
SQRT_3 * tolerance
};
let radius = if ndim == 2 {
SQRT_2 * side_length / 2.0
} else {
SQRT_3 * side_length / 2.0
};
let halo_ncorner = usize::pow(2, ndim as u32);
let halo = vec![vec![0.0; ndim]; halo_ncorner];
Ok(GridSearch {
ndim,
ndiv,
xmin,
xmax,
coefficient,
side_length,
tolerance,
tol_dist,
radius,
halo,
halo_ncorner,
containers: HashMap::new(),
})
}
pub fn is_outside(&self, x: &[f64]) -> bool {
assert_eq!(x.len(), self.ndim);
for i in 0..self.ndim {
if x[i] < self.xmin[i] || x[i] > self.xmax[i] {
return true;
}
}
return false;
}
pub fn insert(&mut self, id: usize, x: &[f64]) -> Result<(), StrError> {
if x.len() != self.ndim {
return Err("x.len() must equal ndim");
}
let key = match self.calc_container_key(x) {
Some(k) => k,
None => return Err("cannot insert point because its coordinates are outside the grid"),
};
self.update_or_insert(key, id, x);
self.set_halo(x);
let mut tmp = vec![0.0; self.ndim];
for c in 0..self.halo_ncorner {
tmp.copy_from_slice(&self.halo[c][0..self.ndim]);
if let Some(key_corner) = self.calc_container_key(&tmp) {
if key_corner != key {
self.update_or_insert(key_corner, id, x); }
}
}
Ok(())
}
pub fn search(&self, x: &[f64]) -> Result<Option<usize>, StrError> {
if x.len() != self.ndim {
return Err("x.len() must equal ndim");
}
let key = match self.calc_container_key(x) {
Some(k) => k,
None => return Err("cannot find point because the coordinates are outside the grid"),
};
let container = match self.containers.get(&key) {
Some(c) => c,
None => return Ok(None), };
for (id, x_other) in container {
let distance = point_point_distance(x_other, x)?;
if distance <= self.tol_dist {
return Ok(Some(*id));
}
}
Ok(None)
}
pub fn search_on_line<F>(&self, a: &[f64], b: &[f64], mut filter: F) -> Result<HashSet<usize>, StrError>
where
F: FnMut(&[f64]) -> bool,
{
if a.len() != self.ndim {
return Err("a.len() must equal ndim");
}
if b.len() != self.ndim {
return Err("b.len() must equal ndim");
}
let nearest_containers = self.containers_near_line(a, b)?;
let mut ids = HashSet::new();
for index in nearest_containers {
let container = self.containers.get(&index).unwrap();
for (id, x_other) in container {
let distance = point_line_distance(a, b, x_other)?;
if distance <= self.tol_dist && filter(x_other) {
ids.insert(*id);
}
}
}
Ok(ids)
}
pub fn search_on_circle<F>(&self, center: &[f64], radius: f64, mut filter: F) -> Result<HashSet<usize>, StrError>
where
F: FnMut(&[f64]) -> bool,
{
if self.ndim != 2 {
return Err("search_on_circle works in 2D only");
}
if center.len() != self.ndim {
return Err("center.len() must equal ndim");
}
let nearest_containers = self.containers_near_circle(center, radius)?;
let mut ids = HashSet::new();
for index in nearest_containers {
let container = self.containers.get(&index).unwrap();
for (id, x_other) in container {
let distance = point_circle_distance(center, radius, x_other)?;
if f64::abs(distance) <= self.tol_dist && filter(x_other) {
ids.insert(*id);
}
}
}
Ok(ids)
}
pub fn search_on_cylinder<F>(
&self,
a: &[f64],
b: &[f64],
radius: f64,
mut filter: F,
) -> Result<HashSet<usize>, StrError>
where
F: FnMut(&[f64]) -> bool,
{
if self.ndim != 3 {
return Err("search_on_cylinder works in 3D only");
}
if a.len() != self.ndim {
return Err("a.len() must equal ndim");
}
if b.len() != self.ndim {
return Err("b.len() must equal ndim");
}
let nearest_containers = self.containers_near_cylinder(a, b, radius)?;
let mut ids = HashSet::new();
for index in nearest_containers {
let container = self.containers.get(&index).unwrap();
for (id, x_other) in container {
let distance = point_cylinder_distance(a, b, radius, x_other)?;
if f64::abs(distance) <= self.tol_dist && filter(x_other) {
ids.insert(*id);
}
}
}
Ok(ids)
}
pub fn search_on_plane_xy<F>(&self, z: f64, mut filter: F) -> Result<HashSet<usize>, StrError>
where
F: FnMut(&[f64]) -> bool,
{
if self.ndim != 3 {
return Err("search_on_plane_xy works in 3D only");
}
let nearest_containers = self.containers_near_plane(2, z);
let mut ids = HashSet::new();
for index in nearest_containers {
let container = self.containers.get(&index).unwrap();
for (id, x_other) in container {
let distance = f64::abs(x_other[2] - z);
if f64::abs(distance) <= self.tol_dist && filter(x_other) {
ids.insert(*id);
}
}
}
Ok(ids)
}
pub fn search_on_plane_yz<F>(&self, x: f64, mut filter: F) -> Result<HashSet<usize>, StrError>
where
F: FnMut(&[f64]) -> bool,
{
if self.ndim != 3 {
return Err("search_on_plane_yz works in 3D only");
}
let nearest_containers = self.containers_near_plane(0, x);
let mut ids = HashSet::new();
for index in nearest_containers {
let container = self.containers.get(&index).unwrap();
for (id, x_other) in container {
let distance = f64::abs(x_other[0] - x);
if f64::abs(distance) <= self.tol_dist && filter(x_other) {
ids.insert(*id);
}
}
}
Ok(ids)
}
pub fn search_on_plane_xz<F>(&self, y: f64, mut filter: F) -> Result<HashSet<usize>, StrError>
where
F: FnMut(&[f64]) -> bool,
{
if self.ndim != 3 {
return Err("search_on_plane_xz works in 3D only");
}
let nearest_containers = self.containers_near_plane(1, y);
let mut ids = HashSet::new();
for index in nearest_containers {
let container = self.containers.get(&index).unwrap();
for (id, x_other) in container {
let distance = f64::abs(x_other[1] - y);
if f64::abs(distance) <= self.tol_dist && filter(x_other) {
ids.insert(*id);
}
}
}
Ok(ids)
}
pub fn draw(&self, plot: &mut Plot, with_ids: bool) -> Result<(), StrError> {
let mut xmin = vec![0.0; self.ndim];
let mut xmax = vec![0.0; self.ndim];
let mut ndiv = vec![0; self.ndim];
for i in 0..self.ndim {
xmin[i] = self.xmin[i];
xmax[i] = self.xmax[i];
ndiv[i] = self.ndiv[i];
}
let mut canvas = Canvas::new();
canvas
.set_alt_text_color("#5d5d5d")
.draw_grid(&xmin, &xmax, &ndiv, false, with_ids)?;
plot.add(&canvas);
let mut curve = Curve::new();
let mut text = Text::new();
curve
.set_marker_style("o")
.set_marker_color("#fab32faa")
.set_marker_line_color("black")
.set_marker_line_width(0.5);
text.set_color("#cd0000");
for container in self.containers.values() {
for (id, x) in container {
let txt = format!("{}", id);
if self.ndim == 2 {
curve.draw(&[x[0]], &[x[1]]);
if with_ids {
text.draw(x[0], x[1], &txt);
}
} else {
curve.draw_3d(&[x[0]], &[x[1]], &[x[2]]);
if with_ids {
text.draw_3d(x[0], x[1], x[2], &txt);
}
}
}
}
plot.add(&curve).add(&text);
Ok(())
}
#[inline]
fn calc_container_key(&self, x: &[f64]) -> Option<usize> {
let mut ratio = vec![0; self.ndim]; let mut key = 0;
for i in 0..self.ndim {
if x[i] < self.xmin[i] || x[i] > self.xmax[i] {
return None;
}
ratio[i] = ((x[i] - self.xmin[i]) / self.side_length) as usize;
if ratio[i] == self.ndiv[i] {
ratio[i] -= 1; }
key += ratio[i] * self.coefficient[i];
}
Some(key)
}
#[inline]
fn container_pivot_indices(&self, key: usize) -> (usize, usize, usize) {
let i = key % self.ndiv[0];
let j = (key % self.coefficient[2]) / self.ndiv[0];
let k = key / self.coefficient[2];
(i, j, k)
}
#[inline]
fn container_center(&self, cen: &mut [f64], i: usize, j: usize, k: usize) {
cen[0] = self.xmin[0] + (i as f64) * self.side_length + self.side_length / 2.0;
cen[1] = self.xmin[1] + (j as f64) * self.side_length + self.side_length / 2.0;
if self.ndim == 3 {
cen[2] = self.xmin[2] + (k as f64) * self.side_length + self.side_length / 2.0;
}
}
#[inline]
fn containers_near_line(&self, a: &[f64], b: &[f64]) -> Result<Vec<usize>, StrError> {
let mut nearest_containers = Vec::new();
let mut cen = vec![0.0; self.ndim];
for key in self.containers.keys() {
let (i, j, k) = self.container_pivot_indices(*key);
self.container_center(&mut cen, i, j, k);
let distance = point_line_distance(a, b, &cen)?;
if distance <= self.radius + self.tol_dist {
nearest_containers.push(*key);
}
}
Ok(nearest_containers)
}
#[inline]
fn containers_near_circle(&self, circle_center: &[f64], radius: f64) -> Result<Vec<usize>, StrError> {
let mut nearest_containers = Vec::new();
let mut cen = vec![0.0; self.ndim];
for key in self.containers.keys() {
let (i, j, k) = self.container_pivot_indices(*key);
self.container_center(&mut cen, i, j, k);
let distance = point_circle_distance(circle_center, radius, &cen)?;
if distance <= self.radius + self.tol_dist {
nearest_containers.push(*key);
}
}
Ok(nearest_containers)
}
#[inline]
fn containers_near_cylinder(&self, a: &[f64], b: &[f64], radius: f64) -> Result<Vec<usize>, StrError> {
let mut nearest_containers = Vec::new();
let mut cen = vec![0.0; self.ndim];
for key in self.containers.keys() {
let (i, j, k) = self.container_pivot_indices(*key);
self.container_center(&mut cen, i, j, k);
let distance = point_cylinder_distance(a, b, radius, &cen)?;
if distance <= self.radius + self.tol_dist {
nearest_containers.push(*key);
}
}
Ok(nearest_containers)
}
#[inline]
fn containers_near_plane(&self, fixed_dim: usize, fixed_coord: f64) -> Vec<usize> {
let mut nearest_containers = Vec::new();
let mut cen = vec![0.0; self.ndim];
for key in self.containers.keys() {
let (i, j, k) = self.container_pivot_indices(*key);
self.container_center(&mut cen, i, j, k);
let distance = f64::abs(cen[fixed_dim] - fixed_coord);
if distance <= self.radius + self.tol_dist {
nearest_containers.push(*key);
}
}
nearest_containers
}
#[inline]
fn update_or_insert(&mut self, key: ContainerKey, id: ItemId, x: &[f64]) {
let container = self.containers.entry(key).or_insert(HashMap::new());
container.insert(id, x.to_vec());
}
#[inline]
fn set_halo(&mut self, x: &[f64]) {
if self.ndim == 2 {
self.halo[0][0] = x[0] - self.tolerance;
self.halo[0][1] = x[1] - self.tolerance;
self.halo[1][0] = x[0] + self.tolerance;
self.halo[1][1] = x[1] - self.tolerance;
self.halo[2][0] = x[0] + self.tolerance;
self.halo[2][1] = x[1] + self.tolerance;
self.halo[3][0] = x[0] - self.tolerance;
self.halo[3][1] = x[1] + self.tolerance;
} else {
self.halo[0][0] = x[0] - self.tolerance;
self.halo[0][1] = x[1] - self.tolerance;
self.halo[0][2] = x[2] - self.tolerance;
self.halo[1][0] = x[0] + self.tolerance;
self.halo[1][1] = x[1] - self.tolerance;
self.halo[1][2] = x[2] - self.tolerance;
self.halo[2][0] = x[0] + self.tolerance;
self.halo[2][1] = x[1] + self.tolerance;
self.halo[2][2] = x[2] - self.tolerance;
self.halo[3][0] = x[0] - self.tolerance;
self.halo[3][1] = x[1] + self.tolerance;
self.halo[3][2] = x[2] - self.tolerance;
self.halo[4][0] = x[0] - self.tolerance;
self.halo[4][1] = x[1] - self.tolerance;
self.halo[4][2] = x[2] + self.tolerance;
self.halo[5][0] = x[0] + self.tolerance;
self.halo[5][1] = x[1] - self.tolerance;
self.halo[5][2] = x[2] + self.tolerance;
self.halo[6][0] = x[0] + self.tolerance;
self.halo[6][1] = x[1] + self.tolerance;
self.halo[6][2] = x[2] + self.tolerance;
self.halo[7][0] = x[0] - self.tolerance;
self.halo[7][1] = x[1] + self.tolerance;
self.halo[7][2] = x[2] + self.tolerance;
}
}
}
impl fmt::Display for GridSearch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut unique_items = HashSet::new();
let mut indices: Vec<_> = self.containers.keys().collect();
indices.sort();
for index in indices {
let container = self.containers.get(index).unwrap();
let mut ids: Vec<_> = container.keys().map(|id| *id).collect();
ids.sort();
write!(f, "{}: {:?}\n", index, ids).unwrap();
for id in ids {
unique_items.insert(id);
}
}
let mut ids: Vec<_> = unique_items.iter().collect();
ids.sort();
write!(f, "ids = {:?}\n", ids).unwrap();
write!(f, "nitem = {}\n", unique_items.len()).unwrap();
write!(f, "ncontainer = {}\n", self.containers.len()).unwrap();
write!(f, "ndiv = {:?}\n", self.ndiv).unwrap();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{any_x, GridSearch, GS_DEFAULT_TOLERANCE};
use plotpy::Plot;
use russell_lab::math::{SQRT_2, SQRT_3};
use russell_lab::{approx_eq, array_approx_eq};
#[allow(unused_imports)]
use plotpy::{Canvas, Curve, RayEndpoint, Surface};
const NOISE: f64 = 1.23456e-5;
const CIRCLE: ([f64; 2], f64) = ([-0.2, 1.8], 0.45); const POINTS_2D: [[f64; 2]; 12] = [
[-0.1, -0.1 + NOISE],
[0.0, 0.0],
[0.185, -0.185],
[0.6 + NOISE, 0.0],
[0.0, 0.3],
[0.2, 0.5 + NOISE],
[0.31, 0.42],
[0.8, 1.8],
[0.6 + NOISE, 1.5],
[CIRCLE.0[0] + NOISE, CIRCLE.0[1] - CIRCLE.1 + NOISE],
[CIRCLE.0[0] + CIRCLE.1 + NOISE, CIRCLE.0[1] + NOISE],
[
CIRCLE.0[0] + CIRCLE.1 * SQRT_2 / 2.0,
CIRCLE.0[1] - CIRCLE.1 * SQRT_2 / 2.0,
],
];
const LINES_2D: [[[f64; 2]; 2]; 3] = [
[[0.6, -0.2], [0.6, 1.8]], [[-0.2, 1.8], [0.8, 1.8]], [[0.4, -0.1], [0.8, 0.1]], ];
const CYLINDER: ([f64; 3], [f64; 3], f64) = ([1.0, -1.0, -1.0], [1.0, 1.0, -1.0], 0.4); const POINTS_3D: [[f64; 3]; 8] = [
[-1.0 + NOISE, -1.0 + NOISE, -1.0 + NOISE],
[0.0, 0.0, 0.0],
[1.0, 1.0 + NOISE, 1.0],
[0.0, -0.5, -1.0 + NOISE],
[CYLINDER.0[0] - CYLINDER.2, CYLINDER.0[1], CYLINDER.0[2] + NOISE],
[CYLINDER.0[0], CYLINDER.0[1], CYLINDER.0[2] + CYLINDER.2 + NOISE],
[CYLINDER.1[0] - CYLINDER.2, CYLINDER.1[1], CYLINDER.1[2] + NOISE],
[CYLINDER.1[0], CYLINDER.1[1], CYLINDER.1[2] + CYLINDER.2 + NOISE],
];
const LINES_3D: [[[f64; 3]; 2]; 2] = [
[[-1.0, -1.0, -1.0], [1.0, -1.0, -1.0]], [[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]], ];
fn add_sample_points_to_grid_2d(grid: &mut GridSearch) {
let mut id = 100;
for x in &POINTS_2D {
grid.insert(id, x).unwrap();
id += 1;
}
}
fn add_sample_points_to_grid_3d(grid: &mut GridSearch) {
let mut id = 100;
for x in &POINTS_3D {
grid.insert(id, x).unwrap();
id += 1;
}
}
fn sample_grid_2d() -> GridSearch {
GridSearch::new(&[-0.2, -0.2], &[0.8, 1.8], Some(8), None, Some(0.1)).unwrap()
}
fn sample_grid_3d() -> GridSearch {
GridSearch::new(&[-1.0, -1.0, -1.0], &[1.0, 1.0, 1.0], Some(2), None, Some(0.1)).unwrap()
}
#[test]
fn new_handles_wrong_input() {
assert_eq!(
GridSearch::new(&[0.0], &[1.0, 1.0], None, None, None).err(),
Some("xmin.len() = ndim must be 2 or 3")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0], &[1.0], None, None, None).err(),
Some("xmax.len() must equal ndim = xmin.len()")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0], &[1.0, 1.0], None, Some(-0.1), None).err(),
Some("tolerance must be > 0.0")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0], &[1.0, 1.0], None, None, Some(-0.1)).err(),
Some("border_tol must be ≥ 0.0")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0], &[1.0, 1.0], Some(0), None, None).err(),
Some("ndiv must be ≥ 1")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0], &[0.0, 1.0], None, None, None).err(),
Some("xmax must be greater than xmin")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0, 0.0], &[1.0, 0.0, 1.0], None, None, None,).err(),
Some("xmax must be greater than xmin")
);
assert_eq!(
GridSearch::new(&[0.0, 0.0, 0.0], &[1.0, 1.0, 0.0], None, None, None,).err(),
Some("xmax must be greater than xmin")
);
assert_eq!(
GridSearch::new(
&[0.0, 0.0],
&[1.0, 1.0],
Some(100),
Some(0.5 * (1.0 / 100.0)),
Some(0.0),
)
.err(),
Some("(xmax-xmin)/ndiv must be > 2·tolerance; reduce the tolerance (or ndiv)")
);
}
#[test]
fn new_works() {
let grid = GridSearch::new(&[-0.2, -0.2], &[0.8, 1.8], Some(20), None, None).unwrap();
assert_eq!(grid.ndim, 2);
assert_eq!(grid.ndiv, [10, 20]);
approx_eq(grid.side_length, 0.102, 1e-15);
array_approx_eq(&grid.xmin, &[-0.21, -0.21], 1e-15);
array_approx_eq(&grid.xmax, &[0.81, -0.21 + 20.0 * 0.102], 1e-15);
assert_eq!(grid.coefficient, &[1, 10, 10 * 20]);
assert_eq!(grid.tolerance, GS_DEFAULT_TOLERANCE);
approx_eq(grid.tol_dist, SQRT_2 * GS_DEFAULT_TOLERANCE, 1e-15);
approx_eq(grid.radius, SQRT_2 * 0.102 / 2.0, 1e-15);
assert_eq!(grid.halo.len(), 4);
assert_eq!(grid.halo_ncorner, 4);
assert_eq!(grid.containers.len(), 0);
let grid = GridSearch::new(&[-0.2, -0.2], &[0.8, 1.8], Some(8), None, Some(0.1)).unwrap();
assert_eq!(grid.ndim, 2);
assert_eq!(grid.ndiv, [4, 8]);
approx_eq(grid.side_length, 0.3, 1e-15);
array_approx_eq(&grid.xmin, &[-0.3, -0.3], 1e-15);
array_approx_eq(&grid.xmax, &[0.9, 2.1], 1e-15);
assert_eq!(grid.coefficient, &[1, 4, 4 * 8]);
assert_eq!(grid.tolerance, GS_DEFAULT_TOLERANCE);
approx_eq(grid.tol_dist, SQRT_2 * GS_DEFAULT_TOLERANCE, 1e-15);
approx_eq(grid.radius, SQRT_2 * 0.3 / 2.0, 1e-15);
assert_eq!(grid.halo.len(), 4);
assert_eq!(grid.halo_ncorner, 4);
assert_eq!(grid.containers.len(), 0);
let grid = GridSearch::new(&[-1.0, -1.0, -1.0], &[1.0, 1.0, 1.0], Some(2), None, Some(0.1)).unwrap();
assert_eq!(grid.ndim, 3);
assert_eq!(grid.ndiv, [2, 2, 2]);
approx_eq(grid.side_length, 1.1, 1e-15);
array_approx_eq(&grid.xmin, &[-1.1, -1.1, -1.1], 1e-15);
array_approx_eq(&grid.xmax, &[1.1, 1.1, 1.1], 1e-15);
assert_eq!(grid.coefficient, &[1, 2, 2 * 2]);
assert_eq!(grid.tolerance, GS_DEFAULT_TOLERANCE);
approx_eq(grid.tol_dist, SQRT_3 * GS_DEFAULT_TOLERANCE, 1e-15);
approx_eq(grid.radius, SQRT_3 * 1.1 / 2.0, 1e-15);
assert_eq!(grid.halo.len(), 8);
assert_eq!(grid.halo_ncorner, 8);
assert_eq!(grid.containers.len(), 0);
}
#[test]
fn display_trait_works() {
let grid = GridSearch::new(&[-0.2, -0.2], &[0.8, 1.8], Some(6), None, None).unwrap();
assert_eq!(
format!("{}", grid),
"ids = []\n\
nitem = 0\n\
ncontainer = 0\n\
ndiv = [3, 6]\n"
);
let grid = GridSearch::new(&[-1.0, -1.0, -1.0], &[1.0, 1.0, 1.0], Some(3), None, None).unwrap();
assert_eq!(
format!("{}", grid),
"ids = []\n\
nitem = 0\n\
ncontainer = 0\n\
ndiv = [3, 3, 3]\n"
);
}
#[test]
fn draw_works_2d() {
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
let mut plot = Plot::new();
grid.draw(&mut plot, true).unwrap();
grid.draw(&mut plot, false).unwrap();
}
#[test]
fn draw_works_3d() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let mut plot = Plot::new();
grid.draw(&mut plot, true).unwrap();
grid.draw(&mut plot, false).unwrap();
}
#[test]
fn set_halo_works() {
let mut grid = sample_grid_2d();
grid.set_halo(&[0.5, 0.5]);
assert_eq!(grid.halo[0], [0.4999, 0.4999]);
assert_eq!(grid.halo[1], [0.5001, 0.4999]);
assert_eq!(grid.halo[2], [0.5001, 0.5001]);
assert_eq!(grid.halo[3], [0.4999, 0.5001]);
let mut grid = sample_grid_3d();
grid.set_halo(&[0.5, 0.5, 0.5]);
assert_eq!(grid.halo[0], [0.4999, 0.4999, 0.4999]);
assert_eq!(grid.halo[1], [0.5001, 0.4999, 0.4999]);
assert_eq!(grid.halo[2], [0.5001, 0.5001, 0.4999]);
assert_eq!(grid.halo[3], [0.4999, 0.5001, 0.4999]);
assert_eq!(grid.halo[4], [0.4999, 0.4999, 0.5001]);
assert_eq!(grid.halo[5], [0.5001, 0.4999, 0.5001]);
assert_eq!(grid.halo[6], [0.5001, 0.5001, 0.5001]);
assert_eq!(grid.halo[7], [0.4999, 0.5001, 0.5001]);
}
#[test]
fn calc_container_key_works() {
let grid = sample_grid_2d();
assert_eq!(grid.calc_container_key(&[-10.0, 0.0]), None);
assert_eq!(grid.calc_container_key(&[10.0, 0.0]), None);
assert_eq!(grid.calc_container_key(&[0.0, -10.0]), None);
assert_eq!(grid.calc_container_key(&[0.0, 10.0]), None);
assert_eq!(grid.calc_container_key(&[-100.0, -100.0]), None);
assert_eq!(grid.calc_container_key(&[100.0, 100.0]), None);
assert_eq!(grid.calc_container_key(&[0.0 - 1e-4, 0.0 - 1e-4]), Some(0));
assert_eq!(grid.calc_container_key(&[0.1, 0.5]), Some(9));
assert_eq!(grid.calc_container_key(&[0.7, 0.8]), Some(15));
assert_eq!(grid.calc_container_key(&[-0.2, 1.8]), Some(24));
assert_eq!(grid.calc_container_key(&[0.8, 1.8]), Some(27));
let grid = sample_grid_3d();
assert_eq!(grid.calc_container_key(&[-10.0, -10.0, -10.0]), None);
assert_eq!(grid.calc_container_key(&[10.0, -10.0, -10.0]), None);
assert_eq!(grid.calc_container_key(&[-10.0, 10.0, -10.0]), None);
assert_eq!(grid.calc_container_key(&[10.0, 10.0, -10.0]), None);
assert_eq!(grid.calc_container_key(&[-10.0, -10.0, 10.0]), None);
assert_eq!(grid.calc_container_key(&[10.0, -10.0, 10.0]), None);
assert_eq!(grid.calc_container_key(&[-10.0, 10.0, 10.0]), None);
assert_eq!(grid.calc_container_key(&[10.0, 10.0, 10.0]), None);
assert_eq!(grid.calc_container_key(&[-1.0, -1.0, -1.0]), Some(0));
assert_eq!(grid.calc_container_key(&[1.0, 1.0, 1.0]), Some(7));
}
#[test]
fn container_pivot_indices_works() {
let grid = sample_grid_2d();
assert_eq!(grid.container_pivot_indices(0), (0, 0, 0));
assert_eq!(grid.container_pivot_indices(3), (3, 0, 0));
assert_eq!(grid.container_pivot_indices(4), (0, 1, 0));
assert_eq!(grid.container_pivot_indices(20), (0, 5, 0));
assert_eq!(grid.container_pivot_indices(27), (3, 6, 0));
let grid = sample_grid_3d();
assert_eq!(grid.container_pivot_indices(0), (0, 0, 0));
assert_eq!(grid.container_pivot_indices(2), (0, 1, 0));
assert_eq!(grid.container_pivot_indices(7), (1, 1, 1));
}
#[test]
fn container_center_works() {
let grid = sample_grid_2d();
let mut x = vec![0.0; 2];
let (xa, ya) = (grid.xmin[0], grid.xmin[1]);
let (xb, yb) = (grid.xmax[0], grid.xmax[1]);
let h = grid.side_length / 2.0;
grid.container_center(&mut x, 0, 0, 0);
array_approx_eq(&x, &[xa + h, ya + h], 1e-15);
grid.container_center(&mut x, grid.ndiv[0] - 1, grid.ndiv[1] - 1, 0);
array_approx_eq(&x, &[xb - h, yb - h], 1e-15);
let grid = sample_grid_3d();
let mut x = vec![0.0; 3];
let (xa, ya, za) = (grid.xmin[0], grid.xmin[1], grid.xmin[2]);
let (xb, yb, zb) = (grid.xmax[0], grid.xmax[1], grid.xmax[2]);
let h = grid.side_length / 2.0;
grid.container_center(&mut x, 0, 0, 0);
array_approx_eq(&x, &[xa + h, ya + h, za + h], 1e-15);
grid.container_center(&mut x, grid.ndiv[0] - 1, grid.ndiv[1] - 1, grid.ndiv[2] - 1);
array_approx_eq(&x, &[xb - h, yb - h, zb - h], 1e-15);
}
#[test]
fn is_outside_works() {
let grid = sample_grid_2d();
assert_eq!(grid.is_outside(&[-10.0, 0.0]), true);
assert_eq!(grid.is_outside(&[10.0, 0.0]), true);
assert_eq!(grid.is_outside(&[0.0, 10.0]), true);
assert_eq!(grid.is_outside(&[0.0, -10.0]), true);
assert_eq!(grid.is_outside(&[0.0, 0.0]), false);
let grid = sample_grid_3d();
assert_eq!(grid.is_outside(&[-10.0, 0.0, 0.0]), true);
assert_eq!(grid.is_outside(&[0.0, -10.0, 0.0]), true);
assert_eq!(grid.is_outside(&[0.0, 0.0, 0.0]), false);
}
#[test]
fn insert_handles_wrong_input() {
let mut grid = sample_grid_2d();
assert_eq!(grid.insert(0, &[0.0, 0.0, 0.0]), Err("x.len() must equal ndim"));
assert_eq!(
grid.insert(1000, &[10.0, 0.0]),
Err("cannot insert point because its coordinates are outside the grid")
);
let mut grid = sample_grid_3d();
assert_eq!(grid.insert(0, &[0.0, 0.0]), Err("x.len() must equal ndim"));
assert_eq!(
grid.insert(1000, &[10.0, 0.0, 0.0]),
Err("cannot insert point because its coordinates are outside the grid")
);
}
#[test]
fn insert_works_2d() {
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
assert_eq!(
format!("{}", grid),
"0: [100, 101]\n\
1: [101, 102]\n\
2: [103]\n\
3: [103]\n\
4: [101, 104]\n\
5: [101, 104]\n\
6: [103]\n\
7: [103]\n\
8: [104]\n\
9: [104, 105]\n\
10: [106]\n\
20: [109]\n\
21: [111]\n\
22: [108]\n\
23: [108]\n\
25: [110]\n\
26: [108]\n\
27: [107, 108]\n\
29: [110]\n\
31: [107]\n\
ids = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111]\n\
nitem = 12\n\
ncontainer = 20\n\
ndiv = [4, 8]\n"
);
}
#[test]
fn insert_works_3d() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
assert_eq!(
format!("{}", grid),
"0: [100, 101, 103]\n\
1: [101, 103, 104, 105]\n\
2: [101]\n\
3: [101, 106, 107]\n\
4: [101]\n\
5: [101]\n\
6: [101]\n\
7: [101, 102]\n\
ids = [100, 101, 102, 103, 104, 105, 106, 107]\n\
nitem = 8\n\
ncontainer = 8\n\
ndiv = [2, 2, 2]\n"
);
}
#[test]
fn search_handles_wrong_input() {
let grid = sample_grid_2d();
assert_eq!(grid.search(&[0.0, 0.0, 0.0]), Err("x.len() must equal ndim"));
assert_eq!(
grid.search(&[10.0, 0.0]),
Err("cannot find point because the coordinates are outside the grid")
);
let grid = sample_grid_3d();
assert_eq!(grid.search(&[0.0, 0.0]), Err("x.len() must equal ndim"));
assert_eq!(
grid.search(&[10.0, 0.0, 0.0]),
Err("cannot find point because the coordinates are outside the grid")
);
}
#[test]
fn search_works() {
const NOISE: f64 = 1e-4 / 2.0;
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
let mut id = 100;
for x in &POINTS_2D {
let mut y = x.clone();
y[0] += NOISE;
y[1] -= NOISE;
assert_eq!(grid.search(&y).unwrap(), Some(id));
id += 1;
}
assert_eq!(grid.search(&[-0.2, 0.7]).unwrap(), None);
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let mut id = 100;
for x in &POINTS_3D {
let mut y = x.clone();
y[0] += NOISE;
y[1] -= NOISE;
y[2] += NOISE;
assert_eq!(grid.search(&y).unwrap(), Some(id));
id += 1;
}
assert_eq!(grid.search(&[-0.9, 0.9, 0.9]).unwrap(), None);
}
#[test]
fn containers_near_line_works_2d() {
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
let mut indices = grid.containers_near_line(&LINES_2D[0][0], &LINES_2D[0][1]).unwrap();
indices.sort();
assert_eq!(indices, &[2, 3, 6, 7, 10, 22, 23, 26, 27, 31]);
let mut indices = grid.containers_near_line(&LINES_2D[1][0], &LINES_2D[1][1]).unwrap();
indices.sort();
assert_eq!(indices, &[25, 26, 27, 29, 31]);
let mut indices = grid.containers_near_line(&LINES_2D[2][0], &LINES_2D[2][1]).unwrap();
indices.sort();
assert_eq!(indices, &[0, 1, 2, 3, 6, 7]);
}
#[test]
fn containers_near_line_works_3d() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let mut indices = grid.containers_near_line(&LINES_3D[0][0], &LINES_3D[0][1]).unwrap();
indices.sort();
assert_eq!(indices, &[0, 1]);
let mut indices = grid.containers_near_line(&LINES_3D[1][0], &LINES_3D[1][1]).unwrap();
indices.sort();
assert_eq!(indices, &[0, 1, 2, 3, 4, 5, 6, 7]);
}
#[test]
fn search_on_line_handles_wrong_input() {
assert_eq!(any_x(&vec![]), true);
let grid = sample_grid_2d();
assert_eq!(
grid.search_on_line(&[0.0], &[1.0, 1.0], any_x),
Err("a.len() must equal ndim")
);
assert_eq!(
grid.search_on_line(&[0.0, 0.0], &[1.0], any_x),
Err("b.len() must equal ndim")
);
}
#[test]
fn search_on_line_works_2d() {
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
let res = grid.search_on_line(&LINES_2D[0][0], &LINES_2D[0][1], any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [103, 108]);
let res = grid
.search_on_line(&LINES_2D[0][0], &LINES_2D[0][1], |x| x[1] > 0.0)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [108]);
let res = grid.search_on_line(&LINES_2D[1][0], &LINES_2D[1][1], any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [107, 110]);
let res = grid
.search_on_line(&LINES_2D[1][0], &LINES_2D[1][1], |x| x[0] < 0.8)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [110]);
let res = grid.search_on_line(&LINES_2D[2][0], &LINES_2D[2][1], any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [103]);
}
#[test]
fn search_on_line_works_3d() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let res = grid.search_on_line(&LINES_3D[0][0], &LINES_3D[0][1], any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100, 104]);
let res = grid
.search_on_line(&LINES_3D[0][0], &LINES_3D[0][1], |x| x[0] < 0.0)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100]);
let res = grid.search_on_line(&LINES_3D[1][0], &LINES_3D[1][1], any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100, 101, 102]);
let res = grid
.search_on_line(&LINES_3D[1][0], &LINES_3D[1][1], |x| x[2] <= 0.0)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100, 101]);
}
#[test]
fn containers_near_circle_works() {
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
let mut indices = grid.containers_near_circle(&CIRCLE.0, CIRCLE.1).unwrap();
indices.sort();
assert_eq!(indices, &[20, 21, 25, 29]);
}
#[test]
fn search_on_circle_fails_on_wrong_input() {
let grid = sample_grid_2d();
assert_eq!(
grid.search_on_circle(&[-0.2], 0.3, any_x),
Err("center.len() must equal ndim")
);
let grid = sample_grid_3d();
assert_eq!(
grid.search_on_circle(&[0.0, 0.0, 0.0], 1.0, any_x),
Err("search_on_circle works in 2D only")
);
}
#[test]
fn search_on_circle_works() {
let mut grid = sample_grid_2d();
add_sample_points_to_grid_2d(&mut grid);
let res = grid.search_on_circle(&CIRCLE.0, CIRCLE.1, any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [109, 110, 111]);
let res = grid
.search_on_circle(&CIRCLE.0, CIRCLE.1, |x| x[0] < 0.0 || x[0] > 0.2)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [109, 110]);
}
#[test]
fn containers_near_cylinder_works() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let mut indices = grid
.containers_near_cylinder(&CYLINDER.0, &CYLINDER.1, CYLINDER.2)
.unwrap();
indices.sort();
assert_eq!(indices, &[1, 3]);
}
#[test]
fn search_on_cylinder_fails_on_wrong_input() {
let grid = sample_grid_3d();
assert_eq!(
grid.search_on_cylinder(&[0.0, 0.0], &[1.0, 0.0, 0.0], 1.0, any_x),
Err("a.len() must equal ndim")
);
assert_eq!(
grid.search_on_cylinder(&[0.0, 0.0, 0.0], &[1.0, 0.0], 1.0, any_x),
Err("b.len() must equal ndim")
);
let grid = sample_grid_2d();
assert_eq!(
grid.search_on_cylinder(&[0.0, 0.0, 0.0], &[1.0, 0.0, 0.0], 1.0, any_x),
Err("search_on_cylinder works in 3D only")
);
}
#[test]
fn search_on_cylinder_works() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let res = grid
.search_on_cylinder(&CYLINDER.0, &CYLINDER.1, CYLINDER.2, any_x)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [104, 105, 106, 107]);
let res = grid
.search_on_cylinder(&CYLINDER.0, &CYLINDER.1, CYLINDER.2, |x| x[1] > 0.0)
.unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [106, 107]);
}
#[test]
fn containers_near_plane_works() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let mut indices = grid.containers_near_plane(0, -1.0);
indices.sort();
assert_eq!(indices, &[0, 2, 4, 6]);
let mut indices = grid.containers_near_plane(0, 1.0);
indices.sort();
assert_eq!(indices, &[1, 3, 5, 7]);
let mut indices = grid.containers_near_plane(1, -1.0);
indices.sort();
assert_eq!(indices, &[0, 1, 4, 5]);
let mut indices = grid.containers_near_plane(1, 1.0);
indices.sort();
assert_eq!(indices, &[2, 3, 6, 7]);
let mut indices = grid.containers_near_plane(2, -1.0);
indices.sort();
assert_eq!(indices, &[0, 1, 2, 3]);
let mut indices = grid.containers_near_plane(2, 1.0);
indices.sort();
assert_eq!(indices, &[4, 5, 6, 7]);
let mut indices = grid.containers_near_plane(2, 0.0);
indices.sort();
assert_eq!(indices, &[0, 1, 2, 3, 4, 5, 6, 7]);
}
#[test]
fn search_on_plane_fails_on_wrong_input() {
let grid = sample_grid_2d();
assert_eq!(
grid.search_on_plane_xy(-1.0, any_x),
Err("search_on_plane_xy works in 3D only")
);
assert_eq!(
grid.search_on_plane_yz(-1.0, any_x),
Err("search_on_plane_yz works in 3D only")
);
assert_eq!(
grid.search_on_plane_xz(-1.0, any_x),
Err("search_on_plane_xz works in 3D only")
);
}
#[test]
fn search_on_plane_works() {
let mut grid = sample_grid_3d();
add_sample_points_to_grid_3d(&mut grid);
let res = grid.search_on_plane_xy(-1.0, any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100, 103, 104, 106]);
let res = grid.search_on_plane_xy(-1.0, |x| x[1] < 0.0).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100, 103, 104]);
let res = grid.search_on_plane_yz(-1.0, any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100]);
let res = grid.search_on_plane_yz(-1.0, |x| x[2] > 1.0).unwrap();
assert_eq!(res.len(), 0);
let res = grid.search_on_plane_xz(-1.0, any_x).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [100, 104, 105]);
let res = grid.search_on_plane_xz(-1.0, |x| x[0] > 0.0).unwrap();
let mut ids: Vec<_> = res.iter().copied().collect();
ids.sort();
assert_eq!(ids, [104, 105]);
}
}