Skip to main content

gamlss_formula/
data.rs

1use std::{fmt, marker::PhantomData, sync::Arc};
2
3use gamlss_core::{ModelError, ObservationView};
4
5use crate::FormulaError;
6
7/// Marker type for categorical columns.
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub struct Category;
10
11/// Typed reference to a named input column.
12#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub struct Col<T> {
14    name: Arc<str>,
15    marker: PhantomData<T>,
16}
17
18impl<T> Col<T> {
19    /// Returns the external column name.
20    #[must_use]
21    #[inline(always)]
22    pub fn name(&self) -> &str {
23        &self.name
24    }
25}
26
27impl<T> fmt::Debug for Col<T> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_tuple("Col").field(&self.name).finish()
30    }
31}
32
33/// Creates a typed column reference.
34#[must_use]
35#[inline]
36pub fn col<T>(name: impl Into<Arc<str>>) -> Col<T> {
37    Col {
38        name: name.into(),
39        marker: PhantomData,
40    }
41}
42
43/// Numeric column storage returned by [`DataView`].
44#[derive(Debug, Clone, PartialEq)]
45pub enum NumericCol<'a> {
46    /// Borrowed contiguous `f64` storage.
47    Borrowed(&'a [f64]),
48    /// Owned contiguous `f64` storage.
49    Owned(Vec<f64>),
50}
51
52impl<'a> NumericCol<'a> {
53    /// Returns the column as a slice.
54    #[must_use]
55    #[inline(always)]
56    pub fn as_slice(&self) -> &[f64] {
57        match self {
58            Self::Borrowed(values) => values,
59            Self::Owned(values) => values,
60        }
61    }
62
63    pub(crate) fn into_response(self) -> NumericResponse<'a> {
64        match self {
65            Self::Borrowed(values) => NumericResponse::Borrowed(values),
66            Self::Owned(values) => NumericResponse::Owned(values),
67        }
68    }
69}
70
71/// Boolean column storage returned by [`DataView`].
72#[derive(Debug, Clone, PartialEq)]
73pub enum BoolCol<'a> {
74    /// Borrowed contiguous `bool` storage.
75    Borrowed(&'a [bool]),
76    /// Owned contiguous `bool` storage.
77    Owned(Vec<bool>),
78}
79
80impl BoolCol<'_> {
81    /// Returns the column as a slice.
82    #[must_use]
83    #[inline(always)]
84    pub fn as_slice(&self) -> &[bool] {
85        match self {
86            Self::Borrowed(values) => values,
87            Self::Owned(values) => values,
88        }
89    }
90}
91
92/// Categorical column storage returned by [`DataView`].
93#[derive(Debug, Clone, PartialEq)]
94pub enum CatCol<'a> {
95    /// Borrowed string levels.
96    Borrowed(&'a [String]),
97    /// Owned string levels.
98    Owned(Vec<String>),
99}
100
101impl CatCol<'_> {
102    /// Returns the column as a slice.
103    #[must_use]
104    #[inline(always)]
105    pub fn as_slice(&self) -> &[String] {
106        match self {
107            Self::Borrowed(values) => values,
108            Self::Owned(values) => values,
109        }
110    }
111}
112
113/// Read-only data access contract for the formula layer.
114pub trait DataView {
115    /// Number of rows visible to the model builder.
116    fn nrows(&self) -> usize;
117
118    /// Returns an `f64` column by typed column reference.
119    ///
120    /// Implementations should return [`FormulaError::UnknownColumn`] when the
121    /// name is not available. The formula layer validates row counts.
122    fn f64_col(&self, col: &Col<f64>) -> Result<NumericCol<'_>, FormulaError>;
123
124    /// Returns a `bool` column by typed column reference.
125    fn bool_col(&self, col: &Col<bool>) -> Result<BoolCol<'_>, FormulaError> {
126        Err(FormulaError::UnsupportedColumnType {
127            name: col.name().to_owned(),
128            requested: "bool",
129        })
130    }
131
132    /// Returns a categorical column by typed column reference.
133    fn cat_col(&self, col: &Col<Category>) -> Result<CatCol<'_>, FormulaError> {
134        Err(FormulaError::UnsupportedColumnType {
135            name: col.name().to_owned(),
136            requested: "category",
137        })
138    }
139}
140
141/// Response storage used by compiled formula models.
142#[derive(Debug, Clone, PartialEq)]
143pub enum NumericResponse<'a> {
144    /// Borrowed response storage.
145    Borrowed(&'a [f64]),
146    /// Owned response storage.
147    Owned(Vec<f64>),
148    /// Response with observation weights.
149    Weighted {
150        /// Response values.
151        values: NumericCol<'a>,
152        /// Observation weights.
153        weights: NumericCol<'a>,
154    },
155}
156
157impl NumericResponse<'_> {
158    /// Returns response values as a slice.
159    #[must_use]
160    #[inline(always)]
161    pub fn as_slice(&self) -> &[f64] {
162        match self {
163            Self::Borrowed(values) => values,
164            Self::Owned(values) => values,
165            Self::Weighted { values, .. } => values.as_slice(),
166        }
167    }
168
169    /// Returns observation weights when present.
170    #[must_use]
171    #[inline(always)]
172    pub fn weights(&self) -> Option<&[f64]> {
173        match self {
174            Self::Borrowed(_) | Self::Owned(_) => None,
175            Self::Weighted { weights, .. } => Some(weights.as_slice()),
176        }
177    }
178}
179
180impl<'row> ObservationView<'row> for NumericResponse<'_> {
181    type Observation = f64;
182
183    #[inline(always)]
184    fn len(&self) -> usize {
185        self.as_slice().len()
186    }
187
188    #[inline(always)]
189    fn observation_at(&'row self, row: usize) -> Self::Observation {
190        self.as_slice()[row]
191    }
192
193    fn weight_at(&self, _row: usize) -> f64 {
194        self.weights().map_or(1.0, |weights| weights[_row])
195    }
196
197    fn validate(&self) -> Result<(), ModelError> {
198        if let Some(weights) = self.weights() {
199            let expected = self.as_slice().len();
200            let actual = weights.len();
201            if actual != expected {
202                return Err(ModelError::WeightLength { expected, actual });
203            }
204            for (index, weight) in weights.iter().copied().enumerate() {
205                if !weight.is_finite() || weight < 0.0 {
206                    return Err(ModelError::InvalidWeight { index });
207                }
208            }
209        }
210        Ok(())
211    }
212}