use std::ops::Range;
use ndarray::{ArrayView, ArrayView1, ArrayView2, Dimension, Ix1, Ix2};
use crate::{
SegmentCostFunction,
cost::{Cost1D, Cost2D},
};
mod sealed {
use ndarray::{Ix1, Ix2};
pub trait Sealed {}
impl Sealed for Ix1 {}
impl Sealed for Ix2 {}
}
pub trait OneOrTwoDimensions: Dimension + sealed::Sealed {
type PrecalculationOutput;
#[doc(hidden)]
fn len_or_nrows(array: &ArrayView<f64, Self>) -> usize;
#[doc(hidden)]
fn precalculate(
cost: SegmentCostFunction,
signal: &ArrayView<f64, Self>,
) -> Self::PrecalculationOutput;
#[doc(hidden)]
fn loss(
cost: &Self::PrecalculationOutput,
total_loss: &mut f64,
signal: &ArrayView<f64, Self>,
range: Range<usize>,
);
#[doc(hidden)]
fn try_as_1d<'a>(array: &'a ArrayView<f64, Self>) -> Option<ArrayView1<'a, f64>>;
}
impl OneOrTwoDimensions for Ix1 {
type PrecalculationOutput = Cost1D;
#[inline]
fn len_or_nrows(array: &ArrayView1<f64>) -> usize {
array.len()
}
#[inline]
fn precalculate(
cost: SegmentCostFunction,
signal: &ArrayView1<f64>,
) -> Self::PrecalculationOutput {
Self::PrecalculationOutput::precalculate(cost, signal)
}
#[inline]
fn loss(
cost: &Self::PrecalculationOutput,
total_loss: &mut f64,
signal: &ArrayView1<f64>,
range: Range<usize>,
) {
cost.loss(total_loss, signal, range)
}
#[inline]
fn try_as_1d<'a>(_array: &'a ArrayView1<f64>) -> Option<ArrayView1<'a, f64>> {
None
}
}
impl OneOrTwoDimensions for Ix2 {
type PrecalculationOutput = Cost2D;
#[inline]
fn len_or_nrows(array: &ArrayView2<f64>) -> usize {
array.nrows()
}
#[inline]
fn precalculate(
cost: SegmentCostFunction,
signal: &ArrayView2<f64>,
) -> Self::PrecalculationOutput {
Self::PrecalculationOutput::precalculate(cost, signal)
}
#[inline]
fn loss(
cost: &Self::PrecalculationOutput,
total_loss: &mut f64,
signal: &ArrayView2<f64>,
range: Range<usize>,
) {
cost.loss(total_loss, signal, range)
}
#[inline]
fn try_as_1d<'a>(array: &'a ArrayView<f64, Self>) -> Option<ArrayView1<'a, f64>> {
(array.ncols() == 1).then(|| array.column(0))
}
}