use crate::bounding_box::are_well_separated;
use crate::point::Point;
use crate::split_tree::{SplitTree, SplitTreeNode};
use num_traits::{FromPrimitive, Zero};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct WellSeparatedPair<P: Point> {
pub node_a: Arc<SplitTreeNode<P>>,
pub node_b: Arc<SplitTreeNode<P>>,
}
impl<P: Point> WellSeparatedPair<P> {
pub fn representative_a(&self) -> Option<usize> {
self.node_a.representative
}
pub fn representative_b(&self) -> Option<usize> {
self.node_b.representative
}
pub fn points_a(&self) -> Vec<usize> {
SplitTree::collect_points(&self.node_a)
}
pub fn points_b(&self) -> Vec<usize> {
SplitTree::collect_points(&self.node_b)
}
pub fn pair_count(&self) -> usize {
self.points_a().len() * self.points_b().len()
}
}
pub struct WSPD<P: Point> {
tree: SplitTree<P>,
pairs: Vec<WellSeparatedPair<P>>,
separation: P::Scalar,
}
impl<P: Point> WSPD<P> {
pub fn new(points: Vec<P>, separation: P::Scalar) -> Self {
assert!(!points.is_empty(), "Point set must not be empty");
assert!(
separation > P::Scalar::zero(),
"Separation factor must be positive"
);
let tree = SplitTree::new(points);
let pairs = Self::compute_pairs(&tree.root, &tree.root, separation);
Self {
tree,
pairs,
separation,
}
}
pub fn num_pairs(&self) -> usize {
self.pairs.len()
}
pub fn pairs(&self) -> &[WellSeparatedPair<P>] {
&self.pairs
}
pub fn points(&self) -> &[P] {
self.tree.points()
}
pub fn separation(&self) -> P::Scalar {
self.separation
}
pub fn stats(&self) -> WSPDStats {
let total_point_pairs = self.pairs.iter().map(|p| p.pair_count()).sum();
let n = self.points().len();
let expected_pairs = n * (n - 1) / 2;
WSPDStats {
num_points: n,
num_pairs: self.num_pairs(),
total_point_pairs,
expected_pairs,
tree_nodes: self.tree.node_count(),
tree_height: self.tree.height(),
}
}
fn compute_pairs(
u: &Arc<SplitTreeNode<P>>,
v: &Arc<SplitTreeNode<P>>,
s: P::Scalar,
) -> Vec<WellSeparatedPair<P>> {
let mut pairs = Vec::new();
if u.is_leaf() && v.is_leaf() {
if let (Some(u_idx), Some(v_idx)) = (u.representative, v.representative) {
if u_idx == v_idx {
return pairs;
}
}
}
if are_well_separated(&u.bbox, &v.bbox, s) {
pairs.push(WellSeparatedPair {
node_a: Arc::clone(u),
node_b: Arc::clone(v),
});
return pairs;
}
let split_u = if v.is_leaf() {
true
} else if u.is_leaf() {
false
} else {
u.level <= v.level
};
if split_u {
if let Some(left) = &u.left {
pairs.extend(Self::compute_pairs(left, v, s));
}
if let Some(right) = &u.right {
pairs.extend(Self::compute_pairs(right, v, s));
}
} else {
if let Some(left) = &v.left {
pairs.extend(Self::compute_pairs(u, left, s));
}
if let Some(right) = &v.right {
pairs.extend(Self::compute_pairs(u, right, s));
}
}
pairs
}
pub fn all_pairs(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
self.pairs.iter().flat_map(|wsp| {
let points_a = wsp.points_a();
let points_b = wsp.points_b();
points_a.into_iter().flat_map(move |a| {
let points_b = points_b.clone(); points_b.into_iter().filter(move |&b| b != a).map(move |b| {
if a < b {
(a, b)
} else {
(b, a)
}
})
})
})
}
}
#[derive(Clone, Debug, Copy)]
pub struct WSPDStats {
pub num_points: usize,
pub num_pairs: usize,
pub total_point_pairs: usize,
pub expected_pairs: usize,
pub tree_nodes: usize,
pub tree_height: usize,
}
impl WSPDStats {
pub fn print(&self) {
println!("WSPD Statistics:");
println!("\tPoints: {}", self.num_points);
println!("\tWell-Separated Pairs: {}", self.num_pairs);
println!("\tPoint pairs covered: {}", self.total_point_pairs);
println!("\tExpected point pairs: {}", self.expected_pairs);
println!(
"\tCompression Ratio: {:.2}x",
self.expected_pairs as f64 / self.total_point_pairs as f64
);
println!("\tTree Nodes: {}", self.tree_nodes);
println!("\tTree Height: {}", self.tree_height);
}
}
pub struct WSPDBuilder<P: Point> {
points: Vec<P>,
separation: Option<P::Scalar>,
}
impl<P: Point> WSPDBuilder<P> {
pub fn new(points: Vec<P>) -> Self {
Self {
points,
separation: None,
}
}
pub fn separation(mut self, separation: P::Scalar) -> Self {
self.separation = Some(separation);
self
}
pub fn build(self) -> WSPD<P> {
let separation = self.separation.unwrap_or_else(|| {
FromPrimitive::from_f64(2.0).expect("Failed to convert default separation")
});
WSPD::new(self.points, separation)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::point::Point2D;
#[test]
fn test_wspd_construction() {
let points = vec![
Point2D::new(0.0, 0.0),
Point2D::new(1.0, 0.0),
Point2D::new(0.0, 1.0),
Point2D::new(10.0, 10.0),
];
let wpsd = WSPD::new(points, 2.0);
assert!(wpsd.num_pairs() > 0);
let stats = wpsd.stats();
assert_eq!(stats.num_points, 4);
assert!(stats.num_pairs < 6); }
#[test]
fn test_all_pairs_converge() {
let points = vec![
Point2D::new(0.0, 0.0),
Point2D::new(1.0, 0.0),
Point2D::new(2.0, 2.0),
];
let wpsd = WSPD::new(points, 2.0);
let pair_set: std::collections::HashSet<(usize, usize)> = wpsd.all_pairs().collect();
assert_eq!(pair_set.len(), 3); }
}