1use std::{fmt, marker::PhantomData, sync::Arc};
2
3use gamlss_core::{ModelError, ObservationView};
4
5use crate::FormulaError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub struct Category;
10
11#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub struct Col<T> {
14 name: Arc<str>,
15 marker: PhantomData<T>,
16}
17
18impl<T> Col<T> {
19 #[must_use]
21 #[inline(always)]
22 pub fn name(&self) -> &str {
23 &self.name
24 }
25}
26
27impl<T> fmt::Debug for Col<T> {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_tuple("Col").field(&self.name).finish()
30 }
31}
32
33#[must_use]
35#[inline]
36pub fn col<T>(name: impl Into<Arc<str>>) -> Col<T> {
37 Col {
38 name: name.into(),
39 marker: PhantomData,
40 }
41}
42
43#[derive(Debug, Clone, PartialEq)]
45pub enum NumericCol<'a> {
46 Borrowed(&'a [f64]),
48 Owned(Vec<f64>),
50}
51
52impl<'a> NumericCol<'a> {
53 #[must_use]
55 #[inline(always)]
56 pub fn as_slice(&self) -> &[f64] {
57 match self {
58 Self::Borrowed(values) => values,
59 Self::Owned(values) => values,
60 }
61 }
62
63 pub(crate) fn into_response(self) -> NumericResponse<'a> {
64 match self {
65 Self::Borrowed(values) => NumericResponse::Borrowed(values),
66 Self::Owned(values) => NumericResponse::Owned(values),
67 }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum BoolCol<'a> {
74 Borrowed(&'a [bool]),
76 Owned(Vec<bool>),
78}
79
80impl BoolCol<'_> {
81 #[must_use]
83 #[inline(always)]
84 pub fn as_slice(&self) -> &[bool] {
85 match self {
86 Self::Borrowed(values) => values,
87 Self::Owned(values) => values,
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
94pub enum CatCol<'a> {
95 Borrowed(&'a [String]),
97 Owned(Vec<String>),
99}
100
101impl CatCol<'_> {
102 #[must_use]
104 #[inline(always)]
105 pub fn as_slice(&self) -> &[String] {
106 match self {
107 Self::Borrowed(values) => values,
108 Self::Owned(values) => values,
109 }
110 }
111}
112
113pub trait DataView {
115 fn nrows(&self) -> usize;
117
118 fn f64_col(&self, col: &Col<f64>) -> Result<NumericCol<'_>, FormulaError>;
123
124 fn bool_col(&self, col: &Col<bool>) -> Result<BoolCol<'_>, FormulaError> {
126 Err(FormulaError::UnsupportedColumnType {
127 name: col.name().to_owned(),
128 requested: "bool",
129 })
130 }
131
132 fn cat_col(&self, col: &Col<Category>) -> Result<CatCol<'_>, FormulaError> {
134 Err(FormulaError::UnsupportedColumnType {
135 name: col.name().to_owned(),
136 requested: "category",
137 })
138 }
139}
140
141#[derive(Debug, Clone, PartialEq)]
143pub enum NumericResponse<'a> {
144 Borrowed(&'a [f64]),
146 Owned(Vec<f64>),
148 Weighted {
150 values: NumericCol<'a>,
152 weights: NumericCol<'a>,
154 },
155}
156
157impl NumericResponse<'_> {
158 #[must_use]
160 #[inline(always)]
161 pub fn as_slice(&self) -> &[f64] {
162 match self {
163 Self::Borrowed(values) => values,
164 Self::Owned(values) => values,
165 Self::Weighted { values, .. } => values.as_slice(),
166 }
167 }
168
169 #[must_use]
171 #[inline(always)]
172 pub fn weights(&self) -> Option<&[f64]> {
173 match self {
174 Self::Borrowed(_) | Self::Owned(_) => None,
175 Self::Weighted { weights, .. } => Some(weights.as_slice()),
176 }
177 }
178}
179
180impl<'row> ObservationView<'row> for NumericResponse<'_> {
181 type Observation = f64;
182
183 #[inline(always)]
184 fn len(&self) -> usize {
185 self.as_slice().len()
186 }
187
188 #[inline(always)]
189 fn observation_at(&'row self, row: usize) -> Self::Observation {
190 self.as_slice()[row]
191 }
192
193 fn weight_at(&self, _row: usize) -> f64 {
194 self.weights().map_or(1.0, |weights| weights[_row])
195 }
196
197 fn validate(&self) -> Result<(), ModelError> {
198 if let Some(weights) = self.weights() {
199 let expected = self.as_slice().len();
200 let actual = weights.len();
201 if actual != expected {
202 return Err(ModelError::WeightLength { expected, actual });
203 }
204 for (index, weight) in weights.iter().copied().enumerate() {
205 if !weight.is_finite() || weight < 0.0 {
206 return Err(ModelError::InvalidWeight { index });
207 }
208 }
209 }
210 Ok(())
211 }
212}