moeflux 0.1.0-pre.3

Pure-Rust streaming-experts MoE inference on Metal. Forked from flash-moe; only the Metal kernels remain from upstream.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
//! GPU MLA per-layer attention forward — Phase 4 of the GPU MLA port.
//!
//! [`mla_attn_layer_forward_gpu`] mirrors
//! [`crate::riir::attn::mla_attn_cpu::mla_attn_layer_forward_cpu`] step-for-step
//! but runs the heavy compute on Metal:
//!
//! 1. Q chain: `q_a_proj` → `q_a_layernorm` → `q_b_proj` (3× matvec
//!    + per-head rmsnorm).
//! 2. KV chain: `kv_a_proj_with_mqa` → `kv_a_layernorm` (matvec +
//!    rmsnorm on the latent half).
//! 3. YaRN RoPE on `q_pe` and `k_pe` halves.
//! 4. Append `(kv_lat, k_pe)` to the per-layer
//!    [`crate::riir::snapshot::state::MlaKvCacheGpu`].
//! 5. Folded SDPA via the Phase 3 kernels (`q_prime`, `mla_sdpa_folded`,
//!    `mla_out_per_head_4bit`).
//! 6. `o_proj` final matvec.
//!
//! The output is the post-`o_proj` hidden state in `out_buf` (shared
//! storage; caller reads host-side or chains the buffer into the next
//! kernel). On entry the buffer set already has
//! `pre_norm_buf` populated with the rms-normed input — the caller
//! owns the pre-attn norm so the same machinery can drive both
//! GPU-MLA and the CPU diff path without two norm calls.
//!
//! ## Hybrid mode (this slice)
//!
//! The MoE / dense MLP block stays CPU-side for first run — see
//! [`crate::riir::step_internal_mla_gpu`]. That bounces post-MLA hidden
//! states through host every layer, but the bounce is cheap (`hidden_dim
//! = 7168` floats) compared to the projection / SDPA cost we just moved
//! to GPU. Full-GPU MoE integration with the deferred ring is a
//! follow-up perf slice.

use metal::{
    Buffer, Device, MTLResourceOptions, MTLSize, NSUInteger,
};

use crate::riir::backend::gpu::gpu_matvec::{
    encode_matvec, MatvecPipelines, MatvecSpec,
};
use crate::riir::attn::gpu_mla::{
    encode_mla_kv_cache_append, encode_mla_out_per_head_4bit,
    encode_mla_q_prime_4bit, encode_mla_sdpa_folded,
    encode_mla_split_q_kv, GpuMlaError, MlaPipelines,
};
use crate::riir::backend::gpu::gpu_norm::{
    encode_rms_norm_bf16_into, RmsNormBf16Pipelines,
};
use crate::riir::attn::gpu_rope::encode_yarn_rope_apply;
use crate::riir::backend::gpu::metal::{MetalContext, MetalError};
use crate::riir::io::mtl_weight_buf::MtlWeightBuf;
use crate::riir::snapshot::state::MlaKvCacheGpu;
use crate::riir::variants::{RMS_NORM_EPS, VARIANT};
use crate::riir::io::weight_file::WeightFile;

/// Per-token GPU scratch for the MLA forward. One set is reused across
/// every layer — attention is sequential per token, so layer N+1 doesn't
/// need its own copy. Total ~3 MB on Cogito-V2 (the per-head q_prime
/// and v_combine dominate).
pub struct MlaForwardBuffers {
    /// Q chain.
    pub q_lat: Buffer, // [q_lora_rank]
    pub q_full: Buffer, // [num_heads, qk_head_dim]
    pub q_nope: Buffer, // [num_heads, qk_nope_head_dim]   (packed view for mla_q_prime)
    pub q_pe: Buffer,   // [num_heads, qk_rope_head_dim]
    /// KV chain.
    pub kv_pre: Buffer, // [kv_lora_rank + qk_rope_head_dim]
    pub kv_lat: Buffer, // [kv_lora_rank]   (post-norm, also written to MlaKvCacheGpu.latent)
    pub k_pe: Buffer, // [qk_rope_head_dim] (post-RoPE, also written to MlaKvCacheGpu.rope_k)
    /// Folded MLA scratch.
    pub q_prime: Buffer, // [num_heads, kv_lora_rank]
    pub v_combine: Buffer, // [num_heads, kv_lora_rank]
    pub out_per_head: Buffer, // [num_heads, v_head_dim]
    /// Pre/post forward I/O.
    pub pre_norm: Buffer, // [hidden_dim]   (caller writes the rms-normed hidden here)
    pub out: Buffer,    // [hidden_dim]   (post-o_proj output)
    /// Sum-sq scratch for the q_a_layernorm / kv_a_layernorm rms norms.
    pub q_a_sum_sq: Buffer, // [1]
    pub kv_lat_sum_sq: Buffer, // [1]
}

impl MlaForwardBuffers {
    pub fn new(device: &Device) -> Self {
        let v = VARIANT;
        let f32_buf = |n: usize| {
            let b = device.new_buffer(
                (n * std::mem::size_of::<f32>()) as NSUInteger,
                MTLResourceOptions::StorageModeShared,
            );
            // SAFETY: shared storage, no GPU work in flight on a
            // freshly allocated buffer.
            unsafe {
                std::ptr::write_bytes(
                    b.contents() as *mut u8,
                    0,
                    n * std::mem::size_of::<f32>(),
                );
            }
            b
        };
        let qk_head_dim = v.qk_nope_head_dim + v.qk_rope_head_dim;
        Self {
            q_lat: f32_buf(v.q_lora_rank),
            q_full: f32_buf(v.num_attn_heads * qk_head_dim),
            q_nope: f32_buf(v.num_attn_heads * v.qk_nope_head_dim),
            q_pe: f32_buf(v.num_attn_heads * v.qk_rope_head_dim),
            kv_pre: f32_buf(v.kv_lora_rank + v.qk_rope_head_dim),
            kv_lat: f32_buf(v.kv_lora_rank),
            k_pe: f32_buf(v.qk_rope_head_dim),
            q_prime: f32_buf(v.num_attn_heads * v.kv_lora_rank),
            v_combine: f32_buf(v.num_attn_heads * v.kv_lora_rank),
            out_per_head: f32_buf(v.num_attn_heads * v.v_head_dim),
            pre_norm: f32_buf(v.hidden_dim),
            out: f32_buf(v.hidden_dim),
            q_a_sum_sq: f32_buf(1),
            kv_lat_sum_sq: f32_buf(1),
        }
    }
}

/// Lazily-built per-`RsCtx` YaRN tables for MLA. `inv_freq` lives in
/// shared-storage Metal memory so the GPU YaRN kernel reads it as a
/// constant buffer; `mscale` is captured at build time and passed as
/// a scalar `set_bytes` argument.
pub struct MlaYarnTables {
    pub inv_freq: Buffer,
    pub mscale: f32,
}

impl MlaYarnTables {
    pub fn new(device: &Device) -> Self {
        use crate::riir::attn::rope::{compute_yarn_inv_freq, yarn_get_mscale_full};
        use crate::riir::variants::ROPE_THETA;
        let v = VARIANT;
        let inv_freq = compute_yarn_inv_freq(
            v.qk_rope_head_dim,
            ROPE_THETA,
            v.yarn_factor,
            v.yarn_original_max_pos as f32,
            v.yarn_beta_fast,
            v.yarn_beta_slow,
        );
        let mscale = yarn_get_mscale_full(
            v.yarn_factor,
            v.yarn_mscale,
            v.yarn_mscale_all_dim,
        );
        let buf = device.new_buffer_with_data(
            inv_freq.as_ptr().cast(),
            (inv_freq.len() * std::mem::size_of::<f32>()) as NSUInteger,
            MTLResourceOptions::StorageModeShared,
        );
        Self {
            inv_freq: buf,
            mscale,
        }
    }
}

/// Pipelines pre-fetched for the MLA forward. Ownership / lifetime
/// matches `MoeBuffers` etc. — built once at engine init.
pub struct MlaForwardPipelines {
    pub mla: MlaPipelines,
    pub matvec: MatvecPipelines,
    pub norms: RmsNormBf16Pipelines,
    pub yarn_rope: metal::ComputePipelineState,
}

impl MlaForwardPipelines {
    pub fn new(metal: &mut MetalContext) -> Result<Self, MetalError> {
        Ok(Self {
            mla: MlaPipelines::fetch(metal)?,
            matvec: MatvecPipelines::fetch(metal)?,
            norms: RmsNormBf16Pipelines::fetch(metal)?,
            yarn_rope: metal.pipeline("yarn_rope_apply")?.clone(),
        })
    }
}

/// Errors produced by [`mla_attn_layer_forward_gpu`].
#[derive(Debug, thiserror::Error)]
pub enum MlaForwardGpuError {
    #[error("MLA only valid on MLA variants (this build's attn_kind is {kind:?})")]
    NotMlaVariant { kind: crate::riir::variants::AttnKind },
    #[error("kv_cache.len {len} would exceed MAX_SEQ_LEN={max} after append")]
    CacheFull { len: i32, max: usize },
    #[error("pos {pos} != kv_cache.len {cache_len} (single-step decode)")]
    PosMismatch { pos: i32, cache_len: i32 },
    #[error("kv_cache buffers not allocated (call ensure_buffers first)")]
    CacheNotReady,
    #[error("Metal weight tensor: {name}")]
    MissingTensor { name: String },
    #[error("Metal: {0}")]
    Metal(#[from] MetalError),
    #[error("MLA dispatch: {0}")]
    Mla(#[from] GpuMlaError),
}

/// Per-layer GPU MLA forward. Synchronous: encodes all dispatches into
/// `cmdbuf`, the caller commits + waits.
///
/// Reads `pre_norm` (shared storage, `hidden_dim` floats) — the
/// rms-normed input. Writes `out` (shared storage, `hidden_dim`
/// floats) — the post-`o_proj` MLA contribution to the residual
/// stream. Caller does the residual add.
///
/// `kv_cache.len` is bumped by 1 on success.
#[allow(clippy::too_many_arguments)]
pub fn mla_attn_layer_forward_gpu(
    metal: &mut MetalContext,
    pipes: &MlaForwardPipelines,
    wf: &WeightFile,
    wf_buf: &MtlWeightBuf,
    yarn: &MlaYarnTables,
    bufs: &mut MlaForwardBuffers,
    kv_cache: &mut MlaKvCacheGpu,
    layer_idx: usize,
    pos: i32,
) -> Result<(), MlaForwardGpuError> {
    use crate::riir::variants::AttnKind;

    if VARIANT.attn_kind != AttnKind::Mla {
        return Err(MlaForwardGpuError::NotMlaVariant {
            kind: VARIANT.attn_kind,
        });
    }
    if pos != kv_cache.len {
        return Err(MlaForwardGpuError::PosMismatch {
            pos,
            cache_len: kv_cache.len,
        });
    }
    if (kv_cache.len as usize) >= crate::riir::variants::MAX_SEQ_LEN {
        return Err(MlaForwardGpuError::CacheFull {
            len: kv_cache.len,
            max: crate::riir::variants::MAX_SEQ_LEN,
        });
    }
    let latent_buf =
        kv_cache.latent_cache.as_ref().ok_or(MlaForwardGpuError::CacheNotReady)?;
    let rope_k_buf =
        kv_cache.rope_k_cache.as_ref().ok_or(MlaForwardGpuError::CacheNotReady)?;

    let v = VARIANT;
    let hidden_dim = v.hidden_dim as u32;
    let q_lora_rank = v.q_lora_rank as u32;
    let kv_lora_rank = v.kv_lora_rank as u32;
    let nope = v.qk_nope_head_dim as u32;
    let rope_dim = v.qk_rope_head_dim as u32;
    let qk_head_dim = nope + rope_dim;
    let v_head_dim = v.v_head_dim as u32;
    let num_heads = v.num_attn_heads as u32;
    let kv_b_per_head = nope + v_head_dim;

    // Helpers: resolve a 4-bit projection's `(weight, scales, biases)`
    // byte offsets in the shared weight buffer.
    let resolve_proj = |name: &str| -> Result<(u64, u64, u64), MlaForwardGpuError> {
        let w = format!("{name}.weight");
        let s = format!("{name}.scales");
        let b = format!("{name}.biases");
        let w_off = wf_buf
            .tensor_offset(wf, &w)
            .map_err(|_| MlaForwardGpuError::MissingTensor { name: w.clone() })?
            .ok_or(MlaForwardGpuError::MissingTensor { name: w })?;
        let s_off = wf_buf
            .tensor_offset(wf, &s)
            .map_err(|_| MlaForwardGpuError::MissingTensor { name: s.clone() })?
            .ok_or(MlaForwardGpuError::MissingTensor { name: s })?;
        let b_off = wf_buf
            .tensor_offset(wf, &b)
            .map_err(|_| MlaForwardGpuError::MissingTensor { name: b.clone() })?
            .ok_or(MlaForwardGpuError::MissingTensor { name: b })?;
        Ok((w_off, s_off, b_off))
    };
    let resolve_norm = |name: &str| -> Result<u64, MlaForwardGpuError> {
        let n = format!("{name}.weight");
        wf_buf
            .tensor_offset(wf, &n)
            .map_err(|_| MlaForwardGpuError::MissingTensor { name: n.clone() })?
            .ok_or(MlaForwardGpuError::MissingTensor { name: n })
    };

    let layer_prefix = format!("model.layers.{layer_idx}.self_attn");
    let q_a_off = resolve_proj(&format!("{layer_prefix}.q_a_proj"))?;
    let q_a_norm_off = resolve_norm(&format!("{layer_prefix}.q_a_layernorm"))?;
    let q_b_off = resolve_proj(&format!("{layer_prefix}.q_b_proj"))?;
    let kv_a_off = resolve_proj(&format!("{layer_prefix}.kv_a_proj_with_mqa"))?;
    let kv_a_norm_off = resolve_norm(&format!("{layer_prefix}.kv_a_layernorm"))?;
    let kv_b_off = resolve_proj(&format!("{layer_prefix}.kv_b_proj"))?;
    let o_off = resolve_proj(&format!("{layer_prefix}.o_proj"))?;

    // Pre-fetch pipelines as locals so the encode passes don't need
    // to borrow `metal` once we start the cmdbuf.
    let pipe_qprime = pipes.mla.q_prime.clone();
    let pipe_sdpa = pipes.mla.sdpa.clone();
    let pipe_outhead = pipes.mla.out_per_head.clone();
    let pipe_split = pipes.mla.split_q_kv.clone();
    let pipe_cache_append = pipes.mla.cache_append.clone();
    let pipe_yarn = pipes.yarn_rope.clone();

    // Phase 4a — single command buffer for the entire MLA forward.
    // The three host-side scatters that used to require commit+wait
    // sync points are replaced by `mla_split_q_kv` and
    // `mla_kv_cache_append` Metal kernels. Within one cmdbuf the
    // queue serializes encoder-internal dispatches, so all
    // intra-buffer data dependencies (matvec → split → norm → RoPE →
    // append → SDPA → o_proj) are honored without explicit barriers.
    let queue = metal.queue();
    let cmdbuf = queue.new_command_buffer();

    // ---- Q chain (pre-norm → q_lat → q_full) ----
    encode_matvec(
        cmdbuf,
        &pipes.matvec,
        wf_buf,
        &MatvecSpec {
            w_off: q_a_off.0,
            s_off: q_a_off.1,
            b_off: q_a_off.2,
            input: &bufs.pre_norm,
            output: &bufs.q_lat,
            out_dim: q_lora_rank,
            in_dim: hidden_dim,
            bits: 4,
        },
    );
    encode_rms_norm_bf16_into(
        cmdbuf,
        &pipes.norms,
        &bufs.q_lat,
        wf_buf.buffer(),
        q_a_norm_off,
        &bufs.q_a_sum_sq,
        &bufs.q_lat,
        q_lora_rank,
        RMS_NORM_EPS,
    );
    encode_matvec(
        cmdbuf,
        &pipes.matvec,
        wf_buf,
        &MatvecSpec {
            w_off: q_b_off.0,
            s_off: q_b_off.1,
            b_off: q_b_off.2,
            input: &bufs.q_lat,
            output: &bufs.q_full,
            out_dim: num_heads * qk_head_dim,
            in_dim: q_lora_rank,
            bits: 4,
        },
    );

    // ---- KV-A chain (pre-norm → kv_pre) ----
    encode_matvec(
        cmdbuf,
        &pipes.matvec,
        wf_buf,
        &MatvecSpec {
            w_off: kv_a_off.0,
            s_off: kv_a_off.1,
            b_off: kv_a_off.2,
            input: &bufs.pre_norm,
            output: &bufs.kv_pre,
            out_dim: kv_lora_rank + rope_dim,
            in_dim: hidden_dim,
            bits: 4,
        },
    );

    // ---- Fan-out scatter: q_full → (q_nope, q_pe); kv_pre → (kv_lat, k_pe) ----
    // Replaces sync point #1 of the pre-Phase-4a forward.
    encode_mla_split_q_kv(
        cmdbuf,
        &pipe_split,
        &bufs.q_full,
        &bufs.kv_pre,
        &bufs.q_nope,
        &bufs.q_pe,
        &bufs.kv_lat,
        &bufs.k_pe,
        num_heads,
        nope,
        rope_dim,
        kv_lora_rank,
    );

    // ---- kv_lat = rms_norm(kv_lat) (kv_a_layernorm) ----
    encode_rms_norm_bf16_into(
        cmdbuf,
        &pipes.norms,
        &bufs.kv_lat,
        wf_buf.buffer(),
        kv_a_norm_off,
        &bufs.kv_lat_sum_sq,
        &bufs.kv_lat,
        kv_lora_rank,
        RMS_NORM_EPS,
    );

    // ---- YaRN RoPE on q_pe and k_pe (in-place rotation) ----
    encode_yarn_rope_apply(
        cmdbuf,
        &pipe_yarn,
        &bufs.q_pe,
        &yarn.inv_freq,
        num_heads,
        rope_dim,
        pos,
        yarn.mscale,
    )
    .map_err(|_| MlaForwardGpuError::Metal(MetalError::NoDevice))?;
    encode_yarn_rope_apply(
        cmdbuf,
        &pipe_yarn,
        &bufs.k_pe,
        &yarn.inv_freq,
        1, // shared k_pe, broadcast-style
        rope_dim,
        pos,
        yarn.mscale,
    )
    .map_err(|_| MlaForwardGpuError::Metal(MetalError::NoDevice))?;

    // ---- Cache append: write (kv_lat, k_pe) into row[pos] ----
    // Replaces sync point #2. The kernel sees a kernel-correct write
    // ordering (after RoPE and norm), so subsequent SDPA reads the
    // freshly-appended row through the cache's shared-storage buffer.
    encode_mla_kv_cache_append(
        cmdbuf,
        &pipe_cache_append,
        &bufs.kv_lat,
        &bufs.k_pe,
        latent_buf,
        rope_k_buf,
        kv_lora_rank,
        rope_dim,
        pos,
    );

    let cache_len = (pos + 1) as u32;

    // ---- Folded SDPA chain (q_prime → SDPA → out_per_head → o_proj) ----
    encode_mla_q_prime_4bit(
        cmdbuf,
        &pipe_qprime,
        wf_buf.buffer(),
        kv_b_off.0,
        wf_buf.buffer(),
        kv_b_off.1,
        wf_buf.buffer(),
        kv_b_off.2,
        &bufs.q_nope,
        &bufs.q_prime,
        num_heads,
        nope,
        kv_lora_rank,
        kv_b_per_head,
        64, // group_size
    );

    let softmax_scale =
        (1.0 / (qk_head_dim as f32).sqrt()) * yarn.mscale * yarn.mscale;
    encode_mla_sdpa_folded(
        cmdbuf,
        &pipe_sdpa,
        &bufs.q_prime,
        &bufs.q_pe,
        latent_buf,
        rope_k_buf,
        &bufs.v_combine,
        num_heads,
        kv_lora_rank,
        rope_dim,
        cache_len,
        softmax_scale,
    )?;

    encode_mla_out_per_head_4bit(
        cmdbuf,
        &pipe_outhead,
        wf_buf.buffer(),
        kv_b_off.0,
        wf_buf.buffer(),
        kv_b_off.1,
        wf_buf.buffer(),
        kv_b_off.2,
        &bufs.v_combine,
        &bufs.out_per_head,
        num_heads,
        nope,
        kv_lora_rank,
        v_head_dim,
        kv_b_per_head,
        64,
    );

    encode_matvec(
        cmdbuf,
        &pipes.matvec,
        wf_buf,
        &MatvecSpec {
            w_off: o_off.0,
            s_off: o_off.1,
            b_off: o_off.2,
            input: &bufs.out_per_head,
            output: &bufs.out,
            out_dim: hidden_dim,
            in_dim: num_heads * v_head_dim,
            bits: 4,
        },
    );

    cmdbuf.commit();
    cmdbuf.wait_until_completed();
    // GPU-side cache row is now populated; bump the Rust-side bookkeeping.
    kv_cache.len = pos + 1;
    Ok(())
}

/// Small helper: dispatch a single threadgroup of `size` threads — the
/// 1-thread-1-output pattern most of our oneshot encodes use. Exposed
/// for symmetry with the existing dispatchers.
#[allow(dead_code)]
fn dispatch_1d(
    enc: &metal::ComputeCommandEncoderRef,
    threadgroups: u64,
    threads: u64,
) {
    enc.dispatch_thread_groups(
        MTLSize::new(threadgroups, 1, 1),
        MTLSize::new(threads, 1, 1),
    );
}