use crate::error::Result;
use crate::primitives::{Matrix, Vector};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ClassWeight {
Uniform,
Balanced,
Manual(Vec<f32>),
}
impl Default for ClassWeight {
fn default() -> Self {
Self::Uniform
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogisticRegression {
coefficients: Option<Vector<f32>>,
intercept: f32,
learning_rate: f32,
max_iter: usize,
tol: f32,
class_weight: ClassWeight,
l2_penalty: f32,
}
impl LogisticRegression {
#[must_use]
pub fn new() -> Self {
Self {
coefficients: None,
intercept: 0.0,
learning_rate: 0.01,
max_iter: 1000,
tol: 1e-4,
class_weight: ClassWeight::Uniform,
l2_penalty: 0.0,
}
}
#[must_use]
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tolerance(mut self, tol: f32) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_class_weight(mut self, class_weight: ClassWeight) -> Self {
self.class_weight = class_weight;
self
}
#[must_use]
pub fn with_l2_penalty(mut self, l2_penalty: f32) -> Self {
self.l2_penalty = l2_penalty;
self
}
fn sigmoid(z: f32) -> f32 {
crate::nn::functional::sigmoid_scalar(z)
}
#[must_use]
pub fn predict_proba(&self, x: &Matrix<f32>) -> Vector<f32> {
let coef = self.coefficients.as_ref().expect("Model not fitted yet");
let (n_samples, _) = x.shape();
let mut probas = Vec::with_capacity(n_samples);
for row in 0..n_samples {
let mut z = self.intercept;
for col in 0..coef.len() {
z += coef[col] * x.get(row, col);
}
probas.push(Self::sigmoid(z));
}
Vector::from_vec(probas)
}
pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
let (n_samples, n_features) = x.shape();
if n_samples != y.len() {
return Err("Number of samples in X and y must match".into());
}
if n_samples == 0 {
return Err("Cannot fit with zero samples".into());
}
for &label in y {
if label != 0 && label != 1 {
return Err("Labels must be 0 or 1 for binary classification".into());
}
}
self.coefficients = Some(Vector::from_vec(vec![0.0; n_features]));
self.intercept = 0.0;
let sample_weights = self.compute_sample_weights(y);
for _ in 0..self.max_iter {
let probas = self.predict_proba(x);
let mut coef_grad = vec![0.0; n_features];
let mut intercept_grad = 0.0;
for i in 0..n_samples {
let error = sample_weights[i] * (probas[i] - y[i] as f32);
intercept_grad += error;
for (j, grad) in coef_grad.iter_mut().enumerate() {
*grad += error * x.get(i, j);
}
}
let n = n_samples as f32;
intercept_grad /= n;
for grad in &mut coef_grad {
*grad /= n;
}
self.intercept -= self.learning_rate * intercept_grad;
if let Some(ref mut coef) = self.coefficients {
for j in 0..n_features {
coef[j] -= self.learning_rate * (coef_grad[j] + self.l2_penalty * coef[j]);
}
}
if intercept_grad.abs() < self.tol && coef_grad.iter().all(|&g| g.abs() < self.tol) {
break;
}
}
Ok(())
}
#[must_use]
pub fn predict(&self, x: &Matrix<f32>) -> Vec<usize> {
let probas = self.predict_proba(x);
probas
.as_slice()
.iter()
.map(|&p| usize::from(p >= 0.5))
.collect()
}
#[must_use]
pub fn score(&self, x: &Matrix<f32>, y: &[usize]) -> f32 {
let predictions = self.predict(x);
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(pred, true_label)| pred == true_label)
.count();
correct as f32 / y.len() as f32
}
#[must_use]
pub fn coefficients(&self) -> &Vector<f32> {
self.coefficients.as_ref().expect("Model not fitted")
}
#[must_use]
pub fn intercept(&self) -> f32 {
self.intercept
}
fn compute_sample_weights(&self, y: &[usize]) -> Vec<f32> {
match &self.class_weight {
ClassWeight::Uniform => vec![1.0; y.len()],
ClassWeight::Balanced => {
let n = y.len() as f32;
let n_class_0 = y.iter().filter(|&&l| l == 0).count() as f32;
let n_class_1 = n - n_class_0;
if n_class_0 == 0.0 || n_class_1 == 0.0 {
return vec![1.0; y.len()];
}
let w0 = (n / (2.0 * n_class_0)).sqrt();
let w1 = (n / (2.0 * n_class_1)).sqrt();
y.iter().map(|&l| if l == 0 { w0 } else { w1 }).collect()
}
ClassWeight::Manual(weights) => {
if weights.len() < 2 {
return vec![1.0; y.len()];
}
y.iter()
.map(|&l| if l < weights.len() { weights[l] } else { 1.0 })
.collect()
}
}
}
pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
use crate::serialization::safetensors;
use std::collections::BTreeMap;
let coefficients = self
.coefficients
.as_ref()
.ok_or("Cannot save unfitted model. Call fit() first.")?;
let mut tensors = BTreeMap::new();
let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
let coef_shape = vec![coefficients.len()];
tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
let intercept_data = vec![self.intercept];
let intercept_shape = vec![1];
tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
safetensors::save_safetensors(path, &tensors)?;
Ok(())
}
pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
use crate::serialization::safetensors;
let (metadata, raw_data) = safetensors::load_safetensors(path)?;
let coef_meta = metadata
.get("coefficients")
.ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
let intercept_meta = metadata
.get("intercept")
.ok_or("Missing 'intercept' tensor in SafeTensors file")?;
let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
if intercept_data.len() != 1 {
return Err(format!(
"Invalid intercept tensor: expected 1 value, got {}",
intercept_data.len()
));
}
Ok(Self {
coefficients: Some(Vector::from_vec(coef_data)),
intercept: intercept_data[0],
learning_rate: 0.01,
max_iter: 1000,
tol: 1e-4,
class_weight: ClassWeight::Uniform,
l2_penalty: 0.0,
})
}
}
impl Default for LogisticRegression {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DistanceMetric {
Euclidean,
Manhattan,
Minkowski(f32),
}
#[derive(Debug, Clone)]
pub struct KNearestNeighbors {
k: usize,
metric: DistanceMetric,
weights: bool,
x_train: Option<Matrix<f32>>,
y_train: Option<Vec<usize>>,
}
mod gaussian_nb;
pub use gaussian_nb::*;
mod linear_svm;
pub use linear_svm::*;
mod sets;
#[cfg(test)]
#[path = "tests_logreg_contract.rs"]
mod tests_logreg_contract;