ruvllm 2.1.0

LLM serving runtime with Ruvector integration - Paged attention, KV cache, and SONA learning
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
//
// Rotary Position Embeddings (RoPE) - Metal Compute Shader
// Optimized for Apple Silicon M4 Pro
//
// Applies rotary embeddings to query and key tensors for position encoding.
// Used in LLaMA, Mistral, and other modern transformer architectures.
//
// Optimizations:
// - Precomputed sin/cos tables in constant memory
// - Batch processing of multiple positions
// - Vectorized memory access (float2/half2)
// - Fused rotation operations
//

#include <metal_stdlib>
using namespace metal;

// Constants for precomputed tables
constant uint MAX_SEQ_LEN = 8192;    // Maximum sequence length
constant uint MAX_HEAD_DIM = 128;    // Maximum head dimension
constant uint SIMD_SIZE = 32;

// RoPE parameters structure (matches Rust RopeParams)
struct RopeParams {
    uint head_dim;      // Head dimension (must be even)
    uint num_heads;     // Number of heads
    uint position;      // Current position
    float theta_base;   // Base for frequency calculation (default 10000)
};

// Extended RoPE parameters for batch processing
struct RopeBatchParams {
    uint head_dim;
    uint num_heads;
    uint seq_len;
    uint batch_size;
    float theta_base;
    uint _padding[3];
};

// =============================================================================
// Apply RoPE with precomputed sin/cos tables (fastest version)
// Tables should be precomputed once per model load
// =============================================================================
kernel void apply_rope_precomputed(
    device float* x [[buffer(0)]],
    constant float* cos_table [[buffer(1)]],  // [position, head_dim/2] in constant memory
    constant float* sin_table [[buffer(2)]],  // [position, head_dim/2] in constant memory
    constant RopeParams& params [[buffer(3)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;           // Position in head dimension (0 to head_dim/2 - 1)
    uint head = gid.y;        // Head index
    uint batch = gid.z;       // Batch index

    uint head_dim = params.head_dim;
    uint num_heads = params.num_heads;

    if (d >= head_dim / 2 || head >= num_heads) {
        return;
    }

    // Offset into the tensor
    uint offset = (batch * num_heads + head) * head_dim;

    // Load pair of values
    float x0 = x[offset + 2 * d];
    float x1 = x[offset + 2 * d + 1];

    // Get precomputed cos/sin from constant memory
    uint table_offset = params.position * (head_dim / 2) + d;
    float cos_val = cos_table[table_offset];
    float sin_val = sin_table[table_offset];

    // Apply rotation with fused multiply-add
    x[offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// Vectorized RoPE using float2 for paired elements
// =============================================================================
kernel void apply_rope_vec2(
    device float2* x [[buffer(0)]],  // Reinterpret as pairs
    constant float* cos_table [[buffer(1)]],
    constant float* sin_table [[buffer(2)]],
    constant RopeParams& params [[buffer(3)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;           // Pair index (0 to head_dim/2 - 1)
    uint head = gid.y;
    uint batch = gid.z;

    uint head_dim = params.head_dim;
    uint num_heads = params.num_heads;

    if (d >= head_dim / 2 || head >= num_heads) {
        return;
    }

    // Offset for float2 access
    uint offset = (batch * num_heads + head) * (head_dim / 2) + d;

    float2 val = x[offset];

    uint table_offset = params.position * (head_dim / 2) + d;
    float cos_val = cos_table[table_offset];
    float sin_val = sin_table[table_offset];

    // Apply rotation: [x0*cos - x1*sin, x0*sin + x1*cos]
    float2 rotated;
    rotated.x = fma(val.x, cos_val, -val.y * sin_val);
    rotated.y = fma(val.x, sin_val, val.y * cos_val);

    x[offset] = rotated;
}

// =============================================================================
// Original RoPE with precomputed tables (kept for compatibility)
// =============================================================================
kernel void apply_rope(
    device float* x [[buffer(0)]],
    device const float* cos_table [[buffer(1)]],
    device const float* sin_table [[buffer(2)]],
    constant RopeParams& params [[buffer(3)]],
    uint3 gid [[thread_position_in_grid]],
    uint tid [[thread_position_in_threadgroup]]
) {
    uint d = gid.x;
    uint head = gid.y;
    uint batch = gid.z;

    uint head_dim = params.head_dim;
    uint num_heads = params.num_heads;

    if (d >= head_dim / 2) {
        return;
    }

    uint offset = (batch * num_heads + head) * head_dim;

    float x0 = x[offset + 2 * d];
    float x1 = x[offset + 2 * d + 1];

    float cos_val = cos_table[d];
    float sin_val = sin_table[d];

    x[offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// RoPE with inline frequency computation (no precomputed tables)
// Useful when tables aren't available or for dynamic positions
// =============================================================================
kernel void apply_rope_inline(
    device float* x [[buffer(0)]],
    constant RopeParams& params [[buffer(1)]],
    uint3 gid [[thread_position_in_grid]],
    uint tid [[thread_position_in_threadgroup]]
) {
    uint d = gid.x;
    uint head = gid.y;
    uint batch = gid.z;

    uint head_dim = params.head_dim;
    uint num_heads = params.num_heads;
    uint position = params.position;
    float theta_base = params.theta_base;

    if (d >= head_dim / 2) {
        return;
    }

    uint offset = (batch * num_heads + head) * head_dim;

    // Compute frequency for this dimension
    // freq = 1 / (theta_base ^ (2d / head_dim))
    float freq = 1.0f / pow(theta_base, float(2 * d) / float(head_dim));
    float angle = float(position) * freq;

    // Use fast sin/cos
    float cos_val = cos(angle);
    float sin_val = sin(angle);

    float x0 = x[offset + 2 * d];
    float x1 = x[offset + 2 * d + 1];

    x[offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// Batched RoPE for multiple positions (efficient for prefill)
// Processes entire sequences in parallel with precomputed tables
// =============================================================================
kernel void apply_rope_batched_v2(
    device float* x [[buffer(0)]],                // [batch, seq_len, num_heads, head_dim]
    constant float* cos_table [[buffer(1)]],      // [max_seq_len, head_dim/2]
    constant float* sin_table [[buffer(2)]],      // [max_seq_len, head_dim/2]
    device const uint* positions [[buffer(3)]],   // [batch, seq_len] position indices
    constant RopeBatchParams& params [[buffer(4)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;                    // Dimension pair index
    uint head = gid.y;                 // Head index
    uint seq_batch = gid.z;            // Combined sequence + batch index

    uint head_dim = params.head_dim;
    uint num_heads = params.num_heads;
    uint seq_len = params.seq_len;

    if (d >= head_dim / 2 || head >= num_heads) {
        return;
    }

    uint batch = seq_batch / seq_len;
    uint seq_pos = seq_batch % seq_len;

    if (batch >= params.batch_size) {
        return;
    }

    // Get position for this token
    uint position = positions[batch * seq_len + seq_pos];

    // Compute offsets
    uint x_offset = ((batch * seq_len + seq_pos) * num_heads + head) * head_dim;
    uint table_offset = position * (head_dim / 2) + d;

    // Load values as pair
    float x0 = x[x_offset + 2 * d];
    float x1 = x[x_offset + 2 * d + 1];

    // Get sin/cos from constant memory
    float cos_val = cos_table[table_offset];
    float sin_val = sin_table[table_offset];

    // Apply rotation
    x[x_offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[x_offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// Original batched RoPE (kept for compatibility)
// =============================================================================
kernel void apply_rope_batched(
    device float* x [[buffer(0)]],
    device const uint* positions [[buffer(1)]],
    constant uint& num_heads [[buffer(2)]],
    constant uint& head_dim [[buffer(3)]],
    constant uint& seq_len [[buffer(4)]],
    constant float& theta_base [[buffer(5)]],
    uint3 gid [[thread_position_in_grid]],
    uint3 group_id [[threadgroup_position_in_grid]]
) {
    uint d = gid.x;
    uint head = gid.y;
    uint seq_batch = gid.z;

    uint batch = seq_batch / seq_len;
    uint seq_pos = seq_batch % seq_len;

    if (d >= head_dim / 2) {
        return;
    }

    uint position = positions[batch * seq_len + seq_pos];
    uint offset = ((batch * seq_len + seq_pos) * num_heads + head) * head_dim;

    float freq = 1.0f / pow(theta_base, float(2 * d) / float(head_dim));
    float angle = float(position) * freq;

    float cos_val = cos(angle);
    float sin_val = sin(angle);

    float x0 = x[offset + 2 * d];
    float x1 = x[offset + 2 * d + 1];

    x[offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// FP16 RoPE with half2 vectorization
// =============================================================================
kernel void apply_rope_f16_v2(
    device half2* x [[buffer(0)]],  // Reinterpret as pairs for vectorized access
    constant half* cos_table [[buffer(1)]],
    constant half* sin_table [[buffer(2)]],
    constant RopeParams& params [[buffer(3)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;
    uint head = gid.y;
    uint batch = gid.z;

    if (d >= params.head_dim / 2) {
        return;
    }

    uint offset = (batch * params.num_heads + head) * (params.head_dim / 2) + d;
    uint table_offset = params.position * (params.head_dim / 2) + d;

    half2 val = x[offset];
    half cos_val = cos_table[table_offset];
    half sin_val = sin_table[table_offset];

    half2 rotated;
    rotated.x = fma(val.x, cos_val, -val.y * sin_val);
    rotated.y = fma(val.x, sin_val, val.y * cos_val);

    x[offset] = rotated;
}

// =============================================================================
// Original FP16 RoPE (kept for compatibility)
// =============================================================================
kernel void apply_rope_f16(
    device half* x [[buffer(0)]],
    device const half* cos_table [[buffer(1)]],
    device const half* sin_table [[buffer(2)]],
    constant RopeParams& params [[buffer(3)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;
    uint head = gid.y;
    uint batch = gid.z;

    if (d >= params.head_dim / 2) {
        return;
    }

    uint offset = (batch * params.num_heads + head) * params.head_dim;

    half x0 = x[offset + 2 * d];
    half x1 = x[offset + 2 * d + 1];

    half cos_val = cos_table[d];
    half sin_val = sin_table[d];

    x[offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// Precompute RoPE cos/sin tables (run once per model load)
// Output can be stored in constant memory for fast access
// =============================================================================
kernel void precompute_rope_tables_v2(
    device float* cos_table [[buffer(0)]],  // [max_seq_len, head_dim/2]
    device float* sin_table [[buffer(1)]],  // [max_seq_len, head_dim/2]
    constant uint& head_dim [[buffer(2)]],
    constant uint& max_seq_len [[buffer(3)]],
    constant float& theta_base [[buffer(4)]],
    uint2 gid [[thread_position_in_grid]],
    uint tid [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]]
) {
    uint pos = gid.y;
    uint d = gid.x;

    if (pos >= max_seq_len || d >= head_dim / 2) {
        return;
    }

    // Compute frequency using reciprocal to avoid repeated division
    float inv_freq = 1.0f / pow(theta_base, float(2 * d) / float(head_dim));
    float angle = float(pos) * inv_freq;

    // Use sincos for efficiency when available
    float s, c;
    s = sin(angle);
    c = cos(angle);

    uint idx = pos * (head_dim / 2) + d;
    cos_table[idx] = c;
    sin_table[idx] = s;
}

// =============================================================================
// Original table precomputation (kept for compatibility)
// =============================================================================
kernel void precompute_rope_tables(
    device float* cos_table [[buffer(0)]],
    device float* sin_table [[buffer(1)]],
    constant uint& head_dim [[buffer(2)]],
    constant uint& max_seq_len [[buffer(3)]],
    constant float& theta_base [[buffer(4)]],
    uint2 gid [[thread_position_in_grid]]
) {
    uint pos = gid.y;
    uint d = gid.x;

    if (pos >= max_seq_len || d >= head_dim / 2) {
        return;
    }

    float freq = 1.0f / pow(theta_base, float(2 * d) / float(head_dim));
    float angle = float(pos) * freq;

    uint idx = pos * (head_dim / 2) + d;
    cos_table[idx] = cos(angle);
    sin_table[idx] = sin(angle);
}

// =============================================================================
// ALiBi (Attention with Linear Biases) - alternative to RoPE
// =============================================================================
kernel void apply_alibi(
    device float* attn_scores [[buffer(0)]],
    constant uint& seq_len [[buffer(1)]],
    constant uint& kv_len [[buffer(2)]],
    constant uint& num_heads [[buffer(3)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint q_pos = gid.y;
    uint k_pos = gid.x;
    uint batch_head = gid.z;

    uint batch = batch_head / num_heads;
    uint head = batch_head % num_heads;

    if (q_pos >= seq_len || k_pos >= kv_len) {
        return;
    }

    // ALiBi slope: 2^(-8*(h+1)/H) where h is head index, H is total heads
    float slope = exp2(-8.0f * float(head + 1) / float(num_heads));
    int pos_diff = int(q_pos) - int(k_pos);
    float bias = slope * float(pos_diff);

    uint idx = ((batch * num_heads + head) * seq_len + q_pos) * kv_len + k_pos;
    attn_scores[idx] += bias;
}

// =============================================================================
// YaRN (Yet another RoPE extension) for extended context
// =============================================================================
struct YaRNParams {
    uint head_dim;
    uint num_heads;
    uint position;
    float theta_base;
    float scale;           // Position scale factor
    float attn_scale;      // Attention scale factor
    float beta_fast;       // High-frequency extrapolation factor
    float beta_slow;       // Low-frequency interpolation factor
    uint original_max_len; // Original training context length
};

kernel void apply_rope_yarn(
    device float* x [[buffer(0)]],
    constant YaRNParams& params [[buffer(1)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;
    uint head = gid.y;
    uint batch = gid.z;

    if (d >= params.head_dim / 2) {
        return;
    }

    uint offset = (batch * params.num_heads + head) * params.head_dim;

    // YaRN frequency scaling
    float freq_base = 1.0f / pow(params.theta_base, float(2 * d) / float(params.head_dim));
    float wavelength = 2.0f * M_PI_F / freq_base;

    // Compute ramp function
    float low = float(params.original_max_len) / params.beta_fast;
    float high = float(params.original_max_len) / params.beta_slow;

    float ramp;
    if (wavelength < low) {
        ramp = 0.0f;  // High frequency: extrapolate
    } else if (wavelength > high) {
        ramp = 1.0f;  // Low frequency: interpolate
    } else {
        ramp = (wavelength - low) / (high - low);
    }

    // Scale frequency with ramp interpolation
    float freq = freq_base * (1.0f - ramp + ramp / params.scale);
    float angle = float(params.position) * freq;

    float cos_val = cos(angle);
    float sin_val = sin(angle);

    float x0 = x[offset + 2 * d];
    float x1 = x[offset + 2 * d + 1];

    x[offset + 2 * d] = fma(x0, cos_val, -x1 * sin_val);
    x[offset + 2 * d + 1] = fma(x0, sin_val, x1 * cos_val);
}

// =============================================================================
// Fused RoPE for Q and K together (common pattern in transformers)
// Applies RoPE to both query and key tensors in a single kernel launch
// =============================================================================
kernel void apply_rope_qk_fused(
    device float* query [[buffer(0)]],           // [batch, seq_len, num_heads, head_dim]
    device float* key [[buffer(1)]],             // [batch, seq_len, num_kv_heads, head_dim]
    constant float* cos_table [[buffer(2)]],     // [max_seq_len, head_dim/2]
    constant float* sin_table [[buffer(3)]],     // [max_seq_len, head_dim/2]
    device const uint* positions [[buffer(4)]],  // [batch, seq_len]
    constant uint& num_q_heads [[buffer(5)]],
    constant uint& num_kv_heads [[buffer(6)]],
    constant uint& head_dim [[buffer(7)]],
    constant uint& seq_len [[buffer(8)]],
    uint3 gid [[thread_position_in_grid]]
) {
    uint d = gid.x;                    // Dimension pair index
    uint head = gid.y;                 // Head index (for Q)
    uint seq_batch = gid.z;            // Combined seq + batch

    if (d >= head_dim / 2) {
        return;
    }

    uint batch = seq_batch / seq_len;
    uint seq_pos = seq_batch % seq_len;
    uint position = positions[batch * seq_len + seq_pos];

    uint table_offset = position * (head_dim / 2) + d;
    float cos_val = cos_table[table_offset];
    float sin_val = sin_table[table_offset];

    // Apply to query
    if (head < num_q_heads) {
        uint q_offset = ((batch * seq_len + seq_pos) * num_q_heads + head) * head_dim;
        float q0 = query[q_offset + 2 * d];
        float q1 = query[q_offset + 2 * d + 1];
        query[q_offset + 2 * d] = fma(q0, cos_val, -q1 * sin_val);
        query[q_offset + 2 * d + 1] = fma(q0, sin_val, q1 * cos_val);
    }

    // Apply to key (handle GQA with fewer KV heads)
    if (head < num_kv_heads) {
        uint k_offset = ((batch * seq_len + seq_pos) * num_kv_heads + head) * head_dim;
        float k0 = key[k_offset + 2 * d];
        float k1 = key[k_offset + 2 * d + 1];
        key[k_offset + 2 * d] = fma(k0, cos_val, -k1 * sin_val);
        key[k_offset + 2 * d + 1] = fma(k0, sin_val, k1 * cos_val);
    }
}