1use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasCoefficients};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2, Axis, ScalarOperand};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36
37use crate::optim::lbfgs::LbfgsOptimizer;
38
39#[derive(Debug, Clone)]
48pub struct LogisticRegression<F> {
49 pub c: F,
52 pub max_iter: usize,
54 pub tol: F,
56 pub fit_intercept: bool,
58}
59
60impl<F: Float> LogisticRegression<F> {
61 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 c: F::one(),
69 max_iter: 1000,
70 tol: F::from(1e-4).unwrap(),
71 fit_intercept: true,
72 }
73 }
74
75 #[must_use]
77 pub fn with_c(mut self, c: F) -> Self {
78 self.c = c;
79 self
80 }
81
82 #[must_use]
84 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
85 self.max_iter = max_iter;
86 self
87 }
88
89 #[must_use]
91 pub fn with_tol(mut self, tol: F) -> Self {
92 self.tol = tol;
93 self
94 }
95
96 #[must_use]
98 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
99 self.fit_intercept = fit_intercept;
100 self
101 }
102}
103
104impl<F: Float> Default for LogisticRegression<F> {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[derive(Debug, Clone)]
116pub struct FittedLogisticRegression<F> {
117 coefficients: Array1<F>,
121 intercept: F,
123 weight_matrix: Array2<F>,
126 intercept_vec: Array1<F>,
128 classes: Vec<usize>,
130 is_binary: bool,
132}
133
134fn sigmoid<F: Float>(z: F) -> F {
136 if z >= F::zero() {
137 F::one() / (F::one() + (-z).exp())
138 } else {
139 let ez = z.exp();
140 ez / (F::one() + ez)
141 }
142}
143
144impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
145 for LogisticRegression<F>
146{
147 type Fitted = FittedLogisticRegression<F>;
148 type Error = FerroError;
149
150 fn fit(
160 &self,
161 x: &Array2<F>,
162 y: &Array1<usize>,
163 ) -> Result<FittedLogisticRegression<F>, FerroError> {
164 let (n_samples, n_features) = x.dim();
165
166 if n_samples != y.len() {
167 return Err(FerroError::ShapeMismatch {
168 expected: vec![n_samples],
169 actual: vec![y.len()],
170 context: "y length must match number of samples in X".into(),
171 });
172 }
173
174 if self.c <= F::zero() {
175 return Err(FerroError::InvalidParameter {
176 name: "C".into(),
177 reason: "must be positive".into(),
178 });
179 }
180
181 if n_samples == 0 {
182 return Err(FerroError::InsufficientSamples {
183 required: 1,
184 actual: 0,
185 context: "LogisticRegression requires at least one sample".into(),
186 });
187 }
188
189 let mut classes: Vec<usize> = y.to_vec();
191 classes.sort_unstable();
192 classes.dedup();
193
194 if classes.len() < 2 {
195 return Err(FerroError::InsufficientSamples {
196 required: 2,
197 actual: classes.len(),
198 context: "LogisticRegression requires at least 2 distinct classes".into(),
199 });
200 }
201
202 let n_classes = classes.len();
203
204 if n_classes == 2 {
205 self.fit_binary(x, y, n_samples, n_features, &classes)
206 } else {
207 self.fit_multinomial(x, y, n_samples, n_features, &classes)
208 }
209 }
210}
211
212impl<F: Float + Send + Sync + ScalarOperand + 'static> LogisticRegression<F> {
213 fn fit_binary(
215 &self,
216 x: &Array2<F>,
217 y: &Array1<usize>,
218 n_samples: usize,
219 n_features: usize,
220 classes: &[usize],
221 ) -> Result<FittedLogisticRegression<F>, FerroError> {
222 let n_f = F::from(n_samples).unwrap();
223 let reg = F::one() / self.c;
224
225 let y_binary: Array1<F> = y.mapv(|label| {
227 if label == classes[1] {
228 F::one()
229 } else {
230 F::zero()
231 }
232 });
233
234 let n_params = if self.fit_intercept {
236 n_features + 1
237 } else {
238 n_features
239 };
240
241 let objective = |params: &Array1<F>| -> (F, Array1<F>) {
242 let w = params.slice(ndarray::s![..n_features]);
243 let b = if self.fit_intercept {
244 params[n_features]
245 } else {
246 F::zero()
247 };
248
249 let logits = x.dot(&w.to_owned()) + b;
251
252 let mut loss = F::zero();
254 let mut grad_w = Array1::<F>::zeros(n_features);
255 let mut grad_b = F::zero();
256
257 for i in 0..n_samples {
258 let p = sigmoid(logits[i]);
259 let yi = y_binary[i];
260
261 let eps = F::from(1e-15).unwrap();
263 let p_clipped = p.max(eps).min(F::one() - eps);
264 loss = loss - (yi * p_clipped.ln() + (F::one() - yi) * (F::one() - p_clipped).ln());
265
266 let diff = p - yi;
268 let xi = x.row(i);
269 for j in 0..n_features {
270 grad_w[j] = grad_w[j] + diff * xi[j];
271 }
272 if self.fit_intercept {
273 grad_b = grad_b + diff;
274 }
275 }
276
277 loss = loss / n_f;
279 grad_w.mapv_inplace(|v| v / n_f);
280 grad_b = grad_b / n_f;
281
282 let reg_loss: F = w.iter().fold(F::zero(), |acc, &wi| acc + wi * wi);
284 loss = loss + reg / (F::from(2.0).unwrap()) * reg_loss;
285
286 for j in 0..n_features {
287 grad_w[j] = grad_w[j] + reg * w[j];
288 }
289
290 let mut grad = Array1::<F>::zeros(n_params);
291 for j in 0..n_features {
292 grad[j] = grad_w[j];
293 }
294 if self.fit_intercept {
295 grad[n_features] = grad_b;
296 }
297
298 (loss, grad)
299 };
300
301 let optimizer = LbfgsOptimizer::new(self.max_iter, self.tol);
302 let x0 = Array1::<F>::zeros(n_params);
303 let params = optimizer.minimize(objective, x0)?;
304
305 let coefficients = params.slice(ndarray::s![..n_features]).to_owned();
306 let intercept = if self.fit_intercept {
307 params[n_features]
308 } else {
309 F::zero()
310 };
311
312 let weight_matrix = coefficients
313 .clone()
314 .into_shape_with_order((1, n_features))
315 .map_err(|_| FerroError::NumericalInstability {
316 message: "failed to reshape coefficients".into(),
317 })?;
318
319 Ok(FittedLogisticRegression {
320 coefficients,
321 intercept,
322 weight_matrix,
323 intercept_vec: Array1::from_vec(vec![intercept]),
324 classes: classes.to_vec(),
325 is_binary: true,
326 })
327 }
328
329 fn fit_multinomial(
331 &self,
332 x: &Array2<F>,
333 y: &Array1<usize>,
334 n_samples: usize,
335 n_features: usize,
336 classes: &[usize],
337 ) -> Result<FittedLogisticRegression<F>, FerroError> {
338 let n_classes = classes.len();
339 let n_f = F::from(n_samples).unwrap();
340 let reg = F::one() / self.c;
341
342 let class_indices: Vec<usize> = y
344 .iter()
345 .map(|&label| classes.iter().position(|&c| c == label).unwrap())
346 .collect();
347
348 let mut y_onehot = Array2::<F>::zeros((n_samples, n_classes));
350 for (i, &ci) in class_indices.iter().enumerate() {
351 y_onehot[[i, ci]] = F::one();
352 }
353
354 let n_weight_params = n_classes * n_features;
356 let n_params = if self.fit_intercept {
357 n_weight_params + n_classes
358 } else {
359 n_weight_params
360 };
361
362 let fit_intercept = self.fit_intercept;
363
364 let objective = move |params: &Array1<F>| -> (F, Array1<F>) {
365 let mut w_mat = Array2::<F>::zeros((n_classes, n_features));
367 for c in 0..n_classes {
368 for j in 0..n_features {
369 w_mat[[c, j]] = params[c * n_features + j];
370 }
371 }
372
373 let b_vec: Array1<F> = if fit_intercept {
374 Array1::from_shape_fn(n_classes, |c| params[n_weight_params + c])
375 } else {
376 Array1::zeros(n_classes)
377 };
378
379 let logits = x.dot(&w_mat.t()) + &b_vec;
381
382 let probs = softmax_2d(&logits);
384
385 let mut loss = F::zero();
387 let eps = F::from(1e-15).unwrap();
388 for i in 0..n_samples {
389 for c in 0..n_classes {
390 let p = probs[[i, c]].max(eps);
391 loss = loss - y_onehot[[i, c]] * p.ln();
392 }
393 }
394 loss = loss / n_f;
395
396 let reg_loss: F = w_mat.iter().fold(F::zero(), |acc, &wi| acc + wi * wi);
398 loss = loss + reg / F::from(2.0).unwrap() * reg_loss;
399
400 let diff = &probs - &y_onehot;
403
404 let grad_w = diff.t().dot(x) / n_f;
406
407 let mut grad = Array1::<F>::zeros(n_params);
408 for c in 0..n_classes {
409 for j in 0..n_features {
410 grad[c * n_features + j] = grad_w[[c, j]] + reg * w_mat[[c, j]];
411 }
412 }
413
414 if fit_intercept {
415 let grad_b = diff.sum_axis(Axis(0)) / n_f;
417 for c in 0..n_classes {
418 grad[n_weight_params + c] = grad_b[c];
419 }
420 }
421
422 (loss, grad)
423 };
424
425 let optimizer = LbfgsOptimizer::new(self.max_iter, self.tol);
426 let x0 = Array1::<F>::zeros(n_params);
427 let params = optimizer.minimize(objective, x0)?;
428
429 let mut weight_matrix = Array2::<F>::zeros((n_classes, n_features));
431 for c in 0..n_classes {
432 for j in 0..n_features {
433 weight_matrix[[c, j]] = params[c * n_features + j];
434 }
435 }
436
437 let intercept_vec = if self.fit_intercept {
438 Array1::from_shape_fn(n_classes, |c| params[n_weight_params + c])
439 } else {
440 Array1::zeros(n_classes)
441 };
442
443 let coefficients = weight_matrix.row(0).to_owned();
445 let intercept = intercept_vec[0];
446
447 Ok(FittedLogisticRegression {
448 coefficients,
449 intercept,
450 weight_matrix,
451 intercept_vec,
452 classes: classes.to_vec(),
453 is_binary: false,
454 })
455 }
456}
457
458fn softmax_2d<F: Float>(logits: &Array2<F>) -> Array2<F> {
460 let n_rows = logits.nrows();
461 let n_cols = logits.ncols();
462 let mut probs = Array2::<F>::zeros((n_rows, n_cols));
463
464 for i in 0..n_rows {
465 let max_logit = logits
467 .row(i)
468 .iter()
469 .fold(F::neg_infinity(), |a, &b| a.max(b));
470
471 let mut sum = F::zero();
472 for j in 0..n_cols {
473 let exp_val = (logits[[i, j]] - max_logit).exp();
474 probs[[i, j]] = exp_val;
475 sum = sum + exp_val;
476 }
477
478 if sum > F::zero() {
479 for j in 0..n_cols {
480 probs[[i, j]] = probs[[i, j]] / sum;
481 }
482 }
483 }
484
485 probs
486}
487
488impl<F: Float + Send + Sync + ScalarOperand + 'static> FittedLogisticRegression<F> {
489 #[must_use]
494 pub fn weight_matrix(&self) -> &Array2<F> {
495 &self.weight_matrix
496 }
497
498 #[must_use]
500 pub fn intercept_vec(&self) -> &Array1<F> {
501 &self.intercept_vec
502 }
503
504 #[must_use]
506 pub fn is_binary(&self) -> bool {
507 self.is_binary
508 }
509
510 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
520 let n_features = x.ncols();
521 let expected_features = self.weight_matrix.ncols();
522
523 if n_features != expected_features {
524 return Err(FerroError::ShapeMismatch {
525 expected: vec![expected_features],
526 actual: vec![n_features],
527 context: "number of features must match fitted model".into(),
528 });
529 }
530
531 if self.is_binary {
532 let logits = x.dot(&self.coefficients) + self.intercept;
533 let n_samples = x.nrows();
534 let mut probs = Array2::<F>::zeros((n_samples, 2));
535 for i in 0..n_samples {
536 let p1 = sigmoid(logits[i]);
537 probs[[i, 0]] = F::one() - p1;
538 probs[[i, 1]] = p1;
539 }
540 Ok(probs)
541 } else {
542 let logits = x.dot(&self.weight_matrix.t()) + &self.intercept_vec;
543 Ok(softmax_2d(&logits))
544 }
545 }
546}
547
548impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
549 for FittedLogisticRegression<F>
550{
551 type Output = Array1<usize>;
552 type Error = FerroError;
553
554 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
563 let proba = self.predict_proba(x)?;
564 let n_samples = proba.nrows();
565 let n_classes = proba.ncols();
566
567 let mut predictions = Array1::<usize>::zeros(n_samples);
568 for i in 0..n_samples {
569 let mut best_class = 0;
570 let mut best_prob = proba[[i, 0]];
571 for c in 1..n_classes {
572 if proba[[i, c]] > best_prob {
573 best_prob = proba[[i, c]];
574 best_class = c;
575 }
576 }
577 predictions[i] = self.classes[best_class];
578 }
579
580 Ok(predictions)
581 }
582}
583
584impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
585 for FittedLogisticRegression<F>
586{
587 fn coefficients(&self) -> &Array1<F> {
588 &self.coefficients
589 }
590
591 fn intercept(&self) -> F {
592 self.intercept
593 }
594}
595
596impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedLogisticRegression<F> {
597 fn classes(&self) -> &[usize] {
598 &self.classes
599 }
600
601 fn n_classes(&self) -> usize {
602 self.classes.len()
603 }
604}
605
606impl<F> PipelineEstimator<F> for LogisticRegression<F>
608where
609 F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
610{
611 fn fit_pipeline(
612 &self,
613 x: &Array2<F>,
614 y: &Array1<F>,
615 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
616 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
618 let fitted = self.fit(x, &y_usize)?;
619 Ok(Box::new(FittedLogisticRegressionPipeline(fitted)))
620 }
621}
622
623struct FittedLogisticRegressionPipeline<F>(FittedLogisticRegression<F>)
625where
626 F: Float + Send + Sync + 'static;
627
628unsafe impl<F: Float + Send + Sync + 'static> Send for FittedLogisticRegressionPipeline<F> {}
630unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedLogisticRegressionPipeline<F> {}
631
632impl<F> FittedPipelineEstimator<F> for FittedLogisticRegressionPipeline<F>
633where
634 F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
635{
636 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
637 let preds = self.0.predict(x)?;
638 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645 use approx::assert_relative_eq;
646 use ndarray::array;
647
648 #[test]
649 fn test_sigmoid() {
650 assert_relative_eq!(sigmoid(0.0_f64), 0.5, epsilon = 1e-10);
651 assert!(sigmoid(10.0_f64) > 0.99);
652 assert!(sigmoid(-10.0_f64) < 0.01);
653 assert_relative_eq!(sigmoid(1.0_f64) + sigmoid(-1.0_f64), 1.0, epsilon = 1e-10);
655 }
656
657 #[test]
658 fn test_binary_classification() {
659 let x = Array2::from_shape_vec(
661 (8, 2),
662 vec![
663 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0, ],
666 )
667 .unwrap();
668 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
669
670 let model = LogisticRegression::<f64>::new()
671 .with_c(1.0)
672 .with_max_iter(1000);
673 let fitted = model.fit(&x, &y).unwrap();
674
675 let preds = fitted.predict(&x).unwrap();
676
677 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
679 assert!(correct >= 6, "expected at least 6 correct, got {correct}");
680 }
681
682 #[test]
683 fn test_binary_predict_proba() {
684 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
685 let y = array![0, 0, 0, 1, 1, 1];
686
687 let model = LogisticRegression::<f64>::new().with_c(1.0);
688 let fitted = model.fit(&x, &y).unwrap();
689
690 let proba = fitted.predict_proba(&x).unwrap();
691
692 for i in 0..proba.nrows() {
694 assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
695 }
696
697 assert!(proba[[0, 0]] > proba[[0, 1]]);
699 assert!(proba[[5, 1]] > proba[[5, 0]]);
701 }
702
703 #[test]
704 fn test_multiclass_classification() {
705 let x = Array2::from_shape_vec(
707 (9, 2),
708 vec![
709 0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0, 0.0, 5.5, ],
713 )
714 .unwrap();
715 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
716
717 let model = LogisticRegression::<f64>::new()
718 .with_c(10.0)
719 .with_max_iter(2000);
720 let fitted = model.fit(&x, &y).unwrap();
721
722 assert_eq!(fitted.n_classes(), 3);
723 assert_eq!(fitted.classes(), &[0, 1, 2]);
724
725 let preds = fitted.predict(&x).unwrap();
726 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
727 assert!(correct >= 7, "expected at least 7 correct, got {correct}");
728 }
729
730 #[test]
731 fn test_multiclass_predict_proba() {
732 let x = Array2::from_shape_vec(
733 (9, 2),
734 vec![
735 0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0,
736 0.0, 5.5,
737 ],
738 )
739 .unwrap();
740 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
741
742 let model = LogisticRegression::<f64>::new()
743 .with_c(10.0)
744 .with_max_iter(2000);
745 let fitted = model.fit(&x, &y).unwrap();
746 let proba = fitted.predict_proba(&x).unwrap();
747
748 for i in 0..proba.nrows() {
750 assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
751 }
752 }
753
754 #[test]
755 fn test_shape_mismatch_fit() {
756 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
757 let y = array![0, 1]; let model = LogisticRegression::<f64>::new();
760 assert!(model.fit(&x, &y).is_err());
761 }
762
763 #[test]
764 fn test_invalid_c() {
765 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
766 let y = array![0, 0, 1, 1];
767
768 let model = LogisticRegression::<f64>::new().with_c(0.0);
769 assert!(model.fit(&x, &y).is_err());
770
771 let model_neg = LogisticRegression::<f64>::new().with_c(-1.0);
772 assert!(model_neg.fit(&x, &y).is_err());
773 }
774
775 #[test]
776 fn test_single_class_error() {
777 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
778 let y = array![0, 0, 0]; let model = LogisticRegression::<f64>::new();
781 assert!(model.fit(&x, &y).is_err());
782 }
783
784 #[test]
785 fn test_has_coefficients() {
786 let x = Array2::from_shape_vec(
787 (6, 2),
788 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0],
789 )
790 .unwrap();
791 let y = array![0, 0, 0, 1, 1, 1];
792
793 let model = LogisticRegression::<f64>::new();
794 let fitted = model.fit(&x, &y).unwrap();
795
796 assert_eq!(fitted.coefficients().len(), 2);
797 }
798
799 #[test]
800 fn test_has_classes() {
801 let x = Array2::from_shape_vec((6, 1), vec![-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]).unwrap();
802 let y = array![0, 0, 0, 1, 1, 1];
803
804 let model = LogisticRegression::<f64>::new();
805 let fitted = model.fit(&x, &y).unwrap();
806
807 assert_eq!(fitted.classes(), &[0, 1]);
808 assert_eq!(fitted.n_classes(), 2);
809 }
810
811 #[test]
812 fn test_pipeline_integration() {
813 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
814 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
815
816 let model = LogisticRegression::<f64>::new();
817 let fitted = model.fit_pipeline(&x, &y).unwrap();
818 let preds = fitted.predict_pipeline(&x).unwrap();
819 assert_eq!(preds.len(), 6);
820 }
821
822 #[test]
823 fn test_no_intercept() {
824 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
825 let y = array![0, 0, 0, 1, 1, 1];
826
827 let model = LogisticRegression::<f64>::new().with_fit_intercept(false);
828 let fitted = model.fit(&x, &y).unwrap();
829 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
830 }
831
832 #[test]
833 fn test_softmax_2d() {
834 let logits = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 1.0, 1.0, 1.0]).unwrap();
835 let probs = softmax_2d(&logits);
836
837 assert_relative_eq!(probs.row(0).sum(), 1.0, epsilon = 1e-10);
839 assert_relative_eq!(probs.row(1).sum(), 1.0, epsilon = 1e-10);
840
841 assert_relative_eq!(probs[[1, 0]], 1.0 / 3.0, epsilon = 1e-10);
843 assert_relative_eq!(probs[[1, 1]], 1.0 / 3.0, epsilon = 1e-10);
844 assert_relative_eq!(probs[[1, 2]], 1.0 / 3.0, epsilon = 1e-10);
845 }
846}