use super::errors::BinNotFound;
use super::grid::Grid;
use ndarray::prelude::*;
pub struct Histogram<A: Ord> {
counts: ArrayD<usize>,
grid: Grid<A>,
}
impl<A: Ord> Histogram<A> {
pub fn new(grid: Grid<A>) -> Self {
let counts = ArrayD::zeros(grid.shape());
Histogram { counts, grid }
}
pub fn add_observation(&mut self, observation: &ArrayRef<A, Ix1>) -> Result<(), BinNotFound> {
match self.grid.index_of(observation) {
Some(bin_index) => {
self.counts[&*bin_index] += 1;
Ok(())
}
None => Err(BinNotFound),
}
}
pub fn ndim(&self) -> usize {
debug_assert_eq!(self.counts.ndim(), self.grid.ndim());
self.counts.ndim()
}
pub fn counts(&self) -> ArrayViewD<'_, usize> {
self.counts.view()
}
pub fn grid(&self) -> &Grid<A> {
&self.grid
}
}
pub trait HistogramExt<A> {
fn histogram(&self, grid: Grid<A>) -> Histogram<A>
where
A: Ord;
private_decl! {}
}
impl<A> HistogramExt<A> for ArrayRef<A, Ix2>
where
A: Ord,
{
fn histogram(&self, grid: Grid<A>) -> Histogram<A> {
let mut histogram = Histogram::new(grid);
for point in self.axis_iter(Axis(0)) {
let _ = histogram.add_observation(&point);
}
histogram
}
private_impl! {}
}