use super::*;
use crate::covertree::node::CoverNode;
use crate::covertree::CoverTreeReader;
use crate::plugins::utils::*;
use ndarray::prelude::*;
use ndarray_linalg::svd::*;
#[derive(Debug, Clone, Default)]
pub struct SvdGaussian {
pub mean: Array1<f32>,
pub vt: Array2<f32>,
pub singular_vals: Array1<f32>,
}
impl SvdGaussian {
pub fn mean(&self) -> Array1<f32> {
self.mean.clone()
}
}
impl<D: PointCloud> NodePlugin<D> for SvdGaussian {}
#[derive(Debug, Clone)]
pub struct GokoSvdGaussian {
max_points: usize,
min_points: usize,
tau: f32,
}
impl GokoSvdGaussian {
pub fn new(min_points: usize, max_points: usize, tau: f32) -> GokoSvdGaussian {
GokoSvdGaussian {
max_points,
min_points,
tau,
}
}
}
impl<D: PointCloud> GokoPlugin<D> for GokoSvdGaussian {
type NodeComponent = SvdGaussian;
fn prepare_tree(parameters: &Self, my_tree: &mut CoverTreeWriter<D>) {
my_tree.add_plugin::<GokoCoverageIndexes>(GokoCoverageIndexes::restricted(
parameters.max_points,
));
my_tree.add_plugin::<GokoDiagGaussian>(GokoDiagGaussian::recursive());
}
fn node_component(
parameters: &Self,
my_node: &CoverNode<D>,
my_tree: &CoverTreeReader<D>,
) -> Option<Self::NodeComponent> {
if my_node.coverage_count() > parameters.min_points {
let points = my_node.get_plugin_and::<CoverageIndexes, _, _>(|p| {
my_tree
.parameters()
.point_cloud
.points_dense_matrix(p.point_indexes())
.unwrap()
});
if let Some(mut points) = points {
let mean = my_node
.get_plugin_and::<DiagGaussian, _, _>(|p| {
Array1::from_shape_vec((p.dim(),), p.mean()).unwrap()
})
.unwrap();
for mut p in points.axis_iter_mut(Axis(0)) {
p -= &mean;
}
let (_u, singular_vals, vt) = points.svd(false, true).unwrap();
let vt = vt.unwrap();
Some(SvdGaussian {
singular_vals,
vt,
mean,
})
} else {
None
}
} else {
None
}
}
}