use crate::error::ModelError;
use crate::neural_network::Tensor;
pub(super) fn validate_dimension_greater_than_zero(
value: usize,
name: &str,
) -> Result<(), ModelError> {
if value == 0 {
return Err(ModelError::InputValidationError(format!(
"{} must be greater than 0",
name
)));
}
Ok(())
}
pub(super) fn validate_recurrent_dimensions(
input_dim: usize,
units: usize,
) -> Result<(), ModelError> {
validate_dimension_greater_than_zero(input_dim, "input_dim")?;
validate_dimension_greater_than_zero(units, "units")?;
Ok(())
}
pub(super) fn validate_input_3d(input: &Tensor) -> Result<(), ModelError> {
if input.ndim() != 3 {
return Err(ModelError::InputValidationError(
"input tensor is not 3D".to_string(),
));
}
Ok(())
}