use alloc::vec::Vec;
use num_traits::Float;
use crate::common::vec::{convolve, ConvolutionMode};
use crate::common::vec::{shift, ShiftMode};
pub fn normalize<F: Float>(pdf: &mut [F]) {
let sum = pdf.iter().fold(F::zero(), |p, q| p + *q);
pdf.iter_mut().for_each(|f| *f = *f / sum);
}
pub fn update<F: Float>(likelihood: &[F], prior: &[F]) -> Result<Vec<F>, ()> {
if likelihood.len() != prior.len() {
return Err(());
}
let mut posterior: Vec<F> = likelihood
.iter()
.zip(prior.iter())
.map(|(&l, &p)| l * p)
.collect();
normalize(&mut posterior);
Ok(posterior)
}
#[derive(Debug)]
pub enum EdgeHandling<F> {
Constant(F),
Wrap,
}
pub fn predict<F: Float>(pdf: &[F], offset: i64, kernel: &[F], mode: EdgeHandling<F>) -> Vec<F> {
match mode {
EdgeHandling::Constant(c) => convolve(
&shift(pdf, offset, ShiftMode::Extend(c)),
kernel,
ConvolutionMode::Extended(c),
),
EdgeHandling::Wrap => convolve(
&shift(pdf, offset, ShiftMode::Wrap),
kernel,
ConvolutionMode::Wrap,
),
}
}
#[cfg(all(test, feature="std"))]
mod tests {
use assert_approx_eq::assert_approx_eq;
use super::*;
#[test]
fn test_prediction_wrap_kernel_3() {
let pdf = {
let mut pdf = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
normalize(&mut pdf);
pdf
};
let kernel = [0.5, 0.5, 0.5, 0.5];
let result = predict(&pdf, -1, &kernel, EdgeHandling::Wrap);
let reference = [0.5, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5];
dbg!(&result);
dbg!(&reference);
debug_assert_eq!(reference.len(), result.len());
for i in 0..reference.len() {
assert_approx_eq!(reference[i], result[i]);
}
}
#[test]
fn test_prediction_extend_kernel_4() {
let pdf = {
let mut pdf = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
normalize(&mut pdf);
pdf
};
let kernel = [0.5, 0.5, 0.5, 0.5];
let result = predict(&pdf, -1, &kernel, EdgeHandling::Constant(99.0));
let reference = [
4.95000000e+01,
6.52189307e-18,
-8.16487636e-19,
1.78559758e-18,
4.95000000e+01,
9.90000000e+01,
1.48500000e+02,
];
dbg!(&result);
dbg!(&reference);
debug_assert_eq!(reference.len(), result.len());
for i in 0..reference.len() {
assert_approx_eq!(reference[i], result[i]);
}
}
#[test]
fn test_prediction_wrap_kernel_4() {
let pdf = {
let mut pdf = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0, 8.0];
normalize(&mut pdf);
pdf
};
let kernel = [0.25, 0.5, 0.125, 0.125];
let result = predict(&pdf, 3, &kernel, EdgeHandling::Wrap);
let reference = [
0.29487179, 0.17948718, 0.08333333, 0.05128205, 0.05448718, 0.11217949, 0.22435897,
];
dbg!(&result);
dbg!(&reference);
debug_assert_eq!(reference.len(), result.len());
for i in 0..reference.len() {
assert_approx_eq!(reference[i], result[i]);
}
}
#[test]
fn test_prediction_wrap_kernel_5() {
let pdf = {
let mut pdf = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0, 8.0];
normalize(&mut pdf);
pdf
};
let kernel = [0.25, 0.5, 0.125, 0.125, 10.0];
let result = predict(&pdf, 3, &kernel, EdgeHandling::Wrap);
let reference = [
0.80769231, 1.20512821, 2.13461538, 4.15384615, 2.10576923, 0.11217949, 0.48076923,
];
dbg!(&result);
dbg!(&reference);
debug_assert_eq!(reference.len(), result.len());
for i in 0..reference.len() {
assert_approx_eq!(reference[i], result[i]);
}
}
#[test]
fn test_prediction_constant_kernel_4() {
let pdf = {
let mut pdf = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0, 8.0];
normalize(&mut pdf);
pdf
};
let kernel = [0.25, 0.5, 0.125, 0.125];
let result = predict(&pdf, 3, &kernel, EdgeHandling::Constant(10.0));
let reference = [
10.0, 7.5, 2.50641026, 1.27564103, 0.05448718, 2.56089744, 7.51923077,
];
dbg!(&result);
dbg!(&reference);
debug_assert_eq!(reference.len(), result.len());
for i in 0..reference.len() {
assert_approx_eq!(reference[i], result[i]);
}
}
}