use candle::{Result, Tensor};
#[derive(Clone, Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let x = match *x.dims() {
[b1, b2, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
} else {
let w = self.weight.broadcast_left((b1, b2))?.t()?;
x.matmul(&w)?
}
}
[bsize, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
} else {
let w = self.weight.broadcast_left(bsize)?.t()?;
x.matmul(&w)?
}
}
_ => {
let w = self.weight.t()?;
x.matmul(&w)?
}
};
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None))
}
pub fn linear_b(
in_dim: usize,
out_dim: usize,
bias: bool,
vb: crate::VarBuilder,
) -> Result<Linear> {
if bias {
linear(in_dim, out_dim, vb)
} else {
linear_no_bias(in_dim, out_dim, vb)
}
}