Skip to main content

runmat_plot/gpu/shaders/
surface.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct SurfaceParams {
8    min_z: f32,
9    max_z: f32,
10    alpha: f32,
11    flatten: u32,
12    x_len: u32,
13    y_len: u32,
14    lod_x_len: u32,
15    lod_y_len: u32,
16    x_stride: u32,
17    y_stride: u32,
18    color_table_len: u32,
19    _pad: u32,
20};
21
22@group(0) @binding(0)
23var<storage, read> buf_x: array<f32>;
24
25@group(0) @binding(1)
26var<storage, read> buf_y: array<f32>;
27
28@group(0) @binding(2)
29var<storage, read> buf_z: array<f32>;
30
31@group(0) @binding(3)
32var<storage, read> color_table: array<vec4<f32>>;
33
34@group(0) @binding(4)
35var<storage, read_write> out_vertices: array<VertexRaw>;
36
37@group(0) @binding(5)
38var<uniform> params: SurfaceParams;
39
40fn sample_color(t: f32) -> vec4<f32> {
41    let table_len = params.color_table_len;
42    if (table_len <= 1u) {
43        return color_table[0u];
44    }
45    let clamped = clamp(t, 0.0, 1.0);
46    let scaled = clamped * f32(table_len - 1u);
47    let lower = u32(scaled);
48    let upper = min(lower + 1u, table_len - 1u);
49    let frac = scaled - f32(lower);
50    return mix(color_table[lower], color_table[upper], frac);
51}
52
53@compute @workgroup_size(WORKGROUP_SIZE)
54fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
55    let lod_x_len = params.lod_x_len;
56    let lod_y_len = params.lod_y_len;
57    let total = lod_x_len * lod_y_len;
58
59    let idx = gid.x;
60    if (idx >= total) {
61        return;
62    }
63
64    let stride_x = max(params.x_stride, 1u);
65    let stride_y = max(params.y_stride, 1u);
66    let lod_row = idx % lod_x_len;
67    let lod_col = idx / lod_x_len;
68    var row = lod_row * stride_x;
69    var col = lod_col * stride_y;
70    row = min(row, params.x_len - 1u);
71    col = min(col, params.y_len - 1u);
72    let source_idx = col * params.x_len + row;
73
74    let px = buf_x[row];
75    let py = buf_y[col];
76    let raw_z = buf_z[source_idx];
77    let z_extent = max(params.max_z - params.min_z, 1e-6);
78    let norm_z = (raw_z - params.min_z) / z_extent;
79
80    let position_z = select(raw_z, 0.0, params.flatten == 1u);
81    let tex_x = f32(lod_row) / max(f32(lod_x_len - 1u), 1.0);
82    let tex_y = f32(lod_col) / max(f32(lod_y_len - 1u), 1.0);
83    let color = sample_color(norm_z) * vec4<f32>(1.0, 1.0, 1.0, params.alpha);
84
85    var vertex: VertexRaw;
86    vertex.data[0u] = px;
87    vertex.data[1u] = py;
88    vertex.data[2u] = position_z;
89
90    vertex.data[3u] = color.x;
91    vertex.data[4u] = color.y;
92    vertex.data[5u] = color.z;
93    vertex.data[6u] = color.w;
94
95    vertex.data[7u] = 0.0;
96    vertex.data[8u] = 0.0;
97    vertex.data[9u] = 1.0;
98
99    vertex.data[10u] = tex_x;
100    vertex.data[11u] = tex_y;
101
102    out_vertices[idx] = vertex;
103}
104"#;
105
106pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
107
108struct VertexRaw {
109    data: array<f32, 12u>,
110};
111
112struct SurfaceParams {
113    min_z: f32,
114    max_z: f32,
115    alpha: f32,
116    flatten: u32,
117    x_len: u32,
118    y_len: u32,
119    lod_x_len: u32,
120    lod_y_len: u32,
121    x_stride: u32,
122    y_stride: u32,
123    color_table_len: u32,
124    _pad: u32,
125};
126
127@group(0) @binding(0)
128var<storage, read> buf_x: array<f64>;
129
130@group(0) @binding(1)
131var<storage, read> buf_y: array<f64>;
132
133@group(0) @binding(2)
134var<storage, read> buf_z: array<f64>;
135
136@group(0) @binding(3)
137var<storage, read> color_table: array<vec4<f32>>;
138
139@group(0) @binding(4)
140var<storage, read_write> out_vertices: array<VertexRaw>;
141
142@group(0) @binding(5)
143var<uniform> params: SurfaceParams;
144
145fn sample_color(t: f32) -> vec4<f32> {
146    let table_len = params.color_table_len;
147    if (table_len <= 1u) {
148        return color_table[0u];
149    }
150    let clamped = clamp(t, 0.0, 1.0);
151    let scaled = clamped * f32(table_len - 1u);
152    let lower = u32(scaled);
153    let upper = min(lower + 1u, table_len - 1u);
154    let frac = scaled - f32(lower);
155    return mix(color_table[lower], color_table[upper], frac);
156}
157
158@compute @workgroup_size(WORKGROUP_SIZE)
159fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
160    let lod_x_len = params.lod_x_len;
161    let lod_y_len = params.lod_y_len;
162    let total = lod_x_len * lod_y_len;
163
164    let idx = gid.x;
165    if (idx >= total) {
166        return;
167    }
168
169    let stride_x = max(params.x_stride, 1u);
170    let stride_y = max(params.y_stride, 1u);
171    let lod_row = idx % lod_x_len;
172    let lod_col = idx / lod_x_len;
173    var row = lod_row * stride_x;
174    var col = lod_col * stride_y;
175    row = min(row, params.x_len - 1u);
176    col = min(col, params.y_len - 1u);
177    let source_idx = col * params.x_len + row;
178
179    let px = f32(buf_x[row]);
180    let py = f32(buf_y[col]);
181    let raw_z64 = buf_z[source_idx];
182    let raw_z = f32(raw_z64);
183    let z_extent = max(params.max_z - params.min_z, 1e-6);
184    let norm_z = (raw_z - params.min_z) / z_extent;
185
186    let position_z = select(raw_z, 0.0, params.flatten == 1u);
187    let tex_x = f32(lod_row) / max(f32(lod_x_len - 1u), 1.0);
188    let tex_y = f32(lod_col) / max(f32(lod_y_len - 1u), 1.0);
189    let color = sample_color(norm_z) * vec4<f32>(1.0, 1.0, 1.0, params.alpha);
190
191    var vertex: VertexRaw;
192    vertex.data[0u] = px;
193    vertex.data[1u] = py;
194    vertex.data[2u] = position_z;
195
196    vertex.data[3u] = color.x;
197    vertex.data[4u] = color.y;
198    vertex.data[5u] = color.z;
199    vertex.data[6u] = color.w;
200
201    vertex.data[7u] = 0.0;
202    vertex.data[8u] = 0.0;
203    vertex.data[9u] = 1.0;
204
205    vertex.data[10u] = tex_x;
206    vertex.data[11u] = tex_y;
207
208    out_vertices[idx] = vertex;
209}
210"#;