use std::{fmt, marker::PhantomData, sync::Arc};
use gamlss_core::{ModelError, ObservationView};
use crate::FormulaError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Category;
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Col<T> {
name: Arc<str>,
marker: PhantomData<T>,
}
impl<T> Col<T> {
#[must_use]
#[inline(always)]
pub fn name(&self) -> &str {
&self.name
}
}
impl<T> fmt::Debug for Col<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Col").field(&self.name).finish()
}
}
#[must_use]
#[inline]
pub fn col<T>(name: impl Into<Arc<str>>) -> Col<T> {
Col {
name: name.into(),
marker: PhantomData,
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum NumericCol<'a> {
Borrowed(&'a [f64]),
Owned(Vec<f64>),
}
impl<'a> NumericCol<'a> {
#[must_use]
#[inline(always)]
pub fn as_slice(&self) -> &[f64] {
match self {
Self::Borrowed(values) => values,
Self::Owned(values) => values,
}
}
pub(crate) fn into_response(self) -> NumericResponse<'a> {
match self {
Self::Borrowed(values) => NumericResponse::Borrowed(values),
Self::Owned(values) => NumericResponse::Owned(values),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BoolCol<'a> {
Borrowed(&'a [bool]),
Owned(Vec<bool>),
}
impl BoolCol<'_> {
#[must_use]
#[inline(always)]
pub fn as_slice(&self) -> &[bool] {
match self {
Self::Borrowed(values) => values,
Self::Owned(values) => values,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CatCol<'a> {
Borrowed(&'a [String]),
Owned(Vec<String>),
}
impl CatCol<'_> {
#[must_use]
#[inline(always)]
pub fn as_slice(&self) -> &[String] {
match self {
Self::Borrowed(values) => values,
Self::Owned(values) => values,
}
}
}
pub trait DataView {
fn nrows(&self) -> usize;
fn f64_col(&self, col: &Col<f64>) -> Result<NumericCol<'_>, FormulaError>;
fn bool_col(&self, col: &Col<bool>) -> Result<BoolCol<'_>, FormulaError> {
Err(FormulaError::UnsupportedColumnType {
name: col.name().to_owned(),
requested: "bool",
})
}
fn cat_col(&self, col: &Col<Category>) -> Result<CatCol<'_>, FormulaError> {
Err(FormulaError::UnsupportedColumnType {
name: col.name().to_owned(),
requested: "category",
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum NumericResponse<'a> {
Borrowed(&'a [f64]),
Owned(Vec<f64>),
Weighted {
values: NumericCol<'a>,
weights: NumericCol<'a>,
},
}
impl NumericResponse<'_> {
#[must_use]
#[inline(always)]
pub fn as_slice(&self) -> &[f64] {
match self {
Self::Borrowed(values) => values,
Self::Owned(values) => values,
Self::Weighted { values, .. } => values.as_slice(),
}
}
#[must_use]
#[inline(always)]
pub fn weights(&self) -> Option<&[f64]> {
match self {
Self::Borrowed(_) | Self::Owned(_) => None,
Self::Weighted { weights, .. } => Some(weights.as_slice()),
}
}
}
impl<'row> ObservationView<'row> for NumericResponse<'_> {
type Observation = f64;
#[inline(always)]
fn len(&self) -> usize {
self.as_slice().len()
}
#[inline(always)]
fn observation_at(&'row self, row: usize) -> Self::Observation {
self.as_slice()[row]
}
fn weight_at(&self, _row: usize) -> f64 {
self.weights().map_or(1.0, |weights| weights[_row])
}
fn validate(&self) -> Result<(), ModelError> {
if let Some(weights) = self.weights() {
let expected = self.as_slice().len();
let actual = weights.len();
if actual != expected {
return Err(ModelError::WeightLength { expected, actual });
}
for (index, weight) in weights.iter().copied().enumerate() {
if !weight.is_finite() || weight < 0.0 {
return Err(ModelError::InvalidWeight { index });
}
}
}
Ok(())
}
}