Skip to main content

trueno/backends/gpu/shaders/
reductions.rs

1//! Reduction and softmax WGSL compute shaders.
2
3/// Max reduction compute shader (WGSL)
4///
5/// Computes max(input) using parallel reduction
6/// Used as first pass in softmax to ensure numerical stability
7pub(crate) const MAX_REDUCTION_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> input: array<f32>;
9@group(0) @binding(1) var<storage, read_write> result: array<f32>;
10
11var<workgroup> partial_max: array<f32, 256>;
12
13@compute @workgroup_size(256)
14fn main(
15    @builtin(global_invocation_id) global_id: vec3<u32>,
16    @builtin(local_invocation_id) local_id: vec3<u32>,
17) {
18    let idx = global_id.x;
19    let local_idx = local_id.x;
20    let len = arrayLength(&input);
21
22    // Load value or negative infinity
23    var max_val: f32 = -3.402823466e+38; // -FLT_MAX
24    if (idx < len) {
25        max_val = input[idx];
26    }
27    partial_max[local_idx] = max_val;
28
29    workgroupBarrier();
30
31    // Parallel reduction within workgroup (find max)
32    var stride: u32 = 128u;
33    while (stride > 0u) {
34        if (local_idx < stride) {
35            partial_max[local_idx] = max(partial_max[local_idx], partial_max[local_idx + stride]);
36        }
37        stride = stride / 2u;
38        workgroupBarrier();
39    }
40
41    // First thread writes workgroup result
42    if (local_idx == 0u) {
43        result[global_id.x / 256u] = partial_max[0];
44    }
45}
46"#;
47
48/// Sum reduction compute shader (WGSL)
49///
50/// Computes sum(input) using parallel reduction
51/// Used in softmax to sum exp values
52pub(crate) const SUM_REDUCTION_SHADER: &str = r#"
53@group(0) @binding(0) var<storage, read> input: array<f32>;
54@group(0) @binding(1) var<storage, read_write> result: array<f32>;
55
56var<workgroup> partial_sums: array<f32, 256>;
57
58@compute @workgroup_size(256)
59fn main(
60    @builtin(global_invocation_id) global_id: vec3<u32>,
61    @builtin(local_invocation_id) local_id: vec3<u32>,
62) {
63    let idx = global_id.x;
64    let local_idx = local_id.x;
65    let len = arrayLength(&input);
66
67    // Load value
68    var sum: f32 = 0.0;
69    if (idx < len) {
70        sum = input[idx];
71    }
72    partial_sums[local_idx] = sum;
73
74    workgroupBarrier();
75
76    // Parallel reduction within workgroup
77    var stride: u32 = 128u;
78    while (stride > 0u) {
79        if (local_idx < stride) {
80            partial_sums[local_idx] = partial_sums[local_idx] + partial_sums[local_idx + stride];
81        }
82        stride = stride / 2u;
83        workgroupBarrier();
84    }
85
86    // First thread writes workgroup result
87    if (local_idx == 0u) {
88        result[global_id.x / 256u] = partial_sums[0];
89    }
90}
91"#;
92
93/// Softmax exp-subtract compute shader (WGSL)
94///
95/// Computes exp(input[i] - max_val) for each element
96/// Second pass in softmax: numerically stable exp computation
97pub(crate) const SOFTMAX_EXP_SHADER: &str = r#"
98@group(0) @binding(0) var<storage, read> input: array<f32>;
99@group(0) @binding(1) var<storage, read_write> output: array<f32>;
100
101struct MaxValue {
102    max_val: f32,
103}
104
105@group(0) @binding(2) var<uniform> params: MaxValue;
106
107@compute @workgroup_size(256)
108fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
109    let idx = global_id.x;
110    let len = arrayLength(&input);
111
112    if (idx < len) {
113        // exp(x - max) for numerical stability
114        output[idx] = exp(input[idx] - params.max_val);
115    }
116}
117"#;
118
119/// Softmax normalize compute shader (WGSL)
120///
121/// Computes output[i] = input[i] / sum_val for each element
122/// Fourth pass in softmax: normalize by sum of exp values
123pub(crate) const SOFTMAX_NORMALIZE_SHADER: &str = r#"
124@group(0) @binding(0) var<storage, read> input: array<f32>;
125@group(0) @binding(1) var<storage, read_write> output: array<f32>;
126
127struct SumValue {
128    sum_val: f32,
129}
130
131@group(0) @binding(2) var<uniform> params: SumValue;
132
133@compute @workgroup_size(256)
134fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
135    let idx = global_id.x;
136    let len = arrayLength(&input);
137
138    if (idx < len) {
139        // Normalize by sum
140        output[idx] = input[idx] / params.sum_val;
141    }
142}
143"#;
144
145/// Log-softmax compute shader (WGSL)
146///
147/// Computes log_softmax[i] = input[i] - max_val - log(sum_val) for each element
148/// Numerically stable log-softmax in single pass after reductions
149pub(crate) const LOG_SOFTMAX_SHADER: &str = r#"
150@group(0) @binding(0) var<storage, read> input: array<f32>;
151@group(0) @binding(1) var<storage, read_write> output: array<f32>;
152
153struct LogSoftmaxParams {
154    max_val: f32,
155    log_sum_exp: f32,
156}
157
158@group(0) @binding(2) var<uniform> params: LogSoftmaxParams;
159
160@compute @workgroup_size(256)
161fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
162    let idx = global_id.x;
163    let len = arrayLength(&input);
164
165    if (idx < len) {
166        // log_softmax(x)[i] = x[i] - max - log(sum(exp(x - max)))
167        output[idx] = input[idx] - params.max_val - params.log_sum_exp;
168    }
169}
170"#;