use fastrand::Rng;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use std::collections::HashSet;
use std::ops::BitXor;
use thiserror::Error;
use crate::{BinLookup, NdBinLookup, RangeTable};
use crate::{DistDot, Precision, TargetNeuron};
const EPSILON: Precision = 1e-6;
type JobSet = HashSet<(usize, usize)>;
#[derive(Debug, Error)]
pub enum ScoreMatBuildErr {
#[error("No matching sets given")]
MatchingSets,
#[error("No distance bin information given")]
DistBins,
#[error("No dot product bin information given")]
DotBins,
}
struct PairSampler {
sets: Vec<Vec<usize>>,
cumu_weights: Vec<f64>,
pub rng: Rng,
}
impl PairSampler {
fn new(sets: Vec<Vec<usize>>, seed: Option<u64>) -> Self {
let rng = Rng::new();
if let Some(s) = seed {
rng.seed(s);
}
let mut weights: Vec<_> = sets.iter().map(|v| v.len().pow(2) as f64).collect();
let total: f64 = weights.iter().sum();
let mut acc = 0.0;
for x in &mut weights {
acc += *x / total;
*x = acc;
}
Self {
sets,
cumu_weights: weights,
rng,
}
}
pub fn from_sets(sets: &[HashSet<usize>], seed: Option<u64>) -> Self {
let vecs = sets
.iter()
.map(|s| {
if s.len() < 2 {
panic!("Gave set with <2 components");
}
let mut v: Vec<_> = s.iter().cloned().collect();
v.sort();
v
})
.collect();
Self::new(vecs, seed)
}
fn outer_idx(&mut self) -> usize {
if self.cumu_weights.len() == 1 {
return 0;
}
let rand = self.rng.f64();
match self
.cumu_weights
.binary_search_by(|w| w.partial_cmp(&rand).expect("NaN weight"))
{
Ok(idx) => idx,
Err(idx) => idx,
}
}
pub fn sample(&mut self) -> (usize, usize) {
let set_idx = self.outer_idx();
let set = &self.sets[set_idx];
let first = self.rng.usize(0..set.len());
let mut second = first;
while second == first {
second = self.rng.usize(0..set.len());
}
(first, second)
}
pub fn sample_n(&mut self, n: usize) -> Result<JobSet, JobSet> {
let mut out = HashSet::default();
if n * 10 < self.n_pairs() {
let mut count = 0;
while out.len() < n {
out.insert(self.sample());
count += 1;
if count >= n * 10 {
return Err(out);
}
}
} else {
let mut v = self.exhaust();
self.rng.shuffle(&mut v);
while out.len() < n {
if let Some(s) = v.pop() {
out.insert(s);
} else {
return Err(out);
}
}
}
Ok(out)
}
pub fn n_pairs(&self) -> usize {
self.sets
.iter()
.map(|v| {
let len = v.len();
if len < 2 {
0
} else {
len * (len - 1)
}
})
.sum()
}
pub fn exhaust(&self) -> Vec<(usize, usize)> {
let mut out = Vec::default();
for set in self.sets.iter() {
for q_idx in set.iter() {
for t_idx in set.iter() {
if t_idx != q_idx {
out.push((*q_idx, *t_idx));
}
}
}
}
out
}
}
fn make_rng(seed: Option<u64>) -> Rng {
let rng = Rng::new();
if let Some(s) = seed {
rng.seed(s);
}
rng
}
pub struct TrainingSampler {
seed: Option<u64>,
matching_sets: Vec<HashSet<usize>>,
nonmatching_sets: Option<Vec<HashSet<usize>>>,
n_neurons: usize,
}
impl TrainingSampler {
pub fn new(n_neurons: usize, seed: Option<u64>) -> Self {
Self {
seed,
matching_sets: Vec::default(),
nonmatching_sets: None,
n_neurons,
}
}
pub fn add_matching_set(&mut self, matching: &[usize]) -> &mut Self {
let set: HashSet<usize> = matching
.iter()
.filter_map(|idx| {
if idx < &self.n_neurons {
Some(*idx)
} else {
None
}
})
.collect();
if set.len() >= 2 {
self.matching_sets.push(set);
}
self
}
#[allow(clippy::vec_init_then_push)]
pub fn add_nonmatching_set(&mut self, nonmatching: &[usize]) -> &mut Self {
let set: HashSet<usize> = nonmatching
.iter()
.filter_map(|idx| {
if idx < &self.n_neurons {
Some(*idx)
} else {
None
}
})
.collect();
if set.len() < 2 {
return self;
}
if let Some(ref mut vector) = self.nonmatching_sets {
vector.push(set);
} else {
let mut v = Vec::default();
v.push(set);
self.nonmatching_sets = Some(v);
}
self
}
pub fn make_jobs(
&self,
n_matching: Option<usize>,
n_nonmatching: Option<usize>,
) -> (JobSet, JobSet) {
let seed = self.seed.map(|s1| {
let mut s = s1;
if let Some(s2) = n_matching {
s = s.bitxor(s2 as u64);
}
if let Some(s3) = n_nonmatching {
s = s.bitxor(s3 as u64);
}
s
});
let rng = make_rng(seed);
let mut match_sampler =
PairSampler::from_sets(&self.matching_sets, Some(rng.u64(0..u64::MAX)));
let matching_jobs: JobSet = match n_matching {
Some(n) => match match_sampler.sample_n(n) {
Ok(jobs) => jobs,
Err(jobs) => jobs,
},
None => match_sampler.exhaust().into_iter().collect(),
};
let nonmatch_seed = Some(rng.u64(0..u64::MAX));
let mut nonmatch_sampler = self.nonmatching_sets.as_ref().map_or_else(
|| {
let v = vec![(0..self.n_neurons).collect()];
PairSampler::from_sets(&v, nonmatch_seed)
},
|s| PairSampler::from_sets(s, nonmatch_seed),
);
let n_nm = n_nonmatching.unwrap_or(matching_jobs.len());
if n_nm > nonmatch_sampler.n_pairs() {
panic!("Not enough non-matching neurons")
}
let nonmatching_jobs = nonmatch_sampler.sample_n(n_nm).unwrap_or_else(|s| s);
(matching_jobs, nonmatching_jobs)
}
}
fn all_distdots<T: TargetNeuron>(
neurons: &[T],
jobs: &[(usize, usize)],
use_alpha: bool,
) -> Vec<DistDot> {
jobs.iter()
.flat_map(|(q, t)| neurons[*q].query_dist_dots(&neurons[*t], use_alpha))
.collect()
}
#[cfg(feature = "parallel")]
pub fn all_distdots_par<T: TargetNeuron + Sync>(
neurons: &[T],
jobs: &[(usize, usize)],
use_alpha: bool,
) -> Vec<DistDot> {
jobs.par_iter()
.map(|(q, t)| {
neurons[*q]
.query_dist_dots(&neurons[*t], use_alpha)
.collect::<Vec<_>>()
})
.flatten()
.collect()
}
#[derive(Clone, Debug)]
enum LookupArgs {
Lookup(BinLookup<Precision>),
NBins(usize),
}
#[allow(dead_code)]
pub struct ScoreMatrixBuilder<T: TargetNeuron> {
neurons: Vec<T>,
sampler: TrainingSampler,
use_alpha: bool,
threads: Option<usize>,
dist_bin_lookup: Option<LookupArgs>,
dot_bin_lookup: Option<LookupArgs>,
max_matching_pairs: Option<usize>,
max_nonmatching_pairs: Option<usize>,
}
impl<T: TargetNeuron + Sync> ScoreMatrixBuilder<T> {
pub fn new(neurons: Vec<T>, seed: u64) -> Self {
let n_neurons = neurons.len();
Self {
neurons,
sampler: TrainingSampler::new(n_neurons, Some(seed)),
use_alpha: false,
threads: None,
dist_bin_lookup: None,
dot_bin_lookup: None,
max_matching_pairs: None,
max_nonmatching_pairs: None,
}
}
pub fn add_matching_set(&mut self, matching: &[usize]) -> &mut Self {
self.sampler.add_matching_set(matching);
self
}
pub fn add_nonmatching_set(&mut self, nonmatching: &[usize]) -> &mut Self {
self.sampler.add_nonmatching_set(nonmatching);
self
}
pub fn set_use_alpha(&mut self, use_alpha: bool) -> &mut Self {
self.use_alpha = use_alpha;
self
}
#[cfg(feature = "parallel")]
pub fn set_threads(&mut self, threads: Option<usize>) -> &mut Self {
self.threads = threads;
self
}
pub fn set_dist_lookup(&mut self, lookup: BinLookup<Precision>) -> &mut Self {
self.dist_bin_lookup = Some(LookupArgs::Lookup(lookup));
self
}
pub fn set_n_dist_bins(&mut self, n_bins: usize) -> &mut Self {
self.dist_bin_lookup = Some(LookupArgs::NBins(n_bins));
self
}
pub fn set_dot_lookup(&mut self, lookup: BinLookup<Precision>) -> &mut Self {
self.dot_bin_lookup = Some(LookupArgs::Lookup(lookup));
self
}
pub fn set_n_dot_bins(&mut self, n_bins: usize) -> &mut Self {
self.dot_bin_lookup = Some(LookupArgs::NBins(n_bins));
self
}
pub fn set_max_matching_pairs(&mut self, n_pairs: usize) -> &mut Self {
self.max_matching_pairs = Some(n_pairs);
self
}
pub fn set_max_nonmatching_pairs(&mut self, n_pairs: usize) -> &mut Self {
self.max_nonmatching_pairs = Some(n_pairs);
self
}
fn _get_lookup_args(&self) -> Result<(LookupArgs, LookupArgs), ScoreMatBuildErr> {
let dist_lookup_args = match &self.dist_bin_lookup {
Some(lookup) => lookup.clone(),
None => return Err(ScoreMatBuildErr::DistBins),
};
let dot_lookup_args = match &self.dot_bin_lookup {
Some(lookup) => lookup.clone(),
None => return Err(ScoreMatBuildErr::DotBins),
};
Ok((dist_lookup_args, dot_lookup_args))
}
fn _get_lookup(
&self,
match_distdots: &[DistDot],
) -> Result<NdBinLookup<Precision>, ScoreMatBuildErr> {
let (dist_lookup_args, dot_lookup_args) = self._get_lookup_args()?;
let dist_bin_lookup = match dist_lookup_args {
LookupArgs::Lookup(lookup) => Ok(lookup),
LookupArgs::NBins(n) => {
let mut dists: Vec<_> = match_distdots.iter().map(|dd| dd.dist).collect();
dists.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
BinLookup::new_n_quantiles(&dists, n, (true, true))
.map_err(|_| ScoreMatBuildErr::DistBins)
}
}?;
let dot_bin_lookup = match dot_lookup_args {
LookupArgs::Lookup(lookup) => Ok(lookup),
LookupArgs::NBins(n) => {
let mut dots: Vec<_> = match_distdots.iter().map(|dd| dd.dot).collect();
dots.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
BinLookup::new_n_quantiles(&dots, n, (true, true))
.map_err(|_| ScoreMatBuildErr::DotBins)
}
}?;
Ok(NdBinLookup::new(vec![dist_bin_lookup, dot_bin_lookup]))
}
pub fn build(&self) -> Result<RangeTable<Precision, Precision>, ScoreMatBuildErr> {
if self.sampler.matching_sets.is_empty() {
return Err(ScoreMatBuildErr::MatchingSets);
}
let (match_jobs, nonmatch_jobs) = self
.sampler
.make_jobs(self.max_matching_pairs, self.max_nonmatching_pairs);
let match_distdots = self._all_distdots(&match_jobs.into_iter().collect::<Vec<_>>());
let dist_dot_lookup = self._get_lookup(&match_distdots)?;
let match_counts = cell_counts(&dist_dot_lookup, match_distdots);
let nonmatch_distdots = self._all_distdots(&nonmatch_jobs.into_iter().collect::<Vec<_>>());
let nonmatch_counts = cell_counts(&dist_dot_lookup, nonmatch_distdots);
let cells = log_odds_ratio(match_counts, nonmatch_counts);
Ok(RangeTable {
bins_lookup: dist_dot_lookup,
cells,
})
}
#[cfg(not(feature = "parallel"))]
fn _all_distdots(&self, jobs: &[(usize, usize)]) -> Vec<DistDot> {
all_distdots(&self.neurons, jobs, self.use_alpha)
}
#[cfg(feature = "parallel")]
fn _all_distdots(&self, jobs: &[(usize, usize)]) -> Vec<DistDot> {
if let Some(t) = self.threads {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(t)
.build()
.unwrap();
pool.install(|| all_distdots_par(&self.neurons, jobs, self.use_alpha))
} else {
all_distdots(&self.neurons, jobs, self.use_alpha)
}
}
}
fn cell_counts(lookup: &NdBinLookup<Precision>, distdots: Vec<DistDot>) -> Vec<Precision> {
let mut counts = vec![0.0; lookup.n_cells];
for dd in distdots {
if let Ok(idx) = lookup.to_linear_idx(&[dd.dist, dd.dot]) {
counts[idx] += 1.0;
}
}
counts
}
fn log_odds_ratio(match_counts: Vec<Precision>, nonmatch_counts: Vec<Precision>) -> Vec<Precision> {
let match_total: Precision = match_counts.iter().sum();
let nonmatch_total: Precision = nonmatch_counts.iter().sum();
match_counts
.into_iter()
.zip(nonmatch_counts)
.map(|(match_count, nonmatch_count)| {
let p_match = match_count / match_total;
let p_nonmatch = nonmatch_count / nonmatch_total;
((p_match + EPSILON) / (p_nonmatch + EPSILON)).log2()
})
.collect()
}