Skip to main content

ndarray_glm/
model.rs

1//! Collect data for and configure a model
2
3use crate::{
4    data::Dataset,
5    error::{RegressionError, RegressionResult},
6    fit::{self, Fit},
7    glm::Glm,
8    num::Float,
9    response::Yval,
10};
11use fit::options::{FitConfig, FitOptions};
12use ndarray::{Array1, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2, Order};
13use std::marker::PhantomData;
14
15/// Holds the data and configuration settings for a regression.
16pub struct Model<M, F>
17where
18    M: Glm,
19    F: Float,
20{
21    pub(crate) model: PhantomData<M>,
22    /// The dataset
23    pub data: Dataset<F>,
24}
25
26impl<M, F> Model<M, F>
27where
28    M: Glm,
29    F: Float,
30{
31    /// Perform the regression and return a fit object holding the results.
32    pub fn fit(&self) -> RegressionResult<Fit<'_, M, F>, F> {
33        self.fit_options().fit()
34    }
35
36    /// Fit options builder interface
37    pub fn fit_options(&self) -> FitConfig<'_, M, F> {
38        FitConfig {
39            model: self,
40            options: FitOptions::default(),
41        }
42    }
43
44    /// An experimental interface that would allow fit options to be set externally.
45    pub fn with_options(&self, options: FitOptions<F>) -> FitConfig<'_, M, F> {
46        FitConfig {
47            model: self,
48            options,
49        }
50    }
51}
52
53/// Provides an interface to create the full model option struct with convenient
54/// type inference.
55pub struct ModelBuilder<M: Glm> {
56    _model: PhantomData<M>,
57}
58
59impl<M: Glm> ModelBuilder<M> {
60    /// Borrow the Y and X data where each row in the arrays is a new
61    /// observation, and create the full model builder with the data to allow
62    /// for adjusting additional options.
63    pub fn data<'a, Y, F, YD, XD>(
64        data_y: &'a ArrayBase<YD, Ix1>,
65        data_x: &'a ArrayBase<XD, Ix2>,
66    ) -> ModelBuilderData<'a, M, Y, F>
67    where
68        Y: Yval<M>,
69        F: Float,
70        YD: Data<Elem = Y>,
71        XD: Data<Elem = F>,
72    {
73        ModelBuilderData {
74            model: PhantomData,
75            data_y: data_y.view(),
76            data_x: data_x.view(),
77            linear_offset: None,
78            var_weights: None,
79            freq_weights: None,
80            use_intercept_term: true,
81            standardize: true,
82            error: None,
83        }
84    }
85}
86
87/// Holds the data and all the specifications for the model and provides
88/// functions to adjust the settings.
89pub struct ModelBuilderData<'a, M, Y, F>
90where
91    M: Glm,
92    Y: Yval<M>,
93    F: 'static + Float,
94{
95    model: PhantomData<M>,
96    /// Observed response variable data where each entry is a new observation.
97    data_y: ArrayView1<'a, Y>,
98    /// Design matrix of observed covariate data where each row is a new
99    /// observation and each column represents a different independent variable.
100    data_x: ArrayView2<'a, F>,
101    /// The offset in the linear predictor for each data point. This can be used
102    /// to incorporate control terms.
103    // TODO: consider making this a reference/ArrayView. Y and X are effectively
104    // cloned so perhaps this isn't a big deal.
105    linear_offset: Option<Array1<F>>,
106    /// The variance/analytic weights for each observation.
107    var_weights: Option<Array1<F>>,
108    /// The frequency/count of each observation.
109    freq_weights: Option<Array1<F>>,
110    /// Whether to standardize the input data. Defaults to `true`.
111    standardize: bool,
112    /// Whether to use an intercept term. Defaults to `true`.
113    use_intercept_term: bool,
114    /// An error that has come up in the build compilation.
115    error: Option<RegressionError<F>>,
116}
117
118/// A builder to generate a Model object
119impl<'a, M, Y, F> ModelBuilderData<'a, M, Y, F>
120where
121    M: Glm,
122    Y: Yval<M> + Copy,
123    F: Float,
124{
125    /// Represents an offset added to the linear predictor for each data point.
126    /// This can be used to control for fixed effects or in multi-level models.
127    pub fn linear_offset(mut self, linear_offset: Array1<F>) -> Self {
128        if self.linear_offset.is_some() {
129            self.error = Some(RegressionError::BuildError(
130                "Offsets specified multiple times".to_string(),
131            ));
132        }
133        self.linear_offset = Some(linear_offset);
134        self
135    }
136
137    /// Frequency weights (a.k.a. counts) for each observation. Traditionally these are positive
138    /// integers representing the number of times each observation appears identically.
139    pub fn freq_weights(mut self, freqs: Array1<usize>) -> Self {
140        if self.freq_weights.is_some() {
141            self.error = Some(RegressionError::BuildError(
142                "Frequency weights specified multiple times".to_string(),
143            ));
144        }
145        let ffreqs: Array1<F> = freqs.mapv(|c| F::from(c).unwrap());
146        // TODO: consider adding a check for non-negative weights
147        self.freq_weights = Some(ffreqs);
148        self
149    }
150
151    /// Variance weights (a.k.a. analytic weights) of each observation. These could represent the
152    /// inverse square of the uncertainties of each measurement.
153    pub fn var_weights(mut self, weights: Array1<F>) -> Self {
154        if self.var_weights.is_some() {
155            self.error = Some(RegressionError::BuildError(
156                "Variance weights specified multiple times".to_string(),
157            ));
158        }
159        // TODO: consider adding a check for non-negative weights
160        self.var_weights = Some(weights);
161        self
162    }
163
164    /// Do not add a constant intercept term of `1`s to the design matrix. This is rarely
165    /// recommended, so you probably don't want to use this option unless you have a very clear
166    /// sense of why. Note that you can supply uniform or per-observation constant terms using
167    /// [`ModelBuilderData::linear_offset`].
168    pub fn no_constant(mut self) -> Self {
169        self.use_intercept_term = false;
170        self
171    }
172
173    /// Don't perform standarization (i.e. scale to 0-mean and 1-variance) of the design matrix.
174    /// Note that the standardization is handled internally, so the reported result coefficients
175    /// should be compatible with the input data directly, meaning the user shouldn't have to
176    /// interact with them.
177    pub fn no_standardize(mut self) -> Self {
178        self.standardize = false;
179        self
180    }
181
182    pub fn build(self) -> RegressionResult<Model<M, F>, F>
183    where
184        M: Glm,
185        F: Float,
186    {
187        if let Some(err) = self.error {
188            return Err(err);
189        }
190
191        let n_data = self.data_y.len();
192        if n_data != self.data_x.nrows() {
193            return Err(RegressionError::BadInput(
194                "y and x data must have same number of points".to_string(),
195            ));
196        }
197        // If they are provided, check that the offsets have the correct number of entries
198        if let Some(lin_off) = &self.linear_offset
199            && n_data != lin_off.len()
200        {
201            return Err(RegressionError::BadInput(
202                "Offsets must have same dimension as observations".to_string(),
203            ));
204        }
205
206        // Check if the data is under-constrained
207        if n_data < self.data_x.ncols() {
208            // The regression can find a solution if n_data == ncols, but there will be
209            // no estimate for the uncertainty. Regularization can solve this, so keep
210            // it to a warning.
211            eprintln!("Warning: data is underconstrained");
212        }
213
214        // Put the data in column-major order, since broadcasting and summing over the observations
215        // are the more common operations.
216        // The shape should be trivially valid, so just unwrap it.
217        let data_x = self.data_x.to_shape((self.data_x.dim(), Order::F)).unwrap();
218
219        // convert y-values to floating-point
220        let data_y: Array1<F> = self
221            .data_y
222            .iter()
223            .map(|&y| y.into_float())
224            .collect::<Result<_, _>>()?;
225
226        // Build the Dataset object
227        let mut data = Dataset {
228            y: data_y,
229            x: data_x.to_owned(),
230            linear_offset: self.linear_offset,
231            weights: self.var_weights,
232            freqs: self.freq_weights,
233            has_intercept: false,
234            standardizer: None,
235        };
236
237        data.finalize_design_matrix(self.standardize, self.use_intercept_term);
238
239        Ok(Model {
240            model: PhantomData,
241            data,
242        })
243    }
244}