use crate::error::Result;
use crate::types::Float;
#[cfg(test)]
use crate::validation::ValidationContext;
use crate::validation::{ConfigValidation, Validate};
#[derive(Debug, Clone)]
pub struct LinearRegressionConfig {
pub learning_rate: Float,
pub alpha: Float,
pub max_iter: usize,
pub tol: Float,
pub fit_intercept: bool,
pub solver: String,
}
impl Validate for LinearRegressionConfig {
fn validate(&self) -> Result<()> {
crate::validation::ml::validate_learning_rate(self.learning_rate)?;
crate::validation::ml::validate_regularization(self.alpha)?;
crate::validation::ml::validate_max_iter(self.max_iter)?;
crate::validation::ValidationRules::new("tol")
.add_rule(crate::validation::ValidationRule::Positive)
.add_rule(crate::validation::ValidationRule::Finite)
.validate_numeric(&self.tol)?;
crate::validation::ValidationRules::new("solver")
.add_rule(crate::validation::ValidationRule::OneOf(vec![
"auto".to_string(),
"svd".to_string(),
"cholesky".to_string(),
"lsqr".to_string(),
"sparse_cg".to_string(),
"sag".to_string(),
"saga".to_string(),
]))
.validate_string(&self.solver)?;
Ok(())
}
}
impl Default for LinearRegressionConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
alpha: 1.0,
max_iter: 1000,
tol: 1e-4,
fit_intercept: true,
solver: "auto".to_string(),
}
}
}
impl ConfigValidation for LinearRegressionConfig {
fn validate_config(&self) -> Result<()> {
self.validate()?;
if self.solver == "cholesky" && !self.fit_intercept {
return Err(crate::error::SklearsError::InvalidParameter {
name: "solver".to_string(),
reason: "cholesky solver requires fit_intercept=true".to_string(),
});
}
Ok(())
}
fn get_warnings(&self) -> Vec<String> {
let mut warnings = Vec::new();
if self.learning_rate > 0.1 {
warnings
.push("Learning rate is quite high, consider using a smaller value".to_string());
}
if self.max_iter < 100 {
warnings.push("Maximum iterations is quite low, model may not converge".to_string());
}
warnings
}
}
#[derive(Debug, Clone)]
pub struct KMeansConfig {
pub n_clusters: usize,
pub max_iter: usize,
pub tol: Float,
pub init: String,
pub n_init: usize,
pub random_state: Option<u64>,
}
impl Validate for KMeansConfig {
fn validate(&self) -> Result<()> {
crate::validation::ml::validate_n_clusters(self.n_clusters, 100)?;
crate::validation::ml::validate_max_iter(self.max_iter)?;
crate::validation::ValidationRules::new("tol")
.add_rule(crate::validation::ValidationRule::Positive)
.add_rule(crate::validation::ValidationRule::Finite)
.validate_numeric(&self.tol)?;
crate::validation::ValidationRules::new("init")
.add_rule(crate::validation::ValidationRule::OneOf(vec![
"k-means++".to_string(),
"random".to_string(),
"custom".to_string(),
]))
.validate_string(&self.init)?;
if self.n_init == 0 {
return Err(crate::error::SklearsError::InvalidParameter {
name: "n_init".to_string(),
reason: "must be positive".to_string(),
});
}
Ok(())
}
}
impl Default for KMeansConfig {
fn default() -> Self {
Self {
n_clusters: 8,
max_iter: 300,
tol: 1e-4,
init: "k-means++".to_string(),
n_init: 10,
random_state: None,
}
}
}
impl ConfigValidation for KMeansConfig {
fn validate_config(&self) -> Result<()> {
self.validate()?;
if self.n_clusters == 1 {
log::warn!("Using only 1 cluster - consider if clustering is necessary");
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MLPConfig {
pub hidden_layer_sizes: Vec<usize>,
pub learning_rate: Float,
pub max_iter: usize,
pub dropout: Float,
pub batch_size: usize,
pub alpha: Float,
pub activation: String,
pub solver: String,
}
impl Validate for MLPConfig {
fn validate(&self) -> Result<()> {
crate::validation::ValidationRules::new("hidden_layer_sizes")
.add_rule(crate::validation::ValidationRule::MinLength(1))
.validate_array(&self.hidden_layer_sizes)?;
crate::validation::ml::validate_learning_rate(self.learning_rate)?;
crate::validation::ml::validate_max_iter(self.max_iter)?;
crate::validation::ml::validate_probability(self.dropout)?;
if self.batch_size == 0 {
return Err(crate::error::SklearsError::InvalidParameter {
name: "batch_size".to_string(),
reason: "must be positive".to_string(),
});
}
crate::validation::ml::validate_regularization(self.alpha)?;
crate::validation::ValidationRules::new("activation")
.add_rule(crate::validation::ValidationRule::OneOf(vec![
"relu".to_string(),
"tanh".to_string(),
"sigmoid".to_string(),
"identity".to_string(),
]))
.validate_string(&self.activation)?;
crate::validation::ValidationRules::new("solver")
.add_rule(crate::validation::ValidationRule::OneOf(vec![
"adam".to_string(),
"sgd".to_string(),
"lbfgs".to_string(),
]))
.validate_string(&self.solver)?;
Ok(())
}
}
impl Default for MLPConfig {
fn default() -> Self {
Self {
hidden_layer_sizes: vec![100],
learning_rate: 0.001,
max_iter: 200,
dropout: 0.0,
batch_size: 32,
alpha: 1e-4,
activation: "relu".to_string(),
solver: "adam".to_string(),
}
}
}
impl ConfigValidation for MLPConfig {
fn validate_config(&self) -> Result<()> {
self.validate()?;
if self.solver == "lbfgs" && self.hidden_layer_sizes.len() > 3 {
return Err(crate::error::SklearsError::InvalidParameter {
name: "solver".to_string(),
reason: "lbfgs solver may be inefficient for deep networks".to_string(),
});
}
if self.batch_size > 1000 {
log::warn!("Large batch size may lead to poor generalization");
}
Ok(())
}
fn get_warnings(&self) -> Vec<String> {
let mut warnings = Vec::new();
if self.hidden_layer_sizes.iter().any(|&size| size > 1000) {
warnings.push("Very large hidden layers may cause overfitting".to_string());
}
if self.dropout > 0.5 {
warnings.push("High dropout rate may prevent learning".to_string());
}
warnings
}
}
pub struct CustomValidationExample {
pub param1: Float,
pub param2: usize,
pub dependent_param: Float,
}
impl Validate for CustomValidationExample {
fn validate(&self) -> Result<()> {
if self.param1 <= 0.0 {
return Err(crate::error::SklearsError::InvalidParameter {
name: "param1".to_string(),
reason: "must be positive".to_string(),
});
}
if self.param2 > 0 && self.dependent_param > self.param1 * 2.0 {
return Err(crate::error::SklearsError::InvalidParameter {
name: "dependent_param".to_string(),
reason: "cannot be more than twice param1 when param2 > 0".to_string(),
});
}
Ok(())
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_regression_config_validation() {
let mut config = LinearRegressionConfig::default();
assert!(config.validate().is_ok());
config.learning_rate = -0.1;
assert!(config.validate().is_err());
config = LinearRegressionConfig::default();
config.solver = "invalid_solver".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_kmeans_config_validation() {
let mut config = KMeansConfig::default();
assert!(config.validate().is_ok());
config.n_clusters = 0;
assert!(config.validate().is_err());
config = KMeansConfig::default();
config.tol = -1.0;
assert!(config.validate().is_err());
}
#[test]
fn test_mlp_config_validation() {
let mut config = MLPConfig::default();
assert!(config.validate().is_ok());
config.hidden_layer_sizes = vec![];
assert!(config.validate().is_err());
config = MLPConfig::default();
config.dropout = 1.5;
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_trait() {
let config = LinearRegressionConfig::default();
assert!(config.validate_config().is_ok());
let warnings = config.get_warnings();
assert!(warnings.is_empty());
}
#[test]
fn test_validation_context() {
let context =
ValidationContext::new("LinearRegression", "fit").with_data_info(100, 5, "float64");
let error = crate::error::SklearsError::InvalidParameter {
name: "learning_rate".to_string(),
reason: "must be positive".to_string(),
};
let formatted = context.format_error(&error);
assert!(formatted.contains("LinearRegression"));
assert!(formatted.contains("fit"));
assert!(formatted.contains("100 samples"));
assert!(formatted.contains("5 features"));
}
#[test]
fn test_custom_validation() {
let example = CustomValidationExample {
param1: 1.0,
param2: 0,
dependent_param: 1.5,
};
assert!(example.validate().is_ok());
let example2 = CustomValidationExample {
param1: 1.0,
param2: 1,
dependent_param: 3.0, };
assert!(example2.validate().is_err());
}
}