use crate::error::ModelError;
pub(super) fn validate_rate(rate: f32, param_name: &str) -> Result<(), ModelError> {
if !(0.0..=1.0).contains(&rate) {
return Err(ModelError::InputValidationError(format!(
"{} must be between 0 and 1, got {}",
param_name, rate
)));
}
Ok(())
}
pub(super) fn validate_rate_exclusive(rate: f32, param_name: &str) -> Result<(), ModelError> {
if rate < 0.0 || rate >= 1.0 {
return Err(ModelError::InputValidationError(format!(
"{} must be in range [0, 1), got {}",
param_name, rate
)));
}
Ok(())
}
pub(super) fn validate_input_shape(
input_shape: &[usize],
expected_shape: &[usize],
) -> Result<(), ModelError> {
if !expected_shape.is_empty() && input_shape != expected_shape {
return Err(ModelError::InputValidationError(format!(
"Input shape mismatch: expected {:?}, got {:?}",
expected_shape, input_shape
)));
}
Ok(())
}
pub(super) fn validate_input_ndim(
input_ndim: usize,
expected_ndim: usize,
layer_name: &str,
) -> Result<(), ModelError> {
if input_ndim != expected_ndim {
return Err(ModelError::InputValidationError(format!(
"{} expects {}D input, got {}D",
layer_name, expected_ndim, input_ndim
)));
}
Ok(())
}
pub(super) fn validate_min_input_ndim(
input_ndim: usize,
min_ndim: usize,
layer_name: &str,
) -> Result<(), ModelError> {
if input_ndim < min_ndim {
return Err(ModelError::InputValidationError(format!(
"{} expects at least {}D input, got {}D",
layer_name, min_ndim, input_ndim
)));
}
Ok(())
}
pub(super) fn validate_stddev(stddev: f32) -> Result<(), ModelError> {
if stddev < 0.0 {
return Err(ModelError::InputValidationError(
"Standard deviation cannot be negative".to_string(),
));
}
Ok(())
}
pub(super) fn validate_epsilon(epsilon: f32) -> Result<(), ModelError> {
if epsilon <= 0.0 {
return Err(ModelError::InputValidationError(format!(
"Epsilon must be positive, got {}",
epsilon
)));
}
if !epsilon.is_finite() {
return Err(ModelError::InputValidationError(format!(
"Epsilon must be finite, got {}",
epsilon
)));
}
Ok(())
}
pub(super) fn validate_momentum(momentum: f32) -> Result<(), ModelError> {
if !(0.0..=1.0).contains(&momentum) {
return Err(ModelError::InputValidationError(format!(
"Momentum must be between 0.0 and 1.0, got {}",
momentum
)));
}
Ok(())
}
pub(super) fn validate_channel_axis(
channel_axis: usize,
input_ndim: usize,
) -> Result<(), ModelError> {
if channel_axis == 0 {
return Err(ModelError::InputValidationError(
"Channel axis cannot be 0 (batch axis)".to_string(),
));
}
if channel_axis >= input_ndim {
return Err(ModelError::InputValidationError(format!(
"Channel axis {} is out of bounds for input with {} dimensions",
channel_axis, input_ndim
)));
}
Ok(())
}
pub(super) fn validate_num_groups(
num_channels: usize,
num_groups: usize,
) -> Result<(), ModelError> {
if num_channels % num_groups != 0 {
return Err(ModelError::InputValidationError(format!(
"Number of channels ({}) must be divisible by num_groups ({})",
num_channels, num_groups
)));
}
Ok(())
}
pub(super) fn validate_input_shape_not_empty(input_shape: &[usize]) -> Result<(), ModelError> {
if input_shape.is_empty() {
return Err(ModelError::InputValidationError(
"Input shape cannot be empty".to_string(),
));
}
Ok(())
}
pub(super) fn validate_num_groups_positive(num_groups: usize) -> Result<(), ModelError> {
if num_groups == 0 {
return Err(ModelError::InputValidationError(
"Number of groups must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_channel_axis_with_shape(
channel_axis: usize,
input_shape: &[usize],
) -> Result<(), ModelError> {
if channel_axis >= input_shape.len() {
return Err(ModelError::InputValidationError(format!(
"Channel axis {} is out of bounds for input shape with {} dimensions",
channel_axis,
input_shape.len()
)));
}
if channel_axis == 0 {
return Err(ModelError::InputValidationError(
"Channel axis cannot be 0 (batch axis)".to_string(),
));
}
Ok(())
}