use candle_core::{Device, Tensor};
use crate::error::Result;
pub struct SwiGLU {
gate_weight: Tensor,
up_weight: Tensor,
down_weight: Tensor,
}
impl SwiGLU {
pub fn new(hidden_size: usize, intermediate_size: usize, device: &Device) -> Result<Self> {
let std = (1.0 / hidden_size as f64).sqrt() as f32;
let gate_weight = Tensor::randn(0.0, std, (intermediate_size, hidden_size), device)?;
let up_weight = Tensor::randn(0.0, std, (intermediate_size, hidden_size), device)?;
let down_weight = Tensor::randn(0.0, std, (hidden_size, intermediate_size), device)?;
Ok(Self {
gate_weight,
up_weight,
down_weight,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let device = x.device();
if device.is_cuda() {
self.forward_cuda(x)
} else {
self.forward_cpu(x)
}
}
fn forward_cpu(&self, x: &Tensor) -> Result<Tensor> {
let gate = x.broadcast_matmul(&self.gate_weight.t()?)?;
let gate = candle_nn::ops::silu(&gate)?;
let up = x.broadcast_matmul(&self.up_weight.t()?)?;
let hidden = (gate * up)?;
let output = hidden.broadcast_matmul(&self.down_weight.t()?)?;
Ok(output)
}
fn forward_cuda(&self, x: &Tensor) -> Result<Tensor> {
tracing::debug!("Using CUDA SwiGLU path for input shape {:?}", x.shape());
self.forward_cpu(x)
}
#[must_use]
pub fn vram_estimate(&self, batch_size: usize, seq_len: usize) -> usize {
let intermediate = self.gate_weight.dim(0).unwrap_or(0);
let bytes_per_elem = 4;
3 * batch_size * seq_len * intermediate * bytes_per_elem
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::DType;
#[test]
fn test_swiglu_creation() {
let device = Device::Cpu;
let swiglu = SwiGLU::new(768, 2048, &device);
assert!(swiglu.is_ok());
}
#[test]
fn test_swiglu_forward() {
let device = Device::Cpu;
let swiglu = SwiGLU::new(768, 2048, &device).unwrap();
let input = Tensor::zeros(&[2, 10, 768], DType::F32, &device).unwrap();
let output = swiglu.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[2, 10, 768]);
}
}