ruvllm 2.2.0

LLM serving runtime with Ruvector integration - Paged attention, KV cache, and SONA learning
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
//
// GEMV (General Matrix-Vector Multiplication) - Metal Compute Shader
// Optimized for Apple Silicon M4 Pro with simdgroup operations
//
// Computes y = A * x where A is (m x n), x is (n), y is (m)
// Target: 100+ GFLOPS on M4 Pro GPU (vs ~35 GFLOPS CPU)
//
// Optimizations:
// - Simdgroup reduction for efficient parallel reduction
// - Tiled memory access for optimal bandwidth
// - FP16 compute path for 2x throughput
// - Vectorized loads (float4/half4) for coalesced access
// - Optimal threadgroup memory layout for 16KB L1
//
// M4 Pro Specifications:
// - 16KB L1 data cache per core
// - 192KB L2 per core cluster
// - 32-wide SIMD groups
// - 1024 threads per threadgroup max
// - ~3 TFLOPS FP16 compute
//

#include <metal_stdlib>
using namespace metal;

// ============================================================================
// M4 Pro Tuned Constants for GEMV
// ============================================================================

// Threads per output row - optimal for M4 Pro SIMD width
constant uint GEMV_THREADS_PER_ROW = 32;

// Number of rows processed per threadgroup
constant uint GEMV_ROWS_PER_BLOCK = 4;

// Vector elements processed per thread per iteration
constant uint GEMV_ELEMENTS_PER_THREAD = 8;

// Block size for K dimension tiling (fits in threadgroup memory)
constant uint GEMV_K_BLOCK = 256;

// =============================================================================
// GEMV Parameters Structure (matches Rust GemvParams)
// =============================================================================
struct GemvParams {
    uint m;      // Rows of A (output dimension)
    uint n;      // Columns of A (input dimension)
    uint lda;    // Leading dimension of A
    float alpha; // Scale factor (default 1.0)
    float beta;  // Output scale factor (default 0.0, for y = alpha*A*x + beta*y)
};

// =============================================================================
// High-Performance FP32 GEMV with simdgroup reduction
// Grid: (tiles_m, 1, 1) where tiles_m = ceil(m / GEMV_ROWS_PER_BLOCK)
// Threadgroup: (GEMV_THREADS_PER_ROW, GEMV_ROWS_PER_BLOCK, 1) = (32, 4, 1) = 128 threads
// Target: 100+ GFLOPS on M4 Pro GPU
// =============================================================================
kernel void gemv_optimized_f32(
    device const float* A [[buffer(0)]],      // Matrix (m x n)
    device const float* x [[buffer(1)]],      // Input vector (n)
    device float* y [[buffer(2)]],            // Output vector (m)
    constant GemvParams& params [[buffer(3)]],
    uint2 gid [[threadgroup_position_in_grid]],
    uint2 tid [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_group [[simdgroup_index_in_threadgroup]]
) {
    const uint row_base = gid.x * GEMV_ROWS_PER_BLOCK;
    const uint local_row = tid.y;
    const uint row = row_base + local_row;

    // Early exit if row out of bounds
    if (row >= params.m) return;

    // Each thread in the row processes a portion of the dot product
    const uint lane = tid.x;  // 0-31
    const uint n = params.n;
    const uint lda = params.lda;

    // Accumulator for this thread's partial sum
    float sum = 0.0f;

    // Get row pointer
    device const float* a_row = A + row * lda;

    // Process elements in vectorized chunks of 4
    const uint vec_start = lane * 4;
    const uint vec_stride = GEMV_THREADS_PER_ROW * 4;  // 128 elements per iteration

    uint col = vec_start;
    while (col + 4 <= n) {
        // Vectorized load from A and x
        float4 a_val = *reinterpret_cast<device const float4*>(a_row + col);
        float4 x_val = *reinterpret_cast<device const float4*>(x + col);

        // Fused multiply-add
        sum = fma(a_val.x, x_val.x, sum);
        sum = fma(a_val.y, x_val.y, sum);
        sum = fma(a_val.z, x_val.z, sum);
        sum = fma(a_val.w, x_val.w, sum);

        col += vec_stride;
    }

    // Handle remaining elements (scalar)
    col = (n / vec_stride) * vec_stride + lane;
    while (col < n) {
        sum = fma(a_row[col], x[col], sum);
        col += GEMV_THREADS_PER_ROW;
    }

    // Simdgroup reduction across all 32 lanes
    // M4 Pro has efficient simd_sum for warp-level reduction
    float row_sum = simd_sum(sum);

    // Lane 0 writes the final result
    if (lane == 0) {
        if (params.beta != 0.0f) {
            y[row] = params.alpha * row_sum + params.beta * y[row];
        } else {
            y[row] = params.alpha * row_sum;
        }
    }
}

// =============================================================================
// High-Performance FP16 GEMV with simdgroup reduction
// Achieves 2x throughput vs FP32 on M4 Pro's FP16 units
// Target: 200+ GFLOPS theoretical on M4 Pro GPU
// =============================================================================
kernel void gemv_optimized_f16(
    device const half* A [[buffer(0)]],       // Matrix (m x n) in FP16
    device const half* x [[buffer(1)]],       // Input vector (n) in FP16
    device half* y [[buffer(2)]],             // Output vector (m) in FP16
    constant GemvParams& params [[buffer(3)]],
    uint2 gid [[threadgroup_position_in_grid]],
    uint2 tid [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_group [[simdgroup_index_in_threadgroup]]
) {
    const uint row_base = gid.x * GEMV_ROWS_PER_BLOCK;
    const uint local_row = tid.y;
    const uint row = row_base + local_row;

    if (row >= params.m) return;

    const uint lane = tid.x;
    const uint n = params.n;
    const uint lda = params.lda;

    // Use FP32 accumulator for precision, FP16 for memory bandwidth
    float sum = 0.0f;

    device const half* a_row = A + row * lda;

    // Process elements in vectorized chunks of 8 (half4 * 2)
    const uint vec_start = lane * 8;
    const uint vec_stride = GEMV_THREADS_PER_ROW * 8;  // 256 elements per iteration

    uint col = vec_start;
    while (col + 8 <= n) {
        // Vectorized load from A and x (half4 for optimal bandwidth)
        half4 a_val0 = *reinterpret_cast<device const half4*>(a_row + col);
        half4 a_val1 = *reinterpret_cast<device const half4*>(a_row + col + 4);
        half4 x_val0 = *reinterpret_cast<device const half4*>(x + col);
        half4 x_val1 = *reinterpret_cast<device const half4*>(x + col + 4);

        // Accumulate in FP32 for precision
        sum = fma(float(a_val0.x), float(x_val0.x), sum);
        sum = fma(float(a_val0.y), float(x_val0.y), sum);
        sum = fma(float(a_val0.z), float(x_val0.z), sum);
        sum = fma(float(a_val0.w), float(x_val0.w), sum);
        sum = fma(float(a_val1.x), float(x_val1.x), sum);
        sum = fma(float(a_val1.y), float(x_val1.y), sum);
        sum = fma(float(a_val1.z), float(x_val1.z), sum);
        sum = fma(float(a_val1.w), float(x_val1.w), sum);

        col += vec_stride;
    }

    // Handle remaining chunks of 4
    while (col + 4 <= n) {
        half4 a_val = *reinterpret_cast<device const half4*>(a_row + col);
        half4 x_val = *reinterpret_cast<device const half4*>(x + col);

        sum = fma(float(a_val.x), float(x_val.x), sum);
        sum = fma(float(a_val.y), float(x_val.y), sum);
        sum = fma(float(a_val.z), float(x_val.z), sum);
        sum = fma(float(a_val.w), float(x_val.w), sum);

        col += GEMV_THREADS_PER_ROW * 4;
    }

    // Handle remaining scalar elements
    col = (n / 4) * 4 + lane;
    while (col < n) {
        sum = fma(float(a_row[col]), float(x[col]), sum);
        col += GEMV_THREADS_PER_ROW;
    }

    // Simdgroup reduction
    float row_sum = simd_sum(sum);

    // Lane 0 writes result
    if (lane == 0) {
        half alpha_h = half(params.alpha);
        if (params.beta != 0.0f) {
            half beta_h = half(params.beta);
            y[row] = alpha_h * half(row_sum) + beta_h * y[row];
        } else {
            y[row] = alpha_h * half(row_sum);
        }
    }
}

// =============================================================================
// Batched GEMV for multi-head attention (FP32)
// Each batch is an independent GEMV: y[b] = A[b] * x[b]
// Grid: (tiles_m, batch_size, 1)
// =============================================================================
kernel void batched_gemv_f32(
    device const float* A [[buffer(0)]],      // Batched matrices (batch, m, n)
    device const float* x [[buffer(1)]],      // Batched vectors (batch, n)
    device float* y [[buffer(2)]],            // Output (batch, m)
    constant uint4& dims [[buffer(3)]],       // (m, n, batch_size, 0)
    uint2 gid [[threadgroup_position_in_grid]],
    uint2 tid [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]]
) {
    const uint m = dims.x;
    const uint n = dims.y;
    const uint batch = gid.y;

    if (batch >= dims.z) return;

    const uint row_base = gid.x * GEMV_ROWS_PER_BLOCK;
    const uint local_row = tid.y;
    const uint row = row_base + local_row;

    if (row >= m) return;

    const uint lane = tid.x;

    // Get batch offsets
    device const float* a_batch = A + batch * m * n;
    device const float* x_batch = x + batch * n;
    device float* y_batch = y + batch * m;

    device const float* a_row = a_batch + row * n;

    float sum = 0.0f;

    // Vectorized processing
    uint col = lane * 4;
    const uint vec_stride = GEMV_THREADS_PER_ROW * 4;

    while (col + 4 <= n) {
        float4 a_val = *reinterpret_cast<device const float4*>(a_row + col);
        float4 x_val = *reinterpret_cast<device const float4*>(x_batch + col);

        sum = fma(a_val.x, x_val.x, sum);
        sum = fma(a_val.y, x_val.y, sum);
        sum = fma(a_val.z, x_val.z, sum);
        sum = fma(a_val.w, x_val.w, sum);

        col += vec_stride;
    }

    // Remaining elements
    for (uint c = (n / 4) * 4 + lane; c < n; c += GEMV_THREADS_PER_ROW) {
        sum = fma(a_row[c], x_batch[c], sum);
    }

    // Simdgroup reduction
    float row_sum = simd_sum(sum);

    if (lane == 0) {
        y_batch[row] = row_sum;
    }
}

// =============================================================================
// Tiled GEMV with shared memory for larger K dimensions
// Uses threadgroup memory to cache x vector for reuse across rows
// Grid: (tiles_m, 1, 1)
// Threadgroup: (256, 1, 1) - 8 rows * 32 threads
// =============================================================================
constant uint GEMV_TILED_ROWS = 8;
constant uint GEMV_TILED_THREADS = 32;
constant uint GEMV_TILED_K_BLOCK = 512;

kernel void gemv_tiled_f32(
    device const float* A [[buffer(0)]],
    device const float* x [[buffer(1)]],
    device float* y [[buffer(2)]],
    constant GemvParams& params [[buffer(3)]],
    uint gid [[threadgroup_position_in_grid]],
    uint tid [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_group [[simdgroup_index_in_threadgroup]]
) {
    // Shared memory for x vector tile
    threadgroup float shared_x[GEMV_TILED_K_BLOCK];

    const uint row = gid * GEMV_TILED_ROWS + simd_group;
    if (row >= params.m) return;

    const uint lane = simd_lane;
    const uint n = params.n;
    const uint lda = params.lda;

    device const float* a_row = A + row * lda;

    float sum = 0.0f;

    // Process K in blocks to maximize cache reuse
    for (uint k_block = 0; k_block < n; k_block += GEMV_TILED_K_BLOCK) {
        // Cooperative load of x into shared memory
        const uint load_start = tid;
        const uint block_size = min(GEMV_TILED_K_BLOCK, n - k_block);

        for (uint i = load_start; i < block_size; i += GEMV_TILED_ROWS * GEMV_TILED_THREADS) {
            shared_x[i] = x[k_block + i];
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // Process block with vectorization
        uint col = lane * 4;
        while (col + 4 <= block_size) {
            float4 a_val = *reinterpret_cast<device const float4*>(a_row + k_block + col);
            float4 x_val = *reinterpret_cast<threadgroup const float4*>(shared_x + col);

            sum = fma(a_val.x, x_val.x, sum);
            sum = fma(a_val.y, x_val.y, sum);
            sum = fma(a_val.z, x_val.z, sum);
            sum = fma(a_val.w, x_val.w, sum);

            col += GEMV_TILED_THREADS * 4;
        }

        // Remaining elements
        for (uint c = col; c < block_size; c += GEMV_TILED_THREADS) {
            sum = fma(a_row[k_block + c], shared_x[c], sum);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // Simdgroup reduction
    float row_sum = simd_sum(sum);

    if (lane == 0) {
        if (params.beta != 0.0f) {
            y[row] = params.alpha * row_sum + params.beta * y[row];
        } else {
            y[row] = params.alpha * row_sum;
        }
    }
}

// =============================================================================
// Simple GEMV for compatibility (single thread per row)
// Grid: (m, 1, 1)
// =============================================================================
kernel void gemv_simple_f32(
    device const float* A [[buffer(0)]],
    device const float* x [[buffer(1)]],
    device float* y [[buffer(2)]],
    constant GemvParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.m) return;

    device const float* a_row = A + gid * params.lda;

    float sum = 0.0f;

    // Vectorized loop
    const uint n_vec = params.n / 4;
    for (uint i = 0; i < n_vec; i++) {
        float4 a_val = *reinterpret_cast<device const float4*>(a_row + i * 4);
        float4 x_val = *reinterpret_cast<device const float4*>(x + i * 4);
        sum = fma(a_val.x, x_val.x, sum);
        sum = fma(a_val.y, x_val.y, sum);
        sum = fma(a_val.z, x_val.z, sum);
        sum = fma(a_val.w, x_val.w, sum);
    }

    // Remainder
    for (uint i = n_vec * 4; i < params.n; i++) {
        sum = fma(a_row[i], x[i], sum);
    }

    if (params.beta != 0.0f) {
        y[gid] = params.alpha * sum + params.beta * y[gid];
    } else {
        y[gid] = params.alpha * sum;
    }
}

// =============================================================================
// Mixed precision GEMV: FP16 matrix, FP32 vector -> FP32 output
// For inference with quantized weights
// =============================================================================
kernel void gemv_mixed_f16_f32(
    device const half* A [[buffer(0)]],       // Matrix in FP16
    device const float* x [[buffer(1)]],      // Vector in FP32
    device float* y [[buffer(2)]],            // Output in FP32
    constant GemvParams& params [[buffer(3)]],
    uint2 gid [[threadgroup_position_in_grid]],
    uint2 tid [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]]
) {
    const uint row_base = gid.x * GEMV_ROWS_PER_BLOCK;
    const uint local_row = tid.y;
    const uint row = row_base + local_row;

    if (row >= params.m) return;

    const uint lane = tid.x;
    const uint n = params.n;
    const uint lda = params.lda;

    float sum = 0.0f;

    device const half* a_row = A + row * lda;

    // Process in chunks of 8 (half4 * 2)
    uint col = lane * 8;
    const uint vec_stride = GEMV_THREADS_PER_ROW * 8;

    while (col + 8 <= n) {
        half4 a_val0 = *reinterpret_cast<device const half4*>(a_row + col);
        half4 a_val1 = *reinterpret_cast<device const half4*>(a_row + col + 4);
        float4 x_val0 = *reinterpret_cast<device const float4*>(x + col);
        float4 x_val1 = *reinterpret_cast<device const float4*>(x + col + 4);

        sum = fma(float(a_val0.x), x_val0.x, sum);
        sum = fma(float(a_val0.y), x_val0.y, sum);
        sum = fma(float(a_val0.z), x_val0.z, sum);
        sum = fma(float(a_val0.w), x_val0.w, sum);
        sum = fma(float(a_val1.x), x_val1.x, sum);
        sum = fma(float(a_val1.y), x_val1.y, sum);
        sum = fma(float(a_val1.z), x_val1.z, sum);
        sum = fma(float(a_val1.w), x_val1.w, sum);

        col += vec_stride;
    }

    // Remaining elements
    for (uint c = (n / 8) * 8 + lane; c < n; c += GEMV_THREADS_PER_ROW) {
        sum = fma(float(a_row[c]), x[c], sum);
    }

    // Simdgroup reduction
    float row_sum = simd_sum(sum);

    if (lane == 0) {
        if (params.beta != 0.0f) {
            y[row] = params.alpha * row_sum + params.beta * y[row];
        } else {
            y[row] = params.alpha * row_sum;
        }
    }
}