meuron 0.4.0

Meuron is a modular neural network library written in rust for training simple neural networks.
Documentation
use bytemuck::{Pod, bytes_of};
use ndarray::Dimension;
use std::sync::Arc;
use wgpu::util::DeviceExt;

use super::context::GpuContext;
use super::params::{BinopParams, ScalarParams, UnaryParams};
use super::tensor::GpuTensor;

pub fn uniform_buf<T: Pod>(ctx: &Arc<GpuContext>, data: &T) -> wgpu::Buffer {
    ctx.device
        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: None,
            contents: bytes_of(data),
            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
        })
}

pub fn storage_ro_buf<T: Pod>(ctx: &Arc<GpuContext>, data: &T) -> wgpu::Buffer {
    ctx.device
        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: None,
            contents: bytes_of(data),
            usage: wgpu::BufferUsages::STORAGE,
        })
}

fn make_bind_group(
    ctx: &Arc<GpuContext>,
    pipeline: &wgpu::ComputePipeline,
    entries: &[wgpu::BindGroupEntry],
) -> wgpu::BindGroup {
    ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: None,
        layout: &pipeline.get_bind_group_layout(0),
        entries,
    })
}

pub fn dispatch_1d(
    ctx: &Arc<GpuContext>,
    pipeline: &wgpu::ComputePipeline,
    entries: &[wgpu::BindGroupEntry],
    count: u32,
) {
    let bg = make_bind_group(ctx, pipeline, entries);
    ctx.with_encoder(|enc| {
        let mut pass = enc.begin_compute_pass(&Default::default());
        pass.set_pipeline(pipeline);
        pass.set_bind_group(0, &bg, &[]);
        pass.dispatch_workgroups(count.div_ceil(256), 1, 1);
    });
}

pub fn dispatch_3d(
    ctx: &Arc<GpuContext>,
    pipeline: &wgpu::ComputePipeline,
    entries: &[wgpu::BindGroupEntry],
    x: u32,
    y: u32,
    z: u32,
) {
    let bg = make_bind_group(ctx, pipeline, entries);
    ctx.with_encoder(|enc| {
        let mut pass = enc.begin_compute_pass(&Default::default());
        pass.set_pipeline(pipeline);
        pass.set_bind_group(0, &bg, &[]);
        pass.dispatch_workgroups(x.div_ceil(8), y.div_ceil(8), z);
    });
}

pub fn binop<D: Dimension>(a: &GpuTensor<D>, b: &GpuTensor<D>, op: u32) -> GpuTensor<D> {
    let ctx = a.ctx.clone();
    let size = a.size as u32;
    let out_buf = a.alloc_like();
    let ub = uniform_buf(
        &ctx,
        &BinopParams {
            size,
            op,
            pad0: 0,
            pad1: 0,
        },
    );
    dispatch_1d(
        &ctx,
        &ctx.pipelines.binop,
        &[
            wgpu::BindGroupEntry {
                binding: 0,
                resource: a.buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 1,
                resource: b.buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 2,
                resource: out_buf.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 3,
                resource: ub.as_entire_binding(),
            },
        ],
        size,
    );
    GpuTensor {
        buffer: Arc::new(out_buf),
        shape: a.shape.clone(),
        size: a.size,
        ctx,
    }
}

pub fn scalar_op<D: Dimension>(tensor: &GpuTensor<D>, op: u32, scalar: f32) -> GpuTensor<D> {
    let ctx = tensor.ctx.clone();
    let size = tensor.size as u32;
    let out_buf = tensor.alloc_like();
    let ub = uniform_buf(
        &ctx,
        &ScalarParams {
            size,
            op,
            scalar,
            pad: 0,
        },
    );
    dispatch_1d(
        &ctx,
        &ctx.pipelines.scalar,
        &[
            wgpu::BindGroupEntry {
                binding: 0,
                resource: tensor.buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 1,
                resource: out_buf.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 2,
                resource: ub.as_entire_binding(),
            },
        ],
        size,
    );
    GpuTensor {
        buffer: Arc::new(out_buf),
        shape: tensor.shape.clone(),
        size: tensor.size,
        ctx,
    }
}

pub fn unary_op<D: Dimension>(tensor: &GpuTensor<D>, op: u32) -> GpuTensor<D> {
    let ctx = tensor.ctx.clone();
    let size = tensor.size as u32;
    let out_buf = tensor.alloc_like();
    let ub = uniform_buf(
        &ctx,
        &UnaryParams {
            size,
            op,
            pad0: 0,
            pad1: 0,
        },
    );
    dispatch_1d(
        &ctx,
        &ctx.pipelines.unary,
        &[
            wgpu::BindGroupEntry {
                binding: 0,
                resource: tensor.buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 1,
                resource: out_buf.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 2,
                resource: ub.as_entire_binding(),
            },
        ],
        size,
    );
    GpuTensor {
        buffer: Arc::new(out_buf),
        shape: tensor.shape.clone(),
        size: tensor.size,
        ctx,
    }
}