gllm_kernels/ops/
softmax.rs1use burn::tensor::backend::Backend;
4use burn::tensor::Tensor;
5
6#[inline]
8pub fn log_add_exp(a: f64, b: f64) -> f64 {
9 if a.is_infinite() && a.is_sign_negative() {
10 return b;
11 }
12 if b.is_infinite() && b.is_sign_negative() {
13 return a;
14 }
15
16 let max = a.max(b);
17 let min = a.min(b);
18
19 max + (1.0 + (min - max).exp()).ln()
20}
21
22#[inline]
24pub fn log_sum_exp(values: &[f64]) -> f64 {
25 if values.is_empty() {
26 return f64::NEG_INFINITY;
27 }
28
29 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
30
31 if max.is_infinite() {
32 return max;
33 }
34
35 let sum: f64 = values.iter().map(|&x| (x - max).exp()).sum();
36
37 max + sum.ln()
38}
39
40pub fn log_sum_exp_kahan(values: &[f64]) -> f64 {
42 use super::stable_accumulator::KahanAccumulator;
43
44 if values.is_empty() {
45 return f64::NEG_INFINITY;
46 }
47
48 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
49
50 if max.is_infinite() {
51 return max;
52 }
53
54 let mut sum = KahanAccumulator::<f64>::new();
55 for &x in values {
56 sum.add((x - max).exp());
57 }
58
59 max + sum.value().ln()
60}
61
62#[derive(Debug, Clone)]
64pub struct LogSpaceSoftmax {
65 m: f64,
67 log_l: f64,
69 count: usize,
71}
72
73impl Default for LogSpaceSoftmax {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl LogSpaceSoftmax {
80 pub fn new() -> Self {
82 Self {
83 m: f64::NEG_INFINITY,
84 log_l: f64::NEG_INFINITY,
85 count: 0,
86 }
87 }
88
89 pub fn update(&mut self, block_max: f64, block_log_sum_exp: f64) -> f64 {
91 let m_new = self.m.max(block_max);
92
93 let prev_contrib = (self.m - m_new) + self.log_l;
94 let new_contrib = (block_max - m_new) + block_log_sum_exp;
95 let log_l_new = log_add_exp(prev_contrib, new_contrib);
96
97 let log_scale = self.m - m_new;
98
99 self.m = m_new;
100 self.log_l = log_l_new;
101 self.count += 1;
102
103 log_scale
104 }
105
106 pub fn update_raw(&mut self, block_max: f64, block_sum_exp: f64) -> f64 {
108 let block_log_sum_exp = block_sum_exp.ln();
109 self.update(block_max, block_log_sum_exp)
110 }
111
112 #[inline]
114 pub fn max(&self) -> f64 {
115 self.m
116 }
117
118 #[inline]
120 pub fn log_sum(&self) -> f64 {
121 self.log_l
122 }
123
124 #[inline]
126 pub fn sum(&self) -> f64 {
127 self.log_l.exp()
128 }
129
130 #[inline]
132 pub fn count(&self) -> usize {
133 self.count
134 }
135
136 #[inline]
138 pub fn normalization_factor(&self) -> f64 {
139 (-self.log_l).exp()
140 }
141
142 pub fn merge(&mut self, other: &LogSpaceSoftmax) -> f64 {
144 if other.count == 0 {
145 return 0.0;
146 }
147 if self.count == 0 {
148 self.m = other.m;
149 self.log_l = other.log_l;
150 self.count = other.count;
151 return f64::NEG_INFINITY;
152 }
153
154 let m_new = self.m.max(other.m);
155
156 let self_contrib = (self.m - m_new) + self.log_l;
157 let other_contrib = (other.m - m_new) + other.log_l;
158 let log_l_new = log_add_exp(self_contrib, other_contrib);
159
160 let log_scale = self.m - m_new;
161
162 self.m = m_new;
163 self.log_l = log_l_new;
164 self.count += other.count;
165
166 log_scale
167 }
168
169 pub fn reset(&mut self) {
171 self.m = f64::NEG_INFINITY;
172 self.log_l = f64::NEG_INFINITY;
173 self.count = 0;
174 }
175}
176
177pub struct TensorLogOps;
179
180impl TensorLogOps {
181 pub fn log_sum_exp<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
183 let max = tensor.clone().max_dim(dim);
184 let shifted = tensor - max.clone();
185 let sum_exp = shifted.exp().sum_dim(dim);
186 max + sum_exp.log()
187 }
188
189 pub fn log_softmax<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
191 let log_sum = Self::log_sum_exp(tensor.clone(), dim);
192 tensor - log_sum
193 }
194
195 pub fn extract_softmax_stats<B: Backend>(
197 tensor: Tensor<B, 4>,
198 dim: usize,
199 ) -> (Tensor<B, 4>, Tensor<B, 4>) {
200 let max = tensor.clone().max_dim(dim);
201 let shifted = tensor - max.clone();
202 let sum_exp = shifted.exp().sum_dim(dim);
203 let log_sum = sum_exp.log();
204 (max, log_sum)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_log_add_exp_basic() {
214 let result = log_add_exp(0.0, 0.0);
215 assert!((result - 2.0_f64.ln()).abs() < 1e-10);
216
217 let result = log_add_exp(1.0, 2.0);
218 let expected = (1.0_f64.exp() + 2.0_f64.exp()).ln();
219 assert!((result - expected).abs() < 1e-10);
220 }
221
222 #[test]
223 fn test_log_add_exp_extreme() {
224 let result = log_add_exp(1000.0, 1000.0);
225 assert!((result - (1000.0 + 2.0_f64.ln())).abs() < 1e-10);
226
227 let result = log_add_exp(-1000.0, -1000.0);
228 assert!((result - (-1000.0 + 2.0_f64.ln())).abs() < 1e-10);
229
230 let result = log_add_exp(1000.0, 0.0);
231 assert!((result - 1000.0).abs() < 1e-10);
232 }
233
234 #[test]
235 fn test_log_add_exp_neg_infinity() {
236 assert_eq!(log_add_exp(f64::NEG_INFINITY, 5.0), 5.0);
237 assert_eq!(log_add_exp(5.0, f64::NEG_INFINITY), 5.0);
238 assert_eq!(
239 log_add_exp(f64::NEG_INFINITY, f64::NEG_INFINITY),
240 f64::NEG_INFINITY
241 );
242 }
243
244 #[test]
245 fn test_log_sum_exp() {
246 let values = vec![1.0, 2.0, 3.0];
247 let result = log_sum_exp(&values);
248 let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
249 assert!((result - expected).abs() < 1e-10);
250 }
251
252 #[test]
253 fn test_log_sum_exp_large_sequence() {
254 let values: Vec<f64> = (0..10000).map(|i| (i as f64) * 0.001).collect();
255 let result = log_sum_exp(&values);
256
257 assert!(result.is_finite());
258 assert!(result > 0.0);
259 }
260
261 #[test]
262 fn test_log_space_softmax_update() {
263 let mut acc = LogSpaceSoftmax::new();
264
265 acc.update_raw(5.0, 10.0);
266 assert_eq!(acc.max(), 5.0);
267 assert!((acc.sum() - 10.0).abs() < 1e-10);
268
269 acc.update_raw(3.0, 5.0);
270 assert_eq!(acc.max(), 5.0);
271 let expected_sum = 10.0 + 5.0 * (-2.0_f64).exp();
272 assert!((acc.sum() - expected_sum).abs() < 1e-8);
273 }
274
275 #[test]
276 fn test_log_space_softmax_merge() {
277 let mut acc1 = LogSpaceSoftmax::new();
278 acc1.update_raw(5.0, 10.0);
279
280 let mut acc2 = LogSpaceSoftmax::new();
281 acc2.update_raw(3.0, 5.0);
282
283 acc1.merge(&acc2);
284
285 assert_eq!(acc1.max(), 5.0);
286 let expected_sum = 10.0 + 5.0 * (-2.0_f64).exp();
287 assert!((acc1.sum() - expected_sum).abs() < 1e-8);
288 }
289
290 #[test]
291 fn test_log_space_vs_standard() {
292 use super::super::stable_accumulator::StableAccumulator;
293
294 let blocks: Vec<(f64, f64)> = vec![(5.0, 10.0), (3.0, 5.0), (7.0, 20.0), (1.0, 2.0)];
295
296 let mut std_acc = StableAccumulator::default_config();
297 for (max, sum_exp) in blocks.iter() {
298 std_acc.update(*max, *sum_exp);
299 }
300
301 let mut log_acc = LogSpaceSoftmax::new();
302 for (max, sum_exp) in blocks.iter() {
303 log_acc.update_raw(*max, *sum_exp);
304 }
305
306 assert_eq!(std_acc.max(), log_acc.max());
307 let diff = (std_acc.sum() - log_acc.sum()).abs() / std_acc.sum();
308 assert!(diff < 1e-10, "Relative difference: {}", diff);
309 }
310}