wgml/ops/
layernorm.rs

1use bytemuck::Pod;
2use nalgebra::DVector;
3use wgcore::kernel::{KernelInvocationBuilder, KernelInvocationQueue};
4use wgcore::tensor::GpuVectorView;
5use wgcore::Shader;
6use wgebra::linalg::Shape;
7use wgpu::ComputePipeline;
8
9#[derive(Shader)]
10#[shader(derive(Shape), src = "layernorm.wgsl", composable = false)]
11/// Shader implementing the layer normalization kernel.
12pub struct LayerNorm {
13    pub main: ComputePipeline,
14}
15
16impl LayerNorm {
17    pub fn queue<'a, 'b, T: Pod>(
18        &'a self,
19        queue: &mut KernelInvocationQueue<'a>,
20        out_vec: impl Into<GpuVectorView<'b, T>>,
21        in_vec: impl Into<GpuVectorView<'b, T>>,
22    ) {
23        let in_vec = in_vec.into();
24        let out_vec = out_vec.into();
25
26        assert_eq!(
27            in_vec.shape().size[0],
28            out_vec.shape().size[0],
29            "LayerNorm: dimension mismatch."
30        );
31
32        let in_shape = queue.shape_buffer(in_vec.shape());
33        let out_shape = queue.shape_buffer(out_vec.shape());
34        KernelInvocationBuilder::new(queue, &self.main)
35            .bind0([&in_shape, &out_shape, in_vec.buffer(), out_vec.buffer()])
36            .queue(1);
37    }
38
39    /// The layernorm function.
40    ///
41    /// See <https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html> for details on the
42    /// math.
43    pub fn run_cpu(res: &mut DVector<f32>, v: &DVector<f32>) {
44        const NUDGE_FACTOR: f32 = 1.0e-5;
45        let mean = v.mean();
46        res.zip_apply(v, |y, v| *y = v - mean);
47        let variance = res.norm_squared() / (res.len() as f32);
48        let scale = 1.0 / (variance + NUDGE_FACTOR).sqrt();
49        *res *= scale;
50    }
51}
52
53#[cfg(test)]
54mod test {
55    use crate::ops::LayerNorm;
56    use nalgebra::DVector;
57    use wgcore::gpu::GpuInstance;
58    use wgcore::kernel::KernelInvocationQueue;
59    use wgcore::tensor::TensorBuilder;
60    use wgcore::Shader;
61    use wgpu::BufferUsages;
62
63    #[futures_test::test]
64    #[serial_test::serial]
65    async fn gpu_layernorm() {
66        let gpu = GpuInstance::new().await.unwrap();
67        let layernorm = super::LayerNorm::from_device(gpu.device());
68        let mut queue = KernelInvocationQueue::new(gpu.device());
69        let mut encoder = gpu.device().create_command_encoder(&Default::default());
70
71        const LEN: u32 = 1757;
72
73        let v0 = DVector::new_random(LEN as usize);
74        let out = DVector::new_random(LEN as usize);
75        let gpu_v0 = TensorBuilder::vector(LEN, BufferUsages::STORAGE | BufferUsages::COPY_SRC)
76            .build_init(gpu.device(), v0.as_slice());
77        let gpu_out = TensorBuilder::vector(LEN, BufferUsages::STORAGE | BufferUsages::COPY_SRC)
78            .build_init(gpu.device(), v0.as_slice());
79        let staging = TensorBuilder::vector(LEN, BufferUsages::MAP_READ | BufferUsages::COPY_DST)
80            .build(gpu.device());
81
82        layernorm.queue(&mut queue, &gpu_out, &gpu_v0);
83
84        queue.encode(&mut encoder, None);
85        staging.copy_from(&mut encoder, &gpu_out);
86
87        gpu.queue().submit(Some(encoder.finish()));
88
89        let mut cpu_result = out;
90        LayerNorm::run_cpu(&mut cpu_result, &v0);
91
92        approx::assert_relative_eq!(
93            DVector::from(staging.read(gpu.device()).await.unwrap()),
94            cpu_result,
95            epsilon = 1.0e-5
96        );
97    }
98}