use crate::umap::smooth_knn_dist::SmoothKnnDist;
use crate::utils::parallel_vec::ParallelVec;
use dashmap::DashSet;
use ndarray::Array1;
use ndarray::ArrayView1;
use ndarray::ArrayView2;
use rayon::prelude::*;
use sprs::CsMatI;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::time::Instant;
use tracing::info;
use typed_builder::TypedBuilder;
pub type SparseMat = CsMatI<f32, u32, usize>;
struct CscStructure {
indptr: Vec<usize>, indices: Vec<u32>, }
impl CscStructure {
fn col_row_indices(&self, col: usize) -> &[u32] {
let start = self.indptr[col];
let end = self.indptr[col + 1];
&self.indices[start..end]
}
}
#[derive(TypedBuilder, Debug)]
pub struct FuzzySimplicialSet<'a, 'd> {
n_samples: usize,
n_neighbors: usize,
knn_indices: ArrayView2<'a, u32>,
knn_dists: ArrayView2<'a, f32>,
knn_disconnections: &'d DashSet<(usize, usize)>,
#[builder(default = 1.0)]
set_op_mix_ratio: f32,
#[builder(default = 1.0)]
local_connectivity: f32,
#[builder(default = true)]
apply_set_operations: bool,
}
impl<'a, 'd> FuzzySimplicialSet<'a, 'd> {
pub fn exec(self) -> (SparseMat, Array1<f32>, Array1<f32>) {
assert!(
self.n_samples < u32::MAX as usize,
"n_samples must be < 2^32 for u32 indices"
);
let knn_dists = self.knn_dists;
let knn_indices = self.knn_indices;
let knn_disconnections = self.knn_disconnections;
let n_neighbors = self.n_neighbors;
let n_samples = self.n_samples;
let local_connectivity = self.local_connectivity;
let set_op_mix_ratio = self.set_op_mix_ratio;
let apply_set_operations = self.apply_set_operations;
let started = Instant::now();
let (sigmas, rhos) = SmoothKnnDist::builder()
.distances(knn_dists)
.k(n_neighbors)
.local_connectivity(local_connectivity)
.build()
.exec();
info!(
duration_ms = started.elapsed().as_millis(),
"smooth_knn_dist complete"
);
let started = Instant::now();
let mut result = build_membership_csr(
n_samples,
n_neighbors,
knn_indices,
knn_dists,
knn_disconnections,
&sigmas.view(),
&rhos.view(),
);
info!(
duration_ms = started.elapsed().as_millis(),
nnz = result.nnz(),
"build_membership_csr complete"
);
if apply_set_operations {
let started = Instant::now();
result = apply_set_operations_parallel(&result, set_op_mix_ratio);
info!(
duration_ms = started.elapsed().as_millis(),
"set_operations complete"
);
}
(result, sigmas, rhos)
}
}
fn build_membership_csr(
n_samples: usize,
n_neighbors: usize,
knn_indices: ArrayView2<u32>,
knn_dists: ArrayView2<f32>,
knn_disconnections: &DashSet<(usize, usize)>,
sigmas: &ArrayView1<f32>,
rhos: &ArrayView1<f32>,
) -> SparseMat {
let started = Instant::now();
let row_counts: Vec<u32> = (0..n_samples)
.into_par_iter()
.map(|i| {
let mut count = 0u32;
for j in 0..n_neighbors {
if knn_disconnections.contains(&(i, j)) {
continue;
}
let knn_idx = knn_indices[(i, j)] as usize;
if knn_idx == i || knn_idx >= n_samples {
continue;
}
let val = compute_membership_strength(i, j, knn_dists, rhos, sigmas);
if val != 0.0 {
count += 1;
}
}
count
})
.collect();
info!(
duration_ms = started.elapsed().as_millis(),
"csr row_counts complete"
);
let started = Instant::now();
let mut indptr: Vec<usize> = Vec::with_capacity(n_samples + 1);
indptr.push(0);
let mut total = 0usize;
for &count in &row_counts {
total += count as usize;
indptr.push(total);
}
let nnz = total;
info!(
duration_ms = started.elapsed().as_millis(),
nnz, "csr indptr complete"
);
let indices_vec = ParallelVec::new(vec![0u32; nnz]);
let data_vec = ParallelVec::new(vec![0.0f32; nnz]);
let started = Instant::now();
(0..n_samples).into_par_iter().for_each(|i| {
let row_start = indptr[i];
let mut offset = 0;
for j in 0..n_neighbors {
if knn_disconnections.contains(&(i, j)) {
continue;
}
let knn_idx = knn_indices[(i, j)];
if knn_idx as usize == i || knn_idx as usize >= n_samples {
continue;
}
let val = compute_membership_strength(i, j, knn_dists, rhos, sigmas);
if val != 0.0 {
unsafe {
indices_vec.write(row_start + offset, knn_idx);
data_vec.write(row_start + offset, val);
}
offset += 1;
}
}
});
info!(
duration_ms = started.elapsed().as_millis(),
"csr fill complete"
);
let started = Instant::now();
(0..n_samples).into_par_iter().for_each(|i| {
let row_start = indptr[i];
let row_len = indptr[i + 1] - indptr[i];
if row_len > 0 {
let row_indices = unsafe { indices_vec.get_mut_slice(row_start, row_len) };
let row_data = unsafe { data_vec.get_mut_slice(row_start, row_len) };
for k in 1..row_len {
let mut m = k;
while m > 0 && row_indices[m - 1] > row_indices[m] {
row_indices.swap(m - 1, m);
row_data.swap(m - 1, m);
m -= 1;
}
}
}
});
info!(
duration_ms = started.elapsed().as_millis(),
"csr row_sort complete"
);
let indices = indices_vec.into_inner();
let data = data_vec.into_inner();
CsMatI::new((n_samples, n_samples), indptr, indices, data)
}
fn compute_membership_strength(
i: usize,
j: usize,
knn_dists: ArrayView2<f32>,
rhos: &ArrayView1<f32>,
sigmas: &ArrayView1<f32>,
) -> f32 {
if knn_dists[(i, j)] - rhos[i] <= 0.0 || sigmas[i] == 0.0 {
1.0
} else {
f32::exp(-(knn_dists[(i, j)] - rhos[i]) / sigmas[i])
}
}
fn build_csc_structure(csr: &SparseMat) -> CscStructure {
let n_rows = csr.shape().0;
let n_cols = csr.shape().1;
let nnz = csr.nnz();
let started = Instant::now();
let col_counts: Vec<AtomicU32> = (0..n_cols).map(|_| AtomicU32::new(0)).collect();
(0..n_rows).into_par_iter().for_each(|row| {
let row_start = csr.indptr().index(row) as usize;
let row_end = csr.indptr().index(row + 1) as usize;
for &col in &csr.indices()[row_start..row_end] {
col_counts[col as usize].fetch_add(1, Ordering::Relaxed);
}
});
info!(
duration_ms = started.elapsed().as_millis(),
"csc col_counts complete"
);
let started = Instant::now();
let mut indptr: Vec<usize> = Vec::with_capacity(n_cols + 1);
indptr.push(0);
let mut total = 0usize;
for count in &col_counts {
total += count.load(Ordering::Relaxed) as usize;
indptr.push(total);
}
assert_eq!(total, nnz);
info!(
duration_ms = started.elapsed().as_millis(),
"csc indptr complete"
);
let started = Instant::now();
let mut indices: Vec<u32> = vec![0; nnz];
let mut col_offsets: Vec<usize> = vec![0; n_cols];
for row in 0..n_rows {
let row_start = csr.indptr().index(row);
let row_end = csr.indptr().index(row + 1);
let row_indices = &csr.indices()[row_start..row_end];
for &col in row_indices {
let write_pos = indptr[col as usize] + col_offsets[col as usize];
indices[write_pos] = row as u32;
col_offsets[col as usize] += 1;
}
}
info!(
duration_ms = started.elapsed().as_millis(),
"csc fill complete"
);
CscStructure { indptr, indices }
}
fn csr_get(csr: &SparseMat, row: usize, col: u32) -> f32 {
let row_start = csr.indptr().index(row);
let row_end = csr.indptr().index(row + 1);
let row_indices = &csr.indices()[row_start..row_end];
let row_data = &csr.data()[row_start..row_end];
match row_indices.binary_search(&col) {
Ok(idx) => row_data[idx],
Err(_) => 0.0,
}
}
fn apply_set_operations_parallel(input: &SparseMat, set_op_mix_ratio: f32) -> SparseMat {
let n_samples = input.shape().0;
let prod_coeff = 1.0 - 2.0 * set_op_mix_ratio;
let started = Instant::now();
let csc = build_csc_structure(input);
info!(
duration_ms = started.elapsed().as_millis(),
"set_operations csc_structure complete"
);
let started = Instant::now();
let row_counts: Vec<u32> = (0..n_samples)
.into_par_iter()
.map(|row| {
let row_start = input.indptr().index(row);
let row_end = input.indptr().index(row + 1);
let row_indices = &input.indices()[row_start..row_end];
let row_data = &input.data()[row_start..row_end];
let mut count = 0u32;
for (&col, &val_rc) in row_indices.iter().zip(row_data) {
let val_cr = csr_get(input, col as usize, row as u32);
let final_val =
set_op_mix_ratio * val_rc + set_op_mix_ratio * val_cr + prod_coeff * val_rc * val_cr;
if final_val != 0.0 {
count += 1;
}
}
for &c in csc.col_row_indices(row) {
if csr_get(input, row, c) != 0.0 {
continue;
}
let val_cr = csr_get(input, c as usize, row as u32);
let final_val = set_op_mix_ratio * val_cr; if final_val != 0.0 {
count += 1;
}
}
count
})
.collect();
info!(
duration_ms = started.elapsed().as_millis(),
"set_operations row_counts complete"
);
let started = Instant::now();
let mut indptr: Vec<usize> = Vec::with_capacity(n_samples + 1);
indptr.push(0);
let mut total = 0usize;
for &count in &row_counts {
total += count as usize;
indptr.push(total);
}
let nnz = total;
info!(
duration_ms = started.elapsed().as_millis(),
nnz, "set_operations indptr complete"
);
let indices_vec = ParallelVec::new(vec![0u32; nnz]);
let data_vec = ParallelVec::new(vec![0.0f32; nnz]);
let started = Instant::now();
(0..n_samples).into_par_iter().for_each(|row| {
let out_start = indptr[row];
let mut offset = 0;
let row_start = input.indptr().index(row);
let row_end = input.indptr().index(row + 1);
let row_indices = &input.indices()[row_start..row_end];
let row_data = &input.data()[row_start..row_end];
for (&col, &val_rc) in row_indices.iter().zip(row_data) {
let val_cr = csr_get(input, col as usize, row as u32);
let final_val =
set_op_mix_ratio * val_rc + set_op_mix_ratio * val_cr + prod_coeff * val_rc * val_cr;
if final_val != 0.0 {
unsafe {
indices_vec.write(out_start + offset, col);
data_vec.write(out_start + offset, final_val);
}
offset += 1;
}
}
for &c in csc.col_row_indices(row) {
if csr_get(input, row, c) != 0.0 {
continue;
}
let val_cr = csr_get(input, c as usize, row as u32);
let final_val = set_op_mix_ratio * val_cr;
if final_val != 0.0 {
unsafe {
indices_vec.write(out_start + offset, c);
data_vec.write(out_start + offset, final_val);
}
offset += 1;
}
}
});
info!(
duration_ms = started.elapsed().as_millis(),
"set_operations fill complete"
);
let started = Instant::now();
(0..n_samples).into_par_iter().for_each(|row| {
let row_start = indptr[row];
let row_len = indptr[row + 1] - indptr[row];
if row_len > 1 {
let row_indices = unsafe { indices_vec.get_mut_slice(row_start, row_len) };
let row_data = unsafe { data_vec.get_mut_slice(row_start, row_len) };
for k in 1..row_len {
let mut m = k;
while m > 0 && row_indices[m - 1] > row_indices[m] {
row_indices.swap(m - 1, m);
row_data.swap(m - 1, m);
m -= 1;
}
}
}
});
info!(
duration_ms = started.elapsed().as_millis(),
"set_operations row_sort complete"
);
let indices = indices_vec.into_inner();
let data = data_vec.into_inner();
CsMatI::new((n_samples, n_samples), indptr, indices, data)
}