Skip to main content

runmat_plot/gpu/shaders/
scatter2.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct ScatterParams {
8    color: vec4<f32>,
9    point_size: f32,
10    count: u32,
11    lod_stride: u32,
12    has_sizes: u32,
13    has_colors: u32,
14    color_stride: u32,
15};
16
17struct IndirectArgs {
18    vertex_count: atomic<u32>,
19    instance_count: u32,
20    first_vertex: u32,
21    first_instance: u32,
22};
23
24@group(0) @binding(0)
25var<storage, read> buf_x: array<f32>;
26
27@group(0) @binding(1)
28var<storage, read> buf_y: array<f32>;
29
30@group(0) @binding(2)
31var<storage, read_write> out_vertices: array<VertexRaw>;
32
33@group(0) @binding(3)
34var<uniform> params: ScatterParams;
35
36@group(0) @binding(4)
37var<storage, read> buf_sizes: array<f32>;
38
39@group(0) @binding(5)
40var<storage, read> buf_colors: array<f32>;
41
42@group(0) @binding(6)
43var<storage, read_write> indirect: IndirectArgs;
44
45@compute @workgroup_size(WORKGROUP_SIZE)
46fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
47    let idx = gid.x;
48    if (idx >= params.count) {
49        return;
50    }
51    let stride = max(params.lod_stride, 1u);
52    if ((idx % stride) != 0u) {
53        return;
54    }
55
56    let px = buf_x[idx];
57    let py = buf_y[idx];
58
59    var vertex: VertexRaw;
60    vertex.data[0u] = px;
61    vertex.data[1u] = py;
62    vertex.data[2u] = 0.0;
63
64    var v_color = params.color;
65    if (params.has_colors != 0u) {
66        let base = idx * params.color_stride;
67        let r = buf_colors[base];
68        let g = buf_colors[base + 1u];
69        let b = buf_colors[base + 2u];
70        let a = if params.color_stride > 3u {
71            buf_colors[base + 3u]
72        } else {
73            1.0
74        };
75        v_color = vec4<f32>(r, g, b, a);
76    }
77
78    let mut marker_size = params.point_size;
79    if (params.has_sizes != 0u) {
80        marker_size = buf_sizes[idx];
81    }
82
83    vertex.data[3u] = v_color.x;
84    vertex.data[4u] = v_color.y;
85    vertex.data[5u] = v_color.z;
86    vertex.data[6u] = v_color.w;
87    vertex.data[7u] = 0.0;
88    vertex.data[8u] = 0.0;
89    vertex.data[9u] = marker_size;
90    vertex.data[10u] = 0.0;
91    vertex.data[11u] = 0.0;
92
93    let out_idx = atomicAdd(&(indirect.vertex_count), 1u);
94    out_vertices[out_idx] = vertex;
95}
96"#;
97
98pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
99
100struct VertexRaw {
101    data: array<f32, 12u>,
102};
103
104struct ScatterParams {
105    color: vec4<f32>,
106    point_size: f32,
107    count: u32,
108    lod_stride: u32,
109    has_sizes: u32,
110    has_colors: u32,
111    color_stride: u32,
112};
113
114struct IndirectArgs {
115    vertex_count: atomic<u32>,
116    instance_count: u32,
117    first_vertex: u32,
118    first_instance: u32,
119};
120
121@group(0) @binding(0)
122var<storage, read> buf_x: array<f64>;
123
124@group(0) @binding(1)
125var<storage, read> buf_y: array<f64>;
126
127@group(0) @binding(2)
128var<storage, read_write> out_vertices: array<VertexRaw>;
129
130@group(0) @binding(3)
131var<uniform> params: ScatterParams;
132
133@group(0) @binding(4)
134var<storage, read> buf_sizes: array<f32>;
135
136@group(0) @binding(5)
137var<storage, read> buf_colors: array<f32>;
138
139@group(0) @binding(6)
140var<storage, read_write> indirect: IndirectArgs;
141
142@compute @workgroup_size(WORKGROUP_SIZE)
143fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
144    let idx = gid.x;
145    if (idx >= params.count) {
146        return;
147    }
148    let stride = max(params.lod_stride, 1u);
149    if ((idx % stride) != 0u) {
150        return;
151    }
152
153    let px = f32(buf_x[idx]);
154    let py = f32(buf_y[idx]);
155
156    var vertex: VertexRaw;
157    vertex.data[0u] = px;
158    vertex.data[1u] = py;
159    vertex.data[2u] = 0.0;
160
161    var v_color = params.color;
162    if (params.has_colors != 0u) {
163        let base = idx * params.color_stride;
164        let r = buf_colors[base];
165        let g = buf_colors[base + 1u];
166        let b = buf_colors[base + 2u];
167        let a = if params.color_stride > 3u {
168            buf_colors[base + 3u]
169        } else {
170            1.0
171        };
172        v_color = vec4<f32>(r, g, b, a);
173    }
174
175    let mut marker_size = params.point_size;
176    if (params.has_sizes != 0u) {
177        marker_size = buf_sizes[idx];
178    }
179
180    vertex.data[3u] = v_color.x;
181    vertex.data[4u] = v_color.y;
182    vertex.data[5u] = v_color.z;
183    vertex.data[6u] = v_color.w;
184    vertex.data[7u] = 0.0;
185    vertex.data[8u] = 0.0;
186    vertex.data[9u] = marker_size;
187    vertex.data[10u] = 0.0;
188    vertex.data[11u] = 0.0;
189
190    let out_idx = atomicAdd(&(indirect.vertex_count), 1u);
191    out_vertices[out_idx] = vertex;
192}
193"#;