Skip to main content

runmat_plot/gpu/shaders/
contour_fill.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct ScalarPoint {
8    pos: vec2<f32>,
9    value: f32,
10};
11
12struct Poly {
13    points: array<ScalarPoint, 5u>,
14    count: u32,
15};
16
17struct ContourFillParams {
18    base_z: f32,
19    alpha: f32,
20    x_len: u32,
21    y_len: u32,
22    color_table_len: u32,
23    band_count: u32,
24    cell_count: u32,
25    _pad: u32,
26};
27
28struct IndirectArgs {
29    vertex_count: atomic<u32>,
30    instance_count: u32,
31    first_vertex: u32,
32    first_instance: u32,
33};
34
35@group(0) @binding(0)
36var<storage, read> buf_x: array<f32>;
37
38@group(0) @binding(1)
39var<storage, read> buf_y: array<f32>;
40
41@group(0) @binding(2)
42var<storage, read> buf_z: array<f32>;
43
44@group(0) @binding(3)
45var<storage, read> color_table: array<vec4<f32>>;
46
47@group(0) @binding(4)
48var<storage, read> level_values: array<f32>;
49
50@group(0) @binding(5)
51var<storage, read_write> out_vertices: array<VertexRaw>;
52
53@group(0) @binding(6)
54var<uniform> params: ContourFillParams;
55
56@group(0) @binding(7)
57var<storage, read_write> indirect: IndirectArgs;
58
59fn encode_vertex(position: vec3<f32>, color: vec4<f32>) -> VertexRaw {
60    var vertex: VertexRaw;
61    vertex.data[0u] = position.x;
62    vertex.data[1u] = position.y;
63    vertex.data[2u] = position.z;
64    vertex.data[3u] = color.x;
65    vertex.data[4u] = color.y;
66    vertex.data[5u] = color.z;
67    vertex.data[6u] = color.w;
68    vertex.data[7u] = 0.0;
69    vertex.data[8u] = 0.0;
70    vertex.data[9u] = 1.0;
71    vertex.data[10u] = 0.0;
72    vertex.data[11u] = 0.0;
73    return vertex;
74}
75
76fn make_point(pos: vec2<f32>, value: f32) -> ScalarPoint {
77    var point: ScalarPoint;
78    point.pos = pos;
79    point.value = value;
80    return point;
81}
82
83fn push_point(poly: ptr<function, Poly>, point: ScalarPoint) {
84    if ((*poly).count < 5u) {
85        (*poly).points[(*poly).count] = point;
86        (*poly).count = (*poly).count + 1u;
87    }
88}
89
90fn interpolate_point(a: ScalarPoint, b: ScalarPoint, threshold: f32) -> ScalarPoint {
91    let delta = b.value - a.value;
92    let t = if (abs(delta) <= 1e-6) { 0.5 } else { clamp((threshold - a.value) / delta, 0.0, 1.0) };
93    return make_point(mix(a.pos, b.pos, t), threshold);
94}
95
96fn inside_lower(value: f32, threshold: f32) -> bool {
97    return value >= threshold;
98}
99
100fn inside_upper(value: f32, threshold: f32, inclusive: bool) -> bool {
101    if (inclusive) {
102        return value <= threshold;
103    }
104    return value < threshold;
105}
106
107fn clip_lower(input: Poly, threshold: f32) -> Poly {
108    var out: Poly;
109    out.count = 0u;
110    if (input.count == 0u) {
111        return out;
112    }
113    var prev = input.points[input.count - 1u];
114    var prev_inside = inside_lower(prev.value, threshold);
115    for (var i: u32 = 0u; i < input.count; i = i + 1u) {
116        let curr = input.points[i];
117        let curr_inside = inside_lower(curr.value, threshold);
118        if (curr_inside != prev_inside) {
119            push_point(&out, interpolate_point(prev, curr, threshold));
120        }
121        if (curr_inside) {
122            push_point(&out, curr);
123        }
124        prev = curr;
125        prev_inside = curr_inside;
126    }
127    return out;
128}
129
130fn clip_upper(input: Poly, threshold: f32, inclusive: bool) -> Poly {
131    var out: Poly;
132    out.count = 0u;
133    if (input.count == 0u) {
134        return out;
135    }
136    var prev = input.points[input.count - 1u];
137    var prev_inside = inside_upper(prev.value, threshold, inclusive);
138    for (var i: u32 = 0u; i < input.count; i = i + 1u) {
139        let curr = input.points[i];
140        let curr_inside = inside_upper(curr.value, threshold, inclusive);
141        if (curr_inside != prev_inside) {
142            push_point(&out, interpolate_point(prev, curr, threshold));
143        }
144        if (curr_inside) {
145            push_point(&out, curr);
146        }
147        prev = curr;
148        prev_inside = curr_inside;
149    }
150    return out;
151}
152
153fn emit_triangle(a: vec2<f32>, b: vec2<f32>, c: vec2<f32>, color: vec4<f32>) {
154    let base = atomicAdd(&(indirect.vertex_count), 3u);
155    out_vertices[base] = encode_vertex(vec3<f32>(a, params.base_z), color);
156    out_vertices[base + 1u] = encode_vertex(vec3<f32>(b, params.base_z), color);
157    out_vertices[base + 2u] = encode_vertex(vec3<f32>(c, params.base_z), color);
158}
159
160fn emit_band_triangle(a: ScalarPoint, b: ScalarPoint, c: ScalarPoint, lo: f32, hi: f32, include_hi: bool, color: vec4<f32>) {
161    var poly: Poly;
162    poly.count = 3u;
163    poly.points[0u] = a;
164    poly.points[1u] = b;
165    poly.points[2u] = c;
166    let clipped_lower = clip_lower(poly, lo);
167    let clipped = clip_upper(clipped_lower, hi, include_hi);
168    if (clipped.count < 3u) {
169        return;
170    }
171    let origin = clipped.points[0u].pos;
172    for (var i: u32 = 1u; i + 1u < clipped.count; i = i + 1u) {
173        emit_triangle(origin, clipped.points[i].pos, clipped.points[i + 1u].pos, color);
174    }
175}
176
177@compute @workgroup_size(WORKGROUP_SIZE)
178fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
179    let total = params.cell_count * params.band_count;
180    let invocation = gid.x;
181    if (invocation >= total) {
182        return;
183    }
184
185    let band_idx = invocation % params.band_count;
186    let cell_idx = invocation / params.band_count;
187    let cells_x = params.x_len - 1u;
188    let row = cell_idx % cells_x;
189    let col = cell_idx / cells_x;
190    let base_index = row + col * params.x_len;
191    let idx00 = base_index;
192    let idx10 = idx00 + 1u;
193    let idx01 = idx00 + params.x_len;
194    let idx11 = idx01 + 1u;
195
196    let p0 = make_point(vec2<f32>(buf_x[row], buf_y[col]), buf_z[idx00]);
197    let p1 = make_point(vec2<f32>(buf_x[row + 1u], buf_y[col]), buf_z[idx10]);
198    let p2 = make_point(vec2<f32>(buf_x[row + 1u], buf_y[col + 1u]), buf_z[idx11]);
199    let p3 = make_point(vec2<f32>(buf_x[row], buf_y[col + 1u]), buf_z[idx01]);
200    let lo = level_values[band_idx];
201    let hi = level_values[band_idx + 1u];
202    let include_hi = band_idx + 1u == params.band_count;
203    let base_color = color_table[min(band_idx, params.color_table_len - 1u)];
204    let color = vec4<f32>(base_color.xyz, base_color.w * params.alpha);
205    emit_band_triangle(p0, p1, p2, lo, hi, include_hi, color);
206    emit_band_triangle(p0, p2, p3, lo, hi, include_hi, color);
207}
208"#;
209
210pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
211
212struct VertexRaw {
213    data: array<f32, 12u>,
214};
215
216struct ScalarPoint {
217    pos: vec2<f32>,
218    value: f32,
219};
220
221struct Poly {
222    points: array<ScalarPoint, 5u>,
223    count: u32,
224};
225
226struct ContourFillParams {
227    base_z: f32,
228    alpha: f32,
229    x_len: u32,
230    y_len: u32,
231    color_table_len: u32,
232    band_count: u32,
233    cell_count: u32,
234    _pad: u32,
235};
236
237struct IndirectArgs {
238    vertex_count: atomic<u32>,
239    instance_count: u32,
240    first_vertex: u32,
241    first_instance: u32,
242};
243
244@group(0) @binding(0)
245var<storage, read> buf_x: array<f64>;
246
247@group(0) @binding(1)
248var<storage, read> buf_y: array<f64>;
249
250@group(0) @binding(2)
251var<storage, read> buf_z: array<f64>;
252
253@group(0) @binding(3)
254var<storage, read> color_table: array<vec4<f32>>;
255
256@group(0) @binding(4)
257var<storage, read> level_values: array<f32>;
258
259@group(0) @binding(5)
260var<storage, read_write> out_vertices: array<VertexRaw>;
261
262@group(0) @binding(6)
263var<uniform> params: ContourFillParams;
264
265@group(0) @binding(7)
266var<storage, read_write> indirect: IndirectArgs;
267
268fn encode_vertex(position: vec3<f32>, color: vec4<f32>) -> VertexRaw {
269    var vertex: VertexRaw;
270    vertex.data[0u] = position.x;
271    vertex.data[1u] = position.y;
272    vertex.data[2u] = position.z;
273    vertex.data[3u] = color.x;
274    vertex.data[4u] = color.y;
275    vertex.data[5u] = color.z;
276    vertex.data[6u] = color.w;
277    vertex.data[7u] = 0.0;
278    vertex.data[8u] = 0.0;
279    vertex.data[9u] = 1.0;
280    vertex.data[10u] = 0.0;
281    vertex.data[11u] = 0.0;
282    return vertex;
283}
284
285fn make_point(pos: vec2<f32>, value: f32) -> ScalarPoint {
286    var point: ScalarPoint;
287    point.pos = pos;
288    point.value = value;
289    return point;
290}
291
292fn push_point(poly: ptr<function, Poly>, point: ScalarPoint) {
293    if ((*poly).count < 5u) {
294        (*poly).points[(*poly).count] = point;
295        (*poly).count = (*poly).count + 1u;
296    }
297}
298
299fn interpolate_point(a: ScalarPoint, b: ScalarPoint, threshold: f32) -> ScalarPoint {
300    let delta = b.value - a.value;
301    let t = if (abs(delta) <= 1e-6) { 0.5 } else { clamp((threshold - a.value) / delta, 0.0, 1.0) };
302    return make_point(mix(a.pos, b.pos, t), threshold);
303}
304
305fn inside_lower(value: f32, threshold: f32) -> bool {
306    return value >= threshold;
307}
308
309fn inside_upper(value: f32, threshold: f32, inclusive: bool) -> bool {
310    if (inclusive) {
311        return value <= threshold;
312    }
313    return value < threshold;
314}
315
316fn clip_lower(input: Poly, threshold: f32) -> Poly {
317    var out: Poly;
318    out.count = 0u;
319    if (input.count == 0u) {
320        return out;
321    }
322    var prev = input.points[input.count - 1u];
323    var prev_inside = inside_lower(prev.value, threshold);
324    for (var i: u32 = 0u; i < input.count; i = i + 1u) {
325        let curr = input.points[i];
326        let curr_inside = inside_lower(curr.value, threshold);
327        if (curr_inside != prev_inside) {
328            push_point(&out, interpolate_point(prev, curr, threshold));
329        }
330        if (curr_inside) {
331            push_point(&out, curr);
332        }
333        prev = curr;
334        prev_inside = curr_inside;
335    }
336    return out;
337}
338
339fn clip_upper(input: Poly, threshold: f32, inclusive: bool) -> Poly {
340    var out: Poly;
341    out.count = 0u;
342    if (input.count == 0u) {
343        return out;
344    }
345    var prev = input.points[input.count - 1u];
346    var prev_inside = inside_upper(prev.value, threshold, inclusive);
347    for (var i: u32 = 0u; i < input.count; i = i + 1u) {
348        let curr = input.points[i];
349        let curr_inside = inside_upper(curr.value, threshold, inclusive);
350        if (curr_inside != prev_inside) {
351            push_point(&out, interpolate_point(prev, curr, threshold));
352        }
353        if (curr_inside) {
354            push_point(&out, curr);
355        }
356        prev = curr;
357        prev_inside = curr_inside;
358    }
359    return out;
360}
361
362fn emit_triangle(a: vec2<f32>, b: vec2<f32>, c: vec2<f32>, color: vec4<f32>) {
363    let base = atomicAdd(&(indirect.vertex_count), 3u);
364    out_vertices[base] = encode_vertex(vec3<f32>(a, params.base_z), color);
365    out_vertices[base + 1u] = encode_vertex(vec3<f32>(b, params.base_z), color);
366    out_vertices[base + 2u] = encode_vertex(vec3<f32>(c, params.base_z), color);
367}
368
369fn emit_band_triangle(a: ScalarPoint, b: ScalarPoint, c: ScalarPoint, lo: f32, hi: f32, include_hi: bool, color: vec4<f32>) {
370    var poly: Poly;
371    poly.count = 3u;
372    poly.points[0u] = a;
373    poly.points[1u] = b;
374    poly.points[2u] = c;
375    let clipped_lower = clip_lower(poly, lo);
376    let clipped = clip_upper(clipped_lower, hi, include_hi);
377    if (clipped.count < 3u) {
378        return;
379    }
380    let origin = clipped.points[0u].pos;
381    for (var i: u32 = 1u; i + 1u < clipped.count; i = i + 1u) {
382        emit_triangle(origin, clipped.points[i].pos, clipped.points[i + 1u].pos, color);
383    }
384}
385
386@compute @workgroup_size(WORKGROUP_SIZE)
387fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
388    let total = params.cell_count * params.band_count;
389    let invocation = gid.x;
390    if (invocation >= total) {
391        return;
392    }
393
394    let band_idx = invocation % params.band_count;
395    let cell_idx = invocation / params.band_count;
396    let cells_x = params.x_len - 1u;
397    let row = cell_idx % cells_x;
398    let col = cell_idx / cells_x;
399    let base_index = row + col * params.x_len;
400    let idx00 = base_index;
401    let idx10 = idx00 + 1u;
402    let idx01 = idx00 + params.x_len;
403    let idx11 = idx01 + 1u;
404
405    let p0 = make_point(vec2<f32>(f32(buf_x[row]), f32(buf_y[col])), f32(buf_z[idx00]));
406    let p1 = make_point(vec2<f32>(f32(buf_x[row + 1u]), f32(buf_y[col])), f32(buf_z[idx10]));
407    let p2 = make_point(vec2<f32>(f32(buf_x[row + 1u]), f32(buf_y[col + 1u])), f32(buf_z[idx11]));
408    let p3 = make_point(vec2<f32>(f32(buf_x[row]), f32(buf_y[col + 1u])), f32(buf_z[idx01]));
409    let lo = level_values[band_idx];
410    let hi = level_values[band_idx + 1u];
411    let include_hi = band_idx + 1u == params.band_count;
412    let base_color = color_table[min(band_idx, params.color_table_len - 1u)];
413    let color = vec4<f32>(base_color.xyz, base_color.w * params.alpha);
414    emit_band_triangle(p0, p1, p2, lo, hi, include_hi, color);
415    emit_band_triangle(p0, p2, p3, lo, hi, include_hi, color);
416}
417"#;