use crate::{
shapes::*,
tensor::*,
tensor_ops::{Device, TryAdd},
};
use super::*;
#[derive(Debug, Clone, Default)]
pub struct Residual<F>(pub F);
impl<D: Device<E>, E: Dtype, F: BuildOnDevice<D, E>> BuildOnDevice<D, E> for Residual<F> {
type Built = Residual<F::Built>;
}
impl<E: Dtype, D: Device<E>, F: TensorCollection<E, D>> TensorCollection<E, D> for Residual<F> {
type To<E2: Dtype, D2: Device<E2>> = Residual<F::To<E2, D2>>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(Self::module("0", |s| &s.0, |s| &mut s.0), Residual)
}
}
impl<T: WithEmptyTape + TryAdd<T>, F: Module<T, Output = T, Error = T::Err>> Module<T>
for Residual<F>
{
type Output = T;
type Error = F::Error;
fn try_forward(&self, x: T) -> Result<Self::Output, F::Error> {
self.0.try_forward(x.with_empty_tape())?.try_add(x)
}
}
impl<T: WithEmptyTape + TryAdd<T>, F: ModuleMut<T, Output = T, Error = T::Err>> ModuleMut<T>
for Residual<F>
{
type Output = T;
type Error = F::Error;
fn try_forward_mut(&mut self, x: T) -> Result<Self::Output, F::Error> {
self.0.try_forward_mut(x.with_empty_tape())?.try_add(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use crate::{nn::builders::Linear, tensor_ops::*};
#[test]
fn test_residual_reset() {
let dev: TestDevice = Default::default();
let model = dev.build_module::<Residual<Linear<2, 5>>, TestDtype>();
assert_ne!(model.0.weight.array(), [[TestDtype::default(); 2]; 5]);
assert_ne!(model.0.bias.array(), [TestDtype::default(); 5]);
}
#[test]
fn test_residual_gradients() {
let dev: TestDevice = Default::default();
let model = dev
.build_module::<Residual<Linear<2, 2>>, f32>()
.to_dtype::<TestDtype>();
let x: Tensor<Rank2<4, 2>, f32, _> = dev.sample_normal();
let x = x.to_dtype::<TestDtype>();
let y = model.forward(x.leaky_trace());
#[rustfmt::skip]
assert_close_to_literal!(y, [[0.25372928, -2.4258814],[1.7892148, -2.6242268],[1.5131638, 0.23407778],[3.4201493, 1.597525]]);
let g = y.mean().backward();
assert_close_to_literal!(g.get(&model.0.weight), [[0.475242, -0.075136]; 2]);
assert_close_to_literal!(g.get(&model.0.bias), [0.5; 2]);
assert_close_to_literal!(g.get(&x), [[0.18806472, 0.21419683]; 4]);
}
}