use crate::error::ModelError;
pub(super) fn validate_filters(filters: usize) -> Result<(), ModelError> {
if filters == 0 {
return Err(ModelError::InputValidationError(
"Number of filters must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_kernel_size_1d(kernel_size: usize) -> Result<(), ModelError> {
if kernel_size == 0 {
return Err(ModelError::InputValidationError(
"Kernel size must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_kernel_size_2d(kernel_size: (usize, usize)) -> Result<(), ModelError> {
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(ModelError::InputValidationError(
"Kernel dimensions must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_kernel_size_3d(
kernel_size: (usize, usize, usize),
) -> Result<(), ModelError> {
if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
return Err(ModelError::InputValidationError(
"Kernel dimensions must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_strides_1d(stride: usize) -> Result<(), ModelError> {
if stride == 0 {
return Err(ModelError::InputValidationError(
"Stride must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_strides_2d(strides: (usize, usize)) -> Result<(), ModelError> {
if strides.0 == 0 || strides.1 == 0 {
return Err(ModelError::InputValidationError(
"Strides must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_strides_3d(strides: (usize, usize, usize)) -> Result<(), ModelError> {
if strides.0 == 0 || strides.1 == 0 || strides.2 == 0 {
return Err(ModelError::InputValidationError(
"Strides must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_input_shape_1d(
input_shape: &[usize],
kernel_size: usize,
) -> Result<(), ModelError> {
if input_shape.len() != 3 {
return Err(ModelError::InputValidationError(
"Input shape must be 3D: [batch_size, channels, length]".to_string(),
));
}
if input_shape[1] == 0 {
return Err(ModelError::InputValidationError(
"Number of input channels must be greater than 0".to_string(),
));
}
if input_shape[2] < kernel_size {
return Err(ModelError::InputValidationError(
"Input length must be at least as large as the kernel size".to_string(),
));
}
Ok(())
}
pub(super) fn validate_input_shape_2d(
input_shape: &[usize],
kernel_size: (usize, usize),
) -> Result<(), ModelError> {
if input_shape.len() != 4 {
return Err(ModelError::InputValidationError(
"Input shape must be 4D: [batch_size, channels, height, width]".to_string(),
));
}
if input_shape[1] == 0 {
return Err(ModelError::InputValidationError(
"Number of input channels must be greater than 0".to_string(),
));
}
if input_shape[2] < kernel_size.0 || input_shape[3] < kernel_size.1 {
return Err(ModelError::InputValidationError(
"Input dimensions must be at least as large as the kernel size".to_string(),
));
}
Ok(())
}
pub(super) fn validate_input_shape_3d(input_shape: &[usize]) -> Result<(), ModelError> {
if input_shape.len() != 5 {
return Err(ModelError::InputValidationError(
"Input shape must be 5-dimensional: [batch, channels, depth, height, width]"
.to_string(),
));
}
if input_shape.iter().any(|&dim| dim == 0) {
return Err(ModelError::InputValidationError(
"All input dimensions must be greater than 0".to_string(),
));
}
Ok(())
}
pub(super) fn validate_depth_multiplier(depth_multiplier: usize) -> Result<(), ModelError> {
if depth_multiplier == 0 {
return Err(ModelError::InputValidationError(
"Depth multiplier must be greater than 0".to_string(),
));
}
Ok(())
}