gllm_kernels/ops/
softmax.rs

1//! Log-space softmax computation for numerical stability.
2
3use burn::tensor::backend::Backend;
4use burn::tensor::Tensor;
5
6/// Compute log(exp(a) + exp(b)) in a numerically stable way.
7#[inline]
8pub fn log_add_exp(a: f64, b: f64) -> f64 {
9    if a.is_infinite() && a.is_sign_negative() {
10        return b;
11    }
12    if b.is_infinite() && b.is_sign_negative() {
13        return a;
14    }
15
16    let max = a.max(b);
17    let min = a.min(b);
18
19    max + (1.0 + (min - max).exp()).ln()
20}
21
22/// Compute log(sum(exp(x))) in a numerically stable way.
23#[inline]
24pub fn log_sum_exp(values: &[f64]) -> f64 {
25    if values.is_empty() {
26        return f64::NEG_INFINITY;
27    }
28
29    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
30
31    if max.is_infinite() {
32        return max;
33    }
34
35    let sum: f64 = values.iter().map(|&x| (x - max).exp()).sum();
36
37    max + sum.ln()
38}
39
40/// Compute log(sum(exp(x))) using Kahan summation for the exp sum.
41pub fn log_sum_exp_kahan(values: &[f64]) -> f64 {
42    use super::stable_accumulator::KahanAccumulator;
43
44    if values.is_empty() {
45        return f64::NEG_INFINITY;
46    }
47
48    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
49
50    if max.is_infinite() {
51        return max;
52    }
53
54    let mut sum = KahanAccumulator::<f64>::new();
55    for &x in values {
56        sum.add((x - max).exp());
57    }
58
59    max + sum.value().ln()
60}
61
62/// Log-space softmax accumulator for online computation.
63#[derive(Debug, Clone)]
64pub struct LogSpaceSoftmax {
65    /// Current maximum score.
66    m: f64,
67    /// Log of the running sum: log(sum(exp(scores - m))).
68    log_l: f64,
69    /// Number of blocks processed.
70    count: usize,
71}
72
73impl Default for LogSpaceSoftmax {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl LogSpaceSoftmax {
80    /// Create a new log-space softmax accumulator.
81    pub fn new() -> Self {
82        Self {
83            m: f64::NEG_INFINITY,
84            log_l: f64::NEG_INFINITY,
85            count: 0,
86        }
87    }
88
89    /// Update with a new block of scores.
90    pub fn update(&mut self, block_max: f64, block_log_sum_exp: f64) -> f64 {
91        let m_new = self.m.max(block_max);
92
93        let prev_contrib = (self.m - m_new) + self.log_l;
94        let new_contrib = (block_max - m_new) + block_log_sum_exp;
95        let log_l_new = log_add_exp(prev_contrib, new_contrib);
96
97        let log_scale = self.m - m_new;
98
99        self.m = m_new;
100        self.log_l = log_l_new;
101        self.count += 1;
102
103        log_scale
104    }
105
106    /// Update with raw block statistics (not in log-space).
107    pub fn update_raw(&mut self, block_max: f64, block_sum_exp: f64) -> f64 {
108        let block_log_sum_exp = block_sum_exp.ln();
109        self.update(block_max, block_log_sum_exp)
110    }
111
112    /// Get the current maximum.
113    #[inline]
114    pub fn max(&self) -> f64 {
115        self.m
116    }
117
118    /// Get the log of the running sum.
119    #[inline]
120    pub fn log_sum(&self) -> f64 {
121        self.log_l
122    }
123
124    /// Get the running sum (converted from log-space).
125    #[inline]
126    pub fn sum(&self) -> f64 {
127        self.log_l.exp()
128    }
129
130    /// Get the number of blocks processed.
131    #[inline]
132    pub fn count(&self) -> usize {
133        self.count
134    }
135
136    /// Compute the final normalization factor: 1 / sum(exp(scores - m)).
137    #[inline]
138    pub fn normalization_factor(&self) -> f64 {
139        (-self.log_l).exp()
140    }
141
142    /// Merge another log-space accumulator into this one.
143    pub fn merge(&mut self, other: &LogSpaceSoftmax) -> f64 {
144        if other.count == 0 {
145            return 0.0;
146        }
147        if self.count == 0 {
148            self.m = other.m;
149            self.log_l = other.log_l;
150            self.count = other.count;
151            return f64::NEG_INFINITY;
152        }
153
154        let m_new = self.m.max(other.m);
155
156        let self_contrib = (self.m - m_new) + self.log_l;
157        let other_contrib = (other.m - m_new) + other.log_l;
158        let log_l_new = log_add_exp(self_contrib, other_contrib);
159
160        let log_scale = self.m - m_new;
161
162        self.m = m_new;
163        self.log_l = log_l_new;
164        self.count += other.count;
165
166        log_scale
167    }
168
169    /// Reset the accumulator.
170    pub fn reset(&mut self) {
171        self.m = f64::NEG_INFINITY;
172        self.log_l = f64::NEG_INFINITY;
173        self.count = 0;
174    }
175}
176
177/// Tensor-based log-space operations for GPU computation.
178pub struct TensorLogOps;
179
180impl TensorLogOps {
181    /// Compute log(sum(exp(tensor))) along a dimension.
182    pub fn log_sum_exp<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
183        let max = tensor.clone().max_dim(dim);
184        let shifted = tensor - max.clone();
185        let sum_exp = shifted.exp().sum_dim(dim);
186        max + sum_exp.log()
187    }
188
189    /// Compute stable softmax in log-space.
190    pub fn log_softmax<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
191        let log_sum = Self::log_sum_exp(tensor.clone(), dim);
192        tensor - log_sum
193    }
194
195    /// Extract maximum and log-sum-exp from a tensor for online softmax.
196    pub fn extract_softmax_stats<B: Backend>(
197        tensor: Tensor<B, 4>,
198        dim: usize,
199    ) -> (Tensor<B, 4>, Tensor<B, 4>) {
200        let max = tensor.clone().max_dim(dim);
201        let shifted = tensor - max.clone();
202        let sum_exp = shifted.exp().sum_dim(dim);
203        let log_sum = sum_exp.log();
204        (max, log_sum)
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_log_add_exp_basic() {
214        let result = log_add_exp(0.0, 0.0);
215        assert!((result - 2.0_f64.ln()).abs() < 1e-10);
216
217        let result = log_add_exp(1.0, 2.0);
218        let expected = (1.0_f64.exp() + 2.0_f64.exp()).ln();
219        assert!((result - expected).abs() < 1e-10);
220    }
221
222    #[test]
223    fn test_log_add_exp_extreme() {
224        let result = log_add_exp(1000.0, 1000.0);
225        assert!((result - (1000.0 + 2.0_f64.ln())).abs() < 1e-10);
226
227        let result = log_add_exp(-1000.0, -1000.0);
228        assert!((result - (-1000.0 + 2.0_f64.ln())).abs() < 1e-10);
229
230        let result = log_add_exp(1000.0, 0.0);
231        assert!((result - 1000.0).abs() < 1e-10);
232    }
233
234    #[test]
235    fn test_log_add_exp_neg_infinity() {
236        assert_eq!(log_add_exp(f64::NEG_INFINITY, 5.0), 5.0);
237        assert_eq!(log_add_exp(5.0, f64::NEG_INFINITY), 5.0);
238        assert_eq!(
239            log_add_exp(f64::NEG_INFINITY, f64::NEG_INFINITY),
240            f64::NEG_INFINITY
241        );
242    }
243
244    #[test]
245    fn test_log_sum_exp() {
246        let values = vec![1.0, 2.0, 3.0];
247        let result = log_sum_exp(&values);
248        let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
249        assert!((result - expected).abs() < 1e-10);
250    }
251
252    #[test]
253    fn test_log_sum_exp_large_sequence() {
254        let values: Vec<f64> = (0..10000).map(|i| (i as f64) * 0.001).collect();
255        let result = log_sum_exp(&values);
256
257        assert!(result.is_finite());
258        assert!(result > 0.0);
259    }
260
261    #[test]
262    fn test_log_space_softmax_update() {
263        let mut acc = LogSpaceSoftmax::new();
264
265        acc.update_raw(5.0, 10.0);
266        assert_eq!(acc.max(), 5.0);
267        assert!((acc.sum() - 10.0).abs() < 1e-10);
268
269        acc.update_raw(3.0, 5.0);
270        assert_eq!(acc.max(), 5.0);
271        let expected_sum = 10.0 + 5.0 * (-2.0_f64).exp();
272        assert!((acc.sum() - expected_sum).abs() < 1e-8);
273    }
274
275    #[test]
276    fn test_log_space_softmax_merge() {
277        let mut acc1 = LogSpaceSoftmax::new();
278        acc1.update_raw(5.0, 10.0);
279
280        let mut acc2 = LogSpaceSoftmax::new();
281        acc2.update_raw(3.0, 5.0);
282
283        acc1.merge(&acc2);
284
285        assert_eq!(acc1.max(), 5.0);
286        let expected_sum = 10.0 + 5.0 * (-2.0_f64).exp();
287        assert!((acc1.sum() - expected_sum).abs() < 1e-8);
288    }
289
290    #[test]
291    fn test_log_space_vs_standard() {
292        use super::super::stable_accumulator::StableAccumulator;
293
294        let blocks: Vec<(f64, f64)> = vec![(5.0, 10.0), (3.0, 5.0), (7.0, 20.0), (1.0, 2.0)];
295
296        let mut std_acc = StableAccumulator::default_config();
297        for (max, sum_exp) in blocks.iter() {
298            std_acc.update(*max, *sum_exp);
299        }
300
301        let mut log_acc = LogSpaceSoftmax::new();
302        for (max, sum_exp) in blocks.iter() {
303            log_acc.update_raw(*max, *sum_exp);
304        }
305
306        assert_eq!(std_acc.max(), log_acc.max());
307        let diff = (std_acc.sum() - log_acc.sum()).abs() / std_acc.sum();
308        assert!(diff < 1e-10, "Relative difference: {}", diff);
309    }
310}