use crate::distance::scale_dist;
use crate::knn::{find_k_nearest_neighbors, find_k_nearest_neighbors_approx, KnnError};
use crate::sampling::{
sample_fp_pair, sample_fp_pair_deterministic, sample_mn_pair, sample_mn_pair_deterministic,
sample_neighbors_pair,
};
use crate::Pairs;
use ndarray::{s, Array1, Array2, ArrayView2, Axis};
use std::cmp::min;
pub fn generate_pairs(
x: ArrayView2<f32>,
n_neighbors: usize,
n_mn: usize,
n_fp: usize,
random_state: Option<u64>,
approx_threshold: usize,
) -> Result<Pairs, KnnError> {
let n = x.nrows();
let n_neighbors_extra = (n_neighbors + 50).min(n - 1);
let n_neighbors = n_neighbors.min(n - 1);
let (neighbors, knn_distances) = if n < approx_threshold {
find_k_nearest_neighbors(x, n_neighbors_extra)
} else {
find_k_nearest_neighbors_approx(x, n_neighbors_extra)?
};
let start = min(3, knn_distances.ncols().saturating_sub(1));
let end = min(6, knn_distances.ncols());
let sig = knn_distances
.slice(s![.., start..end])
.mean_axis(Axis(1))
.map_or_else(|| Array1::from_elem(n, 1e-10), |d| d.mapv(|x| x.max(1e-10)));
let neighbors_view = neighbors.view();
let scaled_dist = scale_dist(knn_distances.view(), sig.view(), neighbors_view);
let pair_neighbors =
sample_neighbors_pair(x.view(), scaled_dist.view(), neighbors_view, n_neighbors);
let (pair_mn, pair_fp) = match random_state {
Some(seed) => (
sample_mn_pair_deterministic(x.view(), n_mn, seed),
sample_fp_pair_deterministic(x.view(), pair_neighbors.view(), n_neighbors, n_fp, seed),
),
None => (
sample_mn_pair(x.view(), n_mn),
sample_fp_pair(x.view(), pair_neighbors.view(), n_neighbors, n_fp),
),
};
Ok(Pairs {
pair_neighbors,
pair_mn,
pair_fp,
})
}
pub fn generate_pair_no_neighbors(
x: ArrayView2<f32>,
n_neighbors: usize,
n_mn: usize,
n_fp: usize,
pair_neighbors: ArrayView2<u32>,
random_seed: Option<u64>,
) -> (Array2<u32>, Array2<u32>) {
match random_seed {
Some(seed) => (
sample_mn_pair_deterministic(x, n_mn, seed),
sample_fp_pair_deterministic(x, pair_neighbors, n_neighbors, n_fp, seed),
),
None => (
sample_mn_pair(x, n_mn),
sample_fp_pair(x, pair_neighbors, n_neighbors, n_fp),
),
}
}