use dfdx::{
nn::modules::{Linear, Module, ModuleVisitor, ReLU, TensorCollection},
prelude::BuildModule,
shapes::{Dtype, Rank1, Rank2},
tensor::{AutoDevice, SampleTensor, Tape, Tensor, Trace},
tensor_ops::Device,
};
struct Mlp<const IN: usize, const INNER: usize, const OUT: usize, E: Dtype, D: Device<E>> {
l1: Linear<IN, INNER, E, D>,
l2: Linear<INNER, OUT, E, D>,
relu: ReLU,
}
impl<const IN: usize, const INNER: usize, const OUT: usize, E, D: Device<E>> TensorCollection<E, D>
for Mlp<IN, INNER, OUT, E, D>
where
E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform,
{
type To<E2: Dtype, D2: Device<E2>> = Mlp<IN, INNER, OUT, E2, D2>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
(
Self::module("l1", |s| &s.l1, |s| &mut s.l1),
Self::module("l2", |s| &s.l2, |s| &mut s.l2),
),
|(l1, l2)| Mlp {
l1,
l2,
relu: Default::default(),
},
)
}
}
impl<const IN: usize, const INNER: usize, const OUT: usize, E: Dtype, D: Device<E>>
Module<Tensor<Rank1<IN>, E, D>> for Mlp<IN, INNER, OUT, E, D>
{
type Output = Tensor<Rank1<OUT>, E, D>;
type Error = D::Err;
fn try_forward(&self, x: Tensor<Rank1<IN>, E, D>) -> Result<Self::Output, D::Err> {
let x = self.l1.try_forward(x)?;
let x = self.relu.try_forward(x)?;
self.l2.try_forward(x)
}
}
impl<
const BATCH: usize,
const IN: usize,
const INNER: usize,
const OUT: usize,
E: Dtype,
D: Device<E>,
T: Tape<E, D>,
> Module<Tensor<Rank2<BATCH, IN>, E, D, T>> for Mlp<IN, INNER, OUT, E, D>
{
type Output = Tensor<Rank2<BATCH, OUT>, E, D, T>;
type Error = D::Err;
fn try_forward(&self, x: Tensor<Rank2<BATCH, IN>, E, D, T>) -> Result<Self::Output, D::Err> {
let x = self.l1.try_forward(x)?;
let x = self.relu.try_forward(x)?;
self.l2.try_forward(x)
}
}
fn main() {
let dev = AutoDevice::default();
let model = Mlp::<10, 512, 20, f32, AutoDevice>::build(&dev);
let item: Tensor<Rank1<10>, f32, _> = dev.sample_normal();
let _: Tensor<Rank1<20>, f32, _> = model.forward(item);
let batch: Tensor<Rank2<32, 10>, f32, _> = dev.sample_normal();
let _: Tensor<Rank2<32, 20>, f32, _, _> = model.forward(batch.leaky_trace());
}