Skip to main content

gamlss_core/model/
observation.rs

1use crate::ModelError;
2
3/// Read-only row-wise observation access for training objective evaluation.
4///
5/// This trait is intentionally small: it describes the row-wise data needed by
6/// the likelihood loop. Implementations should make
7/// [`len`](Self::len) O(1), keep it stable for the lifetime of the model, and
8/// provide deterministic, panic-free access for `row < len()`.
9///
10/// The trait is parameterized by the borrow lifetime so compiled objectives can
11/// remain generic over borrowed storage backends. This permits zero-copy row
12/// views such as `&'row [f64]`.
13pub trait ObservationView<'row> {
14    /// Observation representation returned for one row.
15    type Observation;
16
17    /// Number of observations.
18    fn len(&self) -> usize;
19
20    /// Returns `true` if there are no observations.
21    #[inline(always)]
22    fn is_empty(&self) -> bool {
23        self.len() == 0
24    }
25
26    /// Observation value for `row`.
27    fn observation_at(&'row self, row: usize) -> Self::Observation;
28
29    /// Non-negative finite observation weight for `row`.
30    fn weight_at(&self, row: usize) -> f64;
31
32    /// Validates observation-level invariants before hot-path evaluation.
33    #[inline]
34    fn validate(&self) -> Result<(), ModelError> {
35        for row in 0..self.len() {
36            validate_observation_weight(row, self.weight_at(row))?;
37        }
38        Ok(())
39    }
40}
41
42impl<'row> ObservationView<'row> for &[f64] {
43    type Observation = f64;
44
45    #[inline(always)]
46    fn len(&self) -> usize {
47        <[f64]>::len(self)
48    }
49
50    #[inline(always)]
51    fn observation_at(&'row self, row: usize) -> Self::Observation {
52        self[row]
53    }
54
55    #[inline(always)]
56    fn weight_at(&self, _row: usize) -> f64 {
57        1.0
58    }
59
60    #[inline(always)]
61    fn validate(&self) -> Result<(), ModelError> {
62        Ok(())
63    }
64}
65
66impl<'row> ObservationView<'row> for (&[f64], &[f64]) {
67    type Observation = f64;
68
69    #[inline(always)]
70    fn len(&self) -> usize {
71        self.0.len()
72    }
73
74    #[inline(always)]
75    fn observation_at(&'row self, row: usize) -> Self::Observation {
76        self.0[row]
77    }
78
79    #[inline(always)]
80    fn weight_at(&self, row: usize) -> f64 {
81        self.1[row]
82    }
83
84    fn validate(&self) -> Result<(), ModelError> {
85        let expected = self.0.len();
86        let actual = self.1.len();
87        if actual != expected {
88            return Err(ModelError::WeightLength { expected, actual });
89        }
90
91        for (index, weight) in self.1.iter().copied().enumerate() {
92            validate_observation_weight(index, weight)?;
93        }
94
95        Ok(())
96    }
97}
98
99impl<'row, const N: usize> ObservationView<'row> for &[[f64; N]] {
100    type Observation = [f64; N];
101
102    #[inline(always)]
103    fn len(&self) -> usize {
104        <[[f64; N]]>::len(self)
105    }
106
107    #[inline(always)]
108    fn observation_at(&'row self, row: usize) -> Self::Observation {
109        self[row]
110    }
111
112    #[inline(always)]
113    fn weight_at(&self, _row: usize) -> f64 {
114        1.0
115    }
116
117    #[inline(always)]
118    fn validate(&self) -> Result<(), ModelError> {
119        Ok(())
120    }
121}
122
123impl<'row, const N: usize> ObservationView<'row> for (&[[f64; N]], &[f64]) {
124    type Observation = [f64; N];
125
126    #[inline(always)]
127    fn len(&self) -> usize {
128        self.0.len()
129    }
130
131    #[inline(always)]
132    fn observation_at(&'row self, row: usize) -> Self::Observation {
133        self.0[row]
134    }
135
136    #[inline(always)]
137    fn weight_at(&self, row: usize) -> f64 {
138        self.1[row]
139    }
140
141    fn validate(&self) -> Result<(), ModelError> {
142        let expected = self.0.len();
143        let actual = self.1.len();
144        if actual != expected {
145            return Err(ModelError::WeightLength { expected, actual });
146        }
147
148        for (index, weight) in self.1.iter().copied().enumerate() {
149            validate_observation_weight(index, weight)?;
150        }
151
152        Ok(())
153    }
154}
155
156fn validate_observation_weight(index: usize, weight: f64) -> Result<(), ModelError> {
157    if weight.is_finite() && weight >= 0.0 {
158        Ok(())
159    } else {
160        Err(ModelError::InvalidWeight { index })
161    }
162}