Skip to main content

runmat_plot/gpu/shaders/
contour.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3const VERTICES_PER_INVOCATION: u32 = 4u;
4
5struct VertexRaw {
6    data: array<f32, 12u>,
7};
8
9struct ContourParams {
10    min_z: f32,
11    max_z: f32,
12    base_z: f32,
13    level_count: u32,
14    x_len: u32,
15    y_len: u32,
16    color_table_len: u32,
17    cell_count: u32,
18};
19
20@group(0) @binding(0)
21var<storage, read> buf_x: array<f32>;
22
23@group(0) @binding(1)
24var<storage, read> buf_y: array<f32>;
25
26@group(0) @binding(2)
27var<storage, read> buf_z: array<f32>;
28
29@group(0) @binding(3)
30var<storage, read> color_table: array<vec4<f32>>;
31
32@group(0) @binding(4)
33var<storage, read_write> out_vertices: array<VertexRaw>;
34
35@group(0) @binding(5)
36var<uniform> params: ContourParams;
37
38@group(0) @binding(6)
39var<storage, read> level_values: array<f32>;
40
41fn sample_color(t: f32) -> vec4<f32> {
42    let table_len = params.color_table_len;
43    if (table_len <= 1u) {
44        return color_table[0u];
45    }
46    let clamped = clamp(t, 0.0, 1.0);
47    let scaled = clamped * f32(table_len - 1u);
48    let lower = u32(scaled);
49    let upper = min(lower + 1u, table_len - 1u);
50    let frac = scaled - f32(lower);
51    return mix(color_table[lower], color_table[upper], frac);
52}
53
54fn encode_vertex(position: vec3<f32>, color: vec4<f32>) -> VertexRaw {
55    var vertex: VertexRaw;
56    vertex.data[0u] = position.x;
57    vertex.data[1u] = position.y;
58    vertex.data[2u] = position.z;
59    vertex.data[3u] = color.x;
60    vertex.data[4u] = color.y;
61    vertex.data[5u] = color.z;
62    vertex.data[6u] = color.w;
63    vertex.data[7u] = 0.0;
64    vertex.data[8u] = 0.0;
65    vertex.data[9u] = 1.0;
66    vertex.data[10u] = 0.0;
67    vertex.data[11u] = 0.0;
68    return vertex;
69}
70
71fn interpolate_edge(edge: u32, corners: array<vec2<f32>, 4>, values: array<f32, 4>, level: f32) -> vec2<f32> {
72    var a: vec2<f32>;
73    var b: vec2<f32>;
74    var va: f32;
75    var vb: f32;
76    switch edge {
77        case 0u: { a = corners[0u]; b = corners[1u]; va = values[0u]; vb = values[1u]; }
78        case 1u: { a = corners[1u]; b = corners[2u]; va = values[1u]; vb = values[2u]; }
79        case 2u: { a = corners[2u]; b = corners[3u]; va = values[2u]; vb = values[3u]; }
80        default: { a = corners[3u]; b = corners[0u]; va = values[3u]; vb = values[0u]; }
81    }
82    let delta = vb - va;
83    let t = if (abs(delta) <= 1e-6) { 0.5 } else { clamp((level - va) / delta, 0.0, 1.0) };
84    return mix(a, b, t);
85}
86
87fn add_ambiguous_segments(
88    case_index: u32,
89    corners: array<vec2<f32>, 4>,
90    values: array<f32, 4>,
91    level: f32,
92    io_segments: ptr<function, array<vec2<f32>, 4>>,
93    io_count: ptr<function, u32>,
94) {
95    let f00 = values[0u] - level;
96    let f10 = values[1u] - level;
97    let f11 = values[2u] - level;
98    let f01 = values[3u] - level;
99    let q = f00 * f11 - f10 * f01;
100    let use_default = q > 0.0 || (abs(q) <= 1e-6 && case_index == 5u);
101    if (use_default) {
102        add_segment(3u, 2u, corners, values, level, io_segments, io_count);
103        add_segment(0u, 1u, corners, values, level, io_segments, io_count);
104    } else {
105        add_segment(3u, 0u, corners, values, level, io_segments, io_count);
106        add_segment(1u, 2u, corners, values, level, io_segments, io_count);
107    }
108}
109
110fn write_vertex_range(base_index: u32, segment_points: array<vec2<f32>, 4>, segment_count: u32, color: vec4<f32>) {
111    for (var i: u32 = 0u; i < VERTICES_PER_INVOCATION; i = i + 1u) {
112        let idx = base_index + i;
113        let vertex = if (i < segment_count * 2u) {
114            let pt = segment_points[i];
115            encode_vertex(vec3<f32>(pt, params.base_z), color)
116        } else {
117            encode_vertex(vec3<f32>(0.0, 0.0, params.base_z), vec4<f32>(color.xyz, 0.0))
118        };
119        out_vertices[idx] = vertex;
120    }
121}
122
123fn add_segment(
124    edge_a: u32,
125    edge_b: u32,
126    corners: array<vec2<f32>, 4>,
127    values: array<f32, 4>,
128    level: f32,
129    io_segments: ptr<function, array<vec2<f32>, 4>>,
130    io_count: ptr<function, u32>,
131) {
132    if (*io_count) >= 2u {
133        return;
134    }
135    let idx = (*io_count) * 2u;
136    (*io_segments)[idx] = interpolate_edge(edge_a, corners, values, level);
137    (*io_segments)[idx + 1u] = interpolate_edge(edge_b, corners, values, level);
138    *io_count = *io_count + 1u;
139}
140
141@compute @workgroup_size(WORKGROUP_SIZE)
142fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
143    let total = params.cell_count * params.level_count;
144    let invocation = gid.x;
145    if (invocation >= total) {
146        return;
147    }
148
149    let level_idx = invocation % params.level_count;
150    let cell_idx = invocation / params.level_count;
151    let cells_x = params.x_len - 1u;
152    let row = cell_idx % cells_x;
153    let col = cell_idx / cells_x;
154
155    let base_index = row + col * params.x_len;
156    let idx00 = base_index;
157    let idx10 = idx00 + 1u;
158    let idx01 = idx00 + params.x_len;
159    let idx11 = idx01 + 1u;
160
161    let x0 = buf_x[row];
162    let x1 = buf_x[row + 1u];
163    let y0 = buf_y[col];
164    let y1 = buf_y[col + 1u];
165
166    let z00 = buf_z[idx00];
167    let z10 = buf_z[idx10];
168    let z11 = buf_z[idx11];
169    let z01 = buf_z[idx01];
170
171    let corners = array<vec2<f32>, 4>(
172        vec2<f32>(x0, y0),
173        vec2<f32>(x1, y0),
174        vec2<f32>(x1, y1),
175        vec2<f32>(x0, y1)
176    );
177    let values = array<f32, 4>(z00, z10, z11, z01);
178
179    let level = level_values[level_idx];
180
181    var case_index: u32 = 0u;
182    if (z00 > level) { case_index = case_index | 1u; }
183    if (z10 > level) { case_index = case_index | 2u; }
184    if (z11 > level) { case_index = case_index | 4u; }
185    if (z01 > level) { case_index = case_index | 8u; }
186
187    var segments: array<vec2<f32>, 4>;
188    var segment_count: u32 = 0u;
189
190    switch case_index {
191        case 0u, 15u: {}
192        case 1u, 14u: { add_segment(3u, 0u, corners, values, level, &segments, &segment_count); }
193        case 2u, 13u: { add_segment(0u, 1u, corners, values, level, &segments, &segment_count); }
194        case 3u, 12u: { add_segment(3u, 1u, corners, values, level, &segments, &segment_count); }
195        case 4u, 11u: { add_segment(1u, 2u, corners, values, level, &segments, &segment_count); }
196        case 5u, 10u: { add_ambiguous_segments(case_index, corners, values, level, &segments, &segment_count); }
197        case 6u, 9u: { add_segment(0u, 2u, corners, values, level, &segments, &segment_count); }
198        case 7u, 8u: { add_segment(3u, 2u, corners, values, level, &segments, &segment_count); }
199    }
200
201    let norm = (level - params.min_z) / max(params.max_z - params.min_z, 1e-6);
202    let color = sample_color(norm);
203    let base_vertex = invocation * VERTICES_PER_INVOCATION;
204    write_vertex_range(base_vertex, segments, segment_count, color);
205}
206"#;
207
208pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
209
210const VERTICES_PER_INVOCATION: u32 = 4u;
211
212struct VertexRaw {
213    data: array<f32, 12u>,
214};
215
216struct ContourParams {
217    min_z: f32,
218    max_z: f32,
219    base_z: f32,
220    level_count: u32,
221    x_len: u32,
222    y_len: u32,
223    color_table_len: u32,
224    cell_count: u32,
225};
226
227@group(0) @binding(0)
228var<storage, read> buf_x: array<f64>;
229
230@group(0) @binding(1)
231var<storage, read> buf_y: array<f64>;
232
233@group(0) @binding(2)
234var<storage, read> buf_z: array<f64>;
235
236@group(0) @binding(3)
237var<storage, read> color_table: array<vec4<f32>>;
238
239@group(0) @binding(4)
240var<storage, read_write> out_vertices: array<VertexRaw>;
241
242@group(0) @binding(5)
243var<uniform> params: ContourParams;
244
245@group(0) @binding(6)
246var<storage, read> level_values: array<f32>;
247
248fn sample_color(t: f32) -> vec4<f32> {
249    let table_len = params.color_table_len;
250    if (table_len <= 1u) {
251        return color_table[0u];
252    }
253    let clamped = clamp(t, 0.0, 1.0);
254    let scaled = clamped * f32(table_len - 1u);
255    let lower = u32(scaled);
256    let upper = min(lower + 1u, table_len - 1u);
257    let frac = scaled - f32(lower);
258    return mix(color_table[lower], color_table[upper], frac);
259}
260
261fn encode_vertex(position: vec3<f32>, color: vec4<f32>) -> VertexRaw {
262    var vertex: VertexRaw;
263    vertex.data[0u] = position.x;
264    vertex.data[1u] = position.y;
265    vertex.data[2u] = position.z;
266    vertex.data[3u] = color.x;
267    vertex.data[4u] = color.y;
268    vertex.data[5u] = color.z;
269    vertex.data[6u] = color.w;
270    vertex.data[7u] = 0.0;
271    vertex.data[8u] = 0.0;
272    vertex.data[9u] = 1.0;
273    vertex.data[10u] = 0.0;
274    vertex.data[11u] = 0.0;
275    return vertex;
276}
277
278fn interpolate_edge(edge: u32, corners: array<vec2<f32>, 4>, values: array<f32, 4>, level: f32) -> vec2<f32> {
279    var a: vec2<f32>;
280    var b: vec2<f32>;
281    var va: f32;
282    var vb: f32;
283    switch edge {
284        case 0u: { a = corners[0u]; b = corners[1u]; va = values[0u]; vb = values[1u]; }
285        case 1u: { a = corners[1u]; b = corners[2u]; va = values[1u]; vb = values[2u]; }
286        case 2u: { a = corners[2u]; b = corners[3u]; va = values[2u]; vb = values[3u]; }
287        default: { a = corners[3u]; b = corners[0u]; va = values[3u]; vb = values[0u]; }
288    }
289    let delta = vb - va;
290    let t = if (abs(delta) <= 1e-6) { 0.5 } else { clamp((level - va) / delta, 0.0, 1.0) };
291    return mix(a, b, t);
292}
293
294fn add_ambiguous_segments(
295    case_index: u32,
296    corners: array<vec2<f32>, 4>,
297    values: array<f32, 4>,
298    level: f32,
299    io_segments: ptr<function, array<vec2<f32>, 4>>,
300    io_count: ptr<function, u32>,
301) {
302    let f00 = values[0u] - level;
303    let f10 = values[1u] - level;
304    let f11 = values[2u] - level;
305    let f01 = values[3u] - level;
306    let q = f00 * f11 - f10 * f01;
307    let use_default = q > 0.0 || (abs(q) <= 1e-6 && case_index == 5u);
308    if (use_default) {
309        add_segment(3u, 2u, corners, values, level, io_segments, io_count);
310        add_segment(0u, 1u, corners, values, level, io_segments, io_count);
311    } else {
312        add_segment(3u, 0u, corners, values, level, io_segments, io_count);
313        add_segment(1u, 2u, corners, values, level, io_segments, io_count);
314    }
315}
316
317fn write_vertex_range(base_index: u32, segment_points: array<vec2<f32>, 4>, segment_count: u32, color: vec4<f32>) {
318    for (var i: u32 = 0u; i < VERTICES_PER_INVOCATION; i = i + 1u) {
319        let idx = base_index + i;
320        let vertex = if (i < segment_count * 2u) {
321            let pt = segment_points[i];
322            encode_vertex(vec3<f32>(pt, params.base_z), color)
323        } else {
324            encode_vertex(vec3<f32>(0.0, 0.0, params.base_z), vec4<f32>(color.xyz, 0.0))
325        };
326        out_vertices[idx] = vertex;
327    }
328}
329
330fn add_segment(
331    edge_a: u32,
332    edge_b: u32,
333    corners: array<vec2<f32>, 4>,
334    values: array<f32, 4>,
335    level: f32,
336    io_segments: ptr<function, array<vec2<f32>, 4>>,
337    io_count: ptr<function, u32>,
338) {
339    if (*io_count) >= 2u {
340        return;
341    }
342    let idx = (*io_count) * 2u;
343    (*io_segments)[idx] = interpolate_edge(edge_a, corners, values, level);
344    (*io_segments)[idx + 1u] = interpolate_edge(edge_b, corners, values, level);
345    *io_count = *io_count + 1u;
346}
347
348@compute @workgroup_size(WORKGROUP_SIZE)
349fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
350    let total = params.cell_count * params.level_count;
351    let invocation = gid.x;
352    if (invocation >= total) {
353        return;
354    }
355
356    let level_idx = invocation % params.level_count;
357    let cell_idx = invocation / params.level_count;
358    let cells_x = params.x_len - 1u;
359    let row = cell_idx % cells_x;
360    let col = cell_idx / cells_x;
361
362    let base_index = row + col * params.x_len;
363    let idx00 = base_index;
364    let idx10 = idx00 + 1u;
365    let idx01 = idx00 + params.x_len;
366    let idx11 = idx01 + 1u;
367
368    let x0 = f32(buf_x[row]);
369    let x1 = f32(buf_x[row + 1u]);
370    let y0 = f32(buf_y[col]);
371    let y1 = f32(buf_y[col + 1u]);
372
373    let z00 = f32(buf_z[idx00]);
374    let z10 = f32(buf_z[idx10]);
375    let z11 = f32(buf_z[idx11]);
376    let z01 = f32(buf_z[idx01]);
377
378    let corners = array<vec2<f32>, 4>(
379        vec2<f32>(x0, y0),
380        vec2<f32>(x1, y0),
381        vec2<f32>(x1, y1),
382        vec2<f32>(x0, y1)
383    );
384    let values = array<f32, 4>(z00, z10, z11, z01);
385
386    let level = level_values[level_idx];
387
388    var case_index: u32 = 0u;
389    if (z00 > level) { case_index = case_index | 1u; }
390    if (z10 > level) { case_index = case_index | 2u; }
391    if (z11 > level) { case_index = case_index | 4u; }
392    if (z01 > level) { case_index = case_index | 8u; }
393
394    var segments: array<vec2<f32>, 4>;
395    var segment_count: u32 = 0u;
396
397    switch case_index {
398        case 0u, 15u: {}
399        case 1u, 14u: { add_segment(3u, 0u, corners, values, level, &segments, &segment_count); }
400        case 2u, 13u: { add_segment(0u, 1u, corners, values, level, &segments, &segment_count); }
401        case 3u, 12u: { add_segment(3u, 1u, corners, values, level, &segments, &segment_count); }
402        case 4u, 11u: { add_segment(1u, 2u, corners, values, level, &segments, &segment_count); }
403        case 5u, 10u: { add_ambiguous_segments(case_index, corners, values, level, &segments, &segment_count); }
404        case 6u, 9u: { add_segment(0u, 2u, corners, values, level, &segments, &segment_count); }
405        case 7u, 8u: { add_segment(3u, 2u, corners, values, level, &segments, &segment_count); }
406    }
407
408    let norm = (level - params.min_z) / max(params.max_z - params.min_z, 1e-6);
409    let color = sample_color(norm);
410    let base_vertex = invocation * VERTICES_PER_INVOCATION;
411    write_vertex_range(base_vertex, segments, segment_count, color);
412}
413"#;