use std::fmt::Debug;
use crate::linalg::basic::arrays::Array2;
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
#[derive(Debug)]
pub struct BBDTree {
nodes: Vec<BBDTreeNode>,
index: Vec<usize>,
root: usize,
}
#[derive(Debug)]
struct BBDTreeNode {
count: usize,
index: usize,
center: Vec<f64>,
radius: Vec<f64>,
sum: Vec<f64>,
cost: f64,
lower: Option<usize>,
upper: Option<usize>,
}
impl BBDTreeNode {
fn new(d: usize) -> BBDTreeNode {
BBDTreeNode {
count: 0,
index: 0,
center: vec![0f64; d],
radius: vec![0f64; d],
sum: vec![0f64; d],
cost: 0f64,
lower: Option::None,
upper: Option::None,
}
}
}
impl BBDTree {
pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
let nodes = Vec::new();
let (n, _) = data.shape();
let index = (0..n).collect::<Vec<_>>();
let mut tree = BBDTree {
nodes,
index,
root: 0,
};
let root = tree.build_node(data, 0, n);
tree.root = root;
tree
}
pub(crate) fn clustering(
&self,
centroids: &[Vec<f64>],
sums: &mut Vec<Vec<f64>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> f64 {
let k = centroids.len();
counts.iter_mut().for_each(|v| *v = 0);
let mut candidates = vec![0; k];
for i in 0..k {
candidates[i] = i;
sums[i].iter_mut().for_each(|v| *v = 0f64);
}
self.filter(
self.root,
centroids,
&candidates,
k,
sums,
counts,
membership,
)
}
fn filter(
&self,
node: usize,
centroids: &[Vec<f64>],
candidates: &[usize],
k: usize,
sums: &mut Vec<Vec<f64>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> f64 {
let d = centroids[0].len();
let mut min_dist =
Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
let mut closest = candidates[0];
for i in 1..k {
let dist =
Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
if dist < min_dist {
min_dist = dist;
closest = candidates[i];
}
}
if self.nodes[node].lower.is_some() {
let mut new_candidates = vec![0; k];
let mut newk = 0;
for candidate in candidates.iter().take(k) {
if !BBDTree::prune(
&self.nodes[node].center,
&self.nodes[node].radius,
centroids,
closest,
*candidate,
) {
new_candidates[newk] = *candidate;
newk += 1;
}
}
if newk > 1 {
return self.filter(
self.nodes[node].lower.unwrap(),
centroids,
&new_candidates,
newk,
sums,
counts,
membership,
) + self.filter(
self.nodes[node].upper.unwrap(),
centroids,
&new_candidates,
newk,
sums,
counts,
membership,
);
}
}
for i in 0..d {
sums[closest][i] += self.nodes[node].sum[i];
}
counts[closest] += self.nodes[node].count;
let last = self.nodes[node].index + self.nodes[node].count;
for i in self.nodes[node].index..last {
membership[self.index[i]] = closest;
}
BBDTree::node_cost(&self.nodes[node], ¢roids[closest])
}
fn prune(
center: &[f64],
radius: &[f64],
centroids: &[Vec<f64>],
best_index: usize,
test_index: usize,
) -> bool {
if best_index == test_index {
return false;
}
let d = centroids[0].len();
let best = ¢roids[best_index];
let test = ¢roids[test_index];
let mut lhs = 0f64;
let mut rhs = 0f64;
for i in 0..d {
let diff = test[i] - best[i];
lhs += diff * diff;
if diff > 0f64 {
rhs += (center[i] + radius[i] - best[i]) * diff;
} else {
rhs += (center[i] - radius[i] - best[i]) * diff;
}
}
lhs >= 2f64 * rhs
}
fn build_node<T: Number, M: Array2<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
let (_, d) = data.shape();
let mut node = BBDTreeNode::new(d);
node.count = end - begin;
node.index = begin;
let mut lower_bound = vec![0f64; d];
let mut upper_bound = vec![0f64; d];
for i in 0..d {
lower_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
upper_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
}
for i in begin..end {
for j in 0..d {
let c = data.get((self.index[i], j)).to_f64().unwrap();
if lower_bound[j] > c {
lower_bound[j] = c;
}
if upper_bound[j] < c {
upper_bound[j] = c;
}
}
}
let mut max_radius = -1f64;
let mut split_index = 0;
for i in 0..d {
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2f64;
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2f64;
if node.radius[i] > max_radius {
max_radius = node.radius[i];
split_index = i;
}
}
if max_radius < 1E-10 {
node.lower = Option::None;
node.upper = Option::None;
for i in 0..d {
node.sum[i] = data.get((self.index[begin], i)).to_f64().unwrap();
}
if end > begin + 1 {
let len = end - begin;
for i in 0..d {
node.sum[i] *= len as f64;
}
}
node.cost = 0f64;
return self.add_node(node);
}
let split_cutoff = node.center[split_index];
let mut i1 = begin;
let mut i2 = end - 1;
let mut size = 0;
while i1 <= i2 {
let mut i1_good =
data.get((self.index[i1], split_index)).to_f64().unwrap() < split_cutoff;
let mut i2_good =
data.get((self.index[i2], split_index)).to_f64().unwrap() >= split_cutoff;
if !i1_good && !i2_good {
self.index.swap(i1, i2);
i1_good = true;
i2_good = true;
}
if i1_good {
i1 += 1;
size += 1;
}
if i2_good {
i2 -= 1;
}
}
node.lower = Option::Some(self.build_node(data, begin, begin + size));
node.upper = Option::Some(self.build_node(data, begin + size, end));
for i in 0..d {
node.sum[i] =
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
}
let mut mean = vec![0f64; d];
for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
*mean_i = node.sum[i] / node.count as f64;
}
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
+ BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
self.add_node(node)
}
fn node_cost(node: &BBDTreeNode, center: &[f64]) -> f64 {
let d = center.len();
let mut scatter = 0f64;
for (i, center_i) in center.iter().enumerate().take(d) {
let x = (node.sum[i] / node.count as f64) - *center_i;
scatter += x * x;
}
node.cost + node.count as f64 * scatter
}
fn add_node(&mut self, new_node: BBDTreeNode) -> usize {
let idx = self.nodes.len();
self.nodes.push(new_node);
idx
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn bbdtree_iris() {
let data = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let tree = BBDTree::new(&data);
let centroids = vec![vec![4.86, 3.22, 1.61, 0.29], vec![6.23, 2.92, 4.48, 1.42]];
let mut sums = vec![vec![0f64; 4], vec![0f64; 4]];
let mut counts = vec![11, 9];
let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1];
let dist = tree.clustering(¢roids, &mut sums, &mut counts, &mut membership);
assert!((dist - 10.68).abs() < 1e-2);
assert!((sums[0][0] - 48.6).abs() < 1e-2);
assert!((sums[1][3] - 13.8).abs() < 1e-2);
assert_eq!(membership[17], 1);
}
}