Skip to main content

ferrolearn_linear/
glm.rs

1//! Generalized Linear Models (GLM).
2//!
3//! This module provides IRLS-based GLM regressors for count and positive
4//! continuous data:
5//!
6//! - **[`GLMRegressor`]** — Generic GLM with selectable [`GLMFamily`]
7//! - **[`PoissonRegressor`]** — Convenience wrapper with Poisson family
8//! - **[`GammaRegressor`]** — Convenience wrapper with Gamma family
9//! - **[`TweedieRegressor`]** — Convenience wrapper with Tweedie family
10//!
11//! All models use Iteratively Reweighted Least Squares (IRLS) with a log
12//! link function and L2 regularization.
13//!
14//! # Examples
15//!
16//! ```
17//! use ferrolearn_linear::PoissonRegressor;
18//! use ferrolearn_core::{Fit, Predict};
19//! use ndarray::{array, Array1, Array2};
20//!
21//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
22//! let y = array![2.0, 5.0, 10.0, 20.0];
23//!
24//! let model = PoissonRegressor::<f64>::new().with_alpha(0.0);
25//! let fitted = model.fit(&x, &y).unwrap();
26//! let preds = fitted.predict(&x).unwrap();
27//! assert_eq!(preds.len(), 4);
28//! ```
29
30use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::HasCoefficients;
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2, ScalarOperand};
35use num_traits::{Float, FromPrimitive};
36
37// ---------------------------------------------------------------------------
38// GLMFamily
39// ---------------------------------------------------------------------------
40
41/// The distributional family for a Generalized Linear Model.
42///
43/// Determines the variance function V(mu):
44/// - **Poisson**: V(mu) = mu
45/// - **Gamma**: V(mu) = mu^2
46/// - **Tweedie(p)**: V(mu) = mu^p
47#[derive(Debug, Clone, Copy)]
48pub enum GLMFamily {
49    /// Poisson family — variance proportional to the mean.
50    Poisson,
51    /// Gamma family — variance proportional to the squared mean.
52    Gamma,
53    /// Tweedie family with power parameter `p`.
54    ///
55    /// - `p = 0` gives Normal (constant variance)
56    /// - `p = 1` gives Poisson
57    /// - `p = 2` gives Gamma
58    /// - `1 < p < 2` gives compound Poisson-Gamma
59    Tweedie(f64),
60}
61
62impl GLMFamily {
63    /// Compute the variance function V(mu) for a given mean `mu`.
64    fn variance<F: Float + FromPrimitive>(&self, mu: F) -> F {
65        match self {
66            GLMFamily::Poisson => mu,
67            GLMFamily::Gamma => mu * mu,
68            GLMFamily::Tweedie(p) => {
69                let power = F::from(*p).unwrap();
70                mu.powf(power)
71            }
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// GLMRegressor
78// ---------------------------------------------------------------------------
79
80/// Generalized Linear Model regressor.
81///
82/// Fitted via IRLS with a log link function. The [`GLMFamily`] controls
83/// the assumed variance-mean relationship.
84///
85/// # Type Parameters
86///
87/// - `F`: The floating-point type (`f32` or `f64`).
88#[derive(Debug, Clone)]
89pub struct GLMRegressor<F> {
90    /// Distributional family (Poisson, Gamma, or Tweedie).
91    pub family: GLMFamily,
92    /// L2 regularization strength.
93    pub alpha: F,
94    /// Maximum number of IRLS iterations.
95    pub max_iter: usize,
96    /// Convergence tolerance on the maximum coefficient change.
97    pub tol: F,
98    /// Whether to fit an intercept (bias) term.
99    pub fit_intercept: bool,
100}
101
102impl<F: Float + FromPrimitive> GLMRegressor<F> {
103    /// Create a new `GLMRegressor` with the given family.
104    ///
105    /// Defaults: `alpha = 1.0`, `max_iter = 100`, `tol = 1e-4`,
106    /// `fit_intercept = true`.
107    #[must_use]
108    pub fn new(family: GLMFamily) -> Self {
109        Self {
110            family,
111            alpha: F::one(),
112            max_iter: 100,
113            tol: F::from(1e-4).unwrap(),
114            fit_intercept: true,
115        }
116    }
117
118    /// Set the L2 regularization strength.
119    #[must_use]
120    pub fn with_alpha(mut self, alpha: F) -> Self {
121        self.alpha = alpha;
122        self
123    }
124
125    /// Set the maximum number of IRLS iterations.
126    #[must_use]
127    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
128        self.max_iter = max_iter;
129        self
130    }
131
132    /// Set the convergence tolerance.
133    #[must_use]
134    pub fn with_tol(mut self, tol: F) -> Self {
135        self.tol = tol;
136        self
137    }
138
139    /// Set whether to fit an intercept term.
140    #[must_use]
141    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
142        self.fit_intercept = fit_intercept;
143        self
144    }
145}
146
147/// Fitted GLM regressor.
148///
149/// Stores the learned coefficients and intercept on the log-link scale.
150/// Predictions are computed as `exp(X @ coef + intercept)`.
151#[derive(Debug, Clone)]
152pub struct FittedGLMRegressor<F> {
153    /// Learned coefficient vector on the log scale.
154    coefficients: Array1<F>,
155    /// Learned intercept on the log scale.
156    intercept: F,
157}
158
159// ---------------------------------------------------------------------------
160// Convenience wrappers
161// ---------------------------------------------------------------------------
162
163/// Poisson regressor — GLM with Poisson family and log link.
164///
165/// Suitable for modelling count data (y >= 0, integer-valued).
166///
167/// # Type Parameters
168///
169/// - `F`: The floating-point type (`f32` or `f64`).
170#[derive(Debug, Clone)]
171pub struct PoissonRegressor<F> {
172    /// L2 regularization strength.
173    pub alpha: F,
174    /// Maximum number of IRLS iterations.
175    pub max_iter: usize,
176    /// Convergence tolerance.
177    pub tol: F,
178    /// Whether to fit an intercept.
179    pub fit_intercept: bool,
180}
181
182impl<F: Float + FromPrimitive> PoissonRegressor<F> {
183    /// Create a new `PoissonRegressor` with default settings.
184    ///
185    /// Defaults: `alpha = 1.0`, `max_iter = 100`, `tol = 1e-4`,
186    /// `fit_intercept = true`.
187    #[must_use]
188    pub fn new() -> Self {
189        Self {
190            alpha: F::one(),
191            max_iter: 100,
192            tol: F::from(1e-4).unwrap(),
193            fit_intercept: true,
194        }
195    }
196
197    /// Set the L2 regularization strength.
198    #[must_use]
199    pub fn with_alpha(mut self, alpha: F) -> Self {
200        self.alpha = alpha;
201        self
202    }
203
204    /// Set the maximum number of IRLS iterations.
205    #[must_use]
206    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
207        self.max_iter = max_iter;
208        self
209    }
210
211    /// Set the convergence tolerance.
212    #[must_use]
213    pub fn with_tol(mut self, tol: F) -> Self {
214        self.tol = tol;
215        self
216    }
217
218    /// Set whether to fit an intercept.
219    #[must_use]
220    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
221        self.fit_intercept = fit_intercept;
222        self
223    }
224}
225
226impl<F: Float + FromPrimitive> Default for PoissonRegressor<F> {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232/// Gamma regressor — GLM with Gamma family and log link.
233///
234/// Suitable for modelling positive continuous data (y > 0).
235///
236/// # Type Parameters
237///
238/// - `F`: The floating-point type (`f32` or `f64`).
239#[derive(Debug, Clone)]
240pub struct GammaRegressor<F> {
241    /// L2 regularization strength.
242    pub alpha: F,
243    /// Maximum number of IRLS iterations.
244    pub max_iter: usize,
245    /// Convergence tolerance.
246    pub tol: F,
247    /// Whether to fit an intercept.
248    pub fit_intercept: bool,
249}
250
251impl<F: Float + FromPrimitive> GammaRegressor<F> {
252    /// Create a new `GammaRegressor` with default settings.
253    ///
254    /// Defaults: `alpha = 1.0`, `max_iter = 100`, `tol = 1e-4`,
255    /// `fit_intercept = true`.
256    #[must_use]
257    pub fn new() -> Self {
258        Self {
259            alpha: F::one(),
260            max_iter: 100,
261            tol: F::from(1e-4).unwrap(),
262            fit_intercept: true,
263        }
264    }
265
266    /// Set the L2 regularization strength.
267    #[must_use]
268    pub fn with_alpha(mut self, alpha: F) -> Self {
269        self.alpha = alpha;
270        self
271    }
272
273    /// Set the maximum number of IRLS iterations.
274    #[must_use]
275    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
276        self.max_iter = max_iter;
277        self
278    }
279
280    /// Set the convergence tolerance.
281    #[must_use]
282    pub fn with_tol(mut self, tol: F) -> Self {
283        self.tol = tol;
284        self
285    }
286
287    /// Set whether to fit an intercept.
288    #[must_use]
289    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
290        self.fit_intercept = fit_intercept;
291        self
292    }
293}
294
295impl<F: Float + FromPrimitive> Default for GammaRegressor<F> {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301/// Tweedie regressor — GLM with Tweedie family and log link.
302///
303/// The `power` parameter controls the variance-mean relationship:
304/// V(mu) = mu^power.
305///
306/// # Type Parameters
307///
308/// - `F`: The floating-point type (`f32` or `f64`).
309#[derive(Debug, Clone)]
310pub struct TweedieRegressor<F> {
311    /// Tweedie power parameter.
312    pub power: f64,
313    /// L2 regularization strength.
314    pub alpha: F,
315    /// Maximum number of IRLS iterations.
316    pub max_iter: usize,
317    /// Convergence tolerance.
318    pub tol: F,
319    /// Whether to fit an intercept.
320    pub fit_intercept: bool,
321}
322
323impl<F: Float + FromPrimitive> TweedieRegressor<F> {
324    /// Create a new `TweedieRegressor` with default settings.
325    ///
326    /// Defaults: `power = 1.5`, `alpha = 1.0`, `max_iter = 100`,
327    /// `tol = 1e-4`, `fit_intercept = true`.
328    #[must_use]
329    pub fn new() -> Self {
330        Self {
331            power: 1.5,
332            alpha: F::one(),
333            max_iter: 100,
334            tol: F::from(1e-4).unwrap(),
335            fit_intercept: true,
336        }
337    }
338
339    /// Set the Tweedie power parameter.
340    #[must_use]
341    pub fn with_power(mut self, power: f64) -> Self {
342        self.power = power;
343        self
344    }
345
346    /// Set the L2 regularization strength.
347    #[must_use]
348    pub fn with_alpha(mut self, alpha: F) -> Self {
349        self.alpha = alpha;
350        self
351    }
352
353    /// Set the maximum number of IRLS iterations.
354    #[must_use]
355    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
356        self.max_iter = max_iter;
357        self
358    }
359
360    /// Set the convergence tolerance.
361    #[must_use]
362    pub fn with_tol(mut self, tol: F) -> Self {
363        self.tol = tol;
364        self
365    }
366
367    /// Set whether to fit an intercept.
368    #[must_use]
369    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
370        self.fit_intercept = fit_intercept;
371        self
372    }
373}
374
375impl<F: Float + FromPrimitive> Default for TweedieRegressor<F> {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Internal helpers
383// ---------------------------------------------------------------------------
384
385/// Cholesky solve for `A x = b`.
386fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
387    let n = a.nrows();
388    let mut l = Array2::<F>::zeros((n, n));
389
390    for i in 0..n {
391        for j in 0..=i {
392            let mut s = a[[i, j]];
393            for k in 0..j {
394                s = s - l[[i, k]] * l[[j, k]];
395            }
396            if i == j {
397                if s <= F::zero() {
398                    return Err(FerroError::NumericalInstability {
399                        message: "Cholesky: matrix not positive definite".into(),
400                    });
401                }
402                l[[i, j]] = s.sqrt();
403            } else {
404                l[[i, j]] = s / l[[j, j]];
405            }
406        }
407    }
408
409    let mut z = Array1::<F>::zeros(n);
410    for i in 0..n {
411        let mut s = b[i];
412        for k in 0..i {
413            s = s - l[[i, k]] * z[k];
414        }
415        z[i] = s / l[[i, i]];
416    }
417
418    let mut x_sol = Array1::<F>::zeros(n);
419    for i in (0..n).rev() {
420        let mut s = z[i];
421        for k in (i + 1)..n {
422            s = s - l[[k, i]] * x_sol[k];
423        }
424        x_sol[i] = s / l[[i, i]];
425    }
426
427    Ok(x_sol)
428}
429
430/// Gaussian elimination with partial pivoting.
431fn gaussian_solve<F: Float>(
432    n: usize,
433    a: &Array2<F>,
434    b: &Array1<F>,
435) -> Result<Array1<F>, FerroError> {
436    let mut aug = Array2::<F>::zeros((n, n + 1));
437    for i in 0..n {
438        for j in 0..n {
439            aug[[i, j]] = a[[i, j]];
440        }
441        aug[[i, n]] = b[i];
442    }
443
444    for col in 0..n {
445        let mut max_val = aug[[col, col]].abs();
446        let mut max_row = col;
447        for row in (col + 1)..n {
448            let v = aug[[row, col]].abs();
449            if v > max_val {
450                max_val = v;
451                max_row = row;
452            }
453        }
454
455        if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
456            return Err(FerroError::NumericalInstability {
457                message: "singular matrix in Gaussian elimination".into(),
458            });
459        }
460
461        if max_row != col {
462            for j in 0..=n {
463                let tmp = aug[[col, j]];
464                aug[[col, j]] = aug[[max_row, j]];
465                aug[[max_row, j]] = tmp;
466            }
467        }
468
469        let pivot = aug[[col, col]];
470        for row in (col + 1)..n {
471            let factor = aug[[row, col]] / pivot;
472            for j in col..=n {
473                let above = aug[[col, j]];
474                aug[[row, j]] = aug[[row, j]] - factor * above;
475            }
476        }
477    }
478
479    let mut x_sol = Array1::<F>::zeros(n);
480    for i in (0..n).rev() {
481        let mut s = aug[[i, n]];
482        for j in (i + 1)..n {
483            s = s - aug[[i, j]] * x_sol[j];
484        }
485        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
486            return Err(FerroError::NumericalInstability {
487                message: "near-zero pivot in back substitution".into(),
488            });
489        }
490        x_sol[i] = s / aug[[i, i]];
491    }
492
493    Ok(x_sol)
494}
495
496/// Solve the weighted ridge system `(X^T W X + alpha I) w = X^T W z`.
497fn weighted_ridge_solve<F: Float + FromPrimitive>(
498    x: &Array2<F>,
499    z: &Array1<F>,
500    weights: &Array1<F>,
501    alpha: F,
502) -> Result<Array1<F>, FerroError> {
503    let (n_samples, n_features) = x.dim();
504
505    let mut xtwx = Array2::<F>::zeros((n_features, n_features));
506    let mut xtwz = Array1::<F>::zeros(n_features);
507
508    for i in 0..n_samples {
509        let wi = weights[i];
510        let xi = x.row(i);
511        for r in 0..n_features {
512            xtwz[r] = xtwz[r] + wi * xi[r] * z[i];
513            for c in 0..n_features {
514                xtwx[[r, c]] = xtwx[[r, c]] + wi * xi[r] * xi[c];
515            }
516        }
517    }
518
519    // Add L2 regularization (do not penalise intercept column if present).
520    for i in 0..n_features {
521        xtwx[[i, i]] = xtwx[[i, i]] + alpha;
522    }
523
524    cholesky_solve(&xtwx, &xtwz).or_else(|_| gaussian_solve(n_features, &xtwx, &xtwz))
525}
526
527/// Core IRLS fitting logic shared by all GLM variants.
528fn fit_glm_irls<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static>(
529    x: &Array2<F>,
530    y: &Array1<F>,
531    family: &GLMFamily,
532    alpha: F,
533    max_iter: usize,
534    tol: F,
535    fit_intercept: bool,
536) -> Result<FittedGLMRegressor<F>, FerroError> {
537    let (n_samples, n_features_orig) = x.dim();
538
539    if n_samples != y.len() {
540        return Err(FerroError::ShapeMismatch {
541            expected: vec![n_samples],
542            actual: vec![y.len()],
543            context: "y length must match number of samples in X".into(),
544        });
545    }
546
547    if n_samples == 0 {
548        return Err(FerroError::InsufficientSamples {
549            required: 1,
550            actual: 0,
551            context: "GLM requires at least one sample".into(),
552        });
553    }
554
555    if alpha < F::zero() {
556        return Err(FerroError::InvalidParameter {
557            name: "alpha".into(),
558            reason: "must be non-negative".into(),
559        });
560    }
561
562    // All y values must be positive for log link.
563    let min_y = F::from(1e-10).unwrap();
564    for &yi in y.iter() {
565        if yi < F::zero() {
566            return Err(FerroError::InvalidParameter {
567                name: "y".into(),
568                reason: "target values must be non-negative for GLM with log link".into(),
569            });
570        }
571    }
572
573    // Build design matrix (optionally prepend intercept column).
574    let n_cols = if fit_intercept {
575        n_features_orig + 1
576    } else {
577        n_features_orig
578    };
579
580    let mut x_design = Array2::<F>::zeros((n_samples, n_cols));
581    if fit_intercept {
582        for i in 0..n_samples {
583            x_design[[i, 0]] = F::one();
584            for j in 0..n_features_orig {
585                x_design[[i, j + 1]] = x[[i, j]];
586            }
587        }
588    } else {
589        x_design.assign(x);
590    }
591
592    // Clamp y for log.
593    let y_safe: Array1<F> = y.mapv(|v| if v < min_y { min_y } else { v });
594
595    // Initialise eta = log(y), mu = y.
596    let mut eta: Array1<F> = y_safe.mapv(|v| v.ln());
597    let mut mu: Array1<F> = y_safe.clone();
598    let mut coef = Array1::<F>::zeros(n_cols);
599
600    let min_mu = F::from(1e-10).unwrap();
601    let max_mu = F::from(1e10).unwrap();
602
603    for _iter in 0..max_iter {
604        let coef_old = coef.clone();
605
606        // Compute IRLS weights and working response.
607        let mut weights = Array1::<F>::zeros(n_samples);
608        let mut z = Array1::<F>::zeros(n_samples);
609
610        for i in 0..n_samples {
611            let mu_i = mu[i].max(min_mu).min(max_mu);
612            let var_i = family.variance(mu_i).max(min_mu);
613            // Log link: g'(mu) = 1/mu, so working response = eta + (y - mu)/mu
614            //           weight = mu^2 / V(mu)  (from W = 1/(g'^2 * V))
615            let g_prime = F::one() / mu_i; // derivative of log link
616            z[i] = eta[i] + (y_safe[i] - mu_i) * g_prime;
617            weights[i] = F::one() / (g_prime * g_prime * var_i);
618            // Clamp weight.
619            if weights[i] < min_mu {
620                weights[i] = min_mu;
621            }
622        }
623
624        // Solve weighted ridge.
625        coef = weighted_ridge_solve(&x_design, &z, &weights, alpha)?;
626
627        // Update eta and mu.
628        eta = x_design.dot(&coef);
629        for i in 0..n_samples {
630            // Clamp eta to prevent overflow in exp.
631            let eta_i = eta[i].max(F::from(-20.0).unwrap()).min(F::from(20.0).unwrap());
632            eta[i] = eta_i;
633            mu[i] = eta_i.exp().max(min_mu).min(max_mu);
634        }
635
636        // Check convergence.
637        let max_change = coef
638            .iter()
639            .zip(coef_old.iter())
640            .map(|(&c, &co)| (c - co).abs())
641            .fold(F::zero(), |a, b| if b > a { b } else { a });
642
643        if max_change < tol {
644            break;
645        }
646    }
647
648    // Extract intercept and feature coefficients.
649    let (intercept, coefficients) = if fit_intercept {
650        let intercept = coef[0];
651        let coefficients = Array1::from_iter(coef.iter().skip(1).copied());
652        (intercept, coefficients)
653    } else {
654        (F::zero(), coef)
655    };
656
657    Ok(FittedGLMRegressor {
658        coefficients,
659        intercept,
660    })
661}
662
663// ---------------------------------------------------------------------------
664// Fit — GLMRegressor
665// ---------------------------------------------------------------------------
666
667impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
668    for GLMRegressor<F>
669{
670    type Fitted = FittedGLMRegressor<F>;
671    type Error = FerroError;
672
673    /// Fit the GLM via IRLS.
674    ///
675    /// # Errors
676    ///
677    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
678    /// - [`FerroError::InsufficientSamples`] — zero samples.
679    /// - [`FerroError::InvalidParameter`] — negative alpha or negative y.
680    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedGLMRegressor<F>, FerroError> {
681        fit_glm_irls(
682            x,
683            y,
684            &self.family,
685            self.alpha,
686            self.max_iter,
687            self.tol,
688            self.fit_intercept,
689        )
690    }
691}
692
693// ---------------------------------------------------------------------------
694// Fit — PoissonRegressor
695// ---------------------------------------------------------------------------
696
697impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
698    for PoissonRegressor<F>
699{
700    type Fitted = FittedGLMRegressor<F>;
701    type Error = FerroError;
702
703    /// Fit the Poisson GLM via IRLS.
704    ///
705    /// # Errors
706    ///
707    /// See [`GLMRegressor::fit`].
708    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedGLMRegressor<F>, FerroError> {
709        fit_glm_irls(
710            x,
711            y,
712            &GLMFamily::Poisson,
713            self.alpha,
714            self.max_iter,
715            self.tol,
716            self.fit_intercept,
717        )
718    }
719}
720
721// ---------------------------------------------------------------------------
722// Fit — GammaRegressor
723// ---------------------------------------------------------------------------
724
725impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
726    for GammaRegressor<F>
727{
728    type Fitted = FittedGLMRegressor<F>;
729    type Error = FerroError;
730
731    /// Fit the Gamma GLM via IRLS.
732    ///
733    /// # Errors
734    ///
735    /// See [`GLMRegressor::fit`].
736    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedGLMRegressor<F>, FerroError> {
737        fit_glm_irls(
738            x,
739            y,
740            &GLMFamily::Gamma,
741            self.alpha,
742            self.max_iter,
743            self.tol,
744            self.fit_intercept,
745        )
746    }
747}
748
749// ---------------------------------------------------------------------------
750// Fit — TweedieRegressor
751// ---------------------------------------------------------------------------
752
753impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
754    for TweedieRegressor<F>
755{
756    type Fitted = FittedGLMRegressor<F>;
757    type Error = FerroError;
758
759    /// Fit the Tweedie GLM via IRLS.
760    ///
761    /// # Errors
762    ///
763    /// See [`GLMRegressor::fit`].
764    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedGLMRegressor<F>, FerroError> {
765        fit_glm_irls(
766            x,
767            y,
768            &GLMFamily::Tweedie(self.power),
769            self.alpha,
770            self.max_iter,
771            self.tol,
772            self.fit_intercept,
773        )
774    }
775}
776
777// ---------------------------------------------------------------------------
778// Predict / HasCoefficients / Pipeline — FittedGLMRegressor
779// ---------------------------------------------------------------------------
780
781impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
782    for FittedGLMRegressor<F>
783{
784    type Output = Array1<F>;
785    type Error = FerroError;
786
787    /// Predict using the fitted GLM.
788    ///
789    /// Computes `exp(X @ coefficients + intercept)` (inverse log link).
790    ///
791    /// # Errors
792    ///
793    /// Returns [`FerroError::ShapeMismatch`] if the number of features
794    /// does not match the fitted model.
795    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
796        if x.ncols() != self.coefficients.len() {
797            return Err(FerroError::ShapeMismatch {
798                expected: vec![self.coefficients.len()],
799                actual: vec![x.ncols()],
800                context: "number of features must match fitted model".into(),
801            });
802        }
803        let eta = x.dot(&self.coefficients) + self.intercept;
804        Ok(eta.mapv(|v| v.exp()))
805    }
806}
807
808impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
809    for FittedGLMRegressor<F>
810{
811    fn coefficients(&self) -> &Array1<F> {
812        &self.coefficients
813    }
814
815    fn intercept(&self) -> F {
816        self.intercept
817    }
818}
819
820// Pipeline integration for GLMRegressor.
821impl<F> PipelineEstimator<F> for GLMRegressor<F>
822where
823    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
824{
825    fn fit_pipeline(
826        &self,
827        x: &Array2<F>,
828        y: &Array1<F>,
829    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
830        let fitted = self.fit(x, y)?;
831        Ok(Box::new(fitted))
832    }
833}
834
835impl<F> FittedPipelineEstimator<F> for FittedGLMRegressor<F>
836where
837    F: Float + ScalarOperand + Send + Sync + 'static,
838{
839    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
840        self.predict(x)
841    }
842}
843
844// Pipeline integration for PoissonRegressor.
845impl<F> PipelineEstimator<F> for PoissonRegressor<F>
846where
847    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
848{
849    fn fit_pipeline(
850        &self,
851        x: &Array2<F>,
852        y: &Array1<F>,
853    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
854        let fitted = self.fit(x, y)?;
855        Ok(Box::new(fitted))
856    }
857}
858
859// Pipeline integration for GammaRegressor.
860impl<F> PipelineEstimator<F> for GammaRegressor<F>
861where
862    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
863{
864    fn fit_pipeline(
865        &self,
866        x: &Array2<F>,
867        y: &Array1<F>,
868    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
869        let fitted = self.fit(x, y)?;
870        Ok(Box::new(fitted))
871    }
872}
873
874// Pipeline integration for TweedieRegressor.
875impl<F> PipelineEstimator<F> for TweedieRegressor<F>
876where
877    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
878{
879    fn fit_pipeline(
880        &self,
881        x: &Array2<F>,
882        y: &Array1<F>,
883    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
884        let fitted = self.fit(x, y)?;
885        Ok(Box::new(fitted))
886    }
887}
888
889// ---------------------------------------------------------------------------
890// Tests
891// ---------------------------------------------------------------------------
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896    use approx::assert_relative_eq;
897    use ndarray::array;
898
899    // ---- GLMRegressor ----
900
901    #[test]
902    fn test_glm_poisson_defaults() {
903        let m = GLMRegressor::<f64>::new(GLMFamily::Poisson);
904        assert_relative_eq!(m.alpha, 1.0);
905        assert_eq!(m.max_iter, 100);
906        assert!(m.fit_intercept);
907    }
908
909    #[test]
910    fn test_glm_builder() {
911        let m = GLMRegressor::<f64>::new(GLMFamily::Gamma)
912            .with_alpha(0.5)
913            .with_max_iter(200)
914            .with_tol(1e-6)
915            .with_fit_intercept(false);
916        assert_relative_eq!(m.alpha, 0.5);
917        assert_eq!(m.max_iter, 200);
918        assert!(!m.fit_intercept);
919    }
920
921    #[test]
922    fn test_glm_shape_mismatch() {
923        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
924        let y = array![1.0, 2.0];
925        assert!(GLMRegressor::<f64>::new(GLMFamily::Poisson)
926            .fit(&x, &y)
927            .is_err());
928    }
929
930    #[test]
931    fn test_glm_negative_alpha() {
932        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
933        let y = array![1.0, 2.0, 3.0];
934        assert!(GLMRegressor::<f64>::new(GLMFamily::Poisson)
935            .with_alpha(-1.0)
936            .fit(&x, &y)
937            .is_err());
938    }
939
940    #[test]
941    fn test_glm_poisson_fit_predict() {
942        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
943        let y = array![2.0, 5.0, 10.0, 20.0];
944
945        let fitted = GLMRegressor::<f64>::new(GLMFamily::Poisson)
946            .with_alpha(0.0)
947            .with_max_iter(200)
948            .fit(&x, &y)
949            .unwrap();
950        let preds = fitted.predict(&x).unwrap();
951        assert_eq!(preds.len(), 4);
952        // Predictions should be positive.
953        for &p in preds.iter() {
954            assert!(p > 0.0);
955        }
956    }
957
958    #[test]
959    fn test_glm_gamma_fit_predict() {
960        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
961        let y = array![2.0, 5.0, 10.0, 20.0];
962
963        let fitted = GLMRegressor::<f64>::new(GLMFamily::Gamma)
964            .with_alpha(0.0)
965            .with_max_iter(200)
966            .fit(&x, &y)
967            .unwrap();
968        let preds = fitted.predict(&x).unwrap();
969        assert_eq!(preds.len(), 4);
970        for &p in preds.iter() {
971            assert!(p > 0.0);
972        }
973    }
974
975    #[test]
976    fn test_glm_tweedie_fit_predict() {
977        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
978        let y = array![2.0, 5.0, 10.0, 20.0];
979
980        let fitted = GLMRegressor::<f64>::new(GLMFamily::Tweedie(1.5))
981            .with_alpha(0.0)
982            .with_max_iter(200)
983            .fit(&x, &y)
984            .unwrap();
985        let preds = fitted.predict(&x).unwrap();
986        assert_eq!(preds.len(), 4);
987        for &p in preds.iter() {
988            assert!(p > 0.0);
989        }
990    }
991
992    #[test]
993    fn test_glm_predict_feature_mismatch() {
994        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
995        let y = array![1.0, 2.0, 3.0];
996        let fitted = GLMRegressor::<f64>::new(GLMFamily::Poisson)
997            .fit(&x, &y)
998            .unwrap();
999        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1000        assert!(fitted.predict(&x_bad).is_err());
1001    }
1002
1003    #[test]
1004    fn test_glm_has_coefficients() {
1005        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1006        let y = array![1.0, 2.0, 3.0];
1007        let fitted = GLMRegressor::<f64>::new(GLMFamily::Poisson)
1008            .fit(&x, &y)
1009            .unwrap();
1010        assert_eq!(fitted.coefficients().len(), 2);
1011    }
1012
1013    #[test]
1014    fn test_glm_pipeline() {
1015        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1016        let y = array![2.0, 5.0, 10.0, 20.0];
1017        let model = GLMRegressor::<f64>::new(GLMFamily::Poisson).with_alpha(0.0);
1018        let fitted = model.fit_pipeline(&x, &y).unwrap();
1019        let preds = fitted.predict_pipeline(&x).unwrap();
1020        assert_eq!(preds.len(), 4);
1021    }
1022
1023    // ---- PoissonRegressor ----
1024
1025    #[test]
1026    fn test_poisson_defaults() {
1027        let m = PoissonRegressor::<f64>::new();
1028        assert_relative_eq!(m.alpha, 1.0);
1029        assert_eq!(m.max_iter, 100);
1030        assert!(m.fit_intercept);
1031    }
1032
1033    #[test]
1034    fn test_poisson_fit_predict() {
1035        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1036        let y = array![2.0, 5.0, 10.0, 20.0];
1037
1038        let fitted = PoissonRegressor::<f64>::new()
1039            .with_alpha(0.0)
1040            .with_max_iter(200)
1041            .fit(&x, &y)
1042            .unwrap();
1043        let preds = fitted.predict(&x).unwrap();
1044        assert_eq!(preds.len(), 4);
1045        for &p in preds.iter() {
1046            assert!(p > 0.0);
1047        }
1048    }
1049
1050    #[test]
1051    fn test_poisson_pipeline() {
1052        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1053        let y = array![2.0, 5.0, 10.0, 20.0];
1054        let fitted = PoissonRegressor::<f64>::new()
1055            .with_alpha(0.0)
1056            .fit_pipeline(&x, &y)
1057            .unwrap();
1058        let preds = fitted.predict_pipeline(&x).unwrap();
1059        assert_eq!(preds.len(), 4);
1060    }
1061
1062    // ---- GammaRegressor ----
1063
1064    #[test]
1065    fn test_gamma_defaults() {
1066        let m = GammaRegressor::<f64>::new();
1067        assert_relative_eq!(m.alpha, 1.0);
1068        assert_eq!(m.max_iter, 100);
1069    }
1070
1071    #[test]
1072    fn test_gamma_fit_predict() {
1073        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1074        let y = array![2.0, 5.0, 10.0, 20.0];
1075
1076        let fitted = GammaRegressor::<f64>::new()
1077            .with_alpha(0.0)
1078            .with_max_iter(200)
1079            .fit(&x, &y)
1080            .unwrap();
1081        let preds = fitted.predict(&x).unwrap();
1082        assert_eq!(preds.len(), 4);
1083        for &p in preds.iter() {
1084            assert!(p > 0.0);
1085        }
1086    }
1087
1088    #[test]
1089    fn test_gamma_pipeline() {
1090        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1091        let y = array![2.0, 5.0, 10.0, 20.0];
1092        let fitted = GammaRegressor::<f64>::new()
1093            .with_alpha(0.0)
1094            .fit_pipeline(&x, &y)
1095            .unwrap();
1096        let preds = fitted.predict_pipeline(&x).unwrap();
1097        assert_eq!(preds.len(), 4);
1098    }
1099
1100    // ---- TweedieRegressor ----
1101
1102    #[test]
1103    fn test_tweedie_defaults() {
1104        let m = TweedieRegressor::<f64>::new();
1105        assert_relative_eq!(m.power, 1.5);
1106        assert_relative_eq!(m.alpha, 1.0);
1107    }
1108
1109    #[test]
1110    fn test_tweedie_fit_predict() {
1111        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1112        let y = array![2.0, 5.0, 10.0, 20.0];
1113
1114        let fitted = TweedieRegressor::<f64>::new()
1115            .with_power(1.5)
1116            .with_alpha(0.0)
1117            .with_max_iter(200)
1118            .fit(&x, &y)
1119            .unwrap();
1120        let preds = fitted.predict(&x).unwrap();
1121        assert_eq!(preds.len(), 4);
1122        for &p in preds.iter() {
1123            assert!(p > 0.0);
1124        }
1125    }
1126
1127    #[test]
1128    fn test_tweedie_pipeline() {
1129        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1130        let y = array![2.0, 5.0, 10.0, 20.0];
1131        let fitted = TweedieRegressor::<f64>::new()
1132            .with_alpha(0.0)
1133            .fit_pipeline(&x, &y)
1134            .unwrap();
1135        let preds = fitted.predict_pipeline(&x).unwrap();
1136        assert_eq!(preds.len(), 4);
1137    }
1138
1139    // ---- Variance function ----
1140
1141    #[test]
1142    fn test_variance_poisson() {
1143        let v = GLMFamily::Poisson.variance(3.0_f64);
1144        assert_relative_eq!(v, 3.0);
1145    }
1146
1147    #[test]
1148    fn test_variance_gamma() {
1149        let v = GLMFamily::Gamma.variance(3.0_f64);
1150        assert_relative_eq!(v, 9.0);
1151    }
1152
1153    #[test]
1154    fn test_variance_tweedie() {
1155        let v = GLMFamily::Tweedie(1.5).variance(4.0_f64);
1156        assert_relative_eq!(v, 4.0_f64.powf(1.5), epsilon = 1e-10);
1157    }
1158
1159    #[test]
1160    fn test_glm_negative_y() {
1161        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1162        let y = array![1.0, -2.0, 3.0];
1163        assert!(GLMRegressor::<f64>::new(GLMFamily::Poisson)
1164            .fit(&x, &y)
1165            .is_err());
1166    }
1167}