Skip to main content

runmat_plot/gpu/shaders/
stairs.rs

1pub const F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
2
3struct VertexRaw {
4    data: array<f32, 12u>,
5};
6
7struct StairsParams {
8    color: vec4<f32>,
9    point_count: u32,
10    _pad: vec3<u32>,
11};
12
13@group(0) @binding(0)
14var<storage, read> buf_x: array<f32>;
15
16@group(0) @binding(1)
17var<storage, read> buf_y: array<f32>;
18
19@group(0) @binding(2)
20var<storage, read_write> out_vertices: array<VertexRaw>;
21
22@group(0) @binding(3)
23var<uniform> params: StairsParams;
24
25fn encode_vertex(px: f32, py: f32) -> VertexRaw {
26    var vertex: VertexRaw;
27    vertex.data[0u] = px;
28    vertex.data[1u] = py;
29    vertex.data[2u] = 0.0;
30    vertex.data[3u] = params.color.x;
31    vertex.data[4u] = params.color.y;
32    vertex.data[5u] = params.color.z;
33    vertex.data[6u] = params.color.w;
34    vertex.data[7u] = 0.0;
35    vertex.data[8u] = 0.0;
36    vertex.data[9u] = 1.0;
37    vertex.data[10u] = 0.0;
38    vertex.data[11u] = 0.0;
39    return vertex;
40}
41
42@compute @workgroup_size(WORKGROUP_SIZE)
43fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
44    let seg = gid.x;
45    if (seg + 1u >= params.point_count) {
46        return;
47    }
48    let base = seg * 4u;
49    let x0 = buf_x[seg];
50    let y0 = buf_y[seg];
51    let x1 = buf_x[seg + 1u];
52    let y1 = buf_y[seg + 1u];
53
54    out_vertices[base + 0u] = encode_vertex(x0, y0);
55    out_vertices[base + 1u] = encode_vertex(x1, y0);
56    out_vertices[base + 2u] = encode_vertex(x1, y0);
57    out_vertices[base + 3u] = encode_vertex(x1, y1);
58}
59"#;
60
61pub const F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
62
63struct VertexRaw {
64    data: array<f32, 12u>,
65};
66
67struct StairsParams {
68    color: vec4<f32>,
69    point_count: u32,
70    _pad: vec3<u32>,
71};
72
73@group(0) @binding(0)
74var<storage, read> buf_x: array<f64>;
75
76@group(0) @binding(1)
77var<storage, read> buf_y: array<f64>;
78
79@group(0) @binding(2)
80var<storage, read_write> out_vertices: array<VertexRaw>;
81
82@group(0) @binding(3)
83var<uniform> params: StairsParams;
84
85fn encode_vertex(px: f32, py: f32) -> VertexRaw {
86    var vertex: VertexRaw;
87    vertex.data[0u] = px;
88    vertex.data[1u] = py;
89    vertex.data[2u] = 0.0;
90    vertex.data[3u] = params.color.x;
91    vertex.data[4u] = params.color.y;
92    vertex.data[5u] = params.color.z;
93    vertex.data[6u] = params.color.w;
94    vertex.data[7u] = 0.0;
95    vertex.data[8u] = 0.0;
96    vertex.data[9u] = 1.0;
97    vertex.data[10u] = 0.0;
98    vertex.data[11u] = 0.0;
99    return vertex;
100}
101
102@compute @workgroup_size(WORKGROUP_SIZE)
103fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
104    let seg = gid.x;
105    if (seg + 1u >= params.point_count) {
106        return;
107    }
108    let base = seg * 4u;
109    let x0 = f32(buf_x[seg]);
110    let y0 = f32(buf_y[seg]);
111    let x1 = f32(buf_x[seg + 1u]);
112    let y1 = f32(buf_y[seg + 1u]);
113
114    out_vertices[base + 0u] = encode_vertex(x0, y0);
115    out_vertices[base + 1u] = encode_vertex(x1, y0);
116    out_vertices[base + 2u] = encode_vertex(x1, y0);
117    out_vertices[base + 3u] = encode_vertex(x1, y1);
118}
119"#;