use crate::ModelError;
pub trait ObservationView<'row> {
type Observation;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn observation_at(&'row self, row: usize) -> Self::Observation;
fn weight_at(&self, row: usize) -> f64;
fn validate(&self) -> Result<(), ModelError> {
for row in 0..self.len() {
validate_observation_weight(row, self.weight_at(row))?;
}
Ok(())
}
}
impl<'row> ObservationView<'row> for &[f64] {
type Observation = f64;
fn len(&self) -> usize {
<[f64]>::len(self)
}
fn observation_at(&'row self, row: usize) -> Self::Observation {
self[row]
}
fn weight_at(&self, _row: usize) -> f64 {
1.0
}
fn validate(&self) -> Result<(), ModelError> {
Ok(())
}
}
impl<'row> ObservationView<'row> for (&[f64], &[f64]) {
type Observation = f64;
fn len(&self) -> usize {
self.0.len()
}
fn observation_at(&'row self, row: usize) -> Self::Observation {
self.0[row]
}
fn weight_at(&self, row: usize) -> f64 {
self.1[row]
}
fn validate(&self) -> Result<(), ModelError> {
let expected = self.0.len();
let actual = self.1.len();
if actual != expected {
return Err(ModelError::WeightLength { expected, actual });
}
for (index, weight) in self.1.iter().copied().enumerate() {
validate_observation_weight(index, weight)?;
}
Ok(())
}
}
impl<'row, const N: usize> ObservationView<'row> for &[[f64; N]] {
type Observation = [f64; N];
fn len(&self) -> usize {
<[[f64; N]]>::len(self)
}
fn observation_at(&'row self, row: usize) -> Self::Observation {
self[row]
}
fn weight_at(&self, _row: usize) -> f64 {
1.0
}
fn validate(&self) -> Result<(), ModelError> {
Ok(())
}
}
impl<'row, const N: usize> ObservationView<'row> for (&[[f64; N]], &[f64]) {
type Observation = [f64; N];
fn len(&self) -> usize {
self.0.len()
}
fn observation_at(&'row self, row: usize) -> Self::Observation {
self.0[row]
}
fn weight_at(&self, row: usize) -> f64 {
self.1[row]
}
fn validate(&self) -> Result<(), ModelError> {
let expected = self.0.len();
let actual = self.1.len();
if actual != expected {
return Err(ModelError::WeightLength { expected, actual });
}
for (index, weight) in self.1.iter().copied().enumerate() {
validate_observation_weight(index, weight)?;
}
Ok(())
}
}
fn validate_observation_weight(index: usize, weight: f64) -> Result<(), ModelError> {
if weight.is_finite() && weight >= 0.0 {
Ok(())
} else {
Err(ModelError::InvalidWeight { index })
}
}