use candle_core::{Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
fn linear_z(in_dim: usize, out_dim: usize, vs: &VarBuilder) -> Result<Linear> {
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
pub trait Model: Sized {
fn new(vs: VarBuilder) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
pub struct LinearModel {
linear: Linear,
}
impl Model for LinearModel {
fn new(vs: VarBuilder) -> Result<Self> {
let linear = linear_z(IMAGE_DIM, LABELS, &vs)?;
Ok(Self { linear })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.linear.forward(xs)
}
}
pub struct Mlp {
ln1: Linear,
ln2: Linear,
}
impl Model for Mlp {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
Ok(Self { ln1, ln2 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
self.ln2.forward(&xs)
}
}