use rand::{thread_rng, Rng};
use crate::prelude::*;
pub struct Linear<const A: usize, const B: usize> {
pub weight: GraphTensor<R2<A, B>>,
}
impl<const A: usize, const B: usize> InitModule for Linear<A, B> {
fn initialize(cx: &mut Graph) -> Self {
let s = Self {
weight: cx.named_tensor("Weight"),
};
let mut rng = thread_rng();
s.weight.set(
(0..(A * B))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
);
s
}
}
impl<const A: usize, const B: usize> SerializeModule for Linear<A, B> {
fn serialize(&self, s: &mut crate::serialization::Serializer) {
s.tensor("weight", self.weight);
}
}
impl<const A: usize, const B: usize> Module<GraphTensor<R1<A>>> for Linear<A, B> {
type Output = GraphTensor<R1<B>>;
fn forward(&self, input: GraphTensor<R1<A>>) -> Self::Output {
input.matmul(self.weight)
}
}
impl<const A: usize, const B: usize, C: Dimension> Module<GraphTensor<(C, Const<A>)>>
for Linear<A, B>
{
type Output = GraphTensor<(C, Const<B>)>;
fn forward(&self, input: GraphTensor<(C, Const<A>)>) -> Self::Output {
input.matmul(self.weight)
}
}
impl<const A: usize, const B: usize, C: Dimension, D: Dimension>
Module<GraphTensor<(C, D, Const<A>)>> for Linear<A, B>
{
type Output = GraphTensor<(C, D, Const<B>)>;
fn forward(&self, input: GraphTensor<(C, D, Const<A>)>) -> Self::Output {
input.matmul(self.weight)
}
}
impl<const A: usize, const B: usize, C: Dimension, D: Dimension, E: Dimension>
Module<GraphTensor<(C, D, E, Const<A>)>> for Linear<A, B>
{
type Output = GraphTensor<(C, D, E, Const<B>)>;
fn forward(&self, input: GraphTensor<(C, D, E, Const<A>)>) -> Self::Output {
input.matmul(self.weight.expand())
}
}
#[cfg(test)]
mod tests {
use super::Linear;
use crate::{prelude::*, tests::assert_close};
#[test]
fn test_linear() {
let mut cx = Graph::new();
let batch = cx
.tensor::<R2<2, 3>>()
.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let a = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);
let model: Linear<3, 4> = Linear::initialize(&mut cx);
let mut b = model.forward(a).retrieve();
let mut batch_out = model.forward(batch).retrieve();
cx.execute();
let unoptimized_b = b.data();
let unoptimized_batch_out = batch_out.data();
cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
cx.execute();
assert_close(&unoptimized_b, &b.data());
assert_close(&unoptimized_batch_out, &batch_out.data());
}
}