Skip to main content

runmat_plot/gpu/shaders/
scatter3.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct Scatter3Params {
8    color: vec4<f32>,
9    point_size: f32,
10    count: u32,
11    lod_stride: u32,
12    has_sizes: u32,
13    has_colors: u32,
14    color_stride: u32,
15};
16
17struct IndirectArgs {
18    vertex_count: atomic<u32>,
19    instance_count: u32,
20    first_vertex: u32,
21    first_instance: u32,
22};
23
24@group(0) @binding(0)
25var<storage, read> buf_x: array<f32>;
26
27@group(0) @binding(1)
28var<storage, read> buf_y: array<f32>;
29
30@group(0) @binding(2)
31var<storage, read> buf_z: array<f32>;
32
33@group(0) @binding(3)
34var<storage, read_write> out_vertices: array<VertexRaw>;
35
36@group(0) @binding(4)
37var<uniform> params: Scatter3Params;
38
39@group(0) @binding(5)
40var<storage, read> buf_sizes: array<f32>;
41
42@group(0) @binding(6)
43var<storage, read> buf_colors: array<f32>;
44
45@group(0) @binding(7)
46var<storage, read_write> indirect: IndirectArgs;
47
48@compute @workgroup_size(WORKGROUP_SIZE)
49fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
50    let idx = gid.x;
51    if (idx >= params.count) {
52        return;
53    }
54    let stride = max(params.lod_stride, 1u);
55    if ((idx % stride) != 0u) {
56        return;
57    }
58
59    let px = buf_x[idx];
60    let py = buf_y[idx];
61    let pz = buf_z[idx];
62
63    var v_color = params.color;
64    if (params.has_colors != 0u) {
65        let base = idx * params.color_stride;
66        let r = buf_colors[base];
67        let g = buf_colors[base + 1u];
68        let b = buf_colors[base + 2u];
69        let a = if params.color_stride > 3u {
70            buf_colors[base + 3u]
71        } else {
72            1.0
73        };
74        v_color = vec4<f32>(r, g, b, a);
75    }
76
77    let mut point_size = params.point_size;
78    if (params.has_sizes != 0u) {
79        point_size = buf_sizes[idx];
80    }
81
82    let corners = array<vec2<f32>, 6u>(
83        vec2<f32>(-1.0, -1.0),
84        vec2<f32>( 1.0, -1.0),
85        vec2<f32>( 1.0,  1.0),
86        vec2<f32>(-1.0, -1.0),
87        vec2<f32>( 1.0,  1.0),
88        vec2<f32>(-1.0,  1.0)
89    );
90    let out_base = atomicAdd(&(indirect.vertex_count), 6u);
91    for (var i: u32 = 0u; i < 6u; i = i + 1u) {
92        var vertex: VertexRaw;
93        vertex.data[0u] = px;
94        vertex.data[1u] = py;
95        vertex.data[2u] = pz;
96        vertex.data[3u] = v_color.x;
97        vertex.data[4u] = v_color.y;
98        vertex.data[5u] = v_color.z;
99        vertex.data[6u] = v_color.w;
100        vertex.data[7u] = 0.0;
101        vertex.data[8u] = 0.0;
102        vertex.data[9u] = point_size;
103        vertex.data[10u] = corners[i].x;
104        vertex.data[11u] = corners[i].y;
105        out_vertices[out_base + i] = vertex;
106    }
107}
108"#;
109
110pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
111
112struct VertexRaw {
113    data: array<f32, 12u>,
114};
115
116struct Scatter3Params {
117    color: vec4<f32>,
118    point_size: f32,
119    count: u32,
120    lod_stride: u32,
121    has_sizes: u32,
122    has_colors: u32,
123    color_stride: u32,
124};
125
126struct IndirectArgs {
127    vertex_count: atomic<u32>,
128    instance_count: u32,
129    first_vertex: u32,
130    first_instance: u32,
131};
132
133@group(0) @binding(0)
134var<storage, read> buf_x: array<f64>;
135
136@group(0) @binding(1)
137var<storage, read> buf_y: array<f64>;
138
139@group(0) @binding(2)
140var<storage, read> buf_z: array<f64>;
141
142@group(0) @binding(3)
143var<storage, read_write> out_vertices: array<VertexRaw>;
144
145@group(0) @binding(4)
146var<uniform> params: Scatter3Params;
147
148@group(0) @binding(5)
149var<storage, read> buf_sizes: array<f32>;
150
151@group(0) @binding(6)
152var<storage, read> buf_colors: array<f32>;
153
154@group(0) @binding(7)
155var<storage, read_write> indirect: IndirectArgs;
156
157@compute @workgroup_size(WORKGROUP_SIZE)
158fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
159    let idx = gid.x;
160    if (idx >= params.count) {
161        return;
162    }
163    let stride = max(params.lod_stride, 1u);
164    if ((idx % stride) != 0u) {
165        return;
166    }
167
168    let px = f32(buf_x[idx]);
169    let py = f32(buf_y[idx]);
170    let pz = f32(buf_z[idx]);
171
172    var v_color = params.color;
173    if (params.has_colors != 0u) {
174        let base = idx * params.color_stride;
175        let r = buf_colors[base];
176        let g = buf_colors[base + 1u];
177        let b = buf_colors[base + 2u];
178        let a = if params.color_stride > 3u {
179            buf_colors[base + 3u]
180        } else {
181            1.0
182        };
183        v_color = vec4<f32>(r, g, b, a);
184    }
185
186    let mut point_size = params.point_size;
187    if (params.has_sizes != 0u) {
188        point_size = buf_sizes[idx];
189    }
190
191    let corners = array<vec2<f32>, 6u>(
192        vec2<f32>(-1.0, -1.0),
193        vec2<f32>( 1.0, -1.0),
194        vec2<f32>( 1.0,  1.0),
195        vec2<f32>(-1.0, -1.0),
196        vec2<f32>( 1.0,  1.0),
197        vec2<f32>(-1.0,  1.0)
198    );
199    let out_base = atomicAdd(&(indirect.vertex_count), 6u);
200    for (var i: u32 = 0u; i < 6u; i = i + 1u) {
201        var vertex: VertexRaw;
202        vertex.data[0u] = px;
203        vertex.data[1u] = py;
204        vertex.data[2u] = pz;
205        vertex.data[3u] = v_color.x;
206        vertex.data[4u] = v_color.y;
207        vertex.data[5u] = v_color.z;
208        vertex.data[6u] = v_color.w;
209        vertex.data[7u] = 0.0;
210        vertex.data[8u] = 0.0;
211        vertex.data[9u] = point_size;
212        vertex.data[10u] = corners[i].x;
213        vertex.data[11u] = corners[i].y;
214        out_vertices[out_base + i] = vertex;
215    }
216}
217"#;