Skip to main content

runmat_plot/gpu/shaders/
surface.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct SurfaceParams {
8    min_z: f32,
9    max_z: f32,
10    alpha: f32,
11    flatten: u32,
12    x_len: u32,
13    y_len: u32,
14    lod_x_len: u32,
15    lod_y_len: u32,
16    x_stride: u32,
17    y_stride: u32,
18    color_table_len: u32,
19    _pad: u32,
20};
21
22@group(0) @binding(0)
23var<storage, read> buf_x: array<f32>;
24
25@group(0) @binding(1)
26var<storage, read> buf_y: array<f32>;
27
28@group(0) @binding(2)
29var<storage, read> buf_z: array<f32>;
30
31@group(0) @binding(3)
32var<storage, read> color_table: array<vec4<f32>>;
33
34@group(0) @binding(4)
35var<storage, read_write> out_vertices: array<VertexRaw>;
36
37@group(0) @binding(5)
38var<uniform> params: SurfaceParams;
39
40fn sample_color(t: f32) -> vec4<f32> {
41    let table_len = params.color_table_len;
42    if (table_len <= 1u) {
43        return color_table[0u];
44    }
45    let clamped = clamp(t, 0.0, 1.0);
46    let scaled = clamped * f32(table_len - 1u);
47    let lower = u32(scaled);
48    let upper = min(lower + 1u, table_len - 1u);
49    let frac = scaled - f32(lower);
50    return mix(color_table[lower], color_table[upper], frac);
51}
52
53fn sanitize_finite(v: f32, fallback: f32) -> f32 {
54    return select(v, fallback, isFinite(v));
55}
56
57@compute @workgroup_size(WORKGROUP_SIZE)
58fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
59    let lod_x_len = params.lod_x_len;
60    let lod_y_len = params.lod_y_len;
61    let total = lod_x_len * lod_y_len;
62
63    let idx = gid.x;
64    if (idx >= total) {
65        return;
66    }
67
68    let stride_x = max(params.x_stride, 1u);
69    let stride_y = max(params.y_stride, 1u);
70    let lod_row = idx % lod_x_len;
71    let lod_col = idx / lod_x_len;
72    var row = lod_row * stride_x;
73    var col = lod_col * stride_y;
74    row = min(row, params.x_len - 1u);
75    col = min(col, params.y_len - 1u);
76    let source_idx = col + params.y_len * row;
77
78    let px = buf_x[row];
79    let py = buf_y[col];
80    let raw_z = sanitize_finite(buf_z[source_idx], params.min_z);
81    let min_z = sanitize_finite(params.min_z, 0.0);
82    let max_z = sanitize_finite(params.max_z, min_z + 1.0);
83    let safe_max_z = max(max_z, min_z + 1e-6);
84    let z_extent = safe_max_z - min_z;
85    let norm_z = sanitize_finite((raw_z - min_z) / z_extent, 0.5);
86
87    let position_z = select(raw_z, 0.0, params.flatten == 1u);
88    let tex_x = f32(lod_row) / max(f32(lod_x_len - 1u), 1.0);
89    let tex_y = f32(lod_col) / max(f32(lod_y_len - 1u), 1.0);
90    let color = sample_color(norm_z) * vec4<f32>(1.0, 1.0, 1.0, params.alpha);
91
92    var vertex: VertexRaw;
93    vertex.data[0u] = px;
94    vertex.data[1u] = py;
95    vertex.data[2u] = position_z;
96
97    vertex.data[3u] = color.x;
98    vertex.data[4u] = color.y;
99    vertex.data[5u] = color.z;
100    vertex.data[6u] = color.w;
101
102    vertex.data[7u] = 0.0;
103    vertex.data[8u] = 0.0;
104    vertex.data[9u] = 1.0;
105
106    vertex.data[10u] = tex_x;
107    vertex.data[11u] = tex_y;
108
109    out_vertices[idx] = vertex;
110}
111"#;
112
113pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
114
115struct VertexRaw {
116    data: array<f32, 12u>,
117};
118
119struct SurfaceParams {
120    min_z: f32,
121    max_z: f32,
122    alpha: f32,
123    flatten: u32,
124    x_len: u32,
125    y_len: u32,
126    lod_x_len: u32,
127    lod_y_len: u32,
128    x_stride: u32,
129    y_stride: u32,
130    color_table_len: u32,
131    _pad: u32,
132};
133
134@group(0) @binding(0)
135var<storage, read> buf_x: array<f64>;
136
137@group(0) @binding(1)
138var<storage, read> buf_y: array<f64>;
139
140@group(0) @binding(2)
141var<storage, read> buf_z: array<f64>;
142
143@group(0) @binding(3)
144var<storage, read> color_table: array<vec4<f32>>;
145
146@group(0) @binding(4)
147var<storage, read_write> out_vertices: array<VertexRaw>;
148
149@group(0) @binding(5)
150var<uniform> params: SurfaceParams;
151
152fn sample_color(t: f32) -> vec4<f32> {
153    let table_len = params.color_table_len;
154    if (table_len <= 1u) {
155        return color_table[0u];
156    }
157    let clamped = clamp(t, 0.0, 1.0);
158    let scaled = clamped * f32(table_len - 1u);
159    let lower = u32(scaled);
160    let upper = min(lower + 1u, table_len - 1u);
161    let frac = scaled - f32(lower);
162    return mix(color_table[lower], color_table[upper], frac);
163}
164
165@compute @workgroup_size(WORKGROUP_SIZE)
166fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
167    let lod_x_len = params.lod_x_len;
168    let lod_y_len = params.lod_y_len;
169    let total = lod_x_len * lod_y_len;
170
171    let idx = gid.x;
172    if (idx >= total) {
173        return;
174    }
175
176    let stride_x = max(params.x_stride, 1u);
177    let stride_y = max(params.y_stride, 1u);
178    let lod_row = idx % lod_x_len;
179    let lod_col = idx / lod_x_len;
180    var row = lod_row * stride_x;
181    var col = lod_col * stride_y;
182    row = min(row, params.x_len - 1u);
183    col = min(col, params.y_len - 1u);
184    let source_idx = col + params.y_len * row;
185
186    let px = f32(buf_x[row]);
187    let py = f32(buf_y[col]);
188    let raw_z64 = buf_z[source_idx];
189    let raw_z = sanitize_finite(f32(raw_z64), params.min_z);
190    let min_z = sanitize_finite(params.min_z, 0.0);
191    let max_z = sanitize_finite(params.max_z, min_z + 1.0);
192    let safe_max_z = max(max_z, min_z + 1e-6);
193    let z_extent = safe_max_z - min_z;
194    let norm_z = sanitize_finite((raw_z - min_z) / z_extent, 0.5);
195
196    let position_z = select(raw_z, 0.0, params.flatten == 1u);
197    let tex_x = f32(lod_row) / max(f32(lod_x_len - 1u), 1.0);
198    let tex_y = f32(lod_col) / max(f32(lod_y_len - 1u), 1.0);
199    let color = sample_color(norm_z) * vec4<f32>(1.0, 1.0, 1.0, params.alpha);
200
201    var vertex: VertexRaw;
202    vertex.data[0u] = px;
203    vertex.data[1u] = py;
204    vertex.data[2u] = position_z;
205
206    vertex.data[3u] = color.x;
207    vertex.data[4u] = color.y;
208    vertex.data[5u] = color.z;
209    vertex.data[6u] = color.w;
210
211    vertex.data[7u] = 0.0;
212    vertex.data[8u] = 0.0;
213    vertex.data[9u] = 1.0;
214
215    vertex.data[10u] = tex_x;
216    vertex.data[11u] = tex_y;
217
218    out_vertices[idx] = vertex;
219}
220"#;