Skip to main content

runmat_plot/gpu/shaders/
scatter3.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct Scatter3Params {
8    color: vec4<f32>,
9    point_size: f32,
10    count: u32,
11    lod_stride: u32,
12    has_sizes: u32,
13    has_colors: u32,
14    color_stride: u32,
15};
16
17struct IndirectArgs {
18    vertex_count: atomic<u32>,
19    instance_count: u32,
20    first_vertex: u32,
21    first_instance: u32,
22};
23
24@group(0) @binding(0)
25var<storage, read> buf_x: array<f32>;
26
27@group(0) @binding(1)
28var<storage, read> buf_y: array<f32>;
29
30@group(0) @binding(2)
31var<storage, read> buf_z: array<f32>;
32
33@group(0) @binding(3)
34var<storage, read_write> out_vertices: array<VertexRaw>;
35
36@group(0) @binding(4)
37var<uniform> params: Scatter3Params;
38
39@group(0) @binding(5)
40var<storage, read> buf_sizes: array<f32>;
41
42@group(0) @binding(6)
43var<storage, read> buf_colors: array<f32>;
44
45@group(0) @binding(7)
46var<storage, read_write> indirect: IndirectArgs;
47
48@compute @workgroup_size(WORKGROUP_SIZE)
49fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
50    let idx = gid.x;
51    if (idx >= params.count) {
52        return;
53    }
54    let stride = max(params.lod_stride, 1u);
55    if ((idx % stride) != 0u) {
56        return;
57    }
58
59    let px = buf_x[idx];
60    let py = buf_y[idx];
61    let pz = buf_z[idx];
62
63    var vertex: VertexRaw;
64    vertex.data[0u] = px;
65    vertex.data[1u] = py;
66    vertex.data[2u] = pz;
67
68    var v_color = params.color;
69    if (params.has_colors != 0u) {
70        let base = idx * params.color_stride;
71        let r = buf_colors[base];
72        let g = buf_colors[base + 1u];
73        let b = buf_colors[base + 2u];
74        let a = if params.color_stride > 3u {
75            buf_colors[base + 3u]
76        } else {
77            1.0
78        };
79        v_color = vec4<f32>(r, g, b, a);
80    }
81
82    let mut point_size = params.point_size;
83    if (params.has_sizes != 0u) {
84        point_size = buf_sizes[idx];
85    }
86
87    vertex.data[3u] = v_color.x;
88    vertex.data[4u] = v_color.y;
89    vertex.data[5u] = v_color.z;
90    vertex.data[6u] = v_color.w;
91    vertex.data[7u] = 0.0;
92    vertex.data[8u] = 0.0;
93    vertex.data[9u] = point_size;
94    vertex.data[10u] = 0.0;
95    vertex.data[11u] = 0.0;
96
97    let out_idx = atomicAdd(&(indirect.vertex_count), 1u);
98    out_vertices[out_idx] = vertex;
99}
100"#;
101
102pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
103
104struct VertexRaw {
105    data: array<f32, 12u>,
106};
107
108struct Scatter3Params {
109    color: vec4<f32>,
110    point_size: f32,
111    count: u32,
112    lod_stride: u32,
113    has_sizes: u32,
114    has_colors: u32,
115    color_stride: u32,
116};
117
118struct IndirectArgs {
119    vertex_count: atomic<u32>,
120    instance_count: u32,
121    first_vertex: u32,
122    first_instance: u32,
123};
124
125@group(0) @binding(0)
126var<storage, read> buf_x: array<f64>;
127
128@group(0) @binding(1)
129var<storage, read> buf_y: array<f64>;
130
131@group(0) @binding(2)
132var<storage, read> buf_z: array<f64>;
133
134@group(0) @binding(3)
135var<storage, read_write> out_vertices: array<VertexRaw>;
136
137@group(0) @binding(4)
138var<uniform> params: Scatter3Params;
139
140@group(0) @binding(5)
141var<storage, read> buf_sizes: array<f32>;
142
143@group(0) @binding(6)
144var<storage, read> buf_colors: array<f32>;
145
146@group(0) @binding(7)
147var<storage, read_write> indirect: IndirectArgs;
148
149@compute @workgroup_size(WORKGROUP_SIZE)
150fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
151    let idx = gid.x;
152    if (idx >= params.count) {
153        return;
154    }
155    let stride = max(params.lod_stride, 1u);
156    if ((idx % stride) != 0u) {
157        return;
158    }
159
160    let px = f32(buf_x[idx]);
161    let py = f32(buf_y[idx]);
162    let pz = f32(buf_z[idx]);
163
164    var vertex: VertexRaw;
165    vertex.data[0u] = px;
166    vertex.data[1u] = py;
167    vertex.data[2u] = pz;
168
169    var v_color = params.color;
170    if (params.has_colors != 0u) {
171        let base = idx * params.color_stride;
172        let r = buf_colors[base];
173        let g = buf_colors[base + 1u];
174        let b = buf_colors[base + 2u];
175        let a = if params.color_stride > 3u {
176            buf_colors[base + 3u]
177        } else {
178            1.0
179        };
180        v_color = vec4<f32>(r, g, b, a);
181    }
182
183    let mut point_size = params.point_size;
184    if (params.has_sizes != 0u) {
185        point_size = buf_sizes[idx];
186    }
187
188    vertex.data[3u] = v_color.x;
189    vertex.data[4u] = v_color.y;
190    vertex.data[5u] = v_color.z;
191    vertex.data[6u] = v_color.w;
192    vertex.data[7u] = 0.0;
193    vertex.data[8u] = 0.0;
194    vertex.data[9u] = point_size;
195    vertex.data[10u] = 0.0;
196    vertex.data[11u] = 0.0;
197
198    let out_idx = atomicAdd(&(indirect.vertex_count), 1u);
199    out_vertices[out_idx] = vertex;
200}
201"#;