1use nalgebra::base::{Matrix3, Unit, Vector3};
112use std::collections::{HashMap, HashSet};
113use table_lookup::InvalidRangeTable;
114
115#[cfg(feature = "parallel")]
116pub use rayon;
117#[cfg(feature = "parallel")]
118use rayon::prelude::*;
119
120pub use nalgebra;
121
122mod smat;
123pub use smat::ScoreMatrixBuilder;
124
125mod table_lookup;
126pub use table_lookup::{BinLookup, NdBinLookup, RangeTable};
127
128pub mod neurons;
129pub use neurons::{NblastNeuron, Neuron, QueryNeuron, TargetNeuron};
130
131#[cfg(not(any(
132 feature = "nabo",
133 feature = "rstar",
134 feature = "kiddo",
135 feature = "bosque"
136)))]
137compile_error!("no spatial backend feature enabled");
138
139pub type Precision = f64;
141pub type Point3 = [Precision; 3];
143pub type Normal3 = Unit<Vector3<Precision>>;
145
146fn centroid<T: IntoIterator<Item = Point3>>(points: T) -> Point3 {
147 let mut len: f64 = 0.0;
148 let mut out = [0.0; 3];
149 for p in points {
150 len += 1.0;
151 for idx in 0..3 {
152 out[idx] += p[idx];
153 }
154 }
155 for el in &mut out {
156 *el /= len;
157 }
158 out
159}
160
161fn geometric_mean(a: Precision, b: Precision) -> Precision {
162 (a.max(0.0) * b.max(0.0)).sqrt()
163}
164
165fn harmonic_mean(a: Precision, b: Precision) -> Precision {
166 if a <= 0.0 || b <= 0.0 {
167 0.0
168 } else {
169 2.0 / (1.0 / a + 1.0 / b)
170 }
171}
172
173#[derive(Copy, Clone, Debug, PartialEq)]
175pub struct TangentAlpha {
176 pub tangent: Normal3,
177 pub alpha: Precision,
178}
179
180impl TangentAlpha {
181 fn new_from_points<'a>(points: impl Iterator<Item = &'a Point3>) -> Self {
182 let inertia = calc_inertia(points);
183 let eig = inertia.symmetric_eigen();
184 let mut sum = 0.0;
185 let mut vals: Vec<_> = eig
186 .eigenvalues
187 .iter()
188 .enumerate()
189 .map(|(idx, v)| {
190 sum += v;
191 (idx, v)
192 })
193 .collect();
194 vals.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
195 let alpha = (vals[0].1 - vals[1].1) / sum;
196
197 let tangent = Unit::new_normalize(eig.eigenvectors.column(vals[0].0).into());
199
200 Self { tangent, alpha }
201 }
202}
203
204#[derive(Default)]
211pub enum Symmetry {
212 ArithmeticMean,
213 #[default]
214 GeometricMean,
215 HarmonicMean,
216 Min,
217 Max,
218}
219
220impl Symmetry {
221 pub fn apply(&self, query_score: Precision, target_score: Precision) -> Precision {
222 match self {
223 Symmetry::ArithmeticMean => (query_score + target_score) / 2.0,
224 Symmetry::GeometricMean => geometric_mean(query_score, target_score),
225 Symmetry::HarmonicMean => harmonic_mean(query_score, target_score),
226 Symmetry::Min => query_score.min(target_score),
227 Symmetry::Max => query_score.max(target_score),
228 }
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq)]
238pub struct DistDot {
239 pub dist: Precision,
240 pub dot: Precision,
241}
242
243impl DistDot {
244 fn to_idxs(
245 self,
246 dist_thresholds: &[Precision],
247 dot_thresholds: &[Precision],
248 ) -> (usize, usize) {
249 let dist_bin = find_bin_binary(self.dist, dist_thresholds);
250 let dot_bin = find_bin_binary(self.dot, dot_thresholds);
251 (dist_bin, dot_bin)
252 }
253
254 fn to_linear_idx(self, dist_thresholds: &[Precision], dot_thresholds: &[Precision]) -> usize {
255 let (row_idx, col_idx) = self.to_idxs(dist_thresholds, dot_thresholds);
256 row_idx * dot_thresholds.len() + col_idx
257 }
258}
259
260impl Default for DistDot {
261 fn default() -> Self {
262 Self {
263 dist: 0.0,
264 dot: 1.0,
265 }
266 }
267}
268
269fn subtract_points(p1: &Point3, p2: &Point3) -> Point3 {
270 let mut result = [0.0; 3];
271 for ((rref, v1), v2) in result.iter_mut().zip(p1).zip(p2) {
272 *rref = v1 - v2;
273 }
274 result
275}
276
277fn center_points<'a>(points: impl Iterator<Item = &'a Point3>) -> impl Iterator<Item = Point3> {
278 let mut points_vec = Vec::default();
279 let mut means: Point3 = [0.0, 0.0, 0.0];
280 for pt in points {
281 points_vec.push(*pt);
282 for (sum, v) in means.iter_mut().zip(pt.iter()) {
283 *sum += v;
284 }
285 }
286
287 for val in means.iter_mut() {
288 *val /= points_vec.len() as Precision;
289 }
290 let subtract = move |p| subtract_points(&p, &means);
291 points_vec.into_iter().map(subtract)
292}
293
294fn dot(a: &[Precision], b: &[Precision]) -> Precision {
295 a.iter()
296 .zip(b.iter())
297 .fold(0.0, |sum, (ax, bx)| sum + ax * bx)
298}
299
300fn calc_inertia<'a>(points: impl Iterator<Item = &'a Point3>) -> Matrix3<Precision> {
305 let mut xs = Vec::default();
306 let mut ys = Vec::default();
307 let mut zs = Vec::default();
308 for point in center_points(points) {
309 xs.push(point[0]);
310 ys.push(point[1]);
311 zs.push(point[2]);
312 }
313 Matrix3::new(
314 dot(&xs, &xs),
315 0.0,
316 0.0,
317 dot(&ys, &xs),
318 dot(&ys, &ys),
319 0.0,
320 dot(&zs, &xs),
321 dot(&zs, &ys),
322 dot(&zs, &zs),
323 )
324}
325
326#[derive(Clone)]
330pub struct PointsTangentsAlphas {
331 points: Vec<Point3>,
333 tangents_alphas: Vec<TangentAlpha>,
335}
336
337impl PointsTangentsAlphas {
338 pub fn new(points: Vec<Point3>, tangents_alphas: Vec<TangentAlpha>) -> Self {
339 Self {
340 points,
341 tangents_alphas,
342 }
343 }
344}
345
346impl NblastNeuron for PointsTangentsAlphas {
347 fn len(&self) -> usize {
348 self.points.len()
349 }
350
351 fn points(&self) -> impl Iterator<Item = Point3> + '_ {
352 self.points.iter().cloned()
353 }
354
355 fn centroid(&self) -> Point3 {
356 centroid(self.points())
357 }
358
359 fn tangents(&self) -> impl Iterator<Item = Normal3> + '_ {
360 self.tangents_alphas.iter().map(|ta| ta.tangent)
361 }
362
363 fn alphas(&self) -> impl Iterator<Item = Precision> + '_ {
364 self.tangents_alphas.iter().map(|ta| ta.alpha)
365 }
366}
367
368impl QueryNeuron for PointsTangentsAlphas {
369 fn query_dist_dots<'a>(
370 &'a self,
371 target: &'a impl TargetNeuron,
372 use_alpha: bool,
373 ) -> impl Iterator<Item = DistDot> + 'a {
374 self.points
375 .iter()
376 .zip(self.tangents_alphas.iter())
377 .map(move |(q_pt, q_ta)| {
378 let alpha = if use_alpha { Some(q_ta.alpha) } else { None };
379 target.nearest_match_dist_dot(q_pt, &q_ta.tangent, alpha)
380 })
381 }
382
383 fn query(
384 &self,
385 target: &impl TargetNeuron,
386 use_alpha: bool,
387 score_calc: &ScoreCalc,
388 ) -> Precision {
389 let mut score_total: Precision = 0.0;
390
391 for (q_pt, q_ta) in self.points.iter().zip(self.tangents_alphas.iter()) {
392 let alpha = if use_alpha { Some(q_ta.alpha) } else { None };
393 score_total +=
394 score_calc.calc(&target.nearest_match_dist_dot(q_pt, &q_ta.tangent, alpha));
395 }
396 score_total
397 }
398
399 fn self_hit(&self, score_calc: &ScoreCalc, use_alpha: bool) -> Precision {
400 if use_alpha {
401 self.tangents_alphas
402 .iter()
403 .map(|ta| {
404 score_calc.calc(&DistDot {
405 dist: 0.0,
406 dot: ta.alpha,
407 })
408 })
409 .fold(0.0, |total, s| total + s)
410 } else {
411 score_calc.calc(&DistDot {
412 dist: 0.0,
413 dot: 1.0,
414 }) * self.len() as Precision
415 }
416 }
417}
418
419fn find_bin_binary(value: Precision, upper_bounds: &[Precision]) -> usize {
426 let raw = match upper_bounds.binary_search_by(|bound| bound.partial_cmp(&value).unwrap()) {
427 Ok(v) => v + 1,
428 Err(v) => v,
429 };
430 let highest = upper_bounds.len() - 1;
431 if raw > highest {
432 highest
433 } else {
434 raw
435 }
436}
437
438pub fn table_to_fn(
451 dist_thresholds: Vec<Precision>,
452 dot_thresholds: Vec<Precision>,
453 cells: Vec<Precision>,
454) -> impl Fn(&DistDot) -> Precision {
455 if dist_thresholds.len() * dot_thresholds.len() != cells.len() {
456 panic!("Number of cells in table do not match number of columns/rows");
457 }
458
459 move |dd: &DistDot| -> Precision { cells[dd.to_linear_idx(&dist_thresholds, &dot_thresholds)] }
460}
461
462pub fn range_table_to_fn(
463 range_table: RangeTable<Precision, Precision>,
464) -> impl Fn(&DistDot) -> Precision {
465 move |dd: &DistDot| -> Precision { *range_table.lookup(&[dd.dist, dd.dot]) }
466}
467
468trait Location {
469 fn location(&self) -> &Point3;
470
471 fn distance2_to<T: Location>(&self, other: T) -> Precision {
472 self.location()
473 .iter()
474 .zip(other.location().iter())
475 .map(|(a, b)| a * a + b * b)
476 .sum()
477 }
478
479 fn distance_to<T: Location>(&self, other: T) -> Precision {
480 self.distance2_to(other).sqrt()
481 }
482}
483
484impl Location for Point3 {
485 fn location(&self) -> &Point3 {
486 self
487 }
488}
489
490impl Location for &Point3 {
491 fn location(&self) -> &Point3 {
492 self
493 }
494}
495
496#[derive(Clone)]
497struct NeuronSelfHit<N: QueryNeuron> {
498 neuron: N,
499 self_hit: Precision,
500 centroid: [Precision; 3],
501}
502
503impl<N: QueryNeuron> NeuronSelfHit<N> {
504 fn new(neuron: N, self_hit: Precision) -> Self {
505 let centroid = neuron.centroid();
506 Self {
507 neuron,
508 self_hit,
509 centroid,
510 }
511 }
512
513 fn score(&self) -> Precision {
514 self.self_hit
515 }
516}
517
518#[derive(Debug, Clone)]
520pub enum ScoreCalc {
521 Table(RangeTable<Precision, Precision>),
523}
524
525impl ScoreCalc {
526 pub fn table_from_bins(
535 dists: Vec<Precision>,
536 dots: Vec<Precision>,
537 values: Vec<Precision>,
538 ) -> Result<Self, InvalidRangeTable> {
539 Ok(Self::Table(RangeTable::new_from_bins(
540 vec![dists, dots],
541 values,
542 )?))
543 }
544
545 pub fn calc(&self, dist_dot: &DistDot) -> Precision {
547 match self {
548 Self::Table(tab) => *tab.lookup(&[dist_dot.dist, dist_dot.dot]),
550 }
551 }
552}
553
554#[allow(dead_code)]
556pub struct NblastArena<N>
557where
558 N: TargetNeuron,
559{
560 neurons_scores: Vec<NeuronSelfHit<N>>,
561 score_calc: ScoreCalc,
562 use_alpha: bool,
563 threads: Option<usize>,
565}
566
567pub type NeuronIdx = usize;
568
569impl<N> NblastArena<N>
570where
571 N: TargetNeuron + Sync,
572{
573 pub fn new(score_calc: ScoreCalc, use_alpha: bool) -> Self {
576 Self {
577 neurons_scores: Vec::default(),
578 score_calc,
579 use_alpha,
580 threads: None,
581 }
582 }
583
584 #[cfg(feature = "parallel")]
586 pub fn with_threads(self, threads: usize) -> Self {
587 Self {
588 neurons_scores: self.neurons_scores,
589 score_calc: self.score_calc,
590 use_alpha: self.use_alpha,
591 threads: Some(threads),
592 }
593 }
594
595 pub fn size_of(&self, idx: NeuronIdx) -> Option<usize> {
596 self.neurons_scores.get(idx).map(|n| n.neuron.len())
597 }
598
599 fn next_id(&self) -> NeuronIdx {
600 self.neurons_scores.len()
601 }
602
603 pub fn add_neuron(&mut self, neuron: N) -> NeuronIdx {
605 let idx = self.next_id();
606 let self_hit = neuron.self_hit(&self.score_calc, self.use_alpha);
607 self.neurons_scores
608 .push(NeuronSelfHit::new(neuron, self_hit));
609 idx
610 }
611
612 pub fn query_target(
618 &self,
619 query_idx: NeuronIdx,
620 target_idx: NeuronIdx,
621 normalize: bool,
622 symmetry: &Option<Symmetry>,
623 ) -> Option<Precision> {
624 let q = self.neurons_scores.get(query_idx)?;
625
626 if query_idx == target_idx {
627 return if normalize {
628 Some(1.0)
629 } else {
630 Some(q.score())
631 };
632 }
633
634 let t = self.neurons_scores.get(target_idx)?;
635 let mut score = q.neuron.query(&t.neuron, self.use_alpha, &self.score_calc);
636 if normalize {
637 score /= q.score()
638 }
639 match symmetry {
640 Some(s) => {
641 let mut score2 = t.neuron.query(&q.neuron, self.use_alpha, &self.score_calc);
642 if normalize {
643 score2 /= t.score();
644 }
645 Some(s.apply(score, score2))
646 }
647 _ => Some(score),
648 }
649 }
650
651 pub fn queries_targets(
657 &self,
658 query_idxs: &[NeuronIdx],
659 target_idxs: &[NeuronIdx],
660 normalize: bool,
661 symmetry: &Option<Symmetry>,
662 max_centroid_dist: Option<Precision>,
663 ) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
664 let pairs: Vec<_> = query_idxs
667 .iter()
668 .filter_map(|q| {
669 let q2 = *q;
670 if q2 >= self.len() {
671 None
672 } else {
673 Some(target_idxs.iter().filter_map(move |t| {
674 if t >= &self.len() {
675 None
676 } else {
677 Some((q2, *t))
678 }
679 }))
680 }
681 })
682 .flatten()
683 .collect();
684
685 self.query_target_pairs(&pairs, normalize, symmetry, max_centroid_dist)
686 }
687
688 pub fn query_target_pairs(
692 &self,
693 query_target_idxs: &[(NeuronIdx, NeuronIdx)],
694 normalize: bool,
695 symmetry: &Option<Symmetry>,
696 max_centroid_dist: Option<Precision>,
697 ) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
698 let mut max_jobs = query_target_idxs.len();
699
700 let mut out = HashMap::with_capacity(query_target_idxs.len());
701 if symmetry.is_some() {
702 max_jobs *= 2;
703 }
704 let mut jobs = HashSet::with_capacity(max_jobs);
705 for (q, t) in query_target_idxs {
706 if q > &self.len() || t > &self.len() {
707 continue;
708 }
709
710 let key = (*q, *t);
711
712 if q == t {
713 out.insert(
714 key,
715 if normalize {
716 1.0
717 } else {
718 self.neurons_scores[*q].score()
719 },
720 );
721 continue;
722 } else {
723 out.insert(key, Precision::NAN);
724 }
725
726 if jobs.contains(&(*q, *t)) {
727 continue;
728 }
729
730 if let Some(d) = max_centroid_dist {
731 if !self
732 .centroids_within_distance(*q, *t, d)
733 .expect("Already checked indices")
734 {
735 continue;
736 }
737 }
738
739 jobs.insert(key);
740 if symmetry.is_some() {
741 jobs.insert((key.1, key.0));
742 }
743 }
744
745 let raw = pairs_to_raw(self, &jobs.into_iter().collect::<Vec<_>>(), normalize);
746
747 for (key, value) in out.iter_mut() {
748 if let Some(forward) = raw.get(key) {
752 if let Some(s) = symmetry {
753 let backward = raw[&(key.1, key.0)];
756 *value = s.apply(*forward, backward);
759 } else {
760 *value = *forward;
761 }
762 }
763 }
764
765 out
766 }
767
768 pub fn centroids_within_distance(
769 &self,
770 query_idx: NeuronIdx,
771 target_idx: NeuronIdx,
772 max_centroid_dist: Precision,
773 ) -> Option<bool> {
774 if query_idx == target_idx {
775 return Some(true);
776 }
777 self.neurons_scores.get(query_idx).and_then(|q| {
778 self.neurons_scores
779 .get(target_idx)
780 .map(|t| q.centroid.distance_to(t.centroid) < max_centroid_dist)
781 })
782 }
783
784 pub fn self_hit(&self, idx: NeuronIdx) -> Option<Precision> {
785 self.neurons_scores.get(idx).map(|n| n.score())
786 }
787
788 pub fn all_v_all(
791 &self,
792 normalize: bool,
793 symmetry: &Option<Symmetry>,
794 max_centroid_dist: Option<Precision>,
795 ) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
796 let idxs: Vec<NeuronIdx> = (0..self.len()).collect();
797 self.queries_targets(&idxs, &idxs, normalize, symmetry, max_centroid_dist)
798 }
799
800 pub fn is_empty(&self) -> bool {
801 self.neurons_scores.is_empty()
802 }
803
804 pub fn len(&self) -> usize {
806 self.neurons_scores.len()
807 }
808
809 pub fn points(&self, idx: NeuronIdx) -> Option<impl Iterator<Item = Point3> + '_> {
810 self.neurons_scores.get(idx).map(|n| n.neuron.points())
811 }
812
813 pub fn tangents(&self, idx: NeuronIdx) -> Option<impl Iterator<Item = Normal3> + '_> {
814 self.neurons_scores.get(idx).map(|n| n.neuron.tangents())
815 }
816
817 pub fn alphas(&self, idx: NeuronIdx) -> Option<impl Iterator<Item = Precision> + '_> {
818 self.neurons_scores.get(idx).map(|n| n.neuron.alphas())
819 }
820}
821
822fn pairs_to_raw_serial<N: TargetNeuron + Sync>(
823 arena: &NblastArena<N>,
824 pairs: &[(NeuronIdx, NeuronIdx)],
825 normalize: bool,
826) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
827 pairs
828 .iter()
829 .filter_map(|(q_idx, t_idx)| {
830 arena
831 .query_target(*q_idx, *t_idx, normalize, &None)
832 .map(|s| ((*q_idx, *t_idx), s))
833 })
834 .collect()
835}
836
837#[cfg(not(feature = "parallel"))]
838fn pairs_to_raw<N>(
839 arena: &NblastArena<N>,
840 pairs: &[(NeuronIdx, NeuronIdx)],
841 normalize: bool,
842) -> HashMap<(NeuronIdx, NeuronIdx), Precision>
843where
844 N: TargetNeuron + Sync,
845{
846 pairs_to_raw_serial(arena, pairs, normalize)
847}
848
849#[cfg(feature = "parallel")]
850fn pairs_to_raw<N: TargetNeuron + Sync>(
851 arena: &NblastArena<N>,
852 pairs: &[(NeuronIdx, NeuronIdx)],
853 normalize: bool,
854) -> HashMap<(NeuronIdx, NeuronIdx), Precision> {
855 if let Some(t) = arena.threads {
856 let pool = rayon::ThreadPoolBuilder::new()
857 .num_threads(t)
858 .build()
859 .unwrap();
860 pool.install(|| {
861 pairs
862 .par_iter()
863 .filter_map(|(q_idx, t_idx)| {
864 arena
865 .query_target(*q_idx, *t_idx, normalize, &None)
866 .map(|s| ((*q_idx, *t_idx), s))
867 })
868 .collect()
869 })
870 } else {
871 pairs_to_raw_serial(arena, pairs, normalize)
872 }
873}
874
875#[cfg(test)]
876mod test {
877 use super::*;
878
879 const EPSILON: Precision = 0.001;
880 const N_NEIGHBORS: usize = 5;
881
882 fn add_points(a: &Point3, b: &Point3) -> Point3 {
883 let mut out = [0., 0., 0.];
884 for (idx, (x, y)) in a.iter().zip(b.iter()).enumerate() {
885 out[idx] = x + y;
886 }
887 out
888 }
889
890 fn make_points(offset: &Point3, step: &Point3, count: usize) -> Vec<Point3> {
891 let mut out = Vec::default();
892 out.push(*offset);
893
894 for _ in 0..count - 1 {
895 let to_push = add_points(out.last().unwrap(), step);
896 out.push(to_push);
897 }
898
899 out
900 }
901
902 #[test]
903 fn construct() {
904 let points = make_points(&[0., 0., 0.], &[1., 0., 0.], 10);
905 Neuron::new(points, N_NEIGHBORS).unwrap();
906 }
907
908 fn is_close(val1: Precision, val2: Precision) -> bool {
909 println!("Comparing values:\n\tval1: {:?}\n\tval2: {:?}", val1, val2);
910 (val1 - val2).abs() < EPSILON
911 }
912
913 fn assert_close(val1: Precision, val2: Precision) {
914 if !is_close(val1, val2) {
915 panic!("Not close:\n\t{:?}\n\t{:?}", val1, val2);
916 }
917 }
918
919 #[test]
920 fn unit_tangents_eig() {
921 let (points, _, _) = tangent_data();
922 let tangent = TangentAlpha::new_from_points(points.iter()).tangent;
923 assert_close(tangent.dot(&tangent), 1.0)
924 }
925
926 fn equivalent_tangents(tan1: &Normal3, tan2: &Normal3) -> bool {
927 is_close(tan1.dot(tan2).abs(), 1.0)
928 }
929
930 fn tangent_data() -> (Vec<Point3>, Normal3, Precision) {
931 let tangent = Unit::new_normalize(Vector3::from_column_slice(&[
933 -0.939_392_2,
934 0.313_061_82,
935 0.139_766_18,
936 ]));
937
938 let points = vec![
940 [
941 329.679_962_158_203,
942 72.718_803_405_761_7,
943 31.028_469_085_693_4,
944 ],
945 [
946 328.647_399_902_344,
947 73.046_119_689_941_4,
948 31.537_061_691_284_2,
949 ],
950 [
951 335.219_879_150_391,
952 70.710_479_736_328_1,
953 30.398_145_675_659_2,
954 ],
955 [
956 332.611_389_160_156,
957 72.322_929_382_324_2,
958 30.887_334_823_608_4,
959 ],
960 [
961 331.770_782_470_703,
962 72.434_440_612_793,
963 31.169_372_558_593_8,
964 ],
965 ];
966
967 let alpha = 0.844_842_871_450_449;
968
969 (points, tangent, alpha)
970 }
971
972 #[test]
973 fn test_tangent_eig() {
974 let (points, exp_tan, _exp_alpha) = tangent_data();
975 let ta = TangentAlpha::new_from_points(points.iter());
976 if !equivalent_tangents(&ta.tangent, &exp_tan) {
977 panic!(
978 "Non-equivalent tangents:\n\t{:?}\n\t{:?}",
979 ta.tangent, exp_tan
980 )
981 }
982 }
985
986 #[test]
987 fn test_neuron() {
988 let (points, exp_tan, _exp_alpha) = tangent_data();
989 let tgt = Neuron::new(points, N_NEIGHBORS).unwrap();
990 assert!(equivalent_tangents(
991 &tgt.tangents().next().unwrap(),
992 &exp_tan
993 ));
994 }
997
998 fn score_mat() -> (Vec<Precision>, Vec<Precision>, Vec<Precision>) {
1000 let dists = vec![10.0, 20.0, 30.0, 40.0, 50.0];
1001 let dots = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
1002 let mut values = vec![];
1003 let n_values = dots.len() * dists.len();
1004 for v in 0..n_values {
1005 values.push(v as Precision);
1006 }
1007 (dists, dots, values)
1008 }
1009
1010 #[test]
1011 fn test_score_calc() {
1012 let (dists, dots, values) = score_mat();
1013 let func = table_to_fn(dists, dots, values);
1014 assert_close(
1015 func(&DistDot {
1016 dist: 0.0,
1017 dot: 0.0,
1018 }),
1019 0.0,
1020 );
1021 assert_close(
1022 func(&DistDot {
1023 dist: 0.0,
1024 dot: 0.1,
1025 }),
1026 1.0,
1027 );
1028 assert_close(
1029 func(&DistDot {
1030 dist: 11.0,
1031 dot: 0.0,
1032 }),
1033 10.0,
1034 );
1035 assert_close(
1036 func(&DistDot {
1037 dist: 55.0,
1038 dot: 0.0,
1039 }),
1040 40.0,
1041 );
1042 assert_close(
1043 func(&DistDot {
1044 dist: 55.0,
1045 dot: 10.0,
1046 }),
1047 49.0,
1048 );
1049 assert_close(
1050 func(&DistDot {
1051 dist: 15.0,
1052 dot: 0.15,
1053 }),
1054 11.0,
1055 );
1056 }
1057
1058 #[test]
1059 fn test_find_bin_binary() {
1060 let dots = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
1061 assert_eq!(find_bin_binary(0.0, &dots), 0);
1062 assert_eq!(find_bin_binary(0.15, &dots), 1);
1063 assert_eq!(find_bin_binary(0.95, &dots), 9);
1064 assert_eq!(find_bin_binary(-10.0, &dots), 0);
1065 assert_eq!(find_bin_binary(10.0, &dots), 9);
1066 assert_eq!(find_bin_binary(0.1, &dots), 1);
1067 }
1068
1069 #[test]
1105 fn arena() {
1106 let dist_thresholds = vec![0.0, 1.0, 2.0];
1107 let dot_thresholds = vec![0.0, 0.5, 1.0];
1108 let cells = vec![1.0, 2.0, 4.0, 8.0];
1109
1110 let score_calc = ScoreCalc::Table(
1112 RangeTable::new_from_bins(vec![dist_thresholds, dot_thresholds], cells).unwrap(),
1113 );
1114
1115 let query =
1116 Neuron::new(make_points(&[0., 0., 0.], &[1., 0., 0.], 10), N_NEIGHBORS).unwrap();
1117 let target =
1118 Neuron::new(make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10), N_NEIGHBORS).unwrap();
1119
1120 let mut arena = NblastArena::new(score_calc, false);
1121 let q_idx = arena.add_neuron(query);
1122 let t_idx = arena.add_neuron(target);
1123
1124 let no_norm = arena
1125 .query_target(q_idx, t_idx, false, &None)
1126 .expect("should exist");
1127 let self_hit = arena
1128 .query_target(q_idx, q_idx, false, &None)
1129 .expect("should exist");
1130
1131 assert!(
1132 arena
1133 .query_target(q_idx, t_idx, true, &None)
1134 .expect("should exist")
1135 - no_norm / self_hit
1136 < EPSILON
1137 );
1138 assert_eq!(
1139 arena.query_target(q_idx, t_idx, false, &Some(Symmetry::ArithmeticMean)),
1140 arena.query_target(t_idx, q_idx, false, &Some(Symmetry::ArithmeticMean)),
1141 );
1142
1143 let out = arena.queries_targets(&[q_idx, t_idx], &[t_idx, q_idx], false, &None, None);
1144 assert_eq!(out.len(), 4);
1145 }
1146
1147 fn test_symmetry(symmetry: &Symmetry, a: Precision, b: Precision) {
1148 assert_close(symmetry.apply(a, b), symmetry.apply(b, a))
1149 }
1150
1151 fn test_symmetry_multiple(symmetry: &Symmetry) {
1152 for (a, b) in vec![(0.3, 0.7), (0.0, 0.7), (-1.0, 0.7), (100.0, 1000.0)].into_iter() {
1153 test_symmetry(symmetry, a, b);
1154 }
1155 }
1156
1157 #[test]
1158 fn symmetry_arithmetic() {
1159 test_symmetry_multiple(&Symmetry::ArithmeticMean)
1160 }
1161
1162 #[test]
1163 fn symmetry_harmonic() {
1164 test_symmetry_multiple(&Symmetry::HarmonicMean)
1165 }
1166
1167 #[test]
1168 fn symmetry_geometric() {
1169 test_symmetry_multiple(&Symmetry::GeometricMean)
1170 }
1171
1172 #[test]
1173 fn symmetry_min() {
1174 test_symmetry_multiple(&Symmetry::Min)
1175 }
1176
1177 #[test]
1178 fn symmetry_max() {
1179 test_symmetry_multiple(&Symmetry::Max)
1180 }
1181
1182 }