use crate::prelude::*;
#[derive(Debug, Clone, Default)]
pub struct GeneralizedResidual<F, R>(F, R);
impl<F: CanUpdateWithGradients, R: CanUpdateWithGradients> CanUpdateWithGradients
for GeneralizedResidual<F, R>
{
fn update<G: GradientProvider>(&mut self, grads: &mut G, unused: &mut UnusedTensors) {
self.0.update(grads, unused);
self.1.update(grads, unused);
}
}
impl<F: ResetParams, R: ResetParams> ResetParams for GeneralizedResidual<F, R> {
fn reset_params<RNG: rand::Rng>(&mut self, rng: &mut RNG) {
self.0.reset_params(rng);
self.1.reset_params(rng);
}
}
impl<F, R, T, O> Module<T> for GeneralizedResidual<F, R>
where
T: Tensor<Dtype = f32>,
O: Tensor<Dtype = T::Dtype, Tape = T::Tape>,
F: Module<T, Output = O>,
R: Module<T, Output = O>,
{
type Output = O;
fn forward(&self, x: T) -> Self::Output {
let (x, tape) = x.split_tape();
let (r_x, tape) = self.1.forward(x.duplicate().put_tape(tape)).split_tape();
let f_x = self.0.forward(x.put_tape(tape));
add(f_x, &r_x)
}
}
impl<F: SaveToNpz, R: SaveToNpz> SaveToNpz for GeneralizedResidual<F, R> {
fn write<W>(
&self,
filename_prefix: &str,
w: &mut zip::ZipWriter<W>,
) -> zip::result::ZipResult<()>
where
W: std::io::Write + std::io::Seek,
{
self.0.write(&format!("{}_main", filename_prefix), w)?;
self.1.write(&format!("{}_residual", filename_prefix), w)?;
Ok(())
}
}
impl<F: LoadFromNpz, R: LoadFromNpz> LoadFromNpz for GeneralizedResidual<F, R> {
fn read<READ>(
&mut self,
filename_prefix: &str,
r: &mut zip::ZipArchive<READ>,
) -> Result<(), NpzError>
where
READ: std::io::Read + std::io::Seek,
{
self.0.read(&format!("{}_main", filename_prefix), r)?;
self.1.read(&format!("{}_residual", filename_prefix), r)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::assert_close;
use rand::{prelude::StdRng, SeedableRng};
use std::fs::File;
use tempfile::NamedTempFile;
use zip::ZipArchive;
#[test]
fn test_reset() {
let mut rng = StdRng::seed_from_u64(0);
let mut model: GeneralizedResidual<Linear<2, 5>, Linear<2, 5>> = Default::default();
assert_eq!(model.0.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.0.bias.data(), &[0.0; 5]);
assert_eq!(model.1.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.1.bias.data(), &[0.0; 5]);
model.reset_params(&mut rng);
assert_ne!(model.0.weight.data(), &[[0.0; 2]; 5]);
assert_ne!(model.0.bias.data(), &[0.0; 5]);
assert_ne!(model.1.weight.data(), &[[0.0; 2]; 5]);
assert_ne!(model.1.bias.data(), &[0.0; 5]);
}
const W0: [[f32; 2]; 5] = [
[0.63315326, 0.3361526],
[0.60201937, 0.30927354],
[0.39831632, 0.29526848],
[-0.4730785, -0.10664469],
[0.5074884, -0.08458644],
];
const B0: [f32; 5] = [-0.7014593, 0.01725882, 0.67181975, -0.61593556, 0.27809095];
const W2: [[f32; 5]; 2] = [
[0.37967658, -0.30938417, -0.4046409, 0.34131002, -0.36532],
[0.01010674, 0.2922417, -0.28791183, 0.09316397, 0.00722069],
];
const B2: [f32; 2] = [-0.01353309, 0.19437504];
const X: [[f32; 2]; 10] = [
[0.9706649, -0.50246257],
[0.36609784, 0.22519696],
[-0.26957038, -2.4395447],
[0.729607, 0.06136635],
[1.0758572, -0.6158074],
[1.844528, -0.7769507],
[-0.83232504, 0.26263165],
[-0.18690403, 0.5396985],
[-1.0891576, 0.09805013],
[-0.63034505, 2.4173584],
];
const Y: [[f32; 2]; 10] = [
[0.15374291, -0.43383744],
[-0.26277426, 0.25803787],
[-0.41010314, -2.2426596],
[-0.062764645, 0.117026225],
[0.2237711, -0.54089284],
[0.69048953, -0.6508272],
[-1.0149324, 0.33670622],
[-0.57907265, 0.53813595],
[-1.2107061, 0.21556953],
[-1.2221863, 2.3977249],
];
const W0G: [[f32; 2]; 5] = [
[0.035948314, -0.015142122],
[-0.0035737813, -0.001155745],
[-0.07784372, -0.059181444],
[0.0, 0.0],
[-0.081114516, 0.06281963],
];
const B0G: [f32; 5] = [0.019489167, -0.005999865, -0.3116488, 0.0, -0.12533475];
const W2G: [[f32; 5]; 2] = [[0.010261777, 0.15239798, 0.37232202, 0.0, 0.22712366]; 2];
const B2G: [f32; 2] = [0.50000006; 2];
#[test]
fn test_residual_forward_backward_resadd_as_main() {
type SubModel = (Linear<2, 5>, ReLU, Linear<5, 2>);
type Model = GeneralizedResidual<SubModel, ReLU>;
let mut model: Model = Default::default();
*model.0 .0.weight.mut_data() = W0;
*model.0 .0.bias.mut_data() = B0;
*model.0 .2.weight.mut_data() = W2;
*model.0 .2.bias.mut_data() = B2;
let x = Tensor2D::new(X);
let y = model.forward(x.traced());
assert_close(
y.data(),
add(Tensor2D::new(Y), &(-Tensor2D::new(X)).relu()).data(),
);
let gradients = y.mean().backward();
assert_close(gradients.ref_gradient(&model.0 .0.weight), &W0G);
assert_close(gradients.ref_gradient(&model.0 .0.bias), &B0G);
assert_close(gradients.ref_gradient(&model.0 .2.weight), &W2G);
assert_close(gradients.ref_gradient(&model.0 .2.bias), &B2G);
}
#[test]
fn test_residual_forward_backward_with_update() {
type SubModel = (Linear<2, 5>, ReLU, Linear<5, 2>);
type Model = GeneralizedResidual<SubModel, SubModel>;
let mut model: Model = Default::default();
*model.0 .0.weight.mut_data() = W0;
*model.0 .0.bias.mut_data() = B0;
*model.0 .2.weight.mut_data() = W2;
*model.0 .2.bias.mut_data() = B2;
*model.1 .0.weight.mut_data() = W0;
*model.1 .0.bias.mut_data() = B0;
*model.1 .2.weight.mut_data() = W2;
*model.1 .2.bias.mut_data() = B2;
let mut model2: SubModel = Default::default();
*model2.0.weight.mut_data() = W0;
*model2.0.bias.mut_data() = B0;
*model2.2.weight.mut_data() = W2;
model2.2.weight = model2.2.weight * 2.0;
*model2.2.bias.mut_data() = B2;
model2.2.bias = model2.2.bias * 2.0;
let x = Tensor2D::new(X);
let y = model.forward(x.traced());
let x2 = Tensor2D::new(X);
let y2 = model2.forward(x2.traced());
assert_close(y.data(), y2.data());
let gradients = y.mean().backward();
let gradients2 = y2.mean().backward();
assert_close(gradients.ref_gradient(&model.0 .0.weight), &W0G);
assert_close(gradients.ref_gradient(&model.0 .0.bias), &B0G);
assert_close(gradients.ref_gradient(&model.0 .2.weight), &W2G);
assert_close(gradients.ref_gradient(&model.0 .2.bias), &B2G);
assert_close(gradients.ref_gradient(&model.1 .0.weight), &W0G);
assert_close(gradients.ref_gradient(&model.1 .0.bias), &B0G);
assert_close(gradients.ref_gradient(&model.1 .2.weight), &W2G);
assert_close(gradients.ref_gradient(&model.1 .2.bias), &B2G);
assert_close(
gradients2.ref_gradient(&model2.0.weight),
(Tensor2D::new(W0G) * 2.0).data(),
);
assert_close(
gradients2.ref_gradient(&model2.0.bias),
(Tensor1D::new(B0G) * 2.0).data(),
);
assert_close(gradients2.ref_gradient(&model2.2.weight), &W2G);
assert_close(gradients2.ref_gradient(&model2.2.bias), &B2G);
let sgd_config = SgdConfig {
lr: 1.0,
momentum: None,
};
Sgd::new(sgd_config)
.update(&mut model, gradients)
.expect("unused params");
Sgd::new(sgd_config)
.update(&mut model2, gradients2)
.expect("unused params");
assert_close(
model.0 .0.weight.data(),
sub(Tensor2D::new(W0), &Tensor2D::new(W0G)).data(),
);
assert_close(
model.0 .0.bias.data(),
sub(Tensor1D::new(B0), &Tensor1D::new(B0G)).data(),
);
assert_close(
model.0 .2.weight.data(),
sub(Tensor2D::new(W2), &Tensor2D::new(W2G)).data(),
);
assert_close(
model.0 .2.bias.data(),
sub(Tensor1D::new(B2), &Tensor1D::new(B2G)).data(),
);
assert_close(
model.1 .0.weight.data(),
sub(Tensor2D::new(W0), &Tensor2D::new(W0G)).data(),
);
assert_close(
model.1 .0.bias.data(),
sub(Tensor1D::new(B0), &Tensor1D::new(B0G)).data(),
);
assert_close(
model.1 .2.weight.data(),
sub(Tensor2D::new(W2), &Tensor2D::new(W2G)).data(),
);
assert_close(
model.1 .2.bias.data(),
sub(Tensor1D::new(B2), &Tensor1D::new(B2G)).data(),
);
}
#[test]
fn test_residual_gradients_correctly_added() {
type Model = (Linear<1, 1>, GeneralizedResidual<ReLU, ReLU>);
let mut model: Model = Default::default();
*model.0.weight.mut_data() = [[1.0]];
let x = Tensor2D::new([[-1.0], [1.0]]);
let y = model.forward(x.traced());
assert_close(y.data(), &[[0.0], [2.0]]);
let grads = y.mean().backward();
assert_close(grads.ref_gradient(&model.0.weight), &[[1.0]]);
assert_close(grads.ref_gradient(&model.0.bias), &[1.0]);
}
#[test]
fn test_save_residual() {
let model: GeneralizedResidual<Linear<5, 3>, Linear<5, 3>> = Default::default();
let file = NamedTempFile::new().expect("failed to create tempfile");
model
.save(file.path().to_str().unwrap())
.expect("failed to save model");
let f = File::open(file.path()).expect("failed to open resulting file");
let mut zip = ZipArchive::new(f).expect("failed to create zip archive from file");
{
let weight_file = zip
.by_name("_mainweight.npy")
.expect("failed to find _mainweight.npy file");
assert!(weight_file.size() > 0);
}
{
let bias_file = zip
.by_name("_mainbias.npy")
.expect("failed to find _mainbias.npy file");
assert!(bias_file.size() > 0);
}
{
let weight_file = zip
.by_name("_residualweight.npy")
.expect("failed to find _residualweight.npy file");
assert!(weight_file.size() > 0);
}
{
let bias_file = zip
.by_name("_residualbias.npy")
.expect("failed to find _residualbias.npy file");
assert!(bias_file.size() > 0);
}
}
#[test]
fn test_load_residual() {
let mut rng = StdRng::seed_from_u64(0);
let mut saved_model: GeneralizedResidual<Linear<5, 3>, Linear<5, 3>> = Default::default();
saved_model.reset_params(&mut rng);
let file = NamedTempFile::new().expect("failed to create tempfile");
assert!(saved_model.save(file.path().to_str().unwrap()).is_ok());
let mut loaded_model: GeneralizedResidual<Linear<5, 3>, Linear<5, 3>> = Default::default();
assert_ne!(loaded_model.0.weight.data(), saved_model.0.weight.data());
assert_ne!(loaded_model.0.bias.data(), saved_model.0.bias.data());
assert_ne!(loaded_model.1.weight.data(), saved_model.1.weight.data());
assert_ne!(loaded_model.1.bias.data(), saved_model.1.bias.data());
assert!(loaded_model.load(file.path().to_str().unwrap()).is_ok());
assert_eq!(loaded_model.1.weight.data(), saved_model.1.weight.data());
assert_eq!(loaded_model.1.bias.data(), saved_model.1.bias.data());
}
}