use candle_core::{Result, Tensor, DType, Device, Module};
use candle_nn::{VarBuilder, Linear, linear, Init};
fn find_multiple(a: usize, b: usize) -> usize {
((a + b - 1) / b) * b
}
pub struct CastedLinear {
weight: Tensor,
bias: Option<Tensor>,
}
impl CastedLinear {
pub fn new(
in_features: usize,
out_features: usize,
bias: bool,
vb: VarBuilder,
) -> Result<Self> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let weight = vb.get_with_hints((out_features, in_features), "weight", init_ws)?;
let bias = if bias {
let bound = 1. / (in_features as f64).sqrt();
let init_bs = Init::Uniform { lo: -bound, up: bound };
Some(vb.get_with_hints(out_features, "bias", init_bs)?)
} else {
None
};
Ok(Self { weight, bias })
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_dtype = input.dtype();
let weight = if self.weight.dtype() != input_dtype {
self.weight.to_dtype(input_dtype)?
} else {
self.weight.clone()
};
let weight_t = weight.t()?;
let output = input.broadcast_matmul(&weight_t)?;
if let Some(ref b) = self.bias {
let bias = if b.dtype() != input_dtype {
b.to_dtype(input_dtype)?
} else {
b.clone()
};
output.broadcast_add(&bias)
} else {
Ok(output)
}
}
}
pub struct SwiGLU {
gate_up_proj: CastedLinear,
down_proj: CastedLinear,
}
impl SwiGLU {
pub fn new(hidden_size: usize, expansion: f32, vb: VarBuilder) -> Result<Self> {
let inter = find_multiple(((expansion * hidden_size as f32 * 2.0 / 3.0).round() as usize), 256);
let gate_up_proj = CastedLinear::new(
hidden_size,
inter * 2,
false,
vb.pp("gate_up_proj"),
)?;
let down_proj = CastedLinear::new(
inter,
hidden_size,
false,
vb.pp("down_proj"),
)?;
Ok(Self {
gate_up_proj,
down_proj,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let gate_up = self.gate_up_proj.forward(x)?;
let last_dim = gate_up.dims().len() - 1;
let inter_size = gate_up.dim(last_dim)? / 2;
let gate = gate_up.narrow(last_dim, 0, inter_size)?;
let up = gate_up.narrow(last_dim, inter_size, inter_size)?;
let gate_activated = candle_nn::ops::silu(&gate)?;
let gated = gate_activated.mul(&up)?;
self.down_proj.forward(&gated)
}
}
pub struct LinearSwish {
linear: CastedLinear,
reverse: bool,
}
impl LinearSwish {
pub fn new(hidden_size: usize, reverse: bool, vb: VarBuilder) -> Result<Self> {
let linear = CastedLinear::new(
hidden_size,
hidden_size,
false,
vb.pp("linear"),
)?;
Ok(Self { linear, reverse })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
if self.reverse {
let linear_out = self.linear.forward(x)?;
candle_nn::ops::silu(&linear_out)
} else {
let silu_out = candle_nn::ops::silu(x)?;
self.linear.forward(&silu_out)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_nn::VarMap;
#[test]
fn test_find_multiple() {
assert_eq!(find_multiple(100, 256), 256);
assert_eq!(find_multiple(300, 256), 512);
assert_eq!(find_multiple(256, 256), 256);
assert_eq!(find_multiple(1, 256), 256);
}
#[test]
fn test_casted_linear_shape() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let linear = CastedLinear::new(64, 128, true, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 64), &device)?;
let out = linear.forward(&x)?;
assert_eq!(out.dims(), &[2, 16, 128]);
Ok(())
}
#[test]
fn test_swiglu_shape() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let swiglu = SwiGLU::new(256, 4.0, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let out = swiglu.forward(&x)?;
assert_eq!(out.dims(), x.dims());
Ok(())
}
#[test]
fn test_linear_swish_shape() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let lin_swish = LinearSwish::new(256, false, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let out = lin_swish.forward(&x)?;
assert_eq!(out.dims(), x.dims());
Ok(())
}
}