trueno/backends/gpu/shaders/
reductions.rs1pub(crate) const MAX_REDUCTION_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> input: array<f32>;
9@group(0) @binding(1) var<storage, read_write> result: array<f32>;
10
11var<workgroup> partial_max: array<f32, 256>;
12
13@compute @workgroup_size(256)
14fn main(
15 @builtin(global_invocation_id) global_id: vec3<u32>,
16 @builtin(local_invocation_id) local_id: vec3<u32>,
17) {
18 let idx = global_id.x;
19 let local_idx = local_id.x;
20 let len = arrayLength(&input);
21
22 // Load value or negative infinity
23 var max_val: f32 = -3.402823466e+38; // -FLT_MAX
24 if (idx < len) {
25 max_val = input[idx];
26 }
27 partial_max[local_idx] = max_val;
28
29 workgroupBarrier();
30
31 // Parallel reduction within workgroup (find max)
32 var stride: u32 = 128u;
33 while (stride > 0u) {
34 if (local_idx < stride) {
35 partial_max[local_idx] = max(partial_max[local_idx], partial_max[local_idx + stride]);
36 }
37 stride = stride / 2u;
38 workgroupBarrier();
39 }
40
41 // First thread writes workgroup result
42 if (local_idx == 0u) {
43 result[global_id.x / 256u] = partial_max[0];
44 }
45}
46"#;
47
48pub(crate) const SUM_REDUCTION_SHADER: &str = r#"
53@group(0) @binding(0) var<storage, read> input: array<f32>;
54@group(0) @binding(1) var<storage, read_write> result: array<f32>;
55
56var<workgroup> partial_sums: array<f32, 256>;
57
58@compute @workgroup_size(256)
59fn main(
60 @builtin(global_invocation_id) global_id: vec3<u32>,
61 @builtin(local_invocation_id) local_id: vec3<u32>,
62) {
63 let idx = global_id.x;
64 let local_idx = local_id.x;
65 let len = arrayLength(&input);
66
67 // Load value
68 var sum: f32 = 0.0;
69 if (idx < len) {
70 sum = input[idx];
71 }
72 partial_sums[local_idx] = sum;
73
74 workgroupBarrier();
75
76 // Parallel reduction within workgroup
77 var stride: u32 = 128u;
78 while (stride > 0u) {
79 if (local_idx < stride) {
80 partial_sums[local_idx] = partial_sums[local_idx] + partial_sums[local_idx + stride];
81 }
82 stride = stride / 2u;
83 workgroupBarrier();
84 }
85
86 // First thread writes workgroup result
87 if (local_idx == 0u) {
88 result[global_id.x / 256u] = partial_sums[0];
89 }
90}
91"#;
92
93pub(crate) const SOFTMAX_EXP_SHADER: &str = r#"
98@group(0) @binding(0) var<storage, read> input: array<f32>;
99@group(0) @binding(1) var<storage, read_write> output: array<f32>;
100
101struct MaxValue {
102 max_val: f32,
103}
104
105@group(0) @binding(2) var<uniform> params: MaxValue;
106
107@compute @workgroup_size(256)
108fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
109 let idx = global_id.x;
110 let len = arrayLength(&input);
111
112 if (idx < len) {
113 // exp(x - max) for numerical stability
114 output[idx] = exp(input[idx] - params.max_val);
115 }
116}
117"#;
118
119pub(crate) const SOFTMAX_NORMALIZE_SHADER: &str = r#"
124@group(0) @binding(0) var<storage, read> input: array<f32>;
125@group(0) @binding(1) var<storage, read_write> output: array<f32>;
126
127struct SumValue {
128 sum_val: f32,
129}
130
131@group(0) @binding(2) var<uniform> params: SumValue;
132
133@compute @workgroup_size(256)
134fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
135 let idx = global_id.x;
136 let len = arrayLength(&input);
137
138 if (idx < len) {
139 // Normalize by sum
140 output[idx] = input[idx] / params.sum_val;
141 }
142}
143"#;
144
145pub(crate) const LOG_SOFTMAX_SHADER: &str = r#"
150@group(0) @binding(0) var<storage, read> input: array<f32>;
151@group(0) @binding(1) var<storage, read_write> output: array<f32>;
152
153struct LogSoftmaxParams {
154 max_val: f32,
155 log_sum_exp: f32,
156}
157
158@group(0) @binding(2) var<uniform> params: LogSoftmaxParams;
159
160@compute @workgroup_size(256)
161fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
162 let idx = global_id.x;
163 let len = arrayLength(&input);
164
165 if (idx < len) {
166 // log_softmax(x)[i] = x[i] - max - log(sum(exp(x - max)))
167 output[idx] = input[idx] - params.max_val - params.log_sum_exp;
168 }
169}
170"#;