Skip to main content

runmat_plot/gpu/shaders/
scatter2.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct ScatterParams {
8    color: vec4<f32>,
9    point_size: f32,
10    count: u32,
11    lod_stride: u32,
12    has_sizes: u32,
13    has_colors: u32,
14    color_stride: u32,
15};
16
17struct IndirectArgs {
18    vertex_count: atomic<u32>,
19    instance_count: u32,
20    first_vertex: u32,
21    first_instance: u32,
22};
23
24@group(0) @binding(0)
25var<storage, read> buf_x: array<f32>;
26
27@group(0) @binding(1)
28var<storage, read> buf_y: array<f32>;
29
30@group(0) @binding(2)
31var<storage, read_write> out_vertices: array<VertexRaw>;
32
33@group(0) @binding(3)
34var<uniform> params: ScatterParams;
35
36@group(0) @binding(4)
37var<storage, read> buf_sizes: array<f32>;
38
39@group(0) @binding(5)
40var<storage, read> buf_colors: array<f32>;
41
42@group(0) @binding(6)
43var<storage, read_write> indirect: IndirectArgs;
44
45@compute @workgroup_size(WORKGROUP_SIZE)
46fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
47    let idx = gid.x;
48    if (idx >= params.count) {
49        return;
50    }
51    let stride = max(params.lod_stride, 1u);
52    if ((idx % stride) != 0u) {
53        return;
54    }
55
56    let px = buf_x[idx];
57    let py = buf_y[idx];
58
59    var v_color = params.color;
60    if (params.has_colors != 0u) {
61        let base = idx * params.color_stride;
62        let r = buf_colors[base];
63        let g = buf_colors[base + 1u];
64        let b = buf_colors[base + 2u];
65        let a = if params.color_stride > 3u {
66            buf_colors[base + 3u]
67        } else {
68            1.0
69        };
70        v_color = vec4<f32>(r, g, b, a);
71    }
72
73    let mut marker_size = params.point_size;
74    if (params.has_sizes != 0u) {
75        marker_size = buf_sizes[idx];
76    }
77
78    let corners = array<vec2<f32>, 6u>(
79        vec2<f32>(-1.0, -1.0),
80        vec2<f32>( 1.0, -1.0),
81        vec2<f32>( 1.0,  1.0),
82        vec2<f32>(-1.0, -1.0),
83        vec2<f32>( 1.0,  1.0),
84        vec2<f32>(-1.0,  1.0)
85    );
86    let out_base = atomicAdd(&(indirect.vertex_count), 6u);
87    for (var i: u32 = 0u; i < 6u; i = i + 1u) {
88        var vertex: VertexRaw;
89        vertex.data[0u] = px;
90        vertex.data[1u] = py;
91        vertex.data[2u] = 0.0;
92        vertex.data[3u] = v_color.x;
93        vertex.data[4u] = v_color.y;
94        vertex.data[5u] = v_color.z;
95        vertex.data[6u] = v_color.w;
96        vertex.data[7u] = 0.0;
97        vertex.data[8u] = 0.0;
98        vertex.data[9u] = marker_size;
99        vertex.data[10u] = corners[i].x;
100        vertex.data[11u] = corners[i].y;
101        out_vertices[out_base + i] = vertex;
102    }
103}
104"#;
105
106pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
107
108struct VertexRaw {
109    data: array<f32, 12u>,
110};
111
112struct ScatterParams {
113    color: vec4<f32>,
114    point_size: f32,
115    count: u32,
116    lod_stride: u32,
117    has_sizes: u32,
118    has_colors: u32,
119    color_stride: u32,
120};
121
122struct IndirectArgs {
123    vertex_count: atomic<u32>,
124    instance_count: u32,
125    first_vertex: u32,
126    first_instance: u32,
127};
128
129@group(0) @binding(0)
130var<storage, read> buf_x: array<f64>;
131
132@group(0) @binding(1)
133var<storage, read> buf_y: array<f64>;
134
135@group(0) @binding(2)
136var<storage, read_write> out_vertices: array<VertexRaw>;
137
138@group(0) @binding(3)
139var<uniform> params: ScatterParams;
140
141@group(0) @binding(4)
142var<storage, read> buf_sizes: array<f32>;
143
144@group(0) @binding(5)
145var<storage, read> buf_colors: array<f32>;
146
147@group(0) @binding(6)
148var<storage, read_write> indirect: IndirectArgs;
149
150@compute @workgroup_size(WORKGROUP_SIZE)
151fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
152    let idx = gid.x;
153    if (idx >= params.count) {
154        return;
155    }
156    let stride = max(params.lod_stride, 1u);
157    if ((idx % stride) != 0u) {
158        return;
159    }
160
161    let px = f32(buf_x[idx]);
162    let py = f32(buf_y[idx]);
163
164    var v_color = params.color;
165    if (params.has_colors != 0u) {
166        let base = idx * params.color_stride;
167        let r = buf_colors[base];
168        let g = buf_colors[base + 1u];
169        let b = buf_colors[base + 2u];
170        let a = if params.color_stride > 3u {
171            buf_colors[base + 3u]
172        } else {
173            1.0
174        };
175        v_color = vec4<f32>(r, g, b, a);
176    }
177
178    let mut marker_size = params.point_size;
179    if (params.has_sizes != 0u) {
180        marker_size = buf_sizes[idx];
181    }
182
183    let corners = array<vec2<f32>, 6u>(
184        vec2<f32>(-1.0, -1.0),
185        vec2<f32>( 1.0, -1.0),
186        vec2<f32>( 1.0,  1.0),
187        vec2<f32>(-1.0, -1.0),
188        vec2<f32>( 1.0,  1.0),
189        vec2<f32>(-1.0,  1.0)
190    );
191    let out_base = atomicAdd(&(indirect.vertex_count), 6u);
192    for (var i: u32 = 0u; i < 6u; i = i + 1u) {
193        var vertex: VertexRaw;
194        vertex.data[0u] = px;
195        vertex.data[1u] = py;
196        vertex.data[2u] = 0.0;
197        vertex.data[3u] = v_color.x;
198        vertex.data[4u] = v_color.y;
199        vertex.data[5u] = v_color.z;
200        vertex.data[6u] = v_color.w;
201        vertex.data[7u] = 0.0;
202        vertex.data[8u] = 0.0;
203        vertex.data[9u] = marker_size;
204        vertex.data[10u] = corners[i].x;
205        vertex.data[11u] = corners[i].y;
206        out_vertices[out_base + i] = vertex;
207    }
208}
209"#;