use super::RegularizationType;
use crate::error::ModelError;
use ndarray::{ArrayBase, Data, Ix1, Ix2};
pub(super) fn preliminary_check<S>(
x: &ArrayBase<S, Ix2>,
y: Option<&ArrayBase<S, Ix1>>,
) -> Result<(), ModelError>
where
S: Data<Elem = f64>,
{
if x.nrows() == 0 {
return Err(ModelError::InputValidationError(
"Input data is empty".to_string(),
));
}
for (i, row) in x.outer_iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
if val.is_nan() || val.is_infinite() {
return Err(ModelError::InputValidationError(format!(
"Input data contains NaN or infinite value at position [{}][{}]",
i, j
)));
}
}
}
if let Some(y) = y {
if y.len() == 0 {
return Err(ModelError::InputValidationError(
"Target vector is empty".to_string(),
));
}
if y.len() != x.nrows() {
return Err(ModelError::InputValidationError(format!(
"Input data and target vector have different lengths, x columns: {}, y length: {}",
x.nrows(),
y.len()
)));
}
}
Ok(())
}
pub(super) fn validate_learning_rate(learning_rate: f64) -> Result<(), ModelError> {
if learning_rate <= 0.0 || !learning_rate.is_finite() {
return Err(ModelError::InputValidationError(format!(
"learning_rate must be positive and finite, got {}",
learning_rate
)));
}
Ok(())
}
pub(super) fn validate_max_iterations(max_iterations: usize) -> Result<(), ModelError> {
if max_iterations == 0 {
return Err(ModelError::InputValidationError(
"max_iterations must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_tolerance(tolerance: f64) -> Result<(), ModelError> {
if tolerance <= 0.0 || !tolerance.is_finite() {
return Err(ModelError::InputValidationError(format!(
"tolerance must be positive and finite, got {}",
tolerance
)));
}
Ok(())
}
pub(super) fn validate_regulation_type(
reg_type: Option<RegularizationType>,
) -> Result<(), ModelError> {
if let Some(reg) = ®_type {
match reg {
RegularizationType::L1(alpha) | RegularizationType::L2(alpha) => {
if *alpha < 0.0 || !alpha.is_finite() {
return Err(ModelError::InputValidationError(format!(
"Regularization alpha must be non-negative and finite, got {}",
alpha
)));
}
if *alpha == 0.0 {
eprintln!("Warning: regularization alpha is 0, consider using None instead");
}
}
}
}
Ok(())
}