boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
//! CPU quantized matmul kernels
//!
//! Dequantize-and-accumulate per block row for cache efficiency.
//! Computes: activation [M, K] × weight^T → output [M, N]
//!
//! Weight is stored as [N, K] (N output rows, K input cols each), matching
//! the packing axis contract: quantization blocks run along the last (K) axis.
//! We iterate over weight rows (output columns), dequantize one row at a time,
//! and accumulate the dot product contribution.
//!
//! Optimizations:
//! - Rayon parallelism over N (weight rows / output columns)
//! - Thread-local dequant buffers to avoid contention
//! - AVX2+FMA SIMD dot product

use rayon::prelude::*;

use super::{dequant, dequant_k_quants};
use crate::quant::QuantFormat;

/// f32 dot product with SIMD acceleration when available.
fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
    debug_assert_eq!(a.len(), b.len());

    #[cfg(target_arch = "x86_64")]
    {
        let len = a.len();
        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
            return unsafe { super::simd::dot_f32::dot_f32_avx2_fma(a.as_ptr(), b.as_ptr(), len) };
        }
        a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
    }

    #[cfg(target_arch = "aarch64")]
    unsafe {
        super::simd::aarch64::dot_f32::dot_f32_neon(a.as_ptr(), b.as_ptr(), a.len())
    }

    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
    {
        let mut sum = 0.0f32;
        for (&ai, &bi) in a.iter().zip(b.iter()) {
            sum += ai * bi;
        }
        sum
    }
}

/// Quantized matmul: activation \[M, K\] × weight\[N, K\]^T → output \[M, N\]
///
/// `act`: \[M * K\] f32 values (row-major)
/// `weight_bytes`: raw quantized bytes for \[N, K\] weight matrix (blocks along K)
/// `output`: \[M * N\] f32 values (row-major)
/// `m`, `k`, `n`: matrix dimensions
pub fn quant_matmul_f32(
    act: &[f32],
    weight_bytes: &[u8],
    output: &mut [f32],
    m: usize,
    k: usize,
    n: usize,
    format: QuantFormat,
) {
    debug_assert_eq!(act.len(), m * k);
    debug_assert_eq!(output.len(), m * n);

    let block_size = format.block_size();
    let block_bytes = format.block_bytes();
    let blocks_per_row = k / block_size;
    let row_bytes = blocks_per_row * block_bytes;

    debug_assert_eq!(weight_bytes.len(), n * row_bytes);

    // Parallel over chunks of output columns (N dimension).
    // Each chunk processes a range of weight rows independently with its own
    // dequant buffer, avoiding false sharing by writing to disjoint output regions.
    //
    // For decode (M=1), the output is [1, N] so we split N across threads.
    // Each thread dequantizes its weight rows and computes dot products.

    // Choose chunk size based on M:
    // - M=1 (decode GEMV): one chunk per thread to minimize Rayon scheduling overhead
    // - M>1 (prefill): more chunks for better load balancing
    let num_threads = rayon::current_num_threads();
    let target_chunks = if m == 1 { num_threads } else { num_threads * 4 };
    let chunk_size = n.div_ceil(target_chunks);
    let chunk_size = chunk_size.max(16); // minimum 16 rows per chunk

    // We need to scatter results: output[i * n + j] for each (i, j).
    // To avoid synchronization, we process column ranges in parallel.
    // Each thread writes to output[i * n + j_start..j_end] for all i.
    //
    // We use a flat output slice and index into it.

    // Collect column ranges and process in parallel
    let col_ranges: Vec<(usize, usize)> = (0..n)
        .step_by(chunk_size)
        .map(|start| (start, (start + chunk_size).min(n)))
        .collect();

    // Process column chunks in parallel using index-based approach.
    // Each iteration of par_bridge processes one column range with its own dequant buffer.
    // Output is written via unsafe pointer arithmetic to disjoint column ranges.
    let output_ptr = output.as_mut_ptr() as usize; // usize is Send+Sync

    // K-quants with dedicated fused dequant+dot kernels
    let use_fused = matches!(
        format,
        QuantFormat::Q2K
            | QuantFormat::Q3K
            | QuantFormat::Q4K
            | QuantFormat::Q5K
            | QuantFormat::Q6K
    );

    // Q8_K integer dot product path — quantize activations to Q8_K then integer arithmetic.
    // Matches llama.cpp's approach for K-quant formats.
    // NOTE: Q8_0 intentionally NOT routed here (per-32 scales vs Q8_K's per-256 scales).
    let use_q8k = use_fused && k % 256 == 0;

    // Pre-quantize activation rows to Q8_K (one per M row) if using integer path
    let q8k_block_bytes = super::simd::quantize_act_q8k::Q8K_BLOCK_BYTES;
    let q8k_blocks_per_row = k / 256;
    let q8k_row_size = q8k_blocks_per_row * q8k_block_bytes;
    let act_q8k: Vec<u8> = if use_q8k {
        let mut buf = vec![0u8; m * q8k_row_size];
        for i in 0..m {
            let act_row = &act[i * k..(i + 1) * k];
            let q8k_row = &mut buf[i * q8k_row_size..(i + 1) * q8k_row_size];
            super::simd::quantize_act_q8k::quantize_f32_to_q8k(act_row, q8k_row);
        }
        buf
    } else {
        Vec::new()
    };
    let act_q8k_ptr = act_q8k.as_ptr() as usize;

    col_ranges.par_iter().for_each(|&(j_start, j_end)| {
        let out = output_ptr as *mut f32;

        if use_q8k {
            // Integer maddubs path: Q8_K activation × Q4_K/Q6_K weight
            for i in 0..m {
                let q8k_row = unsafe {
                    std::slice::from_raw_parts(
                        (act_q8k_ptr as *const u8).add(i * q8k_row_size),
                        q8k_row_size,
                    )
                };

                for j in j_start..j_end {
                    let row_data = &weight_bytes[j * row_bytes..(j + 1) * row_bytes];
                    let val = fused_dot_q8k_dispatch(q8k_row, row_data, k, format);
                    unsafe {
                        *out.add(i * n + j) = val;
                    }
                }
            }
        } else if use_fused {
            // f32 FMA fused path (fallback for non-256-aligned K)
            let cols = j_end - j_start;
            let pairs = cols / 2;
            let remainder = cols % 2;

            for i in 0..m {
                let act_row = &act[i * k..(i + 1) * k];

                // Process pairs of weight rows
                for p in 0..pairs {
                    let j0 = j_start + p * 2;
                    let j1 = j0 + 1;
                    let row_data0 = &weight_bytes[j0 * row_bytes..(j0 + 1) * row_bytes];
                    let row_data1 = &weight_bytes[j1 * row_bytes..(j1 + 1) * row_bytes];

                    let val0 = fused_dot_dispatch(act_row, row_data0, k, format);
                    let val1 = fused_dot_dispatch(act_row, row_data1, k, format);
                    unsafe {
                        *out.add(i * n + j0) = val0;
                        *out.add(i * n + j1) = val1;
                    }
                }

                // Handle odd remainder
                if remainder > 0 {
                    let j = j_end - 1;
                    let row_data = &weight_bytes[j * row_bytes..(j + 1) * row_bytes];
                    let val = fused_dot_dispatch(act_row, row_data, k, format);
                    unsafe {
                        *out.add(i * n + j) = val;
                    }
                }
            }
        } else {
            // Dequant to buffer then SIMD dot for other formats
            let mut dequant_row = vec![0.0f32; k];
            for j in j_start..j_end {
                let row_start = j * row_bytes;
                let row_data = &weight_bytes[row_start..row_start + row_bytes];

                dequant_row_f32(row_data, &mut dequant_row, format);

                for i in 0..m {
                    let act_row = &act[i * k..(i + 1) * k];
                    let val = dot_f32(act_row, &dequant_row);
                    unsafe {
                        *out.add(i * n + j) = val;
                    }
                }
            }
        }
    });
}

/// Fused dot product dispatch for formats with SIMD fused kernels
fn fused_dot_dispatch(act_row: &[f32], row_data: &[u8], k: usize, format: QuantFormat) -> f32 {
    match format {
        QuantFormat::Q2K => super::simd::fused_q2k_dot::fused_dot_q2k(act_row, row_data, k),
        QuantFormat::Q3K => super::simd::fused_q3k_dot::fused_dot_q3k(act_row, row_data, k),
        QuantFormat::Q4K => super::simd::fused_q4k_dot::fused_dot_q4k(act_row, row_data, k),
        QuantFormat::Q5K => super::simd::fused_q5k_dot::fused_dot_q5k(act_row, row_data, k),
        QuantFormat::Q6K => super::simd::fused_q6k_dot::fused_dot_q6k(act_row, row_data, k),
        _ => unreachable!(),
    }
}

/// Q8_K integer dot product dispatch (maddubs path)
fn fused_dot_q8k_dispatch(act_q8k: &[u8], row_data: &[u8], k: usize, format: QuantFormat) -> f32 {
    match format {
        QuantFormat::Q2K => super::simd::fused_q2k_q8k_dot::fused_dot_q2k_q8k(act_q8k, row_data, k),
        QuantFormat::Q3K => super::simd::fused_q3k_q8k_dot::fused_dot_q3k_q8k(act_q8k, row_data, k),
        QuantFormat::Q4K => super::simd::fused_q4k_q8k_dot::fused_dot_q4k_q8k(act_q8k, row_data, k),
        QuantFormat::Q5K => super::simd::fused_q5k_q8k_dot::fused_dot_q5k_q8k(act_q8k, row_data, k),
        QuantFormat::Q6K => super::simd::fused_q6k_q8k_dot::fused_dot_q6k_q8k(act_q8k, row_data, k),
        _ => unreachable!(),
    }
}

/// Batched quantized matmul: activation \[M, K\] × multiple weight\[Ni, K\]^T → multiple output\[M, Ni\]
///
/// Processes all weight matrices together so the activation stays in L2 cache.
/// For M=1 decode with QKV (3 projections) or gate+up (2 projections), this avoids
/// re-reading the activation vector 3-5x from L3/memory.
pub fn quant_matmul_batch_f32(
    act: &[f32],
    weight_list: &[(&[u8], usize)], // (weight_bytes, n) per matrix
    outputs: &mut [&mut [f32]],
    m: usize,
    k: usize,
    format: QuantFormat,
) {
    let block_size = format.block_size();
    let block_bytes = format.block_bytes();
    let blocks_per_row = k / block_size;
    let row_bytes = blocks_per_row * block_bytes;

    let use_fused = matches!(
        format,
        QuantFormat::Q2K
            | QuantFormat::Q3K
            | QuantFormat::Q4K
            | QuantFormat::Q5K
            | QuantFormat::Q6K
    );
    let use_q8k = use_fused && k % 256 == 0;

    // Pre-quantize activation rows to Q8_K if using integer path
    let q8k_block_bytes = super::simd::quantize_act_q8k::Q8K_BLOCK_BYTES;
    let q8k_blocks_per_row = k / 256;
    let q8k_row_size = q8k_blocks_per_row * q8k_block_bytes;
    let act_q8k: Vec<u8> = if use_q8k {
        let mut buf = vec![0u8; m * q8k_row_size];
        for i in 0..m {
            let act_row = &act[i * k..(i + 1) * k];
            let q8k_row = &mut buf[i * q8k_row_size..(i + 1) * q8k_row_size];
            super::simd::quantize_act_q8k::quantize_f32_to_q8k(act_row, q8k_row);
        }
        buf
    } else {
        Vec::new()
    };
    let act_q8k_ptr = act_q8k.as_ptr() as usize;

    // For each activation row, compute dot products against all weight matrices.
    // This keeps the activation in L2 cache while streaming through weight data.
    //
    // We parallelize over the N dimension of each weight matrix (same as single matmul),
    // but process all matrices for each column range before moving on.

    // Find the max N across all weight matrices for chunking
    let max_n: usize = weight_list.iter().map(|&(_, n)| n).max().unwrap_or(0);
    if max_n == 0 {
        return;
    }

    let num_threads = rayon::current_num_threads();
    let target_chunks = if m == 1 { num_threads } else { num_threads * 4 };
    let chunk_size = max_n.div_ceil(target_chunks);
    let chunk_size = chunk_size.max(16);

    // Collect output pointers as usize for Send+Sync
    let output_ptrs: Vec<(usize, usize)> = outputs
        .iter()
        .zip(weight_list.iter())
        .map(|(out, &(_, n))| (out.as_ptr() as usize, n))
        .collect();
    let weight_ptrs: Vec<(usize, usize)> = weight_list
        .iter()
        .map(|&(w, n)| (w.as_ptr() as usize, n))
        .collect();

    let col_ranges: Vec<(usize, usize)> = (0..max_n)
        .step_by(chunk_size)
        .map(|start| (start, (start + chunk_size).min(max_n)))
        .collect();

    col_ranges.par_iter().for_each(|&(j_start, j_end)| {
        // For each activation row
        for i in 0..m {
            // Process all weight matrices for this activation row and column range
            for (w_idx, &(w_ptr, n)) in weight_ptrs.iter().enumerate() {
                let (out_ptr, _) = output_ptrs[w_idx];
                let out = out_ptr as *mut f32;
                let w_base = w_ptr as *const u8;

                let j_end_clamped = j_end.min(n);
                if j_start >= n {
                    continue;
                }

                if use_q8k {
                    let q8k_row = unsafe {
                        std::slice::from_raw_parts(
                            (act_q8k_ptr as *const u8).add(i * q8k_row_size),
                            q8k_row_size,
                        )
                    };
                    for j in j_start..j_end_clamped {
                        let row_data = unsafe {
                            std::slice::from_raw_parts(w_base.add(j * row_bytes), row_bytes)
                        };
                        let val = fused_dot_q8k_dispatch(q8k_row, row_data, k, format);
                        unsafe {
                            *out.add(i * n + j) = val;
                        }
                    }
                } else if use_fused {
                    let act_row = &act[i * k..(i + 1) * k];
                    for j in j_start..j_end_clamped {
                        let row_data = unsafe {
                            std::slice::from_raw_parts(w_base.add(j * row_bytes), row_bytes)
                        };
                        let val = fused_dot_dispatch(act_row, row_data, k, format);
                        unsafe {
                            *out.add(i * n + j) = val;
                        }
                    }
                } else {
                    let act_row = &act[i * k..(i + 1) * k];
                    // Scalar path with dequant buffer
                    let mut dequant_row = vec![0.0f32; k];
                    for j in j_start..j_end_clamped {
                        let row_data = unsafe {
                            std::slice::from_raw_parts(w_base.add(j * row_bytes), row_bytes)
                        };
                        dequant_row_f32(row_data, &mut dequant_row, format);
                        let val = dot_f32(act_row, &dequant_row);
                        unsafe {
                            *out.add(i * n + j) = val;
                        }
                    }
                }
            }
        }
    });
}

/// Dequantize a single row of quantized blocks into f32
pub fn dequant_row_f32(row_bytes: &[u8], output: &mut [f32], format: QuantFormat) {
    match format {
        // Simple quants
        QuantFormat::Q4_0 => dequant::dequant_q4_0(row_bytes, output),
        QuantFormat::Q4_1 => dequant::dequant_q4_1(row_bytes, output),
        QuantFormat::Q5_0 => dequant::dequant_q5_0(row_bytes, output),
        QuantFormat::Q5_1 => dequant::dequant_q5_1(row_bytes, output),
        QuantFormat::Q8_0 => dequant::dequant_q8_0(row_bytes, output),
        QuantFormat::Q8_1 => dequant::dequant_q8_1(row_bytes, output),
        // K-quants
        QuantFormat::Q2K => dequant_k_quants::dequant_q2k(row_bytes, output),
        QuantFormat::Q3K => dequant_k_quants::dequant_q3k(row_bytes, output),
        QuantFormat::Q4K => dequant::dequant_q4k(row_bytes, output),
        QuantFormat::Q5K => dequant_k_quants::dequant_q5k(row_bytes, output),
        QuantFormat::Q6K => dequant::dequant_q6k(row_bytes, output),
        QuantFormat::Q8K => dequant_k_quants::dequant_q8k(row_bytes, output),
        // IQ/TQ formats
        QuantFormat::IQ4NL => dequant::dequant_iq4_nl(row_bytes, output),
        QuantFormat::IQ4XS => dequant::dequant_iq4_xs(row_bytes, output),
        QuantFormat::IQ2XXS => dequant::dequant_iq2_xxs(row_bytes, output),
        QuantFormat::IQ2XS => dequant::dequant_iq2_xs(row_bytes, output),
        QuantFormat::IQ2S => dequant::dequant_iq2_s(row_bytes, output),
        QuantFormat::IQ3XXS => dequant::dequant_iq3_xxs(row_bytes, output),
        QuantFormat::IQ3S => dequant::dequant_iq3_s(row_bytes, output),
        QuantFormat::IQ1S => dequant::dequant_iq1_s(row_bytes, output),
        QuantFormat::IQ1M => dequant::dequant_iq1_m(row_bytes, output),
        QuantFormat::TQ1_0 => dequant::dequant_tq1_0(row_bytes, output),
        QuantFormat::TQ2_0 => dequant::dequant_tq2_0(row_bytes, output),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use half::f16;

    #[test]
    fn test_quant_matmul_q4_0_identity_like() {
        // 1×32 activation × 1×32 weight (single output element)
        // activation = all 1.0, weight dequantizes to all 2.0
        // result = 32 * 1.0 * 2.0 = 64.0
        let m = 1;
        let k = 32;
        let n = 1;

        let act = vec![1.0f32; m * k];

        // Q4_0 block: scale=2.0, nibbles=0x99 → (9-8)*2.0 = 2.0
        let mut block = [0u8; 18];
        block[0..2].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
        block[2..18].fill(0x99);

        let mut output = vec![0.0f32; m * n];
        quant_matmul_f32(&act, &block, &mut output, m, k, n, QuantFormat::Q4_0);

        assert!(
            (output[0] - 64.0).abs() < 0.5,
            "expected ~64.0, got {}",
            output[0]
        );
    }

    #[test]
    fn test_quant_matmul_q8_0_2x1() {
        // 2×32 activation × 1×32 weight → 2×1 output
        let m = 2;
        let k = 32;
        let n = 1;

        // Row 0: all 1.0, Row 1: all 0.5
        let mut act = vec![0.0f32; m * k];
        act[..k].fill(1.0);
        act[k..].fill(0.5);

        // Q8_0 block: scale=0.5, qs=4 → value = 4 * 0.5 = 2.0
        let mut block = [0u8; 34];
        block[0..2].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
        block[2..34].fill(4);

        let mut output = vec![0.0f32; m * n];
        quant_matmul_f32(&act, &block, &mut output, m, k, n, QuantFormat::Q8_0);

        // Row 0: 32 * 1.0 * 2.0 = 64.0
        // Row 1: 32 * 0.5 * 2.0 = 32.0
        assert!(
            (output[0] - 64.0).abs() < 0.5,
            "expected ~64.0, got {}",
            output[0]
        );
        assert!(
            (output[1] - 32.0).abs() < 0.5,
            "expected ~32.0, got {}",
            output[1]
        );
    }

    #[test]
    fn test_quant_matmul_multiple_output_cols() {
        // 1×32 activation × 2×32 weight → 1×2 output
        let m = 1;
        let k = 32;
        let n = 2;

        let act = vec![1.0f32; m * k];

        // Weight row 0: scale=1.0, nibbles=0x99 → value=1.0, dot=32.0
        let mut block0 = [0u8; 18];
        block0[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
        block0[2..18].fill(0x99);

        // Weight row 1: scale=3.0, nibbles=0x99 → value=3.0, dot=96.0
        let mut block1 = [0u8; 18];
        block1[0..2].copy_from_slice(&f16::from_f32(3.0).to_le_bytes());
        block1[2..18].fill(0x99);

        let mut weight_bytes = Vec::new();
        weight_bytes.extend_from_slice(&block0);
        weight_bytes.extend_from_slice(&block1);

        let mut output = vec![0.0f32; m * n];
        quant_matmul_f32(&act, &weight_bytes, &mut output, m, k, n, QuantFormat::Q4_0);

        assert!(
            (output[0] - 32.0).abs() < 0.5,
            "expected ~32.0, got {}",
            output[0]
        );
        assert!(
            (output[1] - 96.0).abs() < 0.5,
            "expected ~96.0, got {}",
            output[1]
        );
    }
}