nblast/
lib.rs

1//! Implementation of the NBLAST algorithm for quantifying neurons' morphological similarity.
2//! Originally published in
3//! [Costa et al. (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4961245/)
4//! and implemented as part of the
5//! [NeuroAnatomy Toolbox](http://natverse.org/).
6//!
7//! # Algorithm
8//!
9//! Each neuron is passed in as a point cloud sample (the links between the points are not required).
10//! A tangent vector is calculated for each point, based on its location and that of its nearest neighbors.
11//! Additionally, an `alpha` value is calculated, which describes how colinear the neighbors are,
12//! between 0 and 1.
13//!
14//! To query the similarity of neuron `Q` to neuron `T`:
15//!
16//! - Take a point and its associated tangent in `Q`
17//!   - Find the nearest point in `T`, and its associated tangent
18//!   - Compute the distance between the two points
19//!   - Compute the absolute dot product of the two tangents
20//!   - Apply some empirically-derived function to the (distance, dot_product) tuple
21//!     - As published, this is the log probabity ratio of any pair belonging to closely related or unrelated neurons
22//! - Repeat for all points, summing the results
23//!
24//! The result is not easily comparable:
25//! it is highly dependent on the size of the point cloud
26//! and is not commutative, i.e. `f(Q, T) != f(T, Q)`.
27//!
28//! To make queries between two pairs of neurons comparable,
29//! the result can be normalized by the "self-hit" score of the query, i.e. `f(Q, Q)`.
30//!
31//! To make the result commutative, the forward `f(Q, T)` and backward `f(T, Q)` scores can be combined in some way.
32//! This library supports several means (arithmetic, harmonic, and geometric), the minimum, and the maximum.
33//! The choice will depend on the application.
34//! This can be applied after the scores are normalized.
35//!
36//! The backbone of the neuron is the most easily sampled and most stereotyped part of its morphology,
37//! and therefore should be focused on for comparisons.
38//! However, a lot of cable is in dendrites, which can cause problems when reconstructed in high resolution.
39//! Queries can be weighted towards straighter, less branched regions by multiplying the absolute dot product
40//! for each point match by the geometric mean of the two alpha values.
41//!
42//! More information on the algorithm can be found
43//! [here](http://jefferislab.org/si/nblast).
44//!
45//! # Usage
46//!
47//! The [QueryNeuron](trait.QueryNeuron.html) and [TargetNeuron](trait.TargetNeuron.html) traits
48//! define types which can be compared with NBLAST.
49//! All `TargetNeuron`s are also `QueryNeuron`s.
50//! Both are [Neuron](trait.Neuron.html)s.
51//!
52//! [PointsTangentsAlphas](struct.PointsTangentsAlphas.html) and
53//! [RstarNeuron](struct.RstarNeuron.html) implement these, respectively.
54//! Both can be created with pre-calculated tangents and alphas, or calculate them on instantiation.
55//!
56//! The [NblastArena](struct.NblastArena.html) contains a collection of `TargetNeuron`s
57//! and a function to apply to pointwise [DistDot](struct.DistDot.html)s to generate
58//! a score for that point match, for convenient many-to-many comparisons.
59//! A pre-calculated table of point match scores can be converted into a function with [table_to_fn](fn.table_to_fn.html).
60//!
61//! ```
62//! use nblast::{NblastArena, ScoreCalc, Neuron, Symmetry};
63//!
64//! // Create a lookup table for the point match scores
65//! let smat = ScoreCalc::table_from_bins(
66//!   vec![0.0, 0.1, 0.25, 0.5, 1.0, 5.0, f64::INFINITY], // distance thresholds
67//!   vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0], // dot product thresholds
68//!   vec![ // table values in dot-major order
69//!     0.0, 0.1, 0.2, 0.3, 0.4,
70//!     1.0, 1.1, 1.2, 1.3, 1.4,
71//!     2.0, 2.1, 2.2, 2.3, 2.4,
72//!     3.0, 3.1, 3.2, 3.3, 3.4,
73//!     4.0, 4.1, 4.2, 4.3, 4.4,
74//!     5.0, 5.1, 5.2, 5.3, 5.4,
75//!   ],
76//! ).expect("could not build score matrix");
77//!
78//! // See the ScoreMatrixBuilder for constructing a score matrix from test data.
79//!
80//! // Create an arena to hold your neurons with this score function, and
81//! // whether it should scale the dot products by the colinearity value.
82//! let mut arena = NblastArena::new(smat, false);
83//! // if the "parallel" feature is enabled, use e.g. `.with_threads(5)` to set 5 threads for multi-neuron queries
84//!
85//! let mut rng = fastrand::Rng::with_seed(1991);
86//!
87//! fn random_points(n: usize, rng: &mut fastrand::Rng) -> Vec<[f64; 3]> {
88//!     std::iter::repeat_with(|| [
89//!         10.0 * rng.f64(),
90//!         10.0 * rng.f64(),
91//!         10.0 * rng.f64(),
92//!     ]).take(n).collect()
93//! }
94//!
95//! // Add some neurons built from points and a neighborhood size,
96//! // returning their indices in the arena
97//! let idx1 = arena.add_neuron(
98//!     Neuron::new(random_points(6, &mut rng), 5).expect("cannot construct neuron")
99//! );
100//! let idx2 = arena.add_neuron(
101//!     Neuron::new(random_points(8, &mut rng), 5).expect("cannot construct neuron")
102//! );
103//!
104//! // get a raw score (not normalized by self-hit, no symmetry)
105//! let raw = arena.query_target(idx1, idx2, false, &None);
106//!
107//! // get all the scores, normalized, made symmetric, and with a centroid distance cut-off
108//! let results = arena.all_v_all(true, &Some(Symmetry::ArithmeticMean), Some(10.0));
109//!
110//! ```
111use nalgebra::base::{Matrix3, Unit, Vector3};
112use std::collections::{HashMap, HashSet};
113use table_lookup::InvalidRangeTable;
114
115#[cfg(feature = "parallel")]
116pub use rayon;
117#[cfg(feature = "parallel")]
118use rayon::prelude::*;
119
120pub use nalgebra;
121
122mod smat;
123pub use smat::ScoreMatrixBuilder;
124
125mod table_lookup;
126pub use table_lookup::{BinLookup, NdBinLookup, RangeTable};
127
128pub mod neurons;
129pub use neurons::{NblastNeuron, Neuron, QueryNeuron, TargetNeuron};
130
131#[cfg(not(any(
132    feature = "nabo",
133    feature = "rstar",
134    feature = "kiddo",
135    feature = "bosque"
136)))]
137compile_error!("no spatial backend feature enabled");
138
139/// Floating point precision type used internally
140pub type Precision = f64;
141/// 3D point type used internally
142pub type Point3 = [Precision; 3];
143/// 3D unit-length vector type used internally
144pub type Normal3 = Unit<Vector3<Precision>>;
145
146fn centroid<T: IntoIterator<Item = Point3>>(points: T) -> Point3 {
147    let mut len: f64 = 0.0;
148    let mut out = [0.0; 3];
149    for p in points {
150        len += 1.0;
151        for idx in 0..3 {
152            out[idx] += p[idx];
153        }
154    }
155    for el in &mut out {
156        *el /= len;
157    }
158    out
159}
160
161fn geometric_mean(a: Precision, b: Precision) -> Precision {
162    (a.max(0.0) * b.max(0.0)).sqrt()
163}
164
165fn harmonic_mean(a: Precision, b: Precision) -> Precision {
166    if a <= 0.0 || b <= 0.0 {
167        0.0
168    } else {
169        2.0 / (1.0 / a + 1.0 / b)
170    }
171}
172
173/// A tangent, alpha pair associated with a point.
174#[derive(Copy, Clone, Debug, PartialEq)]
175pub struct TangentAlpha {
176    pub tangent: Normal3,
177    pub alpha: Precision,
178}
179
180impl TangentAlpha {
181    fn new_from_points<'a>(points: impl Iterator<Item = &'a Point3>) -> Self {
182        let inertia = calc_inertia(points);
183        let eig = inertia.symmetric_eigen();
184        let mut sum = 0.0;
185        let mut vals: Vec<_> = eig
186            .eigenvalues
187            .iter()
188            .enumerate()
189            .map(|(idx, v)| {
190                sum += v;
191                (idx, v)
192            })
193            .collect();
194        vals.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
195        let alpha = (vals[0].1 - vals[1].1) / sum;
196
197        // ? new_unchecked
198        let tangent = Unit::new_normalize(eig.eigenvectors.column(vals[0].0).into());
199
200        Self { tangent, alpha }
201    }
202}
203
204/// Enumeration of methods to ensure that queries are symmetric/ commutative
205/// (i.e. `f(q, t) = f(t, q)`).
206/// Specific applications will require different methods.
207/// Geometric and harmonic means bound the output to be >= 0.0.
208/// Geometric mean may work best with non-normalized queries.
209/// Max may work if an unknown one of the query and target is incomplete.
210#[derive(Default)]
211pub enum Symmetry {
212    ArithmeticMean,
213    #[default]
214    GeometricMean,
215    HarmonicMean,
216    Min,
217    Max,
218}
219
220impl Symmetry {
221    pub fn apply(&self, query_score: Precision, target_score: Precision) -> Precision {
222        match self {
223            Symmetry::ArithmeticMean => (query_score + target_score) / 2.0,
224            Symmetry::GeometricMean => geometric_mean(query_score, target_score),
225            Symmetry::HarmonicMean => harmonic_mean(query_score, target_score),
226            Symmetry::Min => query_score.min(target_score),
227            Symmetry::Max => query_score.max(target_score),
228        }
229    }
230}
231
232/// The result of comparing two (point, tangent) tuples.
233/// Contains the Euclidean distance between the points,
234/// and the absolute dot product of the (unit) tangents,
235/// i.e. the absolute cosine of the angle between them
236/// (possibly scaled by the geometric mean of the alphas).
237#[derive(Debug, Clone, Copy, PartialEq)]
238pub struct DistDot {
239    pub dist: Precision,
240    pub dot: Precision,
241}
242
243impl DistDot {
244    fn to_idxs(
245        self,
246        dist_thresholds: &[Precision],
247        dot_thresholds: &[Precision],
248    ) -> (usize, usize) {
249        let dist_bin = find_bin_binary(self.dist, dist_thresholds);
250        let dot_bin = find_bin_binary(self.dot, dot_thresholds);
251        (dist_bin, dot_bin)
252    }
253
254    fn to_linear_idx(self, dist_thresholds: &[Precision], dot_thresholds: &[Precision]) -> usize {
255        let (row_idx, col_idx) = self.to_idxs(dist_thresholds, dot_thresholds);
256        row_idx * dot_thresholds.len() + col_idx
257    }
258}
259
260impl Default for DistDot {
261    fn default() -> Self {
262        Self {
263            dist: 0.0,
264            dot: 1.0,
265        }
266    }
267}
268
269fn subtract_points(p1: &Point3, p2: &Point3) -> Point3 {
270    let mut result = [0.0; 3];
271    for ((rref, v1), v2) in result.iter_mut().zip(p1).zip(p2) {
272        *rref = v1 - v2;
273    }
274    result
275}
276
277fn center_points<'a>(points: impl Iterator<Item = &'a Point3>) -> impl Iterator<Item = Point3> {
278    let mut points_vec = Vec::default();
279    let mut means: Point3 = [0.0, 0.0, 0.0];
280    for pt in points {
281        points_vec.push(*pt);
282        for (sum, v) in means.iter_mut().zip(pt.iter()) {
283            *sum += v;
284        }
285    }
286
287    for val in means.iter_mut() {
288        *val /= points_vec.len() as Precision;
289    }
290    let subtract = move |p| subtract_points(&p, &means);
291    points_vec.into_iter().map(subtract)
292}
293
294fn dot(a: &[Precision], b: &[Precision]) -> Precision {
295    a.iter()
296        .zip(b.iter())
297        .fold(0.0, |sum, (ax, bx)| sum + ax * bx)
298}
299
300/// Calculate inertia from iterator of points.
301/// This is an implementation of matrix * matrix.transpose(),
302/// to sidestep the fixed-size constraints of linalg's built-in classes.
303/// Only calculates the lower triangle and diagonal.
304fn calc_inertia<'a>(points: impl Iterator<Item = &'a Point3>) -> Matrix3<Precision> {
305    let mut xs = Vec::default();
306    let mut ys = Vec::default();
307    let mut zs = Vec::default();
308    for point in center_points(points) {
309        xs.push(point[0]);
310        ys.push(point[1]);
311        zs.push(point[2]);
312    }
313    Matrix3::new(
314        dot(&xs, &xs),
315        0.0,
316        0.0,
317        dot(&ys, &xs),
318        dot(&ys, &ys),
319        0.0,
320        dot(&zs, &xs),
321        dot(&zs, &ys),
322        dot(&zs, &zs),
323    )
324}
325
326/// Minimal struct to use as the query (not the target) of an NBLAST
327/// comparison.
328/// Equivalent to "dotprops" in the reference implementation.
329#[derive(Clone)]
330pub struct PointsTangentsAlphas {
331    /// Locations of points in point cloud.
332    points: Vec<Point3>,
333    /// For each point in the cloud, a unit-length vector and a colinearity metric.
334    tangents_alphas: Vec<TangentAlpha>,
335}
336
337impl PointsTangentsAlphas {
338    pub fn new(points: Vec<Point3>, tangents_alphas: Vec<TangentAlpha>) -> Self {
339        Self {
340            points,
341            tangents_alphas,
342        }
343    }
344}
345
346impl NblastNeuron for PointsTangentsAlphas {
347    fn len(&self) -> usize {
348        self.points.len()
349    }
350
351    fn points(&self) -> impl Iterator<Item = Point3> + '_ {
352        self.points.iter().cloned()
353    }
354
355    fn centroid(&self) -> Point3 {
356        centroid(self.points())
357    }
358
359    fn tangents(&self) -> impl Iterator<Item = Normal3> + '_ {
360        self.tangents_alphas.iter().map(|ta| ta.tangent)
361    }
362
363    fn alphas(&self) -> impl Iterator<Item = Precision> + '_ {
364        self.tangents_alphas.iter().map(|ta| ta.alpha)
365    }
366}
367
368impl QueryNeuron for PointsTangentsAlphas {
369    fn query_dist_dots<'a>(
370        &'a self,
371        target: &'a impl TargetNeuron,
372        use_alpha: bool,
373    ) -> impl Iterator<Item = DistDot> + 'a {
374        self.points
375            .iter()
376            .zip(self.tangents_alphas.iter())
377            .map(move |(q_pt, q_ta)| {
378                let alpha = if use_alpha { Some(q_ta.alpha) } else { None };
379                target.nearest_match_dist_dot(q_pt, &q_ta.tangent, alpha)
380            })
381    }
382
383    fn query(
384        &self,
385        target: &impl TargetNeuron,
386        use_alpha: bool,
387        score_calc: &ScoreCalc,
388    ) -> Precision {
389        let mut score_total: Precision = 0.0;
390
391        for (q_pt, q_ta) in self.points.iter().zip(self.tangents_alphas.iter()) {
392            let alpha = if use_alpha { Some(q_ta.alpha) } else { None };
393            score_total +=
394                score_calc.calc(&target.nearest_match_dist_dot(q_pt, &q_ta.tangent, alpha));
395        }
396        score_total
397    }
398
399    fn self_hit(&self, score_calc: &ScoreCalc, use_alpha: bool) -> Precision {
400        if use_alpha {
401            self.tangents_alphas
402                .iter()
403                .map(|ta| {
404                    score_calc.calc(&DistDot {
405                        dist: 0.0,
406                        dot: ta.alpha,
407                    })
408                })
409                .fold(0.0, |total, s| total + s)
410        } else {
411            score_calc.calc(&DistDot {
412                dist: 0.0,
413                dot: 1.0,
414            }) * self.len() as Precision
415        }
416    }
417}
418
419// ? consider using nalgebra's Point3 in PointWithIndex, for consistency
420// ^ can't implement rstar::Point for nalgebra::geometry::Point3 because of orphan rules
421// TODO: replace Precision with float generic
422
423/// Given the upper bounds of a number of bins, find which bin the value falls into.
424/// Values outside of the range fall into the bottom or top bin.
425fn find_bin_binary(value: Precision, upper_bounds: &[Precision]) -> usize {
426    let raw = match upper_bounds.binary_search_by(|bound| bound.partial_cmp(&value).unwrap()) {
427        Ok(v) => v + 1,
428        Err(v) => v,
429    };
430    let highest = upper_bounds.len() - 1;
431    if raw > highest {
432        highest
433    } else {
434        raw
435    }
436}
437
438/// Convert an empirically-derived table mapping pointwise distance and tangent absolute dot products
439/// to pointwise scores into a function which can be passed to neuron queries.
440/// These scores are then summed across all points in the query to give the raw NBLAST score.
441///
442/// Cells are passed in dist-major order
443/// i.e. if the original table had distance bins in the left margin
444/// and dot product bins on the top margin,
445/// the cells should be given in row-major order.
446///
447/// Each bin is identified by its upper bound:
448/// the lower bound is implicitly the previous bin's upper bound, or zero.
449/// The output is constrained to the limits of the table.
450pub fn table_to_fn(
451    dist_thresholds: Vec<Precision>,
452    dot_thresholds: Vec<Precision>,
453    cells: Vec<Precision>,
454) -> impl Fn(&DistDot) -> Precision {
455    if dist_thresholds.len() * dot_thresholds.len() != cells.len() {
456        panic!("Number of cells in table do not match number of columns/rows");
457    }
458
459    move |dd: &DistDot| -> Precision { cells[dd.to_linear_idx(&dist_thresholds, &dot_thresholds)] }
460}
461
462pub fn range_table_to_fn(
463    range_table: RangeTable<Precision, Precision>,
464) -> impl Fn(&DistDot) -> Precision {
465    move |dd: &DistDot| -> Precision { *range_table.lookup(&[dd.dist, dd.dot]) }
466}
467
468trait Location {
469    fn location(&self) -> &Point3;
470
471    fn distance2_to<T: Location>(&self, other: T) -> Precision {
472        self.location()
473            .iter()
474            .zip(other.location().iter())
475            .map(|(a, b)| a * a + b * b)
476            .sum()
477    }
478
479    fn distance_to<T: Location>(&self, other: T) -> Precision {
480        self.distance2_to(other).sqrt()
481    }
482}
483
484impl Location for Point3 {
485    fn location(&self) -> &Point3 {
486        self
487    }
488}
489
490impl Location for &Point3 {
491    fn location(&self) -> &Point3 {
492        self
493    }
494}
495
496#[derive(Clone)]
497struct NeuronSelfHit<N: QueryNeuron> {
498    neuron: N,
499    self_hit: Precision,
500    centroid: [Precision; 3],
501}
502
503impl<N: QueryNeuron> NeuronSelfHit<N> {
504    fn new(neuron: N, self_hit: Precision) -> Self {
505        let centroid = neuron.centroid();
506        Self {
507            neuron,
508            self_hit,
509            centroid,
510        }
511    }
512
513    fn score(&self) -> Precision {
514        self.self_hit
515    }
516}
517
518/// Different ways of converting point match statistics into a single score.
519#[derive(Debug, Clone)]
520pub enum ScoreCalc {
521    // Func(Box<dyn Fn(&DistDot) -> Precision + Sync>),
522    Table(RangeTable<Precision, Precision>),
523}
524
525impl ScoreCalc {
526    /// Construct a table from `N` dist bins, `M` dot bins,
527    /// and the values in `(N-1)*(M-1)` cells.
528    ///
529    /// The bins vecs must be monotonically increasing.
530    /// The first and last values of each
531    /// are effectively ignored and replaced by -inf and +inf respectively.
532    /// Values are given in dot-major order, i.e.
533    /// cells in the same dist bin are next to each other.
534    pub fn table_from_bins(
535        dists: Vec<Precision>,
536        dots: Vec<Precision>,
537        values: Vec<Precision>,
538    ) -> Result<Self, InvalidRangeTable> {
539        Ok(Self::Table(RangeTable::new_from_bins(
540            vec![dists, dots],
541            values,
542        )?))
543    }
544
545    /// Apply the score function.
546    pub fn calc(&self, dist_dot: &DistDot) -> Precision {
547        match self {
548            // Self::Func(func) => func(dist_dot),
549            Self::Table(tab) => *tab.lookup(&[dist_dot.dist, dist_dot.dot]),
550        }
551    }
552}
553
554/// Struct for caching a number of neurons for multiple comparable NBLAST queries.
555#[allow(dead_code)]
556pub struct NblastArena<N>
557where
558    N: TargetNeuron,
559{
560    neurons_scores: Vec<NeuronSelfHit<N>>,
561    score_calc: ScoreCalc,
562    use_alpha: bool,
563    // shows as dead code without parallel feature
564    threads: Option<usize>,
565}
566
567pub type NeuronIdx = usize;
568
569impl<N> NblastArena<N>
570where
571    N: TargetNeuron + Sync,
572{
573    /// By default, runs in serial.
574    /// See `NblastArena.with_threads` if the "parallel" feature is enabled.
575    pub fn new(score_calc: ScoreCalc, use_alpha: bool) -> Self {
576        Self {
577            neurons_scores: Vec::default(),
578            score_calc,
579            use_alpha,
580            threads: None,
581        }
582    }
583
584    /// 0 means use all available.
585    #[cfg(feature = "parallel")]
586    pub fn with_threads(self, threads: usize) -> Self {
587        Self {
588            neurons_scores: self.neurons_scores,
589            score_calc: self.score_calc,
590            use_alpha: self.use_alpha,
591            threads: Some(threads),
592        }
593    }
594
595    pub fn size_of(&self, idx: NeuronIdx) -> Option<usize> {
596        self.neurons_scores.get(idx).map(|n| n.neuron.len())
597    }
598
599    fn next_id(&self) -> NeuronIdx {
600        self.neurons_scores.len()
601    }
602
603    /// Returns an index which is then used to make queries.
604    pub fn add_neuron(&mut self, neuron: N) -> NeuronIdx {
605        let idx = self.next_id();
606        let self_hit = neuron.self_hit(&self.score_calc, self.use_alpha);
607        self.neurons_scores
608            .push(NeuronSelfHit::new(neuron, self_hit));
609        idx
610    }
611
612    /// Make a single query using the given indexes.
613    /// `normalize` divides the result by the self-hit score of the query neuron.
614    /// `symmetry`, if `Some`, also calculates the reverse score
615    /// (normalizing it if necessary), and then applies a function to ensure
616    /// that the query is symmetric/ commutative.
617    pub fn query_target(
618        &self,
619        query_idx: NeuronIdx,
620        target_idx: NeuronIdx,
621        normalize: bool,
622        symmetry: &Option<Symmetry>,
623    ) -> Option<Precision> {
624        let q = self.neurons_scores.get(query_idx)?;
625
626        if query_idx == target_idx {
627            return if normalize {
628                Some(1.0)
629            } else {
630                Some(q.score())
631            };
632        }
633
634        let t = self.neurons_scores.get(target_idx)?;
635        let mut score = q.neuron.query(&t.neuron, self.use_alpha, &self.score_calc);
636        if normalize {
637            score /= q.score()
638        }
639        match symmetry {
640            Some(s) => {
641                let mut score2 = t.neuron.query(&q.neuron, self.use_alpha, &self.score_calc);
642                if normalize {
643                    score2 /= t.score();
644                }
645                Some(s.apply(score, score2))
646            }
647            _ => Some(score),
648        }
649    }
650
651    /// Make many queries using the Cartesian product of the query and target indices.
652    ///
653    /// See [query_target](#method.query_target) for details on `normalize` and `symmetry`.
654    /// Neurons whose centroids are greater than `max_centroid_dist` apart will return NaN.
655    /// Indices which do not exist will be silently ignored.
656    pub fn queries_targets(
657        &self,
658        query_idxs: &[NeuronIdx],
659        target_idxs: &[NeuronIdx],
660        normalize: bool,
661        symmetry: &Option<Symmetry>,
662        max_centroid_dist: Option<Precision>,
663    ) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
664        // filter out neurons which don't exist,
665        // but leave filters for centroid distance, duplication, and self-hits to query_target_pairs
666        let pairs: Vec<_> = query_idxs
667            .iter()
668            .filter_map(|q| {
669                let q2 = *q;
670                if q2 >= self.len() {
671                    None
672                } else {
673                    Some(target_idxs.iter().filter_map(move |t| {
674                        if t >= &self.len() {
675                            None
676                        } else {
677                            Some((q2, *t))
678                        }
679                    }))
680                }
681            })
682            .flatten()
683            .collect();
684
685        self.query_target_pairs(&pairs, normalize, symmetry, max_centroid_dist)
686    }
687
688    /// See [query_target](#method.query_target) for details on `normalize` and `symmetry`.
689    /// Neurons whose centroids are greater than `max_centroid_dist` apart will return NaN.
690    /// Indices which do not exist will be silently ignored.
691    pub fn query_target_pairs(
692        &self,
693        query_target_idxs: &[(NeuronIdx, NeuronIdx)],
694        normalize: bool,
695        symmetry: &Option<Symmetry>,
696        max_centroid_dist: Option<Precision>,
697    ) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
698        let mut max_jobs = query_target_idxs.len();
699
700        let mut out = HashMap::with_capacity(query_target_idxs.len());
701        if symmetry.is_some() {
702            max_jobs *= 2;
703        }
704        let mut jobs = HashSet::with_capacity(max_jobs);
705        for (q, t) in query_target_idxs {
706            if q > &self.len() || t > &self.len() {
707                continue;
708            }
709
710            let key = (*q, *t);
711
712            if q == t {
713                out.insert(
714                    key,
715                    if normalize {
716                        1.0
717                    } else {
718                        self.neurons_scores[*q].score()
719                    },
720                );
721                continue;
722            } else {
723                out.insert(key, Precision::NAN);
724            }
725
726            if jobs.contains(&(*q, *t)) {
727                continue;
728            }
729
730            if let Some(d) = max_centroid_dist {
731                if !self
732                    .centroids_within_distance(*q, *t, d)
733                    .expect("Already checked indices")
734                {
735                    continue;
736                }
737            }
738
739            jobs.insert(key);
740            if symmetry.is_some() {
741                jobs.insert((key.1, key.0));
742            }
743        }
744
745        let raw = pairs_to_raw(self, &jobs.into_iter().collect::<Vec<_>>(), normalize);
746
747        for (key, value) in out.iter_mut() {
748            // if a query doesn't appear in jobs, it was ignored
749            // as being self-hit (already populated),
750            // or centroids too far (already NaN)
751            if let Some(forward) = raw.get(key) {
752                if let Some(s) = symmetry {
753                    // if symmetry is some and the forward request exists,
754                    // so does the backward
755                    let backward = raw[&(key.1, key.0)];
756                    // ! this applies symmetry twice if idx is in both queries and targets,
757                    // but it's a cheap function
758                    *value = s.apply(*forward, backward);
759                } else {
760                    *value = *forward;
761                }
762            }
763        }
764
765        out
766    }
767
768    pub fn centroids_within_distance(
769        &self,
770        query_idx: NeuronIdx,
771        target_idx: NeuronIdx,
772        max_centroid_dist: Precision,
773    ) -> Option<bool> {
774        if query_idx == target_idx {
775            return Some(true);
776        }
777        self.neurons_scores.get(query_idx).and_then(|q| {
778            self.neurons_scores
779                .get(target_idx)
780                .map(|t| q.centroid.distance_to(t.centroid) < max_centroid_dist)
781        })
782    }
783
784    pub fn self_hit(&self, idx: NeuronIdx) -> Option<Precision> {
785        self.neurons_scores.get(idx).map(|n| n.score())
786    }
787
788    /// Query every neuron against every other neuron.
789    /// See [queries_targets](#method.queries_targets) for more details.
790    pub fn all_v_all(
791        &self,
792        normalize: bool,
793        symmetry: &Option<Symmetry>,
794        max_centroid_dist: Option<Precision>,
795    ) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
796        let idxs: Vec<NeuronIdx> = (0..self.len()).collect();
797        self.queries_targets(&idxs, &idxs, normalize, symmetry, max_centroid_dist)
798    }
799
800    pub fn is_empty(&self) -> bool {
801        self.neurons_scores.is_empty()
802    }
803
804    /// Number of neurons in the arena.
805    pub fn len(&self) -> usize {
806        self.neurons_scores.len()
807    }
808
809    pub fn points(&self, idx: NeuronIdx) -> Option<impl Iterator<Item = Point3> + '_> {
810        self.neurons_scores.get(idx).map(|n| n.neuron.points())
811    }
812
813    pub fn tangents(&self, idx: NeuronIdx) -> Option<impl Iterator<Item = Normal3> + '_> {
814        self.neurons_scores.get(idx).map(|n| n.neuron.tangents())
815    }
816
817    pub fn alphas(&self, idx: NeuronIdx) -> Option<impl Iterator<Item = Precision> + '_> {
818        self.neurons_scores.get(idx).map(|n| n.neuron.alphas())
819    }
820}
821
822fn pairs_to_raw_serial<N: TargetNeuron + Sync>(
823    arena: &NblastArena<N>,
824    pairs: &[(NeuronIdx, NeuronIdx)],
825    normalize: bool,
826) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
827    pairs
828        .iter()
829        .filter_map(|(q_idx, t_idx)| {
830            arena
831                .query_target(*q_idx, *t_idx, normalize, &None)
832                .map(|s| ((*q_idx, *t_idx), s))
833        })
834        .collect()
835}
836
837#[cfg(not(feature = "parallel"))]
838fn pairs_to_raw<N>(
839    arena: &NblastArena<N>,
840    pairs: &[(NeuronIdx, NeuronIdx)],
841    normalize: bool,
842) -> HashMap<(NeuronIdx, NeuronIdx), Precision>
843where
844    N: TargetNeuron + Sync,
845{
846    pairs_to_raw_serial(arena, pairs, normalize)
847}
848
849#[cfg(feature = "parallel")]
850fn pairs_to_raw<N: TargetNeuron + Sync>(
851    arena: &NblastArena<N>,
852    pairs: &[(NeuronIdx, NeuronIdx)],
853    normalize: bool,
854) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
855    if let Some(t) = arena.threads {
856        let pool = rayon::ThreadPoolBuilder::new()
857            .num_threads(t)
858            .build()
859            .unwrap();
860        pool.install(|| {
861            pairs
862                .par_iter()
863                .filter_map(|(q_idx, t_idx)| {
864                    arena
865                        .query_target(*q_idx, *t_idx, normalize, &None)
866                        .map(|s| ((*q_idx, *t_idx), s))
867                })
868                .collect()
869        })
870    } else {
871        pairs_to_raw_serial(arena, pairs, normalize)
872    }
873}
874
875#[cfg(test)]
876mod test {
877    use super::*;
878
879    const EPSILON: Precision = 0.001;
880    const N_NEIGHBORS: usize = 5;
881
882    fn add_points(a: &Point3, b: &Point3) -> Point3 {
883        let mut out = [0., 0., 0.];
884        for (idx, (x, y)) in a.iter().zip(b.iter()).enumerate() {
885            out[idx] = x + y;
886        }
887        out
888    }
889
890    fn make_points(offset: &Point3, step: &Point3, count: usize) -> Vec<Point3> {
891        let mut out = Vec::default();
892        out.push(*offset);
893
894        for _ in 0..count - 1 {
895            let to_push = add_points(out.last().unwrap(), step);
896            out.push(to_push);
897        }
898
899        out
900    }
901
902    #[test]
903    fn construct() {
904        let points = make_points(&[0., 0., 0.], &[1., 0., 0.], 10);
905        Neuron::new(points, N_NEIGHBORS).unwrap();
906    }
907
908    fn is_close(val1: Precision, val2: Precision) -> bool {
909        println!("Comparing values:\n\tval1: {:?}\n\tval2: {:?}", val1, val2);
910        (val1 - val2).abs() < EPSILON
911    }
912
913    fn assert_close(val1: Precision, val2: Precision) {
914        if !is_close(val1, val2) {
915            panic!("Not close:\n\t{:?}\n\t{:?}", val1, val2);
916        }
917    }
918
919    #[test]
920    fn unit_tangents_eig() {
921        let (points, _, _) = tangent_data();
922        let tangent = TangentAlpha::new_from_points(points.iter()).tangent;
923        assert_close(tangent.dot(&tangent), 1.0)
924    }
925
926    fn equivalent_tangents(tan1: &Normal3, tan2: &Normal3) -> bool {
927        is_close(tan1.dot(tan2).abs(), 1.0)
928    }
929
930    fn tangent_data() -> (Vec<Point3>, Normal3, Precision) {
931        // calculated from implementation known to be correct
932        let tangent = Unit::new_normalize(Vector3::from_column_slice(&[
933            -0.939_392_2,
934            0.313_061_82,
935            0.139_766_18,
936        ]));
937
938        // points in first row of data/dotprops/ChaMARCM-F000586_seg002.csv
939        let points = vec![
940            [
941                329.679_962_158_203,
942                72.718_803_405_761_7,
943                31.028_469_085_693_4,
944            ],
945            [
946                328.647_399_902_344,
947                73.046_119_689_941_4,
948                31.537_061_691_284_2,
949            ],
950            [
951                335.219_879_150_391,
952                70.710_479_736_328_1,
953                30.398_145_675_659_2,
954            ],
955            [
956                332.611_389_160_156,
957                72.322_929_382_324_2,
958                30.887_334_823_608_4,
959            ],
960            [
961                331.770_782_470_703,
962                72.434_440_612_793,
963                31.169_372_558_593_8,
964            ],
965        ];
966
967        let alpha = 0.844_842_871_450_449;
968
969        (points, tangent, alpha)
970    }
971
972    #[test]
973    fn test_tangent_eig() {
974        let (points, exp_tan, _exp_alpha) = tangent_data();
975        let ta = TangentAlpha::new_from_points(points.iter());
976        if !equivalent_tangents(&ta.tangent, &exp_tan) {
977            panic!(
978                "Non-equivalent tangents:\n\t{:?}\n\t{:?}",
979                ta.tangent, exp_tan
980            )
981        }
982        // tested from the python side
983        // assert_close(ta.alpha, exp_alpha);
984    }
985
986    #[test]
987    fn test_neuron() {
988        let (points, exp_tan, _exp_alpha) = tangent_data();
989        let tgt = Neuron::new(points, N_NEIGHBORS).unwrap();
990        assert!(equivalent_tangents(
991            &tgt.tangents().next().unwrap(),
992            &exp_tan
993        ));
994        // tested from the python side
995        // assert_close(tgt.alphas()[0], exp_alpha);
996    }
997
998    /// dist_thresholds, dot_thresholds, values
999    fn score_mat() -> (Vec<Precision>, Vec<Precision>, Vec<Precision>) {
1000        let dists = vec![10.0, 20.0, 30.0, 40.0, 50.0];
1001        let dots = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
1002        let mut values = vec![];
1003        let n_values = dots.len() * dists.len();
1004        for v in 0..n_values {
1005            values.push(v as Precision);
1006        }
1007        (dists, dots, values)
1008    }
1009
1010    #[test]
1011    fn test_score_calc() {
1012        let (dists, dots, values) = score_mat();
1013        let func = table_to_fn(dists, dots, values);
1014        assert_close(
1015            func(&DistDot {
1016                dist: 0.0,
1017                dot: 0.0,
1018            }),
1019            0.0,
1020        );
1021        assert_close(
1022            func(&DistDot {
1023                dist: 0.0,
1024                dot: 0.1,
1025            }),
1026            1.0,
1027        );
1028        assert_close(
1029            func(&DistDot {
1030                dist: 11.0,
1031                dot: 0.0,
1032            }),
1033            10.0,
1034        );
1035        assert_close(
1036            func(&DistDot {
1037                dist: 55.0,
1038                dot: 0.0,
1039            }),
1040            40.0,
1041        );
1042        assert_close(
1043            func(&DistDot {
1044                dist: 55.0,
1045                dot: 10.0,
1046            }),
1047            49.0,
1048        );
1049        assert_close(
1050            func(&DistDot {
1051                dist: 15.0,
1052                dot: 0.15,
1053            }),
1054            11.0,
1055        );
1056    }
1057
1058    #[test]
1059    fn test_find_bin_binary() {
1060        let dots = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
1061        assert_eq!(find_bin_binary(0.0, &dots), 0);
1062        assert_eq!(find_bin_binary(0.15, &dots), 1);
1063        assert_eq!(find_bin_binary(0.95, &dots), 9);
1064        assert_eq!(find_bin_binary(-10.0, &dots), 0);
1065        assert_eq!(find_bin_binary(10.0, &dots), 9);
1066        assert_eq!(find_bin_binary(0.1, &dots), 1);
1067    }
1068
1069    // #[test]
1070    // fn score_function() {
1071    //     let dist_thresholds = vec![1.0, 2.0];
1072    //     let dot_thresholds = vec![0.5, 1.0];
1073    //     let cells = vec![1.0, 2.0, 4.0, 8.0];
1074
1075    //     let score_calc = ScoreCalc::Func(Box::new(table_to_fn(dist_thresholds, dot_thresholds, cells)));
1076
1077    //     let q_points = make_points(&[0., 0., 0.], &[1.0, 0.0, 0.0], 10);
1078    //     let query = PointsTangentsAlphas::new(q_points.clone(), N_NEIGHBORS)
1079    //         .expect("Query construction failed");
1080    //     let query2 = RstarNeuron::new(&q_points, N_NEIGHBORS).expect("Construction failed");
1081    //     let target = RstarNeuron::new(
1082    //         &make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10),
1083    //         N_NEIGHBORS,
1084    //     )
1085    //     .expect("Construction failed");
1086
1087    //     assert_close(
1088    //         query.query(&target, false, &score_calc),
1089    //         query2.query(&target, false, &score_calc),
1090    //     );
1091    //     assert_close(
1092    //         query.self_hit(&score_calc, false),
1093    //         query2.self_hit(&score_calc, false),
1094    //     );
1095    //     let score = query.query(&query2, false, &score_calc);
1096    //     let self_hit = query.self_hit(&score_calc, false);
1097    //     println!("score: {:?}, self-hit {:?}", score, self_hit);
1098    //     assert_close(
1099    //         query.query(&query2, false, &score_calc),
1100    //         query.self_hit(&score_calc, false),
1101    //     );
1102    // }
1103
1104    #[test]
1105    fn arena() {
1106        let dist_thresholds = vec![0.0, 1.0, 2.0];
1107        let dot_thresholds = vec![0.0, 0.5, 1.0];
1108        let cells = vec![1.0, 2.0, 4.0, 8.0];
1109
1110        // let score_calc = ScoreCalc::Func(Box::new(table_to_fn(dist_thresholds, dot_thresholds, cells)));
1111        let score_calc = ScoreCalc::Table(
1112            RangeTable::new_from_bins(vec![dist_thresholds, dot_thresholds], cells).unwrap(),
1113        );
1114
1115        let query =
1116            Neuron::new(make_points(&[0., 0., 0.], &[1., 0., 0.], 10), N_NEIGHBORS).unwrap();
1117        let target =
1118            Neuron::new(make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10), N_NEIGHBORS).unwrap();
1119
1120        let mut arena = NblastArena::new(score_calc, false);
1121        let q_idx = arena.add_neuron(query);
1122        let t_idx = arena.add_neuron(target);
1123
1124        let no_norm = arena
1125            .query_target(q_idx, t_idx, false, &None)
1126            .expect("should exist");
1127        let self_hit = arena
1128            .query_target(q_idx, q_idx, false, &None)
1129            .expect("should exist");
1130
1131        assert!(
1132            arena
1133                .query_target(q_idx, t_idx, true, &None)
1134                .expect("should exist")
1135                - no_norm / self_hit
1136                < EPSILON
1137        );
1138        assert_eq!(
1139            arena.query_target(q_idx, t_idx, false, &Some(Symmetry::ArithmeticMean)),
1140            arena.query_target(t_idx, q_idx, false, &Some(Symmetry::ArithmeticMean)),
1141        );
1142
1143        let out = arena.queries_targets(&[q_idx, t_idx], &[t_idx, q_idx], false, &None, None);
1144        assert_eq!(out.len(), 4);
1145    }
1146
1147    fn test_symmetry(symmetry: &Symmetry, a: Precision, b: Precision) {
1148        assert_close(symmetry.apply(a, b), symmetry.apply(b, a))
1149    }
1150
1151    fn test_symmetry_multiple(symmetry: &Symmetry) {
1152        for (a, b) in vec![(0.3, 0.7), (0.0, 0.7), (-1.0, 0.7), (100.0, 1000.0)].into_iter() {
1153            test_symmetry(symmetry, a, b);
1154        }
1155    }
1156
1157    #[test]
1158    fn symmetry_arithmetic() {
1159        test_symmetry_multiple(&Symmetry::ArithmeticMean)
1160    }
1161
1162    #[test]
1163    fn symmetry_harmonic() {
1164        test_symmetry_multiple(&Symmetry::HarmonicMean)
1165    }
1166
1167    #[test]
1168    fn symmetry_geometric() {
1169        test_symmetry_multiple(&Symmetry::GeometricMean)
1170    }
1171
1172    #[test]
1173    fn symmetry_min() {
1174        test_symmetry_multiple(&Symmetry::Min)
1175    }
1176
1177    #[test]
1178    fn symmetry_max() {
1179        test_symmetry_multiple(&Symmetry::Max)
1180    }
1181
1182    // #[test]
1183    // fn alpha_changes_results() {
1184    //     let (points, _, _) = tangent_data();
1185    //     let neuron = RstarNeuron::new(points, N_NEIGHBORS).unwrap();
1186    //     let score_calc = ScoreCalc::Func(Box::new(|dd: &DistDot| dd.dot));
1187
1188    //     let sh = neuron.self_hit(&score_calc, false);
1189    //     let sh_a = neuron.self_hit(&score_calc, true);
1190    //     assert!(!is_close(sh, sh_a));
1191
1192    //     let q = neuron.query(&neuron, false, &score_calc);
1193    //     let q_a = neuron.query(&neuron, true, &score_calc);
1194
1195    //     assert!(!is_close(q, q_a));
1196    // }
1197}