Skip to main content

runmat_plot/gpu/shaders/
bar.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3const VERTICES_PER_BAR: u32 = 6u;
4
5struct VertexRaw {
6    data: array<f32, 12u>,
7};
8
9struct BarParams {
10    color: vec4<f32>,
11    bar_width: f32,
12    row_count: u32,
13    series_index: u32,
14    series_count: u32,
15    group_index: u32,
16    group_count: u32,
17    orientation: u32,
18    layout: u32,
19};
20
21@group(0) @binding(0)
22var<storage, read> values: array<f32>;
23
24@group(0) @binding(1)
25var<storage, read_write> out_vertices: array<VertexRaw>;
26
27@group(0) @binding(2)
28var<uniform> params: BarParams;
29
30fn encode_vertex(position: vec3<f32>) -> VertexRaw {
31    var vertex: VertexRaw;
32    vertex.data[0u] = position.x;
33    vertex.data[1u] = position.y;
34    vertex.data[2u] = position.z;
35    vertex.data[3u] = params.color.x;
36    vertex.data[4u] = params.color.y;
37    vertex.data[5u] = params.color.z;
38    vertex.data[6u] = params.color.w;
39    vertex.data[7u] = 0.0;
40    vertex.data[8u] = 0.0;
41    vertex.data[9u] = 1.0;
42    vertex.data[10u] = 0.0;
43    vertex.data[11u] = 0.0;
44    return vertex;
45}
46
47fn write_vertices(base_index: u32, quad: array<vec3<f32>, 4u>) {
48    out_vertices[base_index + 0u] = encode_vertex(quad[0u]);
49    out_vertices[base_index + 1u] = encode_vertex(quad[1u]);
50    out_vertices[base_index + 2u] = encode_vertex(quad[2u]);
51    out_vertices[base_index + 3u] = encode_vertex(quad[0u]);
52    out_vertices[base_index + 4u] = encode_vertex(quad[2u]);
53    out_vertices[base_index + 5u] = encode_vertex(quad[3u]);
54}
55
56fn build_vertical_quad(idx: u32, start: f32, end: f32, per_group_width: f32, local_offset: f32) -> array<vec3<f32>, 4u> {
57    let center = (f32(idx) + 1.0) + local_offset;
58    let half = per_group_width * 0.5;
59    let left = center - half;
60    let right = center + half;
61    let bottom = min(start, end);
62    let top = max(start, end);
63
64    return array<vec3<f32>, 4u>(
65        vec3<f32>(left, bottom, 0.0),
66        vec3<f32>(right, bottom, 0.0),
67        vec3<f32>(right, top, 0.0),
68        vec3<f32>(left, top, 0.0),
69    );
70}
71
72fn build_horizontal_quad(idx: u32, start: f32, end: f32, per_group_width: f32, local_offset: f32) -> array<vec3<f32>, 4u> {
73    let center = (f32(idx) + 1.0) + local_offset;
74    let half = per_group_width * 0.5;
75    let bottom = center - half;
76    let top = center + half;
77    let left = min(start, end);
78    let right = max(start, end);
79
80    return array<vec3<f32>, 4u>(
81        vec3<f32>(left, bottom, 0.0),
82        vec3<f32>(right, bottom, 0.0),
83        vec3<f32>(right, top, 0.0),
84        vec3<f32>(left, top, 0.0),
85    );
86}
87
88@compute @workgroup_size(WORKGROUP_SIZE)
89fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
90    let idx = gid.x;
91    if (idx >= params.row_count) {
92        return;
93    }
94
95    let safe_group_count = max(params.group_count, 1u);
96    let per_group_width = max(params.bar_width / f32(safe_group_count), 0.01);
97    let group_offset_start = -params.bar_width * 0.5;
98    let local_offset = group_offset_start
99        + per_group_width * f32(min(params.group_index, safe_group_count - 1u))
100        + per_group_width * 0.5;
101    let stride = params.row_count;
102    let column_offset = params.series_index * stride;
103    let value = values[column_offset + idx];
104
105    var base_pos = 0.0;
106    var base_neg = 0.0;
107    if (params.layout == 1u && params.series_index > 0u) {
108        var col: u32 = 0u;
109        loop {
110            if (col >= params.series_index) {
111                break;
112            }
113            let prev = values[col * stride + idx];
114            if (isFinite(prev)) {
115                if (prev >= 0.0) {
116                    base_pos += prev;
117                } else {
118                    base_neg += prev;
119                }
120            }
121            col = col + 1u;
122        }
123    }
124
125    var start = 0.0;
126    var end = 0.0;
127    if (!isFinite(value)) {
128        start = 0.0;
129        end = 0.0;
130    } else if (params.layout == 1u) {
131        if (value >= 0.0) {
132            start = base_pos;
133            end = base_pos + value;
134        } else {
135            start = base_neg + value;
136            end = base_neg;
137        }
138    } else {
139        start = 0.0;
140        end = value;
141    }
142
143    let quad = if (params.orientation == 0u) {
144        build_vertical_quad(idx, start, end, per_group_width, local_offset)
145    } else {
146        build_horizontal_quad(idx, start, end, per_group_width, local_offset)
147    };
148
149    let base_index = idx * VERTICES_PER_BAR;
150    write_vertices(base_index, quad);
151}
152"#;
153
154pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
155
156const VERTICES_PER_BAR: u32 = 6u;
157
158struct VertexRaw {
159    data: array<f32, 12u>,
160};
161
162struct BarParams {
163    color: vec4<f32>,
164    bar_width: f32,
165    row_count: u32,
166    series_index: u32,
167    series_count: u32,
168    group_index: u32,
169    group_count: u32,
170    orientation: u32,
171    layout: u32,
172};
173
174@group(0) @binding(0)
175var<storage, read> values: array<f64>;
176
177@group(0) @binding(1)
178var<storage, read_write> out_vertices: array<VertexRaw>;
179
180@group(0) @binding(2)
181var<uniform> params: BarParams;
182
183fn encode_vertex(position: vec3<f32>) -> VertexRaw {
184    var vertex: VertexRaw;
185    vertex.data[0u] = position.x;
186    vertex.data[1u] = position.y;
187    vertex.data[2u] = position.z;
188    vertex.data[3u] = params.color.x;
189    vertex.data[4u] = params.color.y;
190    vertex.data[5u] = params.color.z;
191    vertex.data[6u] = params.color.w;
192    vertex.data[7u] = 0.0;
193    vertex.data[8u] = 0.0;
194    vertex.data[9u] = 1.0;
195    vertex.data[10u] = 0.0;
196    vertex.data[11u] = 0.0;
197    return vertex;
198}
199
200fn write_vertices(base_index: u32, quad: array<vec3<f32>, 4u>) {
201    out_vertices[base_index + 0u] = encode_vertex(quad[0u]);
202    out_vertices[base_index + 1u] = encode_vertex(quad[1u]);
203    out_vertices[base_index + 2u] = encode_vertex(quad[2u]);
204    out_vertices[base_index + 3u] = encode_vertex(quad[0u]);
205    out_vertices[base_index + 4u] = encode_vertex(quad[2u]);
206    out_vertices[base_index + 5u] = encode_vertex(quad[3u]);
207}
208
209fn build_vertical_quad(idx: u32, start: f32, end: f32, per_group_width: f32, local_offset: f32) -> array<vec3<f32>, 4u> {
210    let center = (f32(idx) + 1.0) + local_offset;
211    let half = per_group_width * 0.5;
212    let left = center - half;
213    let right = center + half;
214    let bottom = min(start, end);
215    let top = max(start, end);
216
217    return array<vec3<f32>, 4u>(
218        vec3<f32>(left, bottom, 0.0),
219        vec3<f32>(right, bottom, 0.0),
220        vec3<f32>(right, top, 0.0),
221        vec3<f32>(left, top, 0.0),
222    );
223}
224
225fn build_horizontal_quad(idx: u32, start: f32, end: f32, per_group_width: f32, local_offset: f32) -> array<vec3<f32>, 4u> {
226    let center = (f32(idx) + 1.0) + local_offset;
227    let half = per_group_width * 0.5;
228    let bottom = center - half;
229    let top = center + half;
230    let left = min(start, end);
231    let right = max(start, end);
232
233    return array<vec3<f32>, 4u>(
234        vec3<f32>(left, bottom, 0.0),
235        vec3<f32>(right, bottom, 0.0),
236        vec3<f32>(right, top, 0.0),
237        vec3<f32>(left, top, 0.0),
238    );
239}
240
241@compute @workgroup_size(WORKGROUP_SIZE)
242fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
243    let idx = gid.x;
244    if (idx >= params.row_count) {
245        return;
246    }
247
248    let safe_group_count = max(params.group_count, 1u);
249    let per_group_width = max(params.bar_width / f32(safe_group_count), 0.01);
250    let group_offset_start = -params.bar_width * 0.5;
251    let local_offset = group_offset_start
252        + per_group_width * f32(min(params.group_index, safe_group_count - 1u))
253        + per_group_width * 0.5;
254    let stride = params.row_count;
255    let column_offset = params.series_index * stride;
256    let value = f32(values[column_offset + idx]);
257
258    var base_pos = 0.0;
259    var base_neg = 0.0;
260    if (params.layout == 1u && params.series_index > 0u) {
261        var col: u32 = 0u;
262        loop {
263            if (col >= params.series_index) {
264                break;
265            }
266            let prev = f32(values[col * stride + idx]);
267            if (isFinite(prev)) {
268                if (prev >= 0.0) {
269                    base_pos += prev;
270                } else {
271                    base_neg += prev;
272                }
273            }
274            col = col + 1u;
275        }
276    }
277
278    var start = 0.0;
279    var end = 0.0;
280    if (!isFinite(value)) {
281        start = 0.0;
282        end = 0.0;
283    } else if (params.layout == 1u) {
284        if (value >= 0.0) {
285            start = base_pos;
286            end = base_pos + value;
287        } else {
288            start = base_neg + value;
289            end = base_neg;
290        }
291    } else {
292        start = 0.0;
293        end = value;
294    }
295
296    let quad = if (params.orientation == 0u) {
297        build_vertical_quad(idx, start, end, per_group_width, local_offset)
298    } else {
299        build_horizontal_quad(idx, start, end, per_group_width, local_offset)
300    };
301
302    let base_index = idx * VERTICES_PER_BAR;
303    write_vertices(base_index, quad);
304}
305"#;