use crate::autograd::Variable;
use crate::tensor::Result;
use super::Module;
use super::parameter::Parameter;
pub struct ZeroPad2d {
padding: [i64; 4], }
impl ZeroPad2d {
pub fn new(padding: i64) -> Self {
Self { padding: [padding, padding, padding, padding] }
}
pub fn asymmetric(left: i64, right: i64, top: i64, bottom: i64) -> Self {
Self { padding: [left, right, top, bottom] }
}
}
impl Module for ZeroPad2d {
fn name(&self) -> &str { "zero_pad2d" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let result = input.data().pad(&self.padding, 0.0)?;
Ok(Variable::wrap(result))
}
fn parameters(&self) -> Vec<Parameter> {
vec![]
}
}
pub struct ReflectionPad2d {
padding: [i64; 4], }
impl ReflectionPad2d {
pub fn new(padding: i64) -> Self {
Self { padding: [padding, padding, padding, padding] }
}
pub fn asymmetric(left: i64, right: i64, top: i64, bottom: i64) -> Self {
Self { padding: [left, right, top, bottom] }
}
}
impl Module for ReflectionPad2d {
fn name(&self) -> &str { "reflection_pad2d" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let result = input.data().pad_mode(&self.padding, 1, 0.0)?; Ok(Variable::wrap(result))
}
fn parameters(&self) -> Vec<Parameter> {
vec![]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{DType, Tensor, TensorOptions};
#[test]
fn test_zero_pad2d() {
let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
let input = Variable::new(Tensor::ones(&[1, 1, 2, 2], opts).unwrap(), false);
let pad = ZeroPad2d::new(1);
let y = pad.forward(&input).unwrap();
assert_eq!(y.shape(), vec![1, 1, 4, 4]);
let data = y.data().to_f32_vec().unwrap();
assert!((data[0] - 0.0).abs() < 1e-5); }
#[test]
fn test_reflection_pad2d() {
let device = crate::tensor::test_device();
let input = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2], device).unwrap(),
false,
);
let pad = ReflectionPad2d::new(1);
let y = pad.forward(&input).unwrap();
assert_eq!(y.shape(), vec![1, 1, 4, 4]);
}
#[test]
fn test_zero_pad2d_asymmetric() {
let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
let input = Variable::new(Tensor::ones(&[1, 1, 3, 3], opts).unwrap(), false);
let pad = ZeroPad2d::asymmetric(1, 2, 0, 3);
let y = pad.forward(&input).unwrap();
assert_eq!(y.shape(), vec![1, 1, 6, 6]);
}
#[test]
fn test_reflection_pad2d_asymmetric() {
let device = crate::tensor::test_device();
let input = Variable::new(
Tensor::randn(&[1, 1, 4, 4], TensorOptions { dtype: DType::Float32, device }).unwrap(),
false,
);
let pad = ReflectionPad2d::asymmetric(1, 2, 1, 2);
let y = pad.forward(&input).unwrap();
assert_eq!(y.shape(), vec![1, 1, 7, 7]);
}
#[test]
fn test_zero_pad2d_no_parameters() {
let pad = ZeroPad2d::new(2);
assert_eq!(pad.parameters().len(), 0);
}
#[test]
fn test_reflection_pad2d_values() {
let device = crate::tensor::test_device();
let input = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2], device).unwrap(),
false,
);
let pad = ReflectionPad2d::new(1);
let y = pad.forward(&input).unwrap();
let data = y.data().to_f32_vec().unwrap();
assert!((data[5] - 1.0).abs() < 1e-5); assert!((data[6] - 2.0).abs() < 1e-5); }
}