1use crate::{
4 error::{RegressionError, RegressionResult},
5 fit::{self, Fit},
6 glm::Glm,
7 math::is_rank_deficient,
8 num::Float,
9 response::Response,
10 utility::one_pad,
11};
12use fit::options::{FitConfig, FitOptions};
13use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
14use std::marker::PhantomData;
15
16#[derive(Clone)]
17pub struct Dataset<F>
18where
19 F: Float,
20{
21 pub y: Array1<F>,
23 pub x: Array2<F>,
25 pub linear_offset: Option<Array1<F>>,
29 pub weights: Option<Array1<F>>,
31 pub freqs: Option<Array1<F>>,
33}
34
35impl<F> Dataset<F>
36where
37 F: Float,
38{
39 pub fn linear_predictor(&self, regressors: &Array1<F>) -> Array1<F> {
44 let linear_predictor: Array1<F> = self.x.dot(regressors);
45 if let Some(lin_offset) = &self.linear_offset {
47 linear_predictor + lin_offset
48 } else {
49 linear_predictor
50 }
51 }
52
53 pub fn n_obs(&self) -> F {
55 match &self.freqs {
56 None => F::from(self.y.len()).unwrap(),
57 Some(f) => f.sum(),
58 }
59 }
60
61 pub(crate) fn sum_weights(&self) -> F {
64 match &self.weights {
65 None => self.n_obs(),
66 Some(w) => self.freq_sum(w),
67 }
68 }
69
70 pub fn n_eff(&self) -> F {
73 match &self.weights {
74 None => self.n_obs(),
75 Some(w) => {
76 let v1 = self.freq_sum(w);
77 let w2 = w * w;
78 let v2 = self.freq_sum(&w2);
79 v1 * v1 / v2
80 }
81 }
82 }
83
84 pub(crate) fn get_variance_weights(&self) -> Array1<F> {
85 match &self.weights {
86 Some(w) => w.clone(),
87 None => Array1::<F>::ones(self.y.len()),
88 }
89 }
90
91 pub(crate) fn apply_freq_weights(&self, rhs: Array1<F>) -> Array1<F> {
93 match &self.freqs {
94 None => rhs,
95 Some(f) => f * rhs,
96 }
97 }
98
99 pub(crate) fn apply_total_weights(&self, rhs: Array1<F>) -> Array1<F> {
101 self.apply_freq_weights(self.apply_var_weights(rhs))
102 }
103
104 pub(crate) fn apply_var_weights(&self, rhs: Array1<F>) -> Array1<F> {
105 match &self.weights {
106 None => rhs,
107 Some(w) => w * rhs,
108 }
109 }
110
111 pub(crate) fn freq_sum(&self, rhs: &Array1<F>) -> F {
115 self.apply_freq_weights(rhs.clone()).sum()
116 }
117
118 pub(crate) fn weighted_sum(&self, rhs: &Array1<F>) -> F {
120 self.freq_sum(&self.apply_var_weights(rhs.clone()))
121 }
122
123 pub(crate) fn x_conj(&self) -> Array2<F> {
125 let xt = self.x.t().to_owned();
126 let xt = match &self.freqs {
127 None => xt,
128 Some(f) => xt * f,
129 };
130 match &self.weights {
131 None => xt,
132 Some(w) => xt * w,
133 }
134 }
135}
136
137pub struct Model<M, F>
139where
140 M: Glm,
141 F: Float,
142{
143 pub(crate) model: PhantomData<M>,
144 pub data: Dataset<F>,
146 pub use_intercept: bool,
148}
149
150impl<M, F> Model<M, F>
151where
152 M: Glm,
153 F: Float,
154{
155 pub fn fit(&self) -> RegressionResult<Fit<'_, M, F>> {
157 self.fit_options().fit()
158 }
159
160 pub fn fit_options(&self) -> FitConfig<'_, M, F> {
162 FitConfig {
163 model: self,
164 options: FitOptions::default(),
165 }
166 }
167
168 pub fn with_options(&self, options: FitOptions<F>) -> FitConfig<'_, M, F> {
170 FitConfig {
171 model: self,
172 options,
173 }
174 }
175}
176
177pub struct ModelBuilder<M: Glm> {
180 _model: PhantomData<M>,
181}
182
183impl<M: Glm> ModelBuilder<M> {
184 pub fn data<'a, Y, F, YD, XD>(
188 data_y: &'a ArrayBase<YD, Ix1>,
189 data_x: &'a ArrayBase<XD, Ix2>,
190 ) -> ModelBuilderData<'a, M, Y, F>
191 where
192 Y: Response<M>,
193 F: Float,
194 YD: Data<Elem = Y>,
195 XD: Data<Elem = F>,
196 {
197 ModelBuilderData {
198 model: PhantomData,
199 data_y: data_y.view(),
200 data_x: data_x.view(),
201 linear_offset: None,
202 var_weights: None,
203 freq_weights: None,
204 use_intercept_term: true,
205 colin_tol: F::epsilon(),
206 error: None,
207 }
208 }
209}
210
211pub struct ModelBuilderData<'a, M, Y, F>
214where
215 M: Glm,
216 Y: Response<M>,
217 F: 'static + Float,
218{
219 model: PhantomData<M>,
220 data_y: ArrayView1<'a, Y>,
222 data_x: ArrayView2<'a, F>,
225 linear_offset: Option<Array1<F>>,
230 var_weights: Option<Array1<F>>,
232 freq_weights: Option<Array1<F>>,
234 use_intercept_term: bool,
236 colin_tol: F,
238 error: Option<RegressionError>,
240}
241
242impl<'a, M, Y, F> ModelBuilderData<'a, M, Y, F>
244where
245 M: Glm,
246 Y: Response<M> + Copy,
247 F: Float,
248{
249 pub fn linear_offset(mut self, linear_offset: Array1<F>) -> Self {
252 if self.linear_offset.is_some() {
253 self.error = Some(RegressionError::BuildError(
254 "Offsets specified multiple times".to_string(),
255 ));
256 }
257 self.linear_offset = Some(linear_offset);
258 self
259 }
260
261 pub fn freq_weights(mut self, freqs: Array1<usize>) -> Self {
264 if self.freq_weights.is_some() {
265 self.error = Some(RegressionError::BuildError(
266 "Frequency weights specified multiple times".to_string(),
267 ));
268 }
269 let ffreqs: Array1<F> = freqs.mapv(|c| F::from(c).unwrap());
270 self.freq_weights = Some(ffreqs);
272 self
273 }
274
275 pub fn var_weights(mut self, weights: Array1<F>) -> Self {
278 if self.var_weights.is_some() {
279 self.error = Some(RegressionError::BuildError(
280 "Variance weights specified multiple times".to_string(),
281 ));
282 }
283 self.var_weights = Some(weights);
285 self
286 }
287
288 pub fn no_constant(mut self) -> Self {
290 self.use_intercept_term = false;
291 self
292 }
293
294 pub fn colinear_tol(mut self, tol: F) -> Self {
297 self.colin_tol = tol;
298 self
299 }
300
301 pub fn build(self) -> RegressionResult<Model<M, F>>
302 where
303 M: Glm,
304 F: Float,
305 {
306 if let Some(err) = self.error {
307 return Err(err);
308 }
309
310 let n_data = self.data_y.len();
311 if n_data != self.data_x.nrows() {
312 return Err(RegressionError::BadInput(
313 "y and x data must have same number of points".to_string(),
314 ));
315 }
316 if let Some(lin_off) = &self.linear_offset {
318 if n_data != lin_off.len() {
319 return Err(RegressionError::BadInput(
320 "Offsets must have same dimension as observations".to_string(),
321 ));
322 }
323 }
324
325 let data_x = if self.use_intercept_term {
327 one_pad(self.data_x)
328 } else {
329 self.data_x.to_owned()
330 };
331 if n_data < data_x.ncols() {
333 eprintln!("Warning: data is underconstrained");
338 }
339 let xtx: Array2<F> = data_x.t().dot(&data_x);
341 if is_rank_deficient(xtx, self.colin_tol)? {
342 return Err(RegressionError::ColinearData);
343 }
344
345 let data_y: Array1<F> = self
347 .data_y
348 .iter()
349 .map(|&y| y.into_float())
350 .collect::<Result<_, _>>()?;
351
352 Ok(Model {
353 model: PhantomData,
354 data: Dataset {
355 y: data_y,
356 x: data_x,
357 linear_offset: self.linear_offset,
358 weights: self.var_weights,
359 freqs: self.freq_weights,
360 },
361 use_intercept: self.use_intercept_term,
362 })
363 }
364}