use crate::builders::builder::Builder;
use crate::core::activations::activation_functions::ActivationFn;
use crate::core::error::ModelError;
use crate::core::types::{Matrix, Vector};
use crate::model::logistic_regression::LogisticRegression;
pub struct LogisticRegressionBuilder {
n_features: usize,
activation_fn: ActivationFn,
threshold: f64,
}
impl LogisticRegressionBuilder {
pub fn new() -> Self {
Self {
n_features: 1,
activation_fn: ActivationFn::Sigmoid,
threshold: 0.5,
}
}
pub fn n_features(mut self, n_features: usize) -> Self {
self.n_features = n_features;
self
}
pub fn activation_function(mut self, activation_function: ActivationFn) -> Self {
self.activation_fn = activation_function;
self
}
pub fn threshold(mut self, threshold: f64) -> Self {
if !(0.0..=1.0).contains(&threshold) {
panic!("Threshold must be between 0 and 1");
}
self.threshold = threshold;
self
}
}
impl Builder<LogisticRegression, Matrix, Vector> for LogisticRegressionBuilder {
fn build(&self) -> Result<LogisticRegression, ModelError> {
Ok(LogisticRegression::new(
self.n_features,
self.activation_fn,
self.threshold,
))
}
}
impl Default for LogisticRegressionBuilder {
fn default() -> Self {
Self::new()
}
}