use fastrand;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use std::collections::HashSet;
use std::error;
use std::fmt;
use std::iter;
#[cfg(feature = "rayon")]
use std::sync::mpsc::channel;
use crate::{BinLookup, NdBinLookup, RangeTable};
use crate::{DistDot, Precision, TargetNeuron};
const EPSILON: Precision = 1e-6;
type JobSet = HashSet<(usize, usize)>;
#[derive(Debug)]
pub struct ScoreMatBuildErr {}
impl fmt::Display for ScoreMatBuildErr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Bins not set or no matching neurons given")
}
}
impl error::Error for ScoreMatBuildErr {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
None
}
}
fn logspace(
base: Precision,
min_exp: Precision,
max_exp: Precision,
count: usize,
) -> Vec<Precision> {
assert!(count > 2);
let step = (max_exp - min_exp) / (count - 1) as Precision;
(0..count)
.map(|idx| base.powf(min_exp + idx as Precision * step))
.collect()
}
pub struct ScoreMatrixBuilder<T: TargetNeuron> {
neurons: Vec<T>,
seed: u64,
matching_sets: Vec<HashSet<usize>>,
nonmatching: Option<Vec<usize>>,
use_alpha: bool,
threads: Option<usize>,
dist_bin_lookup: Option<BinLookup<Precision>>,
dot_bin_lookup: Option<BinLookup<Precision>>,
}
impl<T: TargetNeuron + Sync> ScoreMatrixBuilder<T> {
pub fn new(neurons: Vec<T>, seed: u64) -> Self {
Self {
neurons,
seed,
matching_sets: Vec::default(),
nonmatching: None,
use_alpha: false,
threads: Some(0),
dist_bin_lookup: None,
dot_bin_lookup: None,
}
}
pub fn add_matching_set(&mut self, matching: HashSet<usize>) -> &mut Self {
self.matching_sets.push(matching);
self
}
pub fn set_nonmatching(&mut self, nonmatching: Vec<usize>) -> &mut Self {
self.nonmatching = Some(nonmatching);
self
}
pub fn set_use_alpha(&mut self, use_alpha: bool) -> &mut Self {
self.use_alpha = use_alpha;
self
}
pub fn set_threads(&mut self, threads: Option<usize>) -> &mut Self {
self.threads = threads;
self
}
pub fn set_dist_bins(&mut self, bins: Vec<Precision>) -> &mut Self {
self.dist_bin_lookup = Some(BinLookup::new(bins, (true, true)).expect("Illegal bins"));
self
}
pub fn set_n_dist_bins(
&mut self,
n_bins: usize,
base: Precision,
min_exp: Precision,
max_exp: Precision,
) -> &mut Self {
let mut v = logspace(base, min_exp, max_exp, n_bins);
v.push(1.0 / 0.0);
self.set_dist_bins(v)
}
pub fn set_dot_bins(&mut self, dot_bin_boundaries: Vec<Precision>) -> &mut Self {
self.dot_bin_lookup =
Some(BinLookup::new(dot_bin_boundaries, (true, true)).expect("Illegal bins"));
self
}
pub fn set_n_dot_bins(&mut self, n_bins: usize) -> &mut Self {
let step = 1.0 / (n_bins + 1) as Precision;
self.set_dot_bins((0..(n_bins + 1)).map(|n| step * n as Precision).collect())
}
pub fn build(&self) -> Result<RangeTable<Precision, Precision>, ScoreMatBuildErr> {
if self.matching_sets.is_empty() {
return Err(ScoreMatBuildErr {});
}
let dist_bin_lookup = match &self.dist_bin_lookup {
Some(lookup) => lookup.clone(),
None => return Err(ScoreMatBuildErr {}),
};
let dot_bin_lookup = match &self.dot_bin_lookup {
Some(lookup) => lookup.clone(),
None => return Err(ScoreMatBuildErr {}),
};
let dist_dot_lookup = NdBinLookup::new(vec![dist_bin_lookup, dot_bin_lookup]);
let (match_jobs, nonmatch_jobs) = self._match_nonmatch_jobs();
let matching_factor = nonmatch_jobs.len() as Precision / match_jobs.len() as Precision;
let match_counts = self._pairs_to_counts(match_jobs, &dist_dot_lookup);
let nonmatch_counts = self._pairs_to_counts(nonmatch_jobs, &dist_dot_lookup);
let cells = match_counts
.into_iter()
.zip(nonmatch_counts.into_iter())
.map(|(match_count, nonmatch_count)| {
((match_count as Precision * matching_factor + EPSILON)
/ (nonmatch_count as Precision + EPSILON))
.log2()
})
.collect();
Ok(RangeTable {
bins_lookup: dist_dot_lookup,
cells,
})
}
fn _match_nonmatch_jobs(&self) -> (JobSet, JobSet) {
let mut matching_len = 0;
let mut matching_jobs = JobSet::default();
for matching_set in self.matching_sets.iter() {
for q_idx in matching_set.iter() {
let q_len = match self.neurons.get(*q_idx) {
Some(n) => n.len(),
None => continue,
};
for t_idx in matching_set.iter() {
if t_idx != q_idx
&& t_idx < &self.neurons.len()
&& matching_jobs.insert((*q_idx, *t_idx))
{
matching_len += q_len
}
}
}
}
let nonmatching_idxs = self
.nonmatching
.as_ref()
.cloned() .or_else(|| Some((0..self.neurons.len()).collect()))
.unwrap();
if matching_jobs.len() > nonmatching_idxs.len() * (nonmatching_idxs.len() - 1) {
panic!("Not enough non-matching neurons")
}
let idx_range = ..nonmatching_idxs.len();
let rng = fastrand::Rng::new();
rng.seed(self.seed);
rng.usize(idx_range);
let mut nonmatching_jobs: HashSet<(usize, usize)> = HashSet::default();
while matching_len > 0 {
let q_idx = nonmatching_idxs[rng.usize(idx_range)];
let t_idx = nonmatching_idxs[rng.usize(idx_range)];
let key = (q_idx, t_idx);
if q_idx != t_idx && !matching_jobs.contains(&key) && nonmatching_jobs.insert(key) {
matching_len -= self.neurons[q_idx].len()
}
}
(matching_jobs, nonmatching_jobs)
}
fn _idx_to_distdots(&self, q_idx: usize, t_idx: usize) -> Option<Vec<DistDot>> {
let q = self.neurons.get(q_idx)?;
let t = self.neurons.get(t_idx)?;
Some(q.query_dist_dots(t, self.use_alpha))
}
fn _pairs_to_counts_ser(
&self,
jobs: impl IntoIterator<Item = (usize, usize)>,
dist_dot_lookup: &NdBinLookup<Precision>,
) -> Vec<usize> {
let mut counts: Vec<usize> = iter::repeat(0).take(dist_dot_lookup.n_cells).collect();
let distdots = jobs
.into_iter()
.filter_map(|(q_idx, t_idx)| self._idx_to_distdots(q_idx, t_idx))
.flatten();
for dd in distdots {
let idx = dist_dot_lookup.to_linear_idx(&[dd.dist, dd.dot]).unwrap();
counts[idx] += 1;
}
counts
}
#[cfg(not(feature = "parallel"))]
fn _pairs_to_counts(
&self,
jobs: HashSet<(usize, usize)>,
dist_dot_lookup: &NdBinLookup<Precision>,
) -> Vec<usize> {
self._pairs_to_counts_ser(jobs, dist_dot_lookup)
}
#[cfg(feature = "parallel")]
fn _pairs_to_counts(
&self,
jobs: HashSet<(usize, usize)>,
dist_dot_lookup: &NdBinLookup<Precision>,
) -> Vec<usize> {
if let Some(t) = self.threads {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(t)
.build()
.unwrap();
pool.install(|| {
let (sender, receiver) = channel();
jobs.into_par_iter()
.filter_map(|(q_idx, t_idx)| self._idx_to_distdots(q_idx, t_idx))
.flatten()
.map(|dd| dist_dot_lookup.to_linear_idx(&[dd.dist, dd.dot]).unwrap())
.for_each_with(sender, |s, x| s.send(x).unwrap());
let mut counts: Vec<usize> =
iter::repeat(0).take(dist_dot_lookup.n_cells).collect();
for idx in receiver.iter() {
counts[idx] += 1;
}
counts
})
} else {
self._pairs_to_counts_ser(jobs, dist_dot_lookup)
}
}
}
#[cfg(test)]
mod test {
use super::*;
fn assert_slice_eq(test: &[Precision], reference: &[Precision]) {
let msg = format!("\ttest: {:?}\n\t ref: {:?}", test, reference);
if test.len() != reference.len() {
panic!("Slices have different length\n{}", msg);
}
for (test_val, ref_val) in test.iter().zip(reference.iter()) {
if (test_val - ref_val).abs() > Precision::EPSILON {
panic!("Slices mave mismatched values\n{}", msg)
}
}
}
#[test]
fn test_logspace() {
let base: Precision = 10.0;
let count: usize = 5;
assert_slice_eq(
&logspace(base, 1.0, 5.0, count),
&[
(10f64).powf(1.0),
(10f64).powf(2.0),
(10f64).powf(3.0),
(10f64).powf(4.0),
(10f64).powf(5.0),
],
);
}
}