use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::activation_layer::format_output_shape;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
pub struct Linear {
input_cache: Option<Tensor>,
}
impl Linear {
pub fn new() -> Self {
Linear { input_cache: None }
}
}
impl Layer for Linear {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.is_empty() {
return Err(ModelError::InputValidationError(
"Input tensor is empty".to_string(),
));
}
if input.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err(ModelError::InputValidationError(
"Input tensor contains NaN or infinite values".to_string(),
));
}
self.input_cache = Some(input.clone());
Ok(input.clone())
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
if let Some(input) = &self.input_cache {
if grad_output.shape() != input.shape() {
return Err(ModelError::ProcessingError(format!(
"Gradient output shape {:?} doesn't match input shape {:?}",
grad_output.shape(),
input.shape()
)));
}
if grad_output.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err(ModelError::InputValidationError(
"Gradient output contains NaN or infinite values".to_string(),
));
}
Ok(grad_output.clone())
} else {
Err(ModelError::ProcessingError(
"Forward pass has not been run yet".to_string(),
))
}
}
fn layer_type(&self) -> &str {
"Linear"
}
fn output_shape(&self) -> String {
format_output_shape(&self.input_cache)
}
no_trainable_parameters_layer_functions!();
}
impl ActivationLayer for Linear {}