use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::neural_network_trait::Layer;
use ndarray::IxDyn;
pub struct Flatten {
flattened_features: usize,
input_cache: Option<Tensor>,
}
impl Flatten {
pub fn new(input_shape: Vec<usize>) -> Result<Self, ModelError> {
if input_shape.len() < 2 {
return Err(ModelError::InputValidationError(format!(
"Input shape must have at least 2 dimensions [batch_size, features...], got {}D",
input_shape.len()
)));
}
for (i, &dim) in input_shape.iter().enumerate() {
if dim == 0 {
return Err(ModelError::InputValidationError(format!(
"Dimension {} must be greater than 0, got {}",
i, dim
)));
}
}
let flattened_features = input_shape[1..].iter().product();
Ok(Flatten {
flattened_features,
input_cache: None,
})
}
}
impl Layer for Flatten {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
let input_shape = input.shape();
if input_shape.len() < 3 || input_shape.len() > 5 {
return Err(ModelError::InputValidationError(format!(
"Flatten layer expects 3D, 4D, or 5D input, got {}D tensor",
input_shape.len()
)));
}
self.input_cache = Some(input.clone());
let batch_size = input_shape[0];
let flattened_features: usize = input_shape[1..].iter().product();
Ok(input
.to_shape(IxDyn(&[batch_size, flattened_features]))
.unwrap()
.to_owned())
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
if let Some(input) = &self.input_cache {
let input_shape = input.shape().to_vec();
let expected_grad_shape = [input_shape[0], input_shape[1..].iter().product()];
if grad_output.shape() != expected_grad_shape {
return Err(ModelError::ProcessingError(format!(
"Gradient output shape {:?} doesn't match expected shape {:?}",
grad_output.shape(),
expected_grad_shape
)));
}
let reshaped_grad = grad_output
.to_shape(IxDyn(&input_shape))
.map_err(|e| {
ModelError::ProcessingError(format!("Failed to reshape gradient: {}", e))
})?
.to_owned();
Ok(reshaped_grad)
} else {
Err(ModelError::ProcessingError(
"Forward pass has not been run yet".to_string(),
))
}
}
fn layer_type(&self) -> &str {
"Flatten"
}
fn output_shape(&self) -> String {
format!("(batch_size, {})", self.flattened_features)
}
no_trainable_parameters_layer_functions!();
}