use std::any::Any;
use std::cell::RefCell;
use crate::object::Tensor;
use crate::{Error, Result};
use super::layer::Layer;
use super::parameter::Parameter;
pub struct ForwardCache {
pub inputs: Vec<Tensor<f32>>,
pub layer_caches: Vec<Vec<Box<dyn Any + Send>>>,
}
pub trait Model: Send {
fn forward(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>>;
fn backward(
&self,
grad_outputs: &[Tensor<f32>],
) -> Result<(Vec<Tensor<f32>>, Vec<Tensor<f32>>)>;
fn parameters(&self) -> Vec<&Parameter>;
fn parameters_mut(&mut self) -> Vec<&mut Parameter>;
}
pub struct Sequential {
pub layers: Vec<Box<dyn Layer>>,
last_cache: RefCell<Option<ForwardCache>>,
}
impl Sequential {
pub fn new(layers: Vec<Box<dyn Layer>>) -> Self {
Self {
layers,
last_cache: RefCell::new(None),
}
}
pub fn layer_count(&self) -> usize {
self.layers.len()
}
pub fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
let mut out: Vec<&mut Parameter> = Vec::new();
for layer in &mut self.layers {
out.extend(layer.parameters_mut());
}
out
}
}
impl Model for Sequential {
fn forward(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>> {
let mut current: Vec<Tensor<f32>> = inputs.to_vec();
let mut layer_caches: Vec<Vec<Box<dyn Any + Send>>> = (0..inputs.len())
.map(|_| Vec::with_capacity(self.layers.len()))
.collect();
for layer in &self.layers {
let mut next = Vec::with_capacity(current.len());
for (i, inp) in current.iter().enumerate() {
let (out, cache) = Layer::forward(layer.as_ref(), inp).map_err(|e| {
Error::backend(format!(
"Sequential::forward layer={} idx={}: {e}",
layer.name(),
i
))
})?;
next.push(out);
layer_caches[i].push(cache);
}
current = next;
}
*self.last_cache.borrow_mut() = Some(ForwardCache {
inputs: inputs.to_vec(),
layer_caches,
});
Ok(current)
}
fn backward(
&self,
grad_outputs: &[Tensor<f32>],
) -> Result<(Vec<Tensor<f32>>, Vec<Tensor<f32>>)> {
let cache = self
.last_cache
.borrow_mut()
.take()
.ok_or_else(|| Error::backend("Sequential::backward called before forward"))?;
if cache.layer_caches.len() != grad_outputs.len() {
return Err(Error::backend(format!(
"Sequential::backward: {} output tensors but {} grad_output tensors",
cache.layer_caches.len(),
grad_outputs.len()
)));
}
let n_inputs = grad_outputs.len();
let n_layers = self.layers.len();
let mut per_input_grads: Vec<Tensor<f32>> = grad_outputs.to_vec();
let mut param_grads: Vec<Vec<Tensor<f32>>> = Vec::new();
for layer_idx in (0..n_layers).rev() {
let layer = &self.layers[layer_idx];
let n_params = layer.parameters().len();
let mut summed = vec![Vec::new(); n_params];
let mut next_per_input_grads = Vec::with_capacity(n_inputs);
for i in 0..n_inputs {
let c = cache.layer_caches[i][layer_idx].as_ref();
let (grad_input, pg) = Layer::backward(layer.as_ref(), &per_input_grads[i], c)
.map_err(|e| {
Error::backend(format!(
"Sequential::backward layer={} idx={}: {e}",
layer.name(),
i
))
})?;
for (j, g) in pg.into_iter().enumerate() {
if summed[j].is_empty() {
summed[j] = g.data.clone();
} else {
for (k, v) in summed[j].iter_mut().zip(g.data.iter()) {
*k += *v;
}
}
}
next_per_input_grads.push(grad_input);
}
per_input_grads = next_per_input_grads;
param_grads.push(
summed
.into_iter()
.enumerate()
.map(|(j, data)| {
let param = layer.parameters()[j];
Tensor::dense_cpu(
param.data.meta.domain.clone(),
param.data.meta.shape.clone(),
data,
)
})
.collect(),
);
}
param_grads.reverse();
let parameters: Vec<&Parameter> = self.parameters();
let expected = parameters.len();
let flat: Vec<Tensor<f32>> = param_grads.into_iter().flatten().collect();
if flat.len() != expected {
return Err(Error::backend(format!(
"Sequential::backward produced {} param gradients but model has {} parameters",
flat.len(),
expected
)));
}
Ok((per_input_grads, flat))
}
fn parameters(&self) -> Vec<&Parameter> {
let mut out: Vec<&Parameter> = Vec::new();
for layer in &self.layers {
out.extend(layer.parameters());
}
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
let mut out: Vec<&mut Parameter> = Vec::new();
for layer in &mut self.layers {
out.extend(layer.parameters_mut());
}
out
}
}
impl Model for Box<dyn Layer> {
fn forward(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>> {
let mut outs = Vec::with_capacity(inputs.len());
for inp in inputs {
let (out, _cache) = Layer::forward(self.as_ref(), inp)?;
outs.push(out);
}
Ok(outs)
}
fn backward(
&self,
_grad_outputs: &[Tensor<f32>],
) -> Result<(Vec<Tensor<f32>>, Vec<Tensor<f32>>)> {
Err(Error::backend(
"Box<dyn Layer> does not support Model::backward; use Sequential",
))
}
fn parameters(&self) -> Vec<&Parameter> {
Layer::parameters(self.as_ref())
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
Layer::parameters_mut(self.as_mut())
}
}