1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
/// Zero-allocation variant of quantize_rmsnorm_q8_0
///
/// Writes results directly into pre-allocated output buffers.
pub fn quantize_rmsnorm_q8_0_into(
input: &[f32],
norm_weight: &[f32],
eps: f32,
scales: &mut [f32],
quants: &mut [i8],
) {
let hidden_dim = input.len();
debug_assert_eq!(hidden_dim, norm_weight.len());
// Compute sum of squares for RMSNorm
let sum_sq: f32 = input.iter().map(|x| x * x).sum();
let mean_sq = sum_sq / hidden_dim as f32;
let inv_rms = 1.0 / (mean_sq + eps).sqrt();
let num_blocks = hidden_dim.div_ceil(32);
for block_idx in 0..num_blocks {
let start = block_idx * 32;
let end = (start + 32).min(hidden_dim);
// Find max absolute value of normalized values for this block
let mut max_abs = 0.0f32;
for i in start..end {
let normalized = input[i] * inv_rms * norm_weight[i];
let abs = normalized.abs();
if abs > max_abs {
max_abs = abs;
}
}
// Compute scale
let scale = if max_abs > 1e-10 {
max_abs / 127.0
} else {
1.0 / 127.0
};
let inv_scale = 1.0 / scale;
scales[block_idx] = scale;
// Quantize normalized values
let quant_start = block_idx * 32;
for i in start..end {
let normalized = input[i] * inv_rms * norm_weight[i];
let q = (normalized * inv_scale).round();
quants[quant_start + (i - start)] = q.clamp(-128.0, 127.0) as i8;
}
// Pad to 32 if partial block
for j in (end - start)..32 {
quants[quant_start + j] = 0i8;
}
}
}
/// SIMD-accelerated fused SwiGLU activation: silu(gate) * up
///
/// Combines silu activation and element-wise multiply in a single pass
/// for better cache locality. Uses AVX2/AVX-512 SIMD where available.
///
/// # Arguments
/// * `gate` - Gate values, modified in-place to contain result
/// * `up` - Up projection values
pub fn fused_swiglu_simd(gate: &mut [f32], up: &[f32]) {
debug_assert_eq!(gate.len(), up.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
// SAFETY: AVX2 and FMA verified at runtime
unsafe {
fused_swiglu_avx2(gate, up);
}
return;
}
}
// Scalar fallback
fused_swiglu_scalar(gate, up);
}
/// Scalar fused SwiGLU: silu(gate) * up
///
/// Exposed as `pub(crate)` for direct testing on AVX2 machines.
#[inline]
pub(crate) fn fused_swiglu_scalar(gate: &mut [f32], up: &[f32]) {
for (g, &u) in gate.iter_mut().zip(up.iter()) {
// silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
let silu_g = *g / (1.0 + (-*g).exp());
*g = silu_g * u;
}
}
/// AVX2 SIMD fused SwiGLU with FMA
///
/// Computes silu(gate) * up using:
/// - Polynomial approximation for exp(-x)
/// - FMA for efficient multiply-add
/// - 8-wide AVX2 vectors
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
#[allow(clippy::many_single_char_names)]
// SAFETY: Caller must satisfy the documented preconditions
unsafe fn fused_swiglu_avx2(gate: &mut [f32], up: &[f32]) {
use std::arch::x86_64::{
_mm256_add_epi32, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtps_epi32, _mm256_floor_ps,
_mm256_fmadd_ps, _mm256_fnmadd_ps, _mm256_loadu_ps, _mm256_max_ps, _mm256_mul_ps,
_mm256_rcp_ps, _mm256_set1_epi32, _mm256_set1_ps, _mm256_setzero_ps, _mm256_slli_epi32,
_mm256_storeu_ps, _mm256_sub_ps,
};
// SAFETY: Memory safety ensured by bounds checking and alignment
unsafe {
let n = gate.len();
let mut i = 0;
// Constants for exp approximation (polynomial coefficients)
// Using 5th-degree polynomial approximation for exp(x) on [-87, 0]
let one = _mm256_set1_ps(1.0);
let ln2_inv = _mm256_set1_ps(1.442_695); // 1/ln(2)
let ln2 = _mm256_set1_ps(0.693_147_2);
let c0 = _mm256_set1_ps(1.0);
let c1 = _mm256_set1_ps(0.693_147_2); // ln(2)
let c2 = _mm256_set1_ps(0.240_226_5); // ln(2)^2 / 2!
let c3 = _mm256_set1_ps(0.055_504_11); // ln(2)^3 / 3!
let c4 = _mm256_set1_ps(0.009_618_13); // ln(2)^4 / 4!
let c5 = _mm256_set1_ps(0.001_333_36); // ln(2)^5 / 5!
let min_exp = _mm256_set1_ps(-87.0); // Minimum input to avoid underflow
let two = _mm256_set1_ps(2.0); // For Newton-Raphson
// Process 8 elements at a time
while i + 8 <= n {
// Load gate and up values
let g = _mm256_loadu_ps(gate.as_ptr().add(i));
let u = _mm256_loadu_ps(up.as_ptr().add(i));
// Compute -g for sigmoid
let neg_g = _mm256_sub_ps(_mm256_setzero_ps(), g);
// Clamp to avoid exp underflow
let neg_g_clamped = _mm256_max_ps(neg_g, min_exp);
// Fast exp approximation using 2^(x/ln2) = 2^n * 2^f where n=floor, f=frac
// n = floor(x * 1/ln2)
let xln2 = _mm256_mul_ps(neg_g_clamped, ln2_inv);
let n_f = _mm256_floor_ps(xln2);
let n_i = _mm256_cvtps_epi32(n_f);
// f = x - n * ln2 (fractional part scaled back)
let f = _mm256_fnmadd_ps(n_f, ln2, neg_g_clamped);
// Horner's method: c0 + f*(c1 + f*(c2 + f*(c3 + f*(c4 + f*c5))))
let p = _mm256_fmadd_ps(f, c5, c4);
let p = _mm256_fmadd_ps(f, p, c3);
let p = _mm256_fmadd_ps(f, p, c2);
let p = _mm256_fmadd_ps(f, p, c1);
let p = _mm256_fmadd_ps(f, p, c0);
// Scale by 2^n using integer bit manipulation
// 2^n = reinterpret((n + 127) << 23) as float
let bias = _mm256_set1_epi32(127);
let n_biased = _mm256_add_epi32(n_i, bias);
let exp_scale = _mm256_slli_epi32::<23>(n_biased);
let exp_scale_f = _mm256_castsi256_ps(exp_scale);
// exp(-g) = 2^n * p(f)
let exp_neg_g = _mm256_mul_ps(p, exp_scale_f);
// sigmoid(-(-g)) = 1 / (1 + exp(-g))
// Use fast reciprocal approximation with Newton-Raphson refinement
let denom = _mm256_add_ps(one, exp_neg_g);
let rcp = _mm256_rcp_ps(denom); // ~12-bit precision
// One Newton-Raphson iteration: x' = x * (2 - d*x)
let sigmoid = _mm256_mul_ps(rcp, _mm256_fnmadd_ps(denom, rcp, two));
// silu(g) = g * sigmoid(g)
let silu_g = _mm256_mul_ps(g, sigmoid);
// Result = silu(g) * u
let result = _mm256_mul_ps(silu_g, u);
// Store result
_mm256_storeu_ps(gate.as_mut_ptr().add(i), result);
i += 8;
}
// Handle remainder with scalar code
while i < n {
let g = gate[i];
let silu_g = g / (1.0 + (-g).exp());
gate[i] = silu_g * up[i];
i += 1;
}
}
}
/// SIMD-optimized in-place softmax
///
/// Computes softmax(x) = exp(x - max) / sum(exp(x - max))
/// Uses AVX2/AVX-512 for vectorized exp and horizontal operations.
///
/// # Arguments
/// * `x` - Slice to softmax in-place
#[inline]
pub fn softmax_simd(x: &mut [f32]) {
if x.is_empty() {
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
// SAFETY: Memory safety ensured by bounds checking and alignment
unsafe {
softmax_avx2(x);
}
return;
}
}
// Scalar fallback
softmax_scalar(x);
}
/// Scalar softmax
///
/// Exposed as `pub(crate)` for direct testing on AVX2 machines.
#[inline]
pub(crate) fn softmax_scalar(x: &mut [f32]) {
// Find max for numerical stability
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
// Compute exp(x - max) and sum
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
// Normalize
let inv_sum = 1.0 / sum;
for v in x.iter_mut() {
*v *= inv_sum;
}
}
/// AVX2 SIMD softmax - only SIMD for max-find and normalization
/// (exp() uses libm which is faster than polynomial for short vectors)
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(unsafe_op_in_unsafe_fn)]
// SAFETY: Caller must satisfy the documented preconditions
unsafe fn softmax_avx2(x: &mut [f32]) {
use std::arch::x86_64::{
_mm256_loadu_ps, _mm256_max_ps, _mm256_mul_ps, _mm256_set1_ps, _mm256_storeu_ps,
};
let n = x.len();
if n == 0 {
return;
}
// ============= Phase 1: Find max (SIMD) =============
let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY);
let mut i = 0;
while i + 8 <= n {
let v = _mm256_loadu_ps(x.as_ptr().add(i));
max_vec = _mm256_max_ps(max_vec, v);
i += 8;
}
let mut max_scalar = horizontal_max_avx2(max_vec);
for j in i..n {
max_scalar = max_scalar.max(x[j]);
}
// ============= Phase 2: Compute exp(x - max) (scalar libm) =============
let mut sum_scalar = 0.0f32;
for j in 0..n {
let exp_v = (x[j] - max_scalar).exp();
x[j] = exp_v;
sum_scalar += exp_v;
}
// ============= Phase 3: Normalize (SIMD) =============
let inv_sum = _mm256_set1_ps(1.0 / sum_scalar);
i = 0;
while i + 8 <= n {
let v = _mm256_loadu_ps(x.as_ptr().add(i));
let normalized = _mm256_mul_ps(v, inv_sum);
_mm256_storeu_ps(x.as_mut_ptr().add(i), normalized);
i += 8;
}
let inv_sum_scalar = 1.0 / sum_scalar;
for j in i..n {
x[j] *= inv_sum_scalar;
}
}
/// Horizontal max of 8-wide AVX2 vector
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(unsafe_op_in_unsafe_fn)]
// SAFETY: Caller must satisfy the documented preconditions
unsafe fn horizontal_max_avx2(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::{
_mm256_extractf128_ps, _mm_cvtss_f32, _mm_max_ps, _mm_max_ss, _mm_movehl_ps, _mm_shuffle_ps,
};
{
// Extract high and low 128-bit lanes
let hi = _mm256_extractf128_ps::<1>(v);
let lo = _mm256_extractf128_ps::<0>(v);
let max128 = _mm_max_ps(hi, lo);
// Reduce 4 to 2
let max64 = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
// Reduce 2 to 1
let max32 = _mm_max_ss(max64, _mm_shuffle_ps::<0x55>(max64, max64));
_mm_cvtss_f32(max32)
}
}
/// Quantize f32 activations to Q8_0 format for fast integer matmul
///
/// Returns (scales, quantized_values) where each block of 32 values
/// has one f32 scale and 32 int8 quantized values.
#[inline]
pub fn quantize_activations_q8_0(activations: &[f32]) -> (Vec<f32>, Vec<i8>) {
let num_blocks = activations.len().div_ceil(32);
let mut scales = Vec::with_capacity(num_blocks);
let mut quants = Vec::with_capacity(num_blocks * 32);
for block_idx in 0..num_blocks {
let start = block_idx * 32;
let end = (start + 32).min(activations.len());
// Find max absolute value for symmetric quantization
let mut max_abs = 0.0f32;
for i in start..end {
let abs = activations[i].abs();
if abs > max_abs {
max_abs = abs;
}
}
// Compute scale (avoid division by zero)
let scale = if max_abs > 1e-10 {
max_abs / 127.0
} else {
1.0 / 127.0
};
let inv_scale = 1.0 / scale;
scales.push(scale);
// Quantize values
for i in start..end {
let q = (activations[i] * inv_scale).round();
quants.push(q.clamp(-128.0, 127.0) as i8);
}
// Pad to 32 if partial block
for _ in end..(start + 32) {
quants.push(0i8);
}
}
(scales, quants)
}