Skip to main content

ferrolearn_linear/
logistic_regression.rs

1//! Logistic regression classifier.
2//!
3//! This module provides [`LogisticRegression`], a linear classifier that uses
4//! the logistic (sigmoid) function for binary classification and softmax for
5//! multiclass classification. Parameters are estimated using a custom L-BFGS
6//! optimizer with Wolfe line search.
7//!
8//! The regularization parameter `C` is the inverse of regularization strength
9//! (matching scikit-learn's convention): smaller values specify stronger
10//! regularization.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_linear::LogisticRegression;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array1, Array2};
18//!
19//! let model = LogisticRegression::<f64>::new();
20//! let x = Array2::from_shape_vec(
21//!     (6, 2),
22//!     vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
23//! ).unwrap();
24//! let y = array![0, 0, 0, 1, 1, 1];
25//!
26//! let fitted = model.fit(&x, &y).unwrap();
27//! let preds = fitted.predict(&x).unwrap();
28//! ```
29
30use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasCoefficients};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2, Axis, ScalarOperand};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36
37use crate::optim::lbfgs::LbfgsOptimizer;
38
39/// Logistic regression classifier.
40///
41/// Uses L-BFGS optimization to minimize the regularized logistic loss.
42/// Supports both binary and multiclass (multinomial) classification.
43///
44/// # Type Parameters
45///
46/// - `F`: The floating-point type (`f32` or `f64`).
47#[derive(Debug, Clone)]
48pub struct LogisticRegression<F> {
49    /// Inverse regularization strength. Smaller values specify stronger
50    /// regularization (matching scikit-learn's convention).
51    pub c: F,
52    /// Maximum number of L-BFGS iterations.
53    pub max_iter: usize,
54    /// Convergence tolerance for the optimizer.
55    pub tol: F,
56    /// Whether to fit an intercept (bias) term.
57    pub fit_intercept: bool,
58}
59
60impl<F: Float> LogisticRegression<F> {
61    /// Create a new `LogisticRegression` with default settings.
62    ///
63    /// Defaults: `C = 1.0`, `max_iter = 1000`, `tol = 1e-4`,
64    /// `fit_intercept = true`.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            c: F::one(),
69            max_iter: 1000,
70            tol: F::from(1e-4).unwrap(),
71            fit_intercept: true,
72        }
73    }
74
75    /// Set the inverse regularization strength.
76    #[must_use]
77    pub fn with_c(mut self, c: F) -> Self {
78        self.c = c;
79        self
80    }
81
82    /// Set the maximum number of iterations.
83    #[must_use]
84    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
85        self.max_iter = max_iter;
86        self
87    }
88
89    /// Set the convergence tolerance.
90    #[must_use]
91    pub fn with_tol(mut self, tol: F) -> Self {
92        self.tol = tol;
93        self
94    }
95
96    /// Set whether to fit an intercept term.
97    #[must_use]
98    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
99        self.fit_intercept = fit_intercept;
100        self
101    }
102}
103
104impl<F: Float> Default for LogisticRegression<F> {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110/// Fitted logistic regression classifier.
111///
112/// Stores the learned coefficients, intercept, and class labels.
113/// For binary classification, stores a single coefficient vector.
114/// For multiclass, stores one coefficient vector per class.
115#[derive(Debug, Clone)]
116pub struct FittedLogisticRegression<F> {
117    /// Learned coefficient vectors.
118    /// For binary: shape `(n_features,)` (single vector).
119    /// For multiclass: shape `(n_classes, n_features)`.
120    coefficients: Array1<F>,
121    /// Learned intercept for the primary class (binary).
122    intercept: F,
123    /// All coefficient vectors for multiclass, shape `(n_classes, n_features)`.
124    /// For binary, this has shape `(1, n_features)`.
125    weight_matrix: Array2<F>,
126    /// Intercept vector, one per class.
127    intercept_vec: Array1<F>,
128    /// Sorted unique class labels.
129    classes: Vec<usize>,
130    /// Whether this is a binary problem.
131    is_binary: bool,
132}
133
134/// Sigmoid function: 1 / (1 + exp(-z)).
135fn sigmoid<F: Float>(z: F) -> F {
136    if z >= F::zero() {
137        F::one() / (F::one() + (-z).exp())
138    } else {
139        let ez = z.exp();
140        ez / (F::one() + ez)
141    }
142}
143
144impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
145    for LogisticRegression<F>
146{
147    type Fitted = FittedLogisticRegression<F>;
148    type Error = FerroError;
149
150    /// Fit the logistic regression model using L-BFGS optimization.
151    ///
152    /// # Errors
153    ///
154    /// Returns [`FerroError::ShapeMismatch`] if the number of samples in
155    /// `x` and `y` differ.
156    /// Returns [`FerroError::InvalidParameter`] if `C` is not positive.
157    /// Returns [`FerroError::InsufficientSamples`] if there are fewer
158    /// than 2 distinct classes.
159    fn fit(
160        &self,
161        x: &Array2<F>,
162        y: &Array1<usize>,
163    ) -> Result<FittedLogisticRegression<F>, FerroError> {
164        let (n_samples, n_features) = x.dim();
165
166        if n_samples != y.len() {
167            return Err(FerroError::ShapeMismatch {
168                expected: vec![n_samples],
169                actual: vec![y.len()],
170                context: "y length must match number of samples in X".into(),
171            });
172        }
173
174        if self.c <= F::zero() {
175            return Err(FerroError::InvalidParameter {
176                name: "C".into(),
177                reason: "must be positive".into(),
178            });
179        }
180
181        if n_samples == 0 {
182            return Err(FerroError::InsufficientSamples {
183                required: 1,
184                actual: 0,
185                context: "LogisticRegression requires at least one sample".into(),
186            });
187        }
188
189        // Determine unique classes.
190        let mut classes: Vec<usize> = y.to_vec();
191        classes.sort_unstable();
192        classes.dedup();
193
194        if classes.len() < 2 {
195            return Err(FerroError::InsufficientSamples {
196                required: 2,
197                actual: classes.len(),
198                context: "LogisticRegression requires at least 2 distinct classes".into(),
199            });
200        }
201
202        let n_classes = classes.len();
203
204        if n_classes == 2 {
205            self.fit_binary(x, y, n_samples, n_features, &classes)
206        } else {
207            self.fit_multinomial(x, y, n_samples, n_features, &classes)
208        }
209    }
210}
211
212impl<F: Float + Send + Sync + ScalarOperand + 'static> LogisticRegression<F> {
213    /// Fit binary logistic regression.
214    fn fit_binary(
215        &self,
216        x: &Array2<F>,
217        y: &Array1<usize>,
218        n_samples: usize,
219        n_features: usize,
220        classes: &[usize],
221    ) -> Result<FittedLogisticRegression<F>, FerroError> {
222        let n_f = F::from(n_samples).unwrap();
223        let reg = F::one() / self.c;
224
225        // Convert labels to 0/1 float.
226        let y_binary: Array1<F> = y.mapv(|label| {
227            if label == classes[1] {
228                F::one()
229            } else {
230                F::zero()
231            }
232        });
233
234        // Parameter vector: [w_0, w_1, ..., w_{n_features-1}, (intercept)]
235        let n_params = if self.fit_intercept {
236            n_features + 1
237        } else {
238            n_features
239        };
240
241        let objective = |params: &Array1<F>| -> (F, Array1<F>) {
242            let w = params.slice(ndarray::s![..n_features]);
243            let b = if self.fit_intercept {
244                params[n_features]
245            } else {
246                F::zero()
247            };
248
249            // Compute logits: X @ w + b
250            let logits = x.dot(&w.to_owned()) + b;
251
252            // Compute loss and gradient.
253            let mut loss = F::zero();
254            let mut grad_w = Array1::<F>::zeros(n_features);
255            let mut grad_b = F::zero();
256
257            for i in 0..n_samples {
258                let p = sigmoid(logits[i]);
259                let yi = y_binary[i];
260
261                // Binary cross-entropy loss (negative log-likelihood).
262                let eps = F::from(1e-15).unwrap();
263                let p_clipped = p.max(eps).min(F::one() - eps);
264                loss = loss - (yi * p_clipped.ln() + (F::one() - yi) * (F::one() - p_clipped).ln());
265
266                // Gradient.
267                let diff = p - yi;
268                let xi = x.row(i);
269                for j in 0..n_features {
270                    grad_w[j] = grad_w[j] + diff * xi[j];
271                }
272                if self.fit_intercept {
273                    grad_b = grad_b + diff;
274                }
275            }
276
277            // Average loss and add regularization.
278            loss = loss / n_f;
279            grad_w.mapv_inplace(|v| v / n_f);
280            grad_b = grad_b / n_f;
281
282            // L2 regularization (on weights only, not intercept).
283            let reg_loss: F = w.iter().fold(F::zero(), |acc, &wi| acc + wi * wi);
284            loss = loss + reg / (F::from(2.0).unwrap()) * reg_loss;
285
286            for j in 0..n_features {
287                grad_w[j] = grad_w[j] + reg * w[j];
288            }
289
290            let mut grad = Array1::<F>::zeros(n_params);
291            for j in 0..n_features {
292                grad[j] = grad_w[j];
293            }
294            if self.fit_intercept {
295                grad[n_features] = grad_b;
296            }
297
298            (loss, grad)
299        };
300
301        let optimizer = LbfgsOptimizer::new(self.max_iter, self.tol);
302        let x0 = Array1::<F>::zeros(n_params);
303        let params = optimizer.minimize(objective, x0)?;
304
305        let coefficients = params.slice(ndarray::s![..n_features]).to_owned();
306        let intercept = if self.fit_intercept {
307            params[n_features]
308        } else {
309            F::zero()
310        };
311
312        let weight_matrix = coefficients
313            .clone()
314            .into_shape_with_order((1, n_features))
315            .map_err(|_| FerroError::NumericalInstability {
316                message: "failed to reshape coefficients".into(),
317            })?;
318
319        Ok(FittedLogisticRegression {
320            coefficients,
321            intercept,
322            weight_matrix,
323            intercept_vec: Array1::from_vec(vec![intercept]),
324            classes: classes.to_vec(),
325            is_binary: true,
326        })
327    }
328
329    /// Fit multinomial logistic regression.
330    fn fit_multinomial(
331        &self,
332        x: &Array2<F>,
333        y: &Array1<usize>,
334        n_samples: usize,
335        n_features: usize,
336        classes: &[usize],
337    ) -> Result<FittedLogisticRegression<F>, FerroError> {
338        let n_classes = classes.len();
339        let n_f = F::from(n_samples).unwrap();
340        let reg = F::one() / self.c;
341
342        // Create class index map.
343        let class_indices: Vec<usize> = y
344            .iter()
345            .map(|&label| classes.iter().position(|&c| c == label).unwrap())
346            .collect();
347
348        // One-hot encode targets.
349        let mut y_onehot = Array2::<F>::zeros((n_samples, n_classes));
350        for (i, &ci) in class_indices.iter().enumerate() {
351            y_onehot[[i, ci]] = F::one();
352        }
353
354        // Parameter vector: flattened [W (n_classes x n_features), b (n_classes)]
355        let n_weight_params = n_classes * n_features;
356        let n_params = if self.fit_intercept {
357            n_weight_params + n_classes
358        } else {
359            n_weight_params
360        };
361
362        let fit_intercept = self.fit_intercept;
363
364        let objective = move |params: &Array1<F>| -> (F, Array1<F>) {
365            // Extract weight matrix W (n_classes x n_features).
366            let mut w_mat = Array2::<F>::zeros((n_classes, n_features));
367            for c in 0..n_classes {
368                for j in 0..n_features {
369                    w_mat[[c, j]] = params[c * n_features + j];
370                }
371            }
372
373            let b_vec: Array1<F> = if fit_intercept {
374                Array1::from_shape_fn(n_classes, |c| params[n_weight_params + c])
375            } else {
376                Array1::zeros(n_classes)
377            };
378
379            // Compute logits: X @ W^T + b^T, shape (n_samples, n_classes).
380            let logits = x.dot(&w_mat.t()) + &b_vec;
381
382            // Softmax probabilities.
383            let probs = softmax_2d(&logits);
384
385            // Multinomial cross-entropy loss.
386            let mut loss = F::zero();
387            let eps = F::from(1e-15).unwrap();
388            for i in 0..n_samples {
389                for c in 0..n_classes {
390                    let p = probs[[i, c]].max(eps);
391                    loss = loss - y_onehot[[i, c]] * p.ln();
392                }
393            }
394            loss = loss / n_f;
395
396            // L2 regularization.
397            let reg_loss: F = w_mat.iter().fold(F::zero(), |acc, &wi| acc + wi * wi);
398            loss = loss + reg / F::from(2.0).unwrap() * reg_loss;
399
400            // Gradient.
401            // diff = probs - y_onehot, shape (n_samples, n_classes)
402            let diff = &probs - &y_onehot;
403
404            // grad_W = diff^T @ X / n, shape (n_classes, n_features)
405            let grad_w = diff.t().dot(x) / n_f;
406
407            let mut grad = Array1::<F>::zeros(n_params);
408            for c in 0..n_classes {
409                for j in 0..n_features {
410                    grad[c * n_features + j] = grad_w[[c, j]] + reg * w_mat[[c, j]];
411                }
412            }
413
414            if fit_intercept {
415                // grad_b = sum(diff, axis=0) / n
416                let grad_b = diff.sum_axis(Axis(0)) / n_f;
417                for c in 0..n_classes {
418                    grad[n_weight_params + c] = grad_b[c];
419                }
420            }
421
422            (loss, grad)
423        };
424
425        let optimizer = LbfgsOptimizer::new(self.max_iter, self.tol);
426        let x0 = Array1::<F>::zeros(n_params);
427        let params = optimizer.minimize(objective, x0)?;
428
429        // Extract results.
430        let mut weight_matrix = Array2::<F>::zeros((n_classes, n_features));
431        for c in 0..n_classes {
432            for j in 0..n_features {
433                weight_matrix[[c, j]] = params[c * n_features + j];
434            }
435        }
436
437        let intercept_vec = if self.fit_intercept {
438            Array1::from_shape_fn(n_classes, |c| params[n_weight_params + c])
439        } else {
440            Array1::zeros(n_classes)
441        };
442
443        // For HasCoefficients, store the first class coefficients.
444        let coefficients = weight_matrix.row(0).to_owned();
445        let intercept = intercept_vec[0];
446
447        Ok(FittedLogisticRegression {
448            coefficients,
449            intercept,
450            weight_matrix,
451            intercept_vec,
452            classes: classes.to_vec(),
453            is_binary: false,
454        })
455    }
456}
457
458/// Compute softmax probabilities row-wise for a 2D array.
459fn softmax_2d<F: Float>(logits: &Array2<F>) -> Array2<F> {
460    let n_rows = logits.nrows();
461    let n_cols = logits.ncols();
462    let mut probs = Array2::<F>::zeros((n_rows, n_cols));
463
464    for i in 0..n_rows {
465        // Numerical stability: subtract max.
466        let max_logit = logits
467            .row(i)
468            .iter()
469            .fold(F::neg_infinity(), |a, &b| a.max(b));
470
471        let mut sum = F::zero();
472        for j in 0..n_cols {
473            let exp_val = (logits[[i, j]] - max_logit).exp();
474            probs[[i, j]] = exp_val;
475            sum = sum + exp_val;
476        }
477
478        if sum > F::zero() {
479            for j in 0..n_cols {
480                probs[[i, j]] = probs[[i, j]] / sum;
481            }
482        }
483    }
484
485    probs
486}
487
488impl<F: Float + Send + Sync + ScalarOperand + 'static> FittedLogisticRegression<F> {
489    /// Returns a reference to the full weight matrix.
490    ///
491    /// For binary classification, shape is `(1, n_features)`.
492    /// For multiclass, shape is `(n_classes, n_features)`.
493    #[must_use]
494    pub fn weight_matrix(&self) -> &Array2<F> {
495        &self.weight_matrix
496    }
497
498    /// Returns a reference to the intercept vector (one per class).
499    #[must_use]
500    pub fn intercept_vec(&self) -> &Array1<F> {
501        &self.intercept_vec
502    }
503
504    /// Returns whether this is a binary classification model.
505    #[must_use]
506    pub fn is_binary(&self) -> bool {
507        self.is_binary
508    }
509
510    /// Predict class probabilities for the given feature matrix.
511    ///
512    /// For binary classification, returns an array of shape `(n_samples, 2)`.
513    /// For multiclass, returns shape `(n_samples, n_classes)`.
514    ///
515    /// # Errors
516    ///
517    /// Returns [`FerroError::ShapeMismatch`] if the number of features
518    /// does not match the fitted model.
519    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
520        let n_features = x.ncols();
521        let expected_features = self.weight_matrix.ncols();
522
523        if n_features != expected_features {
524            return Err(FerroError::ShapeMismatch {
525                expected: vec![expected_features],
526                actual: vec![n_features],
527                context: "number of features must match fitted model".into(),
528            });
529        }
530
531        if self.is_binary {
532            let logits = x.dot(&self.coefficients) + self.intercept;
533            let n_samples = x.nrows();
534            let mut probs = Array2::<F>::zeros((n_samples, 2));
535            for i in 0..n_samples {
536                let p1 = sigmoid(logits[i]);
537                probs[[i, 0]] = F::one() - p1;
538                probs[[i, 1]] = p1;
539            }
540            Ok(probs)
541        } else {
542            let logits = x.dot(&self.weight_matrix.t()) + &self.intercept_vec;
543            Ok(softmax_2d(&logits))
544        }
545    }
546}
547
548impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
549    for FittedLogisticRegression<F>
550{
551    type Output = Array1<usize>;
552    type Error = FerroError;
553
554    /// Predict class labels for the given feature matrix.
555    ///
556    /// Returns the class with the highest predicted probability.
557    ///
558    /// # Errors
559    ///
560    /// Returns [`FerroError::ShapeMismatch`] if the number of features
561    /// does not match the fitted model.
562    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
563        let proba = self.predict_proba(x)?;
564        let n_samples = proba.nrows();
565        let n_classes = proba.ncols();
566
567        let mut predictions = Array1::<usize>::zeros(n_samples);
568        for i in 0..n_samples {
569            let mut best_class = 0;
570            let mut best_prob = proba[[i, 0]];
571            for c in 1..n_classes {
572                if proba[[i, c]] > best_prob {
573                    best_prob = proba[[i, c]];
574                    best_class = c;
575                }
576            }
577            predictions[i] = self.classes[best_class];
578        }
579
580        Ok(predictions)
581    }
582}
583
584impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
585    for FittedLogisticRegression<F>
586{
587    fn coefficients(&self) -> &Array1<F> {
588        &self.coefficients
589    }
590
591    fn intercept(&self) -> F {
592        self.intercept
593    }
594}
595
596impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedLogisticRegression<F> {
597    fn classes(&self) -> &[usize] {
598        &self.classes
599    }
600
601    fn n_classes(&self) -> usize {
602        self.classes.len()
603    }
604}
605
606// Pipeline integration.
607impl<F> PipelineEstimator<F> for LogisticRegression<F>
608where
609    F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
610{
611    fn fit_pipeline(
612        &self,
613        x: &Array2<F>,
614        y: &Array1<F>,
615    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
616        // Convert f64 labels to usize.
617        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
618        let fitted = self.fit(x, &y_usize)?;
619        Ok(Box::new(FittedLogisticRegressionPipeline(fitted)))
620    }
621}
622
623/// Wrapper for pipeline integration that converts predictions to float.
624struct FittedLogisticRegressionPipeline<F>(FittedLogisticRegression<F>)
625where
626    F: Float + Send + Sync + 'static;
627
628// Safety: the inner type is Send + Sync.
629unsafe impl<F: Float + Send + Sync + 'static> Send for FittedLogisticRegressionPipeline<F> {}
630unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedLogisticRegressionPipeline<F> {}
631
632impl<F> FittedPipelineEstimator<F> for FittedLogisticRegressionPipeline<F>
633where
634    F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
635{
636    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
637        let preds = self.0.predict(x)?;
638        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645    use approx::assert_relative_eq;
646    use ndarray::array;
647
648    #[test]
649    fn test_sigmoid() {
650        assert_relative_eq!(sigmoid(0.0_f64), 0.5, epsilon = 1e-10);
651        assert!(sigmoid(10.0_f64) > 0.99);
652        assert!(sigmoid(-10.0_f64) < 0.01);
653        // Check symmetry.
654        assert_relative_eq!(sigmoid(1.0_f64) + sigmoid(-1.0_f64), 1.0, epsilon = 1e-10);
655    }
656
657    #[test]
658    fn test_binary_classification() {
659        // Linearly separable binary data.
660        let x = Array2::from_shape_vec(
661            (8, 2),
662            vec![
663                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, // class 0
664                5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0, // class 1
665            ],
666        )
667        .unwrap();
668        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
669
670        let model = LogisticRegression::<f64>::new()
671            .with_c(1.0)
672            .with_max_iter(1000);
673        let fitted = model.fit(&x, &y).unwrap();
674
675        let preds = fitted.predict(&x).unwrap();
676
677        // At minimum, most samples should be correctly classified.
678        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
679        assert!(correct >= 6, "expected at least 6 correct, got {correct}");
680    }
681
682    #[test]
683    fn test_binary_predict_proba() {
684        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
685        let y = array![0, 0, 0, 1, 1, 1];
686
687        let model = LogisticRegression::<f64>::new().with_c(1.0);
688        let fitted = model.fit(&x, &y).unwrap();
689
690        let proba = fitted.predict_proba(&x).unwrap();
691
692        // Probabilities should sum to 1.
693        for i in 0..proba.nrows() {
694            assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
695        }
696
697        // Class 0 should have higher probability for negative x.
698        assert!(proba[[0, 0]] > proba[[0, 1]]);
699        // Class 1 should have higher probability for positive x.
700        assert!(proba[[5, 1]] > proba[[5, 0]]);
701    }
702
703    #[test]
704    fn test_multiclass_classification() {
705        // Three linearly separable clusters.
706        let x = Array2::from_shape_vec(
707            (9, 2),
708            vec![
709                0.0, 0.0, 0.5, 0.0, 0.0, 0.5, // class 0
710                5.0, 0.0, 5.5, 0.0, 5.0, 0.5, // class 1
711                0.0, 5.0, 0.5, 5.0, 0.0, 5.5, // class 2
712            ],
713        )
714        .unwrap();
715        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
716
717        let model = LogisticRegression::<f64>::new()
718            .with_c(10.0)
719            .with_max_iter(2000);
720        let fitted = model.fit(&x, &y).unwrap();
721
722        assert_eq!(fitted.n_classes(), 3);
723        assert_eq!(fitted.classes(), &[0, 1, 2]);
724
725        let preds = fitted.predict(&x).unwrap();
726        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
727        assert!(correct >= 7, "expected at least 7 correct, got {correct}");
728    }
729
730    #[test]
731    fn test_multiclass_predict_proba() {
732        let x = Array2::from_shape_vec(
733            (9, 2),
734            vec![
735                0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0,
736                0.0, 5.5,
737            ],
738        )
739        .unwrap();
740        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
741
742        let model = LogisticRegression::<f64>::new()
743            .with_c(10.0)
744            .with_max_iter(2000);
745        let fitted = model.fit(&x, &y).unwrap();
746        let proba = fitted.predict_proba(&x).unwrap();
747
748        // Probabilities should sum to 1 for each sample.
749        for i in 0..proba.nrows() {
750            assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
751        }
752    }
753
754    #[test]
755    fn test_shape_mismatch_fit() {
756        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
757        let y = array![0, 1]; // Wrong length
758
759        let model = LogisticRegression::<f64>::new();
760        assert!(model.fit(&x, &y).is_err());
761    }
762
763    #[test]
764    fn test_invalid_c() {
765        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
766        let y = array![0, 0, 1, 1];
767
768        let model = LogisticRegression::<f64>::new().with_c(0.0);
769        assert!(model.fit(&x, &y).is_err());
770
771        let model_neg = LogisticRegression::<f64>::new().with_c(-1.0);
772        assert!(model_neg.fit(&x, &y).is_err());
773    }
774
775    #[test]
776    fn test_single_class_error() {
777        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
778        let y = array![0, 0, 0]; // Only one class
779
780        let model = LogisticRegression::<f64>::new();
781        assert!(model.fit(&x, &y).is_err());
782    }
783
784    #[test]
785    fn test_has_coefficients() {
786        let x = Array2::from_shape_vec(
787            (6, 2),
788            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0],
789        )
790        .unwrap();
791        let y = array![0, 0, 0, 1, 1, 1];
792
793        let model = LogisticRegression::<f64>::new();
794        let fitted = model.fit(&x, &y).unwrap();
795
796        assert_eq!(fitted.coefficients().len(), 2);
797    }
798
799    #[test]
800    fn test_has_classes() {
801        let x = Array2::from_shape_vec((6, 1), vec![-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]).unwrap();
802        let y = array![0, 0, 0, 1, 1, 1];
803
804        let model = LogisticRegression::<f64>::new();
805        let fitted = model.fit(&x, &y).unwrap();
806
807        assert_eq!(fitted.classes(), &[0, 1]);
808        assert_eq!(fitted.n_classes(), 2);
809    }
810
811    #[test]
812    fn test_pipeline_integration() {
813        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
814        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
815
816        let model = LogisticRegression::<f64>::new();
817        let fitted = model.fit_pipeline(&x, &y).unwrap();
818        let preds = fitted.predict_pipeline(&x).unwrap();
819        assert_eq!(preds.len(), 6);
820    }
821
822    #[test]
823    fn test_no_intercept() {
824        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
825        let y = array![0, 0, 0, 1, 1, 1];
826
827        let model = LogisticRegression::<f64>::new().with_fit_intercept(false);
828        let fitted = model.fit(&x, &y).unwrap();
829        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
830    }
831
832    #[test]
833    fn test_softmax_2d() {
834        let logits = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 1.0, 1.0, 1.0]).unwrap();
835        let probs = softmax_2d(&logits);
836
837        // Each row should sum to 1.
838        assert_relative_eq!(probs.row(0).sum(), 1.0, epsilon = 1e-10);
839        assert_relative_eq!(probs.row(1).sum(), 1.0, epsilon = 1e-10);
840
841        // Uniform logits should give uniform probs.
842        assert_relative_eq!(probs[[1, 0]], 1.0 / 3.0, epsilon = 1e-10);
843        assert_relative_eq!(probs[[1, 1]], 1.0 / 3.0, epsilon = 1e-10);
844        assert_relative_eq!(probs[[1, 2]], 1.0 / 3.0, epsilon = 1e-10);
845    }
846}