ndarray_glm/model.rs
1//! Collect data for and configure a model
2
3use crate::{
4 data::Dataset,
5 error::{RegressionError, RegressionResult},
6 fit::{self, Fit},
7 glm::Glm,
8 num::Float,
9 response::Yval,
10};
11use fit::options::{FitConfig, FitOptions};
12use ndarray::{Array1, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2, Order};
13use std::marker::PhantomData;
14
15/// Holds the data and configuration settings for a regression.
16pub struct Model<M, F>
17where
18 M: Glm,
19 F: Float,
20{
21 pub(crate) model: PhantomData<M>,
22 /// The dataset
23 pub data: Dataset<F>,
24}
25
26impl<M, F> Model<M, F>
27where
28 M: Glm,
29 F: Float,
30{
31 /// Perform the regression and return a fit object holding the results.
32 pub fn fit(&self) -> RegressionResult<Fit<'_, M, F>, F> {
33 self.fit_options().fit()
34 }
35
36 /// Fit options builder interface
37 pub fn fit_options(&self) -> FitConfig<'_, M, F> {
38 FitConfig {
39 model: self,
40 options: FitOptions::default(),
41 }
42 }
43
44 /// An experimental interface that would allow fit options to be set externally.
45 pub fn with_options(&self, options: FitOptions<F>) -> FitConfig<'_, M, F> {
46 FitConfig {
47 model: self,
48 options,
49 }
50 }
51}
52
53/// Provides an interface to create the full model option struct with convenient
54/// type inference.
55pub struct ModelBuilder<M: Glm> {
56 _model: PhantomData<M>,
57}
58
59impl<M: Glm> ModelBuilder<M> {
60 /// Borrow the Y and X data where each row in the arrays is a new
61 /// observation, and create the full model builder with the data to allow
62 /// for adjusting additional options.
63 pub fn data<'a, Y, F, YD, XD>(
64 data_y: &'a ArrayBase<YD, Ix1>,
65 data_x: &'a ArrayBase<XD, Ix2>,
66 ) -> ModelBuilderData<'a, M, Y, F>
67 where
68 Y: Yval<M>,
69 F: Float,
70 YD: Data<Elem = Y>,
71 XD: Data<Elem = F>,
72 {
73 ModelBuilderData {
74 model: PhantomData,
75 data_y: data_y.view(),
76 data_x: data_x.view(),
77 linear_offset: None,
78 var_weights: None,
79 freq_weights: None,
80 use_intercept_term: true,
81 standardize: true,
82 error: None,
83 }
84 }
85}
86
87/// Holds the data and all the specifications for the model and provides
88/// functions to adjust the settings.
89pub struct ModelBuilderData<'a, M, Y, F>
90where
91 M: Glm,
92 Y: Yval<M>,
93 F: 'static + Float,
94{
95 model: PhantomData<M>,
96 /// Observed response variable data where each entry is a new observation.
97 data_y: ArrayView1<'a, Y>,
98 /// Design matrix of observed covariate data where each row is a new
99 /// observation and each column represents a different independent variable.
100 data_x: ArrayView2<'a, F>,
101 /// The offset in the linear predictor for each data point. This can be used
102 /// to incorporate control terms.
103 // TODO: consider making this a reference/ArrayView. Y and X are effectively
104 // cloned so perhaps this isn't a big deal.
105 linear_offset: Option<Array1<F>>,
106 /// The variance/analytic weights for each observation.
107 var_weights: Option<Array1<F>>,
108 /// The frequency/count of each observation.
109 freq_weights: Option<Array1<F>>,
110 /// Whether to standardize the input data. Defaults to `true`.
111 standardize: bool,
112 /// Whether to use an intercept term. Defaults to `true`.
113 use_intercept_term: bool,
114 /// An error that has come up in the build compilation.
115 error: Option<RegressionError<F>>,
116}
117
118/// A builder to generate a Model object
119impl<'a, M, Y, F> ModelBuilderData<'a, M, Y, F>
120where
121 M: Glm,
122 Y: Yval<M> + Copy,
123 F: Float,
124{
125 /// Represents an offset added to the linear predictor for each data point.
126 /// This can be used to control for fixed effects or in multi-level models.
127 pub fn linear_offset(mut self, linear_offset: Array1<F>) -> Self {
128 if self.linear_offset.is_some() {
129 self.error = Some(RegressionError::BuildError(
130 "Offsets specified multiple times".to_string(),
131 ));
132 }
133 self.linear_offset = Some(linear_offset);
134 self
135 }
136
137 /// Frequency weights (a.k.a. counts) for each observation. Traditionally these are positive
138 /// integers representing the number of times each observation appears identically.
139 pub fn freq_weights(mut self, freqs: Array1<usize>) -> Self {
140 if self.freq_weights.is_some() {
141 self.error = Some(RegressionError::BuildError(
142 "Frequency weights specified multiple times".to_string(),
143 ));
144 }
145 let ffreqs: Array1<F> = freqs.mapv(|c| F::from(c).unwrap());
146 // TODO: consider adding a check for non-negative weights
147 self.freq_weights = Some(ffreqs);
148 self
149 }
150
151 /// Variance weights (a.k.a. analytic weights) of each observation. These could represent the
152 /// inverse square of the uncertainties of each measurement.
153 pub fn var_weights(mut self, weights: Array1<F>) -> Self {
154 if self.var_weights.is_some() {
155 self.error = Some(RegressionError::BuildError(
156 "Variance weights specified multiple times".to_string(),
157 ));
158 }
159 // TODO: consider adding a check for non-negative weights
160 self.var_weights = Some(weights);
161 self
162 }
163
164 /// Do not add a constant intercept term of `1`s to the design matrix. This is rarely
165 /// recommended, so you probably don't want to use this option unless you have a very clear
166 /// sense of why. Note that you can supply uniform or per-observation constant terms using
167 /// [`ModelBuilderData::linear_offset`].
168 pub fn no_constant(mut self) -> Self {
169 self.use_intercept_term = false;
170 self
171 }
172
173 /// Don't perform standarization (i.e. scale to 0-mean and 1-variance) of the design matrix.
174 /// Note that the standardization is handled internally, so the reported result coefficients
175 /// should be compatible with the input data directly, meaning the user shouldn't have to
176 /// interact with them.
177 pub fn no_standardize(mut self) -> Self {
178 self.standardize = false;
179 self
180 }
181
182 pub fn build(self) -> RegressionResult<Model<M, F>, F>
183 where
184 M: Glm,
185 F: Float,
186 {
187 if let Some(err) = self.error {
188 return Err(err);
189 }
190
191 let n_data = self.data_y.len();
192 if n_data != self.data_x.nrows() {
193 return Err(RegressionError::BadInput(
194 "y and x data must have same number of points".to_string(),
195 ));
196 }
197 // If they are provided, check that the offsets have the correct number of entries
198 if let Some(lin_off) = &self.linear_offset
199 && n_data != lin_off.len()
200 {
201 return Err(RegressionError::BadInput(
202 "Offsets must have same dimension as observations".to_string(),
203 ));
204 }
205
206 // Check if the data is under-constrained
207 if n_data < self.data_x.ncols() {
208 // The regression can find a solution if n_data == ncols, but there will be
209 // no estimate for the uncertainty. Regularization can solve this, so keep
210 // it to a warning.
211 eprintln!("Warning: data is underconstrained");
212 }
213
214 // Put the data in column-major order, since broadcasting and summing over the observations
215 // are the more common operations.
216 // The shape should be trivially valid, so just unwrap it.
217 let data_x = self.data_x.to_shape((self.data_x.dim(), Order::F)).unwrap();
218
219 // convert y-values to floating-point
220 let data_y: Array1<F> = self
221 .data_y
222 .iter()
223 .map(|&y| y.into_float())
224 .collect::<Result<_, _>>()?;
225
226 // Build the Dataset object
227 let mut data = Dataset {
228 y: data_y,
229 x: data_x.to_owned(),
230 linear_offset: self.linear_offset,
231 weights: self.var_weights,
232 freqs: self.freq_weights,
233 has_intercept: false,
234 standardizer: None,
235 };
236
237 data.finalize_design_matrix(self.standardize, self.use_intercept_term);
238
239 Ok(Model {
240 model: PhantomData,
241 data,
242 })
243 }
244}