use loess_rs::internals::algorithms::regression::SolverLinalg;
use loess_rs::internals::math::distance::DistanceLinalg;
use loess_rs::internals::math::linalg::FloatLinalg;
use loess_rs::internals::math::neighborhood::{KDNode, KDTree};
use num_traits::Float;
#[cfg(feature = "cpu")]
use rayon::join;
#[cfg(feature = "cpu")]
pub fn build_kdtree_parallel<T>(points: &[T], dimensions: usize) -> KDTree<T>
where
T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Send + Sync + 'static,
{
let n = points.len() / dimensions;
let mut nodes = vec![KDNode::default(); n];
let mut permuted_points = vec![T::zero(); points.len()];
let mut indices: Vec<usize> = (0..n).collect();
if n > 0 {
let nodes_ptr = nodes.as_mut_ptr() as usize;
let permuted_ptr = permuted_points.as_mut_ptr() as usize;
build_recursive_parallel(
points,
&mut indices,
dimensions,
nodes_ptr,
permuted_ptr,
0,
1,
);
}
KDTree::from_parts(nodes, permuted_points, dimensions)
}
#[cfg(feature = "cpu")]
fn build_recursive_parallel<T>(
points: &[T],
indices: &mut [usize],
dimensions: usize,
nodes_ptr: usize,
permuted_ptr: usize,
depth: usize,
v: usize,
) where
T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Send + Sync + 'static,
{
let n = indices.len();
if n == 0 {
return;
}
let axis = depth % dimensions;
let left_size = KDTree::<T>::calculate_left_subtree_size(n);
let mid = left_size;
indices.select_nth_unstable_by(mid, |&a, &b| {
let val_a = points[a * dimensions + axis];
let val_b = points[b * dimensions + axis];
val_a
.partial_cmp(&val_b)
.unwrap_or(core::cmp::Ordering::Equal)
});
unsafe {
let node_ref = &mut *(nodes_ptr as *mut KDNode).add(v - 1);
node_ref.index = indices[mid];
let dest_offset = (v - 1) * dimensions;
let src_offset = indices[mid] * dimensions;
let dest_ptr = (permuted_ptr as *mut T).add(dest_offset);
let src_ptr = points.as_ptr().add(src_offset);
std::ptr::copy_nonoverlapping(src_ptr, dest_ptr, dimensions);
}
let (left_indices, right_indices_with_mid) = indices.split_at_mut(mid);
let right_indices = &mut right_indices_with_mid[1..];
if n > 1024 {
join(
|| {
build_recursive_parallel(
points,
left_indices,
dimensions,
nodes_ptr,
permuted_ptr,
depth + 1,
2 * v,
)
},
|| {
build_recursive_parallel(
points,
right_indices,
dimensions,
nodes_ptr,
permuted_ptr,
depth + 1,
2 * v + 1,
)
},
);
} else {
build_recursive_sequential(
points,
left_indices,
dimensions,
nodes_ptr,
permuted_ptr,
depth + 1,
2 * v,
);
build_recursive_sequential(
points,
right_indices,
dimensions,
nodes_ptr,
permuted_ptr,
depth + 1,
2 * v + 1,
);
}
}
#[cfg(feature = "cpu")]
fn build_recursive_sequential<T>(
points: &[T],
indices: &mut [usize],
dimensions: usize,
nodes_ptr: usize,
permuted_ptr: usize,
depth: usize,
v: usize,
) where
T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Send + Sync + 'static,
{
let n = indices.len();
if n == 0 {
return;
}
let axis = depth % dimensions;
let left_size = KDTree::<T>::calculate_left_subtree_size(n);
let mid = left_size;
indices.select_nth_unstable_by(mid, |&a, &b| {
let val_a = points[a * dimensions + axis];
let val_b = points[b * dimensions + axis];
val_a
.partial_cmp(&val_b)
.unwrap_or(core::cmp::Ordering::Equal)
});
unsafe {
let node_ref = &mut *(nodes_ptr as *mut KDNode).add(v - 1);
node_ref.index = indices[mid];
let dest_offset = (v - 1) * dimensions;
let src_offset = indices[mid] * dimensions;
let dest_ptr = (permuted_ptr as *mut T).add(dest_offset);
let src_ptr = points.as_ptr().add(src_offset);
std::ptr::copy_nonoverlapping(src_ptr, dest_ptr, dimensions);
}
let (left_indices, right_indices_with_mid) = indices.split_at_mut(mid);
let right_indices = &mut right_indices_with_mid[1..];
build_recursive_sequential(
points,
left_indices,
dimensions,
nodes_ptr,
permuted_ptr,
depth + 1,
2 * v,
);
build_recursive_sequential(
points,
right_indices,
dimensions,
nodes_ptr,
permuted_ptr,
depth + 1,
2 * v + 1,
);
}
#[cfg(not(feature = "cpu"))]
pub fn build_kdtree_parallel<T>(points: &[T], dimensions: usize) -> KDTree<T>
where
T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Send + Sync + 'static,
{
KDTree::new(points, dimensions)
}