use crate::prelude::*;
pub struct ReLU;
impl InitModule for ReLU {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for ReLU {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: Shape> Module<GraphTensor<S>> for ReLU {
type Output = GraphTensor<S>;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
input.relu()
}
}
pub struct Sigmoid;
impl InitModule for Sigmoid {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for Sigmoid {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: ConstShape> Module<GraphTensor<S>> for Sigmoid {
type Output = GraphTensor<S>;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
input.sigmoid()
}
}
pub struct Swish;
impl InitModule for Swish {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for Swish {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: ConstShape> Module<GraphTensor<S>> for Swish {
type Output = GraphTensor<S>;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
input.swish()
}
}
pub struct Tanh;
impl InitModule for Tanh {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for Tanh {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: ConstShape> Module<GraphTensor<S>> for Tanh {
type Output = GraphTensor<S>;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
input.tanh()
}
}
#[cfg(test)]
mod tests {
use super::ReLU;
use crate::{
nn::linear::Linear,
prelude::{Module, *},
tests::assert_close,
};
use dfdx::prelude::{Module as DfdxModule, *};
#[test]
fn test_relu_and_linear() {
let mut cx = Graph::new();
let batch = cx
.tensor::<R2<2, 3>>()
.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let a = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);
let model: (Linear<3, 4>, ReLU, Linear<4, 2>) = InitModule::initialize(&mut cx);
model
.0
.weight
.set(vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.]);
model.2.weight.set(vec![1., 2., 3., 1., 2., 3., 1., 2.]);
let mut b = model.forward(a).retrieve();
let mut batch_out = model.forward(batch).retrieve();
cx.execute();
let unoptimized_b = b.data();
let unoptimized_batch_out = batch_out.data();
cx.compile(
<(GenericCompiler, CPUCompiler)>::default(),
(&mut b, &mut batch_out),
);
cx.execute();
assert_close(&unoptimized_b, &b.data());
assert_close(&unoptimized_batch_out, &batch_out.data());
let dev = Cpu::default();
let mut model = <(
dfdx::nn::modules::builders::UnbiasedLinear<3, 4>,
dfdx::nn::modules::builders::ReLU,
dfdx::nn::modules::builders::UnbiasedLinear<4, 2>,
)>::build_on_device(&dev);
model.0.weight = dev
.tensor_from_vec(
vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.],
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>),
)
.permute();
model.2.weight = dev
.tensor_from_vec(
vec![1., 2., 3., 1., 2., 3., 1., 2.],
(dfdx::shapes::Const::<4>, dfdx::shapes::Const::<2>),
)
.permute();
let a = dev.tensor_from_vec(vec![1.0, 2.0, 3.0], (dfdx::shapes::Const::<3>,));
let d_batch = dev.tensor_from_vec(
vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
(dfdx::shapes::Const::<2>, dfdx::shapes::Const::<3>),
);
let out = model.forward(a);
let d_batch_out = model.forward(d_batch);
assert_close(&unoptimized_b, &out.as_vec());
assert_close(&unoptimized_batch_out, &d_batch_out.as_vec());
}
}