1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use super::Kernel;
use crate::kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource};
use core::marker::PhantomData;

/// Provides launch information specifying the number of work groups to be used by a compute shader.
#[derive(new, Clone, Debug)]
pub struct WorkGroup {
    /// Work groups for the x axis.
    pub x: u32,
    /// Work groups for the y axis.
    pub y: u32,
    /// Work groups for the z axis.
    pub z: u32,
}

impl WorkGroup {
    /// Calculate the number of invocations of a compute shader.
    pub fn num_invocations(&self) -> usize {
        (self.x * self.y * self.z) as usize
    }
}

/// Wraps a [dynamic kernel source](DynamicKernelSource) into a [kernel](Kernel) with launch
/// information such as [workgroup](WorkGroup).
#[derive(new)]
pub struct DynamicKernel<K> {
    kernel: K,
    workgroup: WorkGroup,
}

/// Wraps a [static kernel source](StaticKernelSource) into a [kernel](Kernel) with launch
/// information such as [workgroup](WorkGroup).
#[derive(new)]
pub struct StaticKernel<K> {
    workgroup: WorkGroup,
    _kernel: PhantomData<K>,
}

impl<K> Kernel for DynamicKernel<K>
where
    K: DynamicKernelSource + 'static,
{
    fn source(&self) -> SourceTemplate {
        self.kernel.source()
    }

    fn id(&self) -> String {
        self.kernel.id()
    }

    fn workgroup(&self) -> WorkGroup {
        self.workgroup.clone()
    }
}

impl<K> Kernel for StaticKernel<K>
where
    K: StaticKernelSource + 'static,
{
    fn source(&self) -> SourceTemplate {
        K::source()
    }

    fn id(&self) -> String {
        format!("{:?}", core::any::TypeId::of::<K>())
    }

    fn workgroup(&self) -> WorkGroup {
        self.workgroup.clone()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi,
        WgpuDevice,
    };

    #[test]
    fn can_run_kernel() {
        binary_elemwise!(Add, "+");

        let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());

        let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
        let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
        let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];

        let lhs = client.create(bytemuck::cast_slice(&lhs));
        let rhs = client.create(bytemuck::cast_slice(&rhs));
        let out = client.empty(core::mem::size_of::<f32>() * 8);
        let info = client.create(bytemuck::cast_slice(&info));

        type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
        let kernel = Box::new(StaticKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));

        client.execute(kernel, &[&lhs, &rhs, &out, &info]);

        let data = client.read(&out).read_sync().unwrap();
        let output: &[f32] = bytemuck::cast_slice(&data);

        assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
    }
}