use crate::neural::traits::{BackwardOutput, Layer};
pub struct Flatten {
pub(crate) dim: usize,
pub(crate) cache_batch: usize,
}
impl Flatten {
pub fn new() -> Self {
Self {
dim: 0,
cache_batch: 0,
}
}
}
impl Default for Flatten {
fn default() -> Self {
Self::new()
}
}
impl Layer for Flatten {
fn forward(&mut self, input: &[f64], batch: usize, training: bool) -> Vec<f64> {
self.dim = input.len() / batch;
if training {
self.cache_batch = batch;
}
input.to_vec()
}
fn backward(&self, grad_output: &[f64]) -> BackwardOutput {
(grad_output.to_vec(), vec![])
}
fn n_param_groups(&self) -> usize {
0
}
fn params_mut(&mut self) -> Vec<(&mut Vec<f64>, &mut Vec<f64>)> {
vec![]
}
fn save_params(&self) -> Vec<(Vec<f64>, Vec<f64>)> {
vec![]
}
fn restore_params(&mut self, _saved: &[(Vec<f64>, Vec<f64>)]) {}
fn in_size(&self) -> usize {
self.dim
}
fn out_size(&self) -> usize {
self.dim
}
fn name(&self) -> &'static str {
"Flatten"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flatten_passthrough() {
let mut flat = Flatten::new();
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let output = flat.forward(&input, 2, true);
assert_eq!(output, input);
assert_eq!(flat.dim, 3); }
#[test]
fn flatten_backward_passthrough() {
let mut flat = Flatten::new();
let input = vec![1.0, 2.0, 3.0, 4.0];
flat.forward(&input, 1, true);
let grad = vec![0.1, 0.2, 0.3, 0.4];
let (grad_input, params) = flat.backward(&grad);
assert_eq!(grad_input, grad);
assert!(params.is_empty());
}
}