use nalgebra::base::{Matrix3, Unit, Vector3};
use rstar::primitives::PointWithData;
use rstar::RTree;
use std::collections::{HashMap, HashSet};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub use nalgebra;
pub type Precision = f64;
pub type Point3 = [Precision; 3];
pub type Normal3 = Unit<Vector3<Precision>>;
type PointWithIndex = PointWithData<usize, Point3>;
pub enum Symmetry {
ArithmeticMean,
GeometricMean,
HarmonicMean,
Min,
Max,
}
fn apply_symmetry(
symmetry: &Symmetry,
query_score: Precision,
target_score: Precision,
) -> Precision {
match symmetry {
Symmetry::ArithmeticMean => (query_score + target_score) / 2.0,
Symmetry::GeometricMean => (query_score.max(0.0) * target_score.max(0.0)).sqrt(),
Symmetry::HarmonicMean => {
if query_score.max(0.0) * target_score.max(0.0) == 0.0 {
0.0
} else {
2.0 / (1.0 / query_score + 1.0 / target_score)
}
}
Symmetry::Min => query_score.min(target_score),
Symmetry::Max => query_score.max(target_score),
}
}
#[derive(Debug, Clone, Copy)]
pub struct DistDot {
pub dist: Precision,
pub dot: Precision,
}
impl Default for DistDot {
fn default() -> Self {
Self {
dist: 0.0,
dot: 1.0,
}
}
}
pub trait QueryNeuron {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn query(
&self,
target: &impl TargetNeuron,
score_fn: &impl Fn(&DistDot) -> Precision,
) -> Precision;
fn self_hit(&self, score_fn: &impl Fn(&DistDot) -> Precision) -> Precision {
score_fn(&DistDot::default()) * self.len() as Precision
}
fn points(&self) -> Vec<Point3>;
fn tangents(&self) -> Vec<Normal3>;
}
fn subtract_points(p1: &Point3, p2: &Point3) -> Point3 {
let mut result = [0.0; 3];
for ((rref, v1), v2) in result.iter_mut().zip(p1).zip(p2) {
*rref = v1 - v2;
}
result
}
fn center_points<'a>(points: impl Iterator<Item = &'a Point3>) -> impl Iterator<Item = Point3> {
let mut points_vec = Vec::default();
let mut means: Point3 = [0.0, 0.0, 0.0];
for pt in points {
points_vec.push(*pt);
for (sum, v) in means.iter_mut().zip(pt.iter()) {
*sum += v;
}
}
for val in means.iter_mut() {
*val /= points_vec.len() as Precision;
}
let subtract = move |p| subtract_points(&p, &means);
points_vec.into_iter().map(subtract)
}
fn dot(a: &[Precision], b: &[Precision]) -> Precision {
a.iter()
.zip(b.iter())
.fold(0.0, |sum, (ax, bx)| sum + ax * bx)
}
fn calc_inertia<'a>(points: impl Iterator<Item = &'a Point3>) -> Matrix3<Precision> {
let mut xs = Vec::default();
let mut ys = Vec::default();
let mut zs = Vec::default();
for point in center_points(points) {
xs.push(point[0]);
ys.push(point[1]);
zs.push(point[2]);
}
Matrix3::new(
dot(&xs, &xs),
0.0,
0.0,
dot(&ys, &xs),
dot(&ys, &ys),
0.0,
dot(&zs, &xs),
dot(&zs, &ys),
dot(&zs, &zs),
)
}
fn points_to_tangent_eig<'a>(points: impl Iterator<Item = &'a Point3>) -> Option<Normal3> {
let inertia = calc_inertia(points);
let eig = inertia.symmetric_eigen();
Some(Unit::new_normalize(
eig.eigenvectors.column(eig.eigenvalues.argmax().0).into(),
))
}
fn points_to_rtree(
points: impl Iterator<Item = impl std::borrow::Borrow<Point3>>,
) -> Result<RTree<PointWithIndex>, &'static str> {
Ok(RTree::bulk_load(
points
.enumerate()
.map(|(idx, point)| PointWithIndex::new(idx, *point.borrow()))
.collect(),
))
}
fn points_to_rtree_tangents(
points: impl Iterator<Item = impl std::borrow::Borrow<Point3>> + ExactSizeIterator + Clone,
k: usize,
) -> Result<(RTree<PointWithIndex>, Vec<Normal3>), &'static str> {
if points.len() < k {
return Err("Too few points to generate tangents");
}
let rtree = points_to_rtree(points.clone())?;
let mut tangents: Vec<Normal3> = Vec::with_capacity(rtree.size());
for point in points {
match points_to_tangent_eig(
rtree
.nearest_neighbor_iter(point.borrow())
.take(k)
.map(|pwd| pwd.position()),
) {
Some(t) => tangents.push(t),
None => return Err("Failed to SVD"),
}
}
Ok((rtree, tangents))
}
#[derive(Clone)]
pub struct QueryPointTangents {
points: Vec<Point3>,
tangents: Vec<Normal3>,
}
impl QueryPointTangents {
pub fn new(points: Vec<Point3>, k: usize) -> Result<Self, &'static str> {
points_to_rtree_tangents(points.iter(), k).map(|(_, tangents)| Self { points, tangents })
}
}
impl QueryNeuron for QueryPointTangents {
fn len(&self) -> usize {
self.points.len()
}
fn query(
&self,
target: &impl TargetNeuron,
score_fn: &impl Fn(&DistDot) -> Precision,
) -> Precision {
let mut score_total: Precision = 0.0;
for (q_pt, q_tan) in self.points.iter().zip(self.tangents.iter()) {
score_total += score_fn(&target.nearest_match_dist_dot(q_pt, q_tan));
}
score_total
}
fn points(&self) -> Vec<Point3> {
self.points.clone()
}
fn tangents(&self) -> Vec<Normal3> {
self.tangents.clone()
}
}
pub trait TargetNeuron: QueryNeuron {
fn nearest_match_dist_dot(&self, point: &Point3, tangent: &Normal3) -> DistDot;
}
#[derive(Clone)]
pub struct RStarPointTangents {
rtree: RTree<PointWithIndex>,
tangents: Vec<Normal3>,
}
impl RStarPointTangents {
pub fn new<T: std::borrow::Borrow<Point3>>(
points: impl IntoIterator<
Item = T,
IntoIter = impl Iterator<Item = T> + ExactSizeIterator + Clone,
>,
k: usize,
) -> Result<Self, &'static str> {
points_to_rtree_tangents(points.into_iter(), k)
.map(|(rtree, tangents)| Self { rtree, tangents })
}
pub fn new_with_tangents<T: std::borrow::Borrow<Point3>>(
points: impl IntoIterator<
Item = T,
IntoIter = impl Iterator<Item = T> + ExactSizeIterator + Clone,
>,
tangents: Vec<Normal3>,
) -> Result<Self, &'static str> {
points_to_rtree(points.into_iter()).map(|rtree| Self { rtree, tangents })
}
}
impl QueryNeuron for RStarPointTangents {
fn len(&self) -> usize {
self.tangents.len()
}
fn query(
&self,
target: &impl TargetNeuron,
score_fn: &impl Fn(&DistDot) -> Precision,
) -> Precision {
let mut score_total: Precision = 0.0;
for q_pt_idx in self.rtree.iter() {
let dd =
target.nearest_match_dist_dot(q_pt_idx.position(), &self.tangents[q_pt_idx.data]);
let score = score_fn(&dd);
score_total += score;
}
score_total
}
fn points(&self) -> Vec<Point3> {
let mut unsorted: Vec<&PointWithIndex> = self.rtree.iter().collect();
unsorted.sort_by_key(|pwd| pwd.data);
unsorted.into_iter().map(|pwd| *pwd.position()).collect()
}
fn tangents(&self) -> Vec<Normal3> {
self.tangents.clone()
}
}
impl TargetNeuron for RStarPointTangents {
fn nearest_match_dist_dot(&self, point: &Point3, tangent: &Normal3) -> DistDot {
self.rtree
.nearest_neighbor_iter_with_distance(point)
.next()
.map(|(element, dist2)| {
let this_tangent = self.tangents[element.data];
let dot = this_tangent.dot(tangent).abs();
DistDot {
dist: dist2.sqrt(),
dot,
}
})
.expect("impossible")
}
}
fn find_bin_binary(value: Precision, upper_bounds: &[Precision]) -> usize {
let raw = match upper_bounds.binary_search_by(|bound| bound.partial_cmp(&value).unwrap()) {
Ok(v) => v + 1,
Err(v) => v,
};
let highest = upper_bounds.len() - 1;
if raw > highest {
highest
} else {
raw
}
}
pub fn table_to_fn(
dist_thresholds: Vec<Precision>,
dot_thresholds: Vec<Precision>,
cells: Vec<Precision>,
) -> impl Fn(&DistDot) -> Precision {
if dist_thresholds.len() * dot_thresholds.len() != cells.len() {
panic!("Number of cells in table do not match number of columns/rows");
}
move |dd: &DistDot| -> Precision {
let col_idx = find_bin_binary(dd.dot, &dot_thresholds);
let row_idx = find_bin_binary(dd.dist, &dist_thresholds);
let lin_idx = row_idx * dot_thresholds.len() + col_idx;
cells[lin_idx]
}
}
#[derive(Clone)]
pub struct NblastArena<N, F>
where
N: TargetNeuron,
F: Fn(&DistDot) -> Precision,
{
neurons_scores: Vec<(N, Precision)>,
score_fn: F,
}
pub type NeuronIdx = usize;
impl<N, F> NblastArena<N, F>
where
N: TargetNeuron + Sync,
F: Fn(&DistDot) -> Precision + Sync,
{
pub fn new(score_fn: F) -> Self {
Self {
neurons_scores: Vec::default(),
score_fn,
}
}
fn next_id(&self) -> NeuronIdx {
self.neurons_scores.len()
}
pub fn add_neuron(&mut self, neuron: N) -> NeuronIdx {
let idx = self.next_id();
let score = neuron.self_hit(&self.score_fn);
self.neurons_scores.push((neuron, score));
idx
}
pub fn query_target(
&self,
query_idx: NeuronIdx,
target_idx: NeuronIdx,
normalize: bool,
symmetry: &Option<Symmetry>,
) -> Option<Precision> {
let q = self.neurons_scores.get(query_idx)?;
let t = self.neurons_scores.get(target_idx)?;
let mut score = q.0.query(&t.0, &self.score_fn);
if normalize {
score /= q.1;
}
match symmetry {
Some(s) => {
let mut score2 = t.0.query(&q.0, &self.score_fn);
if normalize {
score2 /= t.1;
}
Some(apply_symmetry(s, score, score2))
}
_ => Some(score),
}
}
pub fn queries_targets(
&self,
query_idxs: &[NeuronIdx],
target_idxs: &[NeuronIdx],
normalize: bool,
symmetry: &Option<Symmetry>,
threads: Option<usize>,
) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
let mut out = HashMap::with_capacity(query_idxs.len() * target_idxs.len());
let mut out_keys: HashSet<(NeuronIdx, NeuronIdx)> = HashSet::default();
let mut jobs: HashSet<(NeuronIdx, NeuronIdx)> = HashSet::default();
for q_idx in query_idxs {
for t_idx in target_idxs {
let key = (*q_idx, *t_idx);
if q_idx == t_idx {
if let Some(ns) = self.neurons_scores.get(*q_idx) {
out.insert(key, if normalize { 1.0 } else { ns.1 });
};
continue;
}
out_keys.insert(key);
jobs.insert(key);
if symmetry.is_some() {
jobs.insert((*t_idx, *q_idx));
}
}
}
let jobs_vec: Vec<_> = jobs.into_iter().collect();
let raw = pairs_to_raw(self, &jobs_vec, normalize, threads);
for key in out_keys.into_iter() {
if let Some(forward) = raw.get(&key) {
if let Some(s) = symmetry {
if let Some(backward) = raw.get(&(key.1, key.0)) {
out.insert(key, apply_symmetry(s, *forward, *backward));
}
} else {
out.insert(key, *forward);
}
}
}
out
}
pub fn self_hit(&self, idx: NeuronIdx) -> Option<Precision> {
self.neurons_scores.get(idx).map(|(_, s)| *s)
}
pub fn all_v_all(
&self,
normalize: bool,
symmetry: &Option<Symmetry>,
threads: Option<usize>,
) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
let idxs: Vec<NeuronIdx> = (0..self.len()).collect();
self.queries_targets(&idxs, &idxs, normalize, symmetry, threads)
}
pub fn is_empty(&self) -> bool {
self.neurons_scores.is_empty()
}
pub fn len(&self) -> usize {
self.neurons_scores.len()
}
pub fn points(&self, idx: NeuronIdx) -> Option<Vec<Point3>> {
self.neurons_scores.get(idx).map(|(n, _)| n.points())
}
pub fn tangents(&self, idx: NeuronIdx) -> Option<Vec<Normal3>> {
self.neurons_scores.get(idx).map(|(n, _)| n.tangents())
}
}
fn pairs_to_raw_serial<N, F>(
arena: &NblastArena<N, F>,
pairs: &[(NeuronIdx, NeuronIdx)],
normalize: bool,
) -> HashMap<(NeuronIdx, NeuronIdx), Precision>
where
N: TargetNeuron + Sync,
F: Fn(&DistDot) -> Precision + Sync,
{
pairs
.iter()
.filter_map(|(q_idx, t_idx)| {
arena
.query_target(*q_idx, *t_idx, normalize, &None)
.map(|s| ((*q_idx, *t_idx), s))
})
.collect()
}
#[cfg(not(feature = "parallel"))]
fn pairs_to_raw<N, F>(
arena: &NblastArena<N, F>,
pairs: &[(NeuronIdx, NeuronIdx)],
normalize: bool,
threads: Option<usize>,
) -> HashMap<(NeuronIdx, NeuronIdx), Precision>
where
N: TargetNeuron + Sync,
F: Fn(&DistDot) -> Precision + Sync,
{
pairs_to_raw_serial(arena, pairs, normalize)
}
#[cfg(feature = "parallel")]
fn pairs_to_raw<N, F>(
arena: &NblastArena<N, F>,
pairs: &[(NeuronIdx, NeuronIdx)],
normalize: bool,
threads: Option<usize>,
) -> HashMap<(NeuronIdx, NeuronIdx), Precision>
where
N: TargetNeuron + Sync,
F: Fn(&DistDot) -> Precision + Sync,
{
if let Some(t) = threads {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(t)
.build()
.unwrap();
pool.install(|| {
pairs
.par_iter()
.filter_map(|(q_idx, t_idx)| {
arena
.query_target(*q_idx, *t_idx, normalize, &None)
.map(|s| ((*q_idx, *t_idx), s))
})
.collect()
})
} else {
pairs_to_raw_serial(arena, pairs, normalize)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: Precision = 0.001;
const N_NEIGHBORS: usize = 5;
fn add_points(a: &Point3, b: &Point3) -> Point3 {
let mut out = [0., 0., 0.];
for (idx, (x, y)) in a.iter().zip(b.iter()).enumerate() {
out[idx] = x + y;
}
out
}
fn make_points(offset: &Point3, step: &Point3, count: usize) -> Vec<Point3> {
let mut out = Vec::default();
out.push(*offset);
for _ in 0..count - 1 {
let to_push = add_points(out.last().unwrap(), step);
out.push(to_push);
}
out
}
#[test]
fn construct() {
let points = make_points(&[0., 0., 0.], &[1., 0., 0.], 10);
QueryPointTangents::new(points.clone(), N_NEIGHBORS).expect("Query construction failed");
RStarPointTangents::new(&points, N_NEIGHBORS).expect("Target construction failed");
}
fn is_close(val1: Precision, val2: Precision) -> bool {
(val1 - val2).abs() < EPSILON
}
fn assert_close(val1: Precision, val2: Precision) {
if !is_close(val1, val2) {
panic!("Not close:\n\t{:?}\n\t{:?}", val1, val2);
}
}
#[test]
fn unit_tangents_eig() {
let (points, _) = tangent_data();
let tangent = points_to_tangent_eig(points.iter()).expect("eig failed");
assert_close(tangent.dot(&tangent), 1.0)
}
fn equivalent_tangents(tan1: &Normal3, tan2: &Normal3) -> bool {
is_close(tan1.dot(tan2).abs(), 1.0)
}
fn tangent_data() -> (Vec<Point3>, Normal3) {
let expected = Unit::new_normalize(Vector3::from_column_slice(&[
-0.939_392_2,
0.313_061_82,
0.139_766_18,
]));
let points = vec![
[
329.679_962_158_203,
72.718_803_405_761_7,
31.028_469_085_693_4,
],
[
328.647_399_902_344,
73.046_119_689_941_4,
31.537_061_691_284_2,
],
[
335.219_879_150_391,
70.710_479_736_328_1,
30.398_145_675_659_2,
],
[
332.611_389_160_156,
72.322_929_382_324_2,
30.887_334_823_608_4,
],
[
331.770_782_470_703,
72.434_440_612_793,
31.169_372_558_593_8,
],
];
(points, expected)
}
#[test]
fn test_tangent_eig() {
let (points, expected) = tangent_data();
let tangent = points_to_tangent_eig(points.iter()).expect("Failed to create tangent");
if !equivalent_tangents(&tangent, &expected) {
panic!(
"Non-equivalent tangents:\n\t{:?}\n\t{:?}",
tangent, expected
)
}
}
fn score_mat() -> (Vec<Precision>, Vec<Precision>, Vec<Precision>) {
let dists = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let dots = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
let mut values = vec![];
let n_values = dots.len() * dists.len();
for v in 0..n_values {
values.push(v as Precision);
}
(dists, dots, values)
}
#[test]
fn test_score_fn() {
let (dists, dots, values) = score_mat();
let func = table_to_fn(dists, dots, values);
assert_close(
func(&DistDot {
dist: 0.0,
dot: 0.0,
}),
0.0,
);
assert_close(
func(&DistDot {
dist: 0.0,
dot: 0.1,
}),
1.0,
);
assert_close(
func(&DistDot {
dist: 11.0,
dot: 0.0,
}),
10.0,
);
assert_close(
func(&DistDot {
dist: 55.0,
dot: 0.0,
}),
40.0,
);
assert_close(
func(&DistDot {
dist: 55.0,
dot: 10.0,
}),
49.0,
);
assert_close(
func(&DistDot {
dist: 15.0,
dot: 0.15,
}),
11.0,
);
}
#[test]
fn test_find_bin_binary() {
let dots = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
assert_eq!(find_bin_binary(0.0, &dots), 0);
assert_eq!(find_bin_binary(0.15, &dots), 1);
assert_eq!(find_bin_binary(0.95, &dots), 9);
assert_eq!(find_bin_binary(-10.0, &dots), 0);
assert_eq!(find_bin_binary(10.0, &dots), 9);
assert_eq!(find_bin_binary(0.1, &dots), 1);
}
#[test]
fn score_function() {
let dist_thresholds = vec![1.0, 2.0];
let dot_thresholds = vec![0.5, 1.0];
let cells = vec![1.0, 2.0, 4.0, 8.0];
let score_fn = table_to_fn(dist_thresholds, dot_thresholds, cells);
let q_points = make_points(&[0., 0., 0.], &[1.0, 0.0, 0.0], 10);
let query = QueryPointTangents::new(q_points.clone(), N_NEIGHBORS)
.expect("Query construction failed");
let query2 = RStarPointTangents::new(&q_points, N_NEIGHBORS).expect("Construction failed");
let target = RStarPointTangents::new(
&make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10),
N_NEIGHBORS,
)
.expect("Construction failed");
assert_close(
query.query(&target, &score_fn),
query2.query(&target, &score_fn),
);
assert_close(query.self_hit(&score_fn), query2.self_hit(&score_fn));
let score = query.query(&query2, &score_fn);
let self_hit = query.self_hit(&score_fn);
println!("score: {:?}, self-hit {:?}", score, self_hit);
assert_close(query.query(&query2, &score_fn), query.self_hit(&score_fn));
}
#[test]
fn arena() {
let dist_thresholds = vec![1.0, 2.0];
let dot_thresholds = vec![0.5, 1.0];
let cells = vec![1.0, 2.0, 4.0, 8.0];
let score_fn = table_to_fn(dist_thresholds, dot_thresholds, cells);
let query =
RStarPointTangents::new(&make_points(&[0., 0., 0.], &[1., 0., 0.], 10), N_NEIGHBORS)
.expect("Construction failed");
let target = RStarPointTangents::new(
&make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10),
N_NEIGHBORS,
)
.expect("Construction failed");
let mut arena = NblastArena::new(score_fn);
let q_idx = arena.add_neuron(query);
let t_idx = arena.add_neuron(target);
let no_norm = arena
.query_target(q_idx, t_idx, false, &None)
.expect("should exist");
let self_hit = arena
.query_target(q_idx, q_idx, false, &None)
.expect("should exist");
assert!(
arena
.query_target(q_idx, t_idx, true, &None)
.expect("should exist")
- no_norm / self_hit
< EPSILON
);
assert_eq!(
arena.query_target(q_idx, t_idx, false, &Some(Symmetry::ArithmeticMean)),
arena.query_target(t_idx, q_idx, false, &Some(Symmetry::ArithmeticMean)),
);
let out = arena.queries_targets(&[q_idx, t_idx], &[t_idx, q_idx], false, &None, None);
assert_eq!(out.len(), 4);
}
fn test_symmetry(symmetry: &Symmetry, a: Precision, b: Precision) {
assert_close(
apply_symmetry(symmetry, a, b),
apply_symmetry(symmetry, b, a),
)
}
fn test_symmetry_multiple(symmetry: &Symmetry) {
for (a, b) in vec![(0.3, 0.7), (0.0, 0.7), (-1.0, 0.7), (100.0, 1000.0)].into_iter() {
test_symmetry(symmetry, a, b);
}
}
#[test]
fn symmetry_arithmetic() {
test_symmetry_multiple(&Symmetry::ArithmeticMean)
}
#[test]
fn symmetry_harmonic() {
test_symmetry_multiple(&Symmetry::HarmonicMean)
}
#[test]
fn symmetry_geometric() {
test_symmetry_multiple(&Symmetry::GeometricMean)
}
#[test]
fn symmetry_min() {
test_symmetry_multiple(&Symmetry::Min)
}
#[test]
fn symmetry_max() {
test_symmetry_multiple(&Symmetry::Max)
}
}