Skip to main content

ripvec_core/backend/driver/
mod.rs

1//! Hardware-agnostic compute driver trait.
2//!
3//! The [`Driver`] trait exposes low-level compute primitives (GEMM, layer-norm,
4//! activations, etc.) that each hardware backend implements. Model architectures
5//! are generic over `D: Driver` and compose these primitives into a forward pass.
6//!
7//! # Design
8//!
9//! - **Associated type `Tensor`**: each driver defines its own opaque tensor
10//!   handle (Metal: buffer+offset, CUDA: device pointer, CPU: ndarray).
11//! - **Not object-safe**: architectures use `D: Driver` generics so the compiler
12//!   can monomorphize and inline driver calls.
13//! - **Send + Sync**: drivers are shared across the pipeline.
14
15#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
16pub mod cpu;
17#[cfg(feature = "cuda")]
18pub mod cuda;
19#[cfg(feature = "metal")]
20pub mod metal;
21#[cfg(feature = "mlx")]
22pub mod mlx;
23
24use super::Encoding;
25
26/// Hardware-agnostic compute primitives for BERT inference.
27///
28/// Each method corresponds to one operation in the forward pass. Drivers handle
29/// memory allocation, kernel dispatch, and synchronization. Architectures
30/// compose these primitives via the [`super::arch::ModelArch`] trait.
31pub trait Driver: Send + Sync {
32    /// Opaque tensor handle.
33    ///
34    /// Metal: `MTLBuffer` + byte offset. CUDA: `CUdeviceptr`. CPU: `Array2<f32>`.
35    type Tensor;
36
37    /// Short human-readable label for diagnostics (e.g. "Metal", "CUDA", "CPU").
38    /// Surfaced via [`super::EmbedBackend::name`].
39    fn name(&self) -> &'static str;
40
41    /// Create a new driver instance for a cloned worker thread.
42    ///
43    /// CPU drivers are zero-size and always succeed. GPU drivers typically
44    /// cannot be cloned this way (they share device state) and should leave
45    /// the default panic implementation.
46    fn new_for_clone() -> crate::Result<Self>
47    where
48        Self: Sized,
49    {
50        Err(crate::Error::Other(anyhow::anyhow!(
51            "this driver does not support cloning"
52        )))
53    }
54
55    // --- Batching ---
56
57    /// Begin batched mode: all subsequent operations encode into one dispatch.
58    ///
59    /// GPU drivers accumulate into a single command buffer; CPU is a no-op.
60    /// Call [`Self::end_batch`] to commit. This eliminates per-call overhead.
61    fn begin_batch(&self) -> crate::Result<()> {
62        Ok(())
63    }
64
65    /// End batched mode: commit all accumulated operations and wait.
66    fn end_batch(&self) -> crate::Result<()> {
67        Ok(())
68    }
69
70    /// Flush the current command buffer and start a new one, preserving pool
71    /// state. Use mid-forward-pass to prevent GPU timeouts on deep models.
72    fn flush_batch(&self) -> crate::Result<()> {
73        Ok(())
74    }
75
76    /// Close and reopen the compute encoder within the same command buffer.
77    ///
78    /// This segments a long sequence of compute dispatches into multiple
79    /// encoders without committing or waiting. Metal processes encoders
80    /// back-to-back from the same CB — zero sync overhead.
81    ///
82    /// Use every few layers to prevent encoder state overflow (>~60 dispatches
83    /// per encoder can cause hangs on some Apple Silicon GPUs).
84    fn segment_encoder(&self) {
85        // No-op for non-Metal backends
86    }
87
88    /// Save the current pool cursor position. Call BEFORE a layer's work.
89    fn save_pool_cursor(&self) -> usize {
90        0
91    }
92
93    /// Restore the pool cursor to a previously saved position. Call AFTER
94    /// a layer's transient tensors have been dropped (out of scope).
95    ///
96    /// The architecture must ensure only the output tensor (`hidden_states`)
97    /// survives — all layer-internal tensors (qkv, scores, context, etc.)
98    /// must be dropped before this call so their pool slots can be recycled.
99    fn restore_pool_cursor(&self, _saved: usize) {}
100
101    // --- Allocation ---
102
103    /// Allocate a zero-initialized tensor with `n` float elements on device.
104    ///
105    /// Used by architectures to create workspace buffers (QKV projections,
106    /// attention scores, intermediate activations, etc.).
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if device memory allocation fails.
111    fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
112
113    /// Clone a tensor, producing an independent copy of the data.
114    ///
115    /// Used when an operation needs both the original and a mutable output
116    /// referencing the same logical data (e.g., in-place layer normalization
117    /// where input == output).
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if device memory allocation or the copy fails.
122    fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
123
124    // --- Batch preparation ---
125
126    /// Prepare a batch of encodings for inference, returning input tensors on device.
127    ///
128    /// Pads all sequences to `max_seq` and uploads `input_ids`, `attention_mask`,
129    /// `token_type_ids`, `position_ids`, and a float attention mask to device memory.
130    fn prepare_batch(
131        &self,
132        encodings: &[Encoding],
133        max_seq: usize,
134    ) -> crate::Result<BatchInputs<Self::Tensor>>;
135
136    /// Prepare a batch WITHOUT padding — concatenate all tokens flat.
137    ///
138    /// Returns `BatchInputs` with `total_tokens` actual tokens (no padding),
139    /// `cu_seqlens` for attention boundaries, and per-token position IDs.
140    /// Linear layers (GEMM, LN, GELU) process `total_tokens` rows.
141    /// Attention must pad/unpad around the per-head operations.
142    fn prepare_batch_unpadded(
143        &self,
144        encodings: &[Encoding],
145    ) -> crate::Result<BatchInputs<Self::Tensor>> {
146        // Default: fall back to padded (backends override for unpadded support)
147        let max_seq = encodings
148            .iter()
149            .map(|e| e.input_ids.len())
150            .max()
151            .unwrap_or(0)
152            .next_multiple_of(8);
153        self.prepare_batch(encodings, max_seq)
154    }
155
156    /// Scatter flat `[total_tokens, dim]` tensor into padded `[batch, max_seq, dim]`.
157    ///
158    /// Used before attention: linear layers produce unpadded output, but the
159    /// QKV split + batched attention GEMM need aligned `[batch*heads, seq, head_dim]`.
160    /// Padding positions are zeroed.
161    fn pad_to_batch(
162        &self,
163        flat: &Self::Tensor,
164        padded: &mut Self::Tensor,
165        seq_lengths: &[usize],
166        max_seq: usize,
167        dim: usize,
168    ) -> crate::Result<()>;
169
170    /// Gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
171    ///
172    /// Used after attention: extracts only the real tokens, discarding padding.
173    fn unpad_from_batch(
174        &self,
175        padded: &Self::Tensor,
176        flat: &mut Self::Tensor,
177        seq_lengths: &[usize],
178        max_seq: usize,
179        dim: usize,
180    ) -> crate::Result<()>;
181
182    // --- Embedding operations ---
183
184    /// Word/position/token-type embedding lookup via gather.
185    ///
186    /// Reads `seq_len` token IDs from `word_ids`, gathers rows from
187    /// `embedding_table`, and writes `[seq_len, hidden]` floats to the result.
188    fn embedding_lookup(
189        &self,
190        word_ids: &Self::Tensor,
191        embedding_table: &Self::Tensor,
192        seq_len: usize,
193        hidden: usize,
194    ) -> crate::Result<Self::Tensor>;
195
196    /// Element-wise add an embedding table lookup into `hidden`.
197    ///
198    /// Used for position and token-type embeddings:
199    /// `hidden[i] += table[ids[i]]` for each token position.
200    fn add_embeddings(
201        &self,
202        hidden: &mut Self::Tensor,
203        table: &Self::Tensor,
204        ids: &Self::Tensor,
205        seq_len: usize,
206        hidden_dim: usize,
207    ) -> crate::Result<()>;
208
209    // --- Normalization ---
210
211    /// Layer normalization: `output = (input - mean) / sqrt(var + eps) * weight + bias`.
212    fn layer_norm(
213        &self,
214        output: &mut Self::Tensor,
215        input: &Self::Tensor,
216        weight: &Self::Tensor,
217        bias: &Self::Tensor,
218        rows: usize,
219        cols: usize,
220        eps: f32,
221    ) -> crate::Result<()>;
222
223    // --- Linear algebra ---
224
225    /// General matrix multiply: `output = A * B` (or `A * B^T` if `transpose_b`).
226    ///
227    /// Dimensions: A is `[m, k]`, B is `[k, n]` (or `[n, k]` if transposed),
228    /// output is `[m, n]`.
229    fn gemm(
230        &self,
231        a: &Self::Tensor,
232        b: &Self::Tensor,
233        output: &mut Self::Tensor,
234        m: usize,
235        n: usize,
236        k: usize,
237        transpose_b: bool,
238    ) -> crate::Result<()>;
239
240    /// Batched GEMM for multi-head attention.
241    ///
242    /// Performs `batch_count` independent GEMMs with strided access into
243    /// contiguous buffers. Used for per-head Q*K^T and attn*V.
244    fn gemm_batched(
245        &self,
246        a: &Self::Tensor,
247        b: &Self::Tensor,
248        output: &mut Self::Tensor,
249        m: usize,
250        n: usize,
251        k: usize,
252        transpose_b: bool,
253        stride_a: usize,
254        stride_b: usize,
255        stride_c: usize,
256        batch_count: usize,
257    ) -> crate::Result<()>;
258
259    // --- Attention ---
260
261    /// Fused scale + mask + softmax for attention scores.
262    ///
263    /// `scores = softmax(scores * scale + mask)` computed per-head.
264    fn fused_scale_mask_softmax(
265        &self,
266        scores: &mut Self::Tensor,
267        mask: &Self::Tensor,
268        batch: usize,
269        num_heads: usize,
270        seq_len: usize,
271        scale: f32,
272    ) -> crate::Result<()>;
273
274    /// Fused scale + mask + sliding window + softmax for attention scores.
275    ///
276    /// Like [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax) but
277    /// additionally masks out positions where `|query_pos - key_pos| > window_size / 2`.
278    /// Used by `ModernBERT`'s local attention layers.
279    fn fused_scale_mask_softmax_windowed(
280        &self,
281        scores: &mut Self::Tensor,
282        mask: &Self::Tensor,
283        batch: usize,
284        num_heads: usize,
285        seq_len: usize,
286        scale: f32,
287        window_size: usize,
288    ) -> crate::Result<()>;
289
290    /// Build a float attention mask from an integer mask.
291    ///
292    /// Converts `[batch * seq]` int mask (0/1) to `[batch * seq]` float mask
293    /// (0.0 / -10000.0) for use with [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax).
294    fn build_attn_mask(
295        &self,
296        output: &mut Self::Tensor,
297        int_mask: &Self::Tensor,
298        n: usize,
299    ) -> crate::Result<()>;
300
301    /// Split a fused QKV projection into separate Q, K, V tensors.
302    fn qkv_split(
303        &self,
304        q: &mut Self::Tensor,
305        k: &mut Self::Tensor,
306        v: &mut Self::Tensor,
307        qkv: &Self::Tensor,
308        batch: usize,
309        seq: usize,
310        hidden: usize,
311        num_heads: usize,
312        head_dim: usize,
313    ) -> crate::Result<()>;
314
315    // --- Banded (local/sliding-window) attention ---
316
317    /// Banded Q@K^T: compute attention scores only within a sliding window.
318    ///
319    /// Output shape: `[batch * num_heads, seq, window]` (NOT `[seq, seq]`).
320    /// `scores[h, i, w]` = dot(Q[h, i, :], K[h, i - window/2 + w, :])
321    /// where out-of-bounds positions are set to `-inf` (masked in softmax).
322    ///
323    /// Reduces attention compute from O(seq²) to O(seq × window).
324    /// For `seq=512, window=128`: **4× less compute** per local layer.
325    fn banded_qk(
326        &self,
327        q: &Self::Tensor,
328        k: &Self::Tensor,
329        scores: &mut Self::Tensor,
330        batch_heads: usize,
331        seq: usize,
332        head_dim: usize,
333        window: usize,
334        stride_qk: usize,
335        stride_scores: usize,
336    ) -> crate::Result<()>;
337
338    /// Banded scores@V: weighted sum using banded attention scores.
339    ///
340    /// Input scores: `[batch * num_heads, seq, window]` (from `banded_qk`).
341    /// Output: `[batch * num_heads, seq, head_dim]`.
342    /// `output[h, i, d]` = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
343    fn banded_sv(
344        &self,
345        scores: &Self::Tensor,
346        v: &Self::Tensor,
347        output: &mut Self::Tensor,
348        batch_heads: usize,
349        seq: usize,
350        head_dim: usize,
351        window: usize,
352        stride_scores: usize,
353        stride_v: usize,
354        stride_out: usize,
355    ) -> crate::Result<()>;
356
357    /// Fused scale + softmax over the window dimension (no padding mask needed).
358    ///
359    /// Operates on `[batch * num_heads * seq, window]` rows.
360    fn banded_softmax(
361        &self,
362        scores: &mut Self::Tensor,
363        total_rows: usize,
364        window: usize,
365        scale: f32,
366    ) -> crate::Result<()>;
367
368    /// Reshape attention output from `[batch, num_heads, seq, head_dim]` to
369    /// `[batch * seq, hidden]`.
370    fn attn_reshape(
371        &self,
372        output: &mut Self::Tensor,
373        input: &Self::Tensor,
374        batch: usize,
375        seq: usize,
376        num_heads: usize,
377        head_dim: usize,
378    ) -> crate::Result<()>;
379
380    /// Apply Rotary Position Embedding (RoPE) to Q/K tensors.
381    ///
382    /// Used by ModernBERT (not ClassicBert which uses learned position embeddings).
383    fn apply_rope(
384        &self,
385        qk: &mut Self::Tensor,
386        cos: &Self::Tensor,
387        sin: &Self::Tensor,
388        num_rows: usize,
389        seq_len: usize,
390        head_dim: usize,
391        num_heads: usize,
392    ) -> crate::Result<()>;
393
394    // --- Tensor manipulation ---
395
396    /// Split a `[rows, 2*cols]` matrix into two `[rows, cols]` halves.
397    ///
398    /// Each row of `input` is `[first_half | second_half]`. The first `cols`
399    /// elements go to `first`, the remaining `cols` to `second`.
400    /// Used by `ModernBERT` for gated MLP splits.
401    fn split_gate_value(
402        &self,
403        first: &mut Self::Tensor,
404        second: &mut Self::Tensor,
405        input: &Self::Tensor,
406        rows: usize,
407        cols: usize,
408    ) -> crate::Result<()>;
409
410    // --- Activations ---
411
412    /// GELU activation (Gaussian Error Linear Unit), applied in-place.
413    fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
414
415    /// SwiGLU gated activation: `output = value * silu(gate)`.
416    ///
417    /// The gate and value come from splitting the intermediate projection.
418    fn swiglu(
419        &self,
420        value: &Self::Tensor,
421        gate: &Self::Tensor,
422        output: &mut Self::Tensor,
423        n: usize,
424    ) -> crate::Result<()>;
425
426    /// `GeGLU` gated activation: `output = gelu(value) * gate`.
427    ///
428    /// Used by `ModernBERT`. The value and gate come from splitting the
429    /// MLP `Wi` projection output in half.
430    fn geglu(
431        &self,
432        value: &Self::Tensor,
433        gate: &Self::Tensor,
434        output: &mut Self::Tensor,
435        n: usize,
436    ) -> crate::Result<()>;
437
438    /// Fused bias + GELU: `x = gelu(x + bias)` row-wise.
439    fn fused_bias_gelu(
440        &self,
441        x: &mut Self::Tensor,
442        bias: &Self::Tensor,
443        rows: usize,
444        cols: usize,
445    ) -> crate::Result<()>;
446
447    // --- Fused residual operations ---
448
449    /// Fused bias + residual add: `output = input + bias + residual`.
450    ///
451    /// Bias is broadcast row-wise (`cols`-wide) across `n / cols` rows.
452    fn fused_bias_residual(
453        &self,
454        output: &mut Self::Tensor,
455        input: &Self::Tensor,
456        bias: &Self::Tensor,
457        residual: &Self::Tensor,
458        n: usize,
459        cols: usize,
460    ) -> crate::Result<()>;
461
462    /// Fused residual add + layer normalization.
463    ///
464    /// `output = layer_norm(hidden + residual, weight, bias, eps)`.
465    fn fused_residual_layernorm(
466        &self,
467        output: &mut Self::Tensor,
468        hidden: &Self::Tensor,
469        residual: &Self::Tensor,
470        weight: &Self::Tensor,
471        bias: &Self::Tensor,
472        rows: usize,
473        cols: usize,
474        eps: f32,
475    ) -> crate::Result<()>;
476
477    /// Residual add without bias: `output = hidden + residual`.
478    ///
479    /// Used by `ModernBERT` which has no bias terms.
480    fn residual_add(
481        &self,
482        output: &mut Self::Tensor,
483        hidden: &Self::Tensor,
484        residual: &Self::Tensor,
485        n: usize,
486    ) -> crate::Result<()>;
487
488    /// Add bias to a matrix row-wise: `x[row] += bias` for each row.
489    fn add_bias(
490        &self,
491        x: &mut Self::Tensor,
492        bias: &Self::Tensor,
493        rows: usize,
494        cols: usize,
495    ) -> crate::Result<()>;
496
497    // --- Pooling ---
498
499    /// CLS pooling: extract the first token's hidden state per batch element.
500    fn cls_pool(
501        &self,
502        output: &mut Self::Tensor,
503        hidden: &Self::Tensor,
504        batch: usize,
505        seq: usize,
506        hidden_dim: usize,
507    ) -> crate::Result<()>;
508
509    /// Mean pooling: attention-mask-weighted average of hidden states.
510    fn mean_pool(
511        &self,
512        output: &mut Self::Tensor,
513        hidden: &Self::Tensor,
514        mask: &Self::Tensor,
515        batch: usize,
516        seq: usize,
517        hidden_dim: usize,
518    ) -> crate::Result<()>;
519
520    // --- Post-processing ---
521
522    /// L2-normalize each row vector in-place.
523    fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
524
525    /// Copy tensor data back to host memory as `Vec<Vec<f32>>`.
526    ///
527    /// Returns one `Vec<f32>` of length `dim` per batch element.
528    fn to_host(
529        &self,
530        tensor: &Self::Tensor,
531        batch: usize,
532        dim: usize,
533    ) -> crate::Result<Vec<Vec<f32>>>;
534
535    /// Optional finite-value diagnostic hook for backend tensors.
536    ///
537    /// Backends should keep this cheap or disabled by default. The CUDA driver
538    /// enables full tensor readback only with `RIPVEC_CUDA_DEBUG_TENSORS=1`.
539    fn debug_tensor(
540        &self,
541        _label: &str,
542        _tensor: &Self::Tensor,
543        _rows: usize,
544        _cols: usize,
545    ) -> crate::Result<()> {
546        Ok(())
547    }
548
549    /// Whether calls to [`Driver::debug_tensor`] will inspect tensor contents.
550    ///
551    /// Architecture code uses this to avoid allocating and converting probe
552    /// tensors when diagnostics are disabled.
553    fn debug_tensors_enabled(&self) -> bool {
554        false
555    }
556
557    // =======================================================================
558    // FP16 operations for full half-precision pipeline
559    //
560    // These methods mirror the FP32 counterparts but operate on FP16 tensors.
561    // Internal reductions (softmax, layer-norm) use FP32 accumulators but
562    // all tensor I/O is half precision. Default implementations return an
563    // error — only backends with FP16 support override them.
564    // =======================================================================
565
566    /// Allocate a zero-initialized FP16 tensor with `n` half-precision elements.
567    ///
568    /// # Errors
569    ///
570    /// Returns an error if device memory allocation fails or FP16 is unsupported.
571    fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
572        Err(crate::Error::Metal(
573            "FP16 not supported by this driver".into(),
574        ))
575    }
576
577    /// Convert FP32 tensor to FP16 (element-wise narrowing).
578    fn f32_to_f16(
579        &self,
580        _output: &mut Self::Tensor,
581        _input: &Self::Tensor,
582        _n: usize,
583    ) -> crate::Result<()> {
584        Err(crate::Error::Metal(
585            "FP16 not supported by this driver".into(),
586        ))
587    }
588
589    /// Convert FP16 tensor back to FP32 (element-wise widening).
590    fn f16_to_f32(
591        &self,
592        _output: &mut Self::Tensor,
593        _input: &Self::Tensor,
594        _n: usize,
595    ) -> crate::Result<()> {
596        Err(crate::Error::Metal(
597            "FP16 not supported by this driver".into(),
598        ))
599    }
600
601    /// Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
602    fn gemm_mixed(
603        &self,
604        _a_f16: &Self::Tensor,
605        _b_f16: &Self::Tensor,
606        _output_f32: &mut Self::Tensor,
607        _m: usize,
608        _n: usize,
609        _k: usize,
610        _transpose_b: bool,
611    ) -> crate::Result<()> {
612        Err(crate::Error::Metal(
613            "gemm_mixed not supported by this driver".into(),
614        ))
615    }
616
617    /// FP16 GEMM: `output = A * B` (or `A * B^T`). All tensors are half.
618    fn gemm_f16(
619        &self,
620        _a: &Self::Tensor,
621        _b: &Self::Tensor,
622        _output: &mut Self::Tensor,
623        _m: usize,
624        _n: usize,
625        _k: usize,
626        _transpose_b: bool,
627    ) -> crate::Result<()> {
628        Err(crate::Error::Metal(
629            "FP16 not supported by this driver".into(),
630        ))
631    }
632
633    /// FP16 batched GEMM for multi-head attention. All tensors are half.
634    #[expect(
635        clippy::too_many_arguments,
636        reason = "matches FP32 gemm_batched signature"
637    )]
638    fn gemm_batched_f16(
639        &self,
640        _a: &Self::Tensor,
641        _b: &Self::Tensor,
642        _output: &mut Self::Tensor,
643        _m: usize,
644        _n: usize,
645        _k: usize,
646        _transpose_b: bool,
647        _stride_a: usize,
648        _stride_b: usize,
649        _stride_c: usize,
650        _batch_count: usize,
651    ) -> crate::Result<()> {
652        Err(crate::Error::Metal(
653            "FP16 not supported by this driver".into(),
654        ))
655    }
656
657    /// FP16 layer normalization. Half I/O, FP32 reductions.
658    fn layer_norm_f16(
659        &self,
660        _output: &mut Self::Tensor,
661        _input: &Self::Tensor,
662        _weight: &Self::Tensor,
663        _bias: &Self::Tensor,
664        _rows: usize,
665        _cols: usize,
666        _eps: f32,
667    ) -> crate::Result<()> {
668        Err(crate::Error::Metal(
669            "FP16 not supported by this driver".into(),
670        ))
671    }
672
673    /// FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
674    fn fused_scale_mask_softmax_f16(
675        &self,
676        _scores: &mut Self::Tensor,
677        _mask: &Self::Tensor,
678        _batch: usize,
679        _num_heads: usize,
680        _seq_len: usize,
681        _scale: f32,
682    ) -> crate::Result<()> {
683        Err(crate::Error::Metal(
684            "FP16 not supported by this driver".into(),
685        ))
686    }
687
688    /// FP16 fused scale + mask + sliding window + softmax.
689    fn fused_scale_mask_softmax_windowed_f16(
690        &self,
691        _scores: &mut Self::Tensor,
692        _mask: &Self::Tensor,
693        _batch: usize,
694        _num_heads: usize,
695        _seq_len: usize,
696        _scale: f32,
697        _window_size: usize,
698    ) -> crate::Result<()> {
699        Err(crate::Error::Metal(
700            "FP16 not supported by this driver".into(),
701        ))
702    }
703
704    /// FP16 QKV split: `[batch*seq, 3*hidden]` into Q, K, V per-head layout.
705    fn qkv_split_f16(
706        &self,
707        _q: &mut Self::Tensor,
708        _k: &mut Self::Tensor,
709        _v: &mut Self::Tensor,
710        _qkv: &Self::Tensor,
711        _batch: usize,
712        _seq: usize,
713        _hidden: usize,
714        _num_heads: usize,
715        _head_dim: usize,
716    ) -> crate::Result<()> {
717        Err(crate::Error::Metal(
718            "FP16 not supported by this driver".into(),
719        ))
720    }
721
722    /// FP16 attention output reshape: `[batch*num_heads, seq, head_dim]` to
723    /// `[batch*seq, hidden]`.
724    fn attn_reshape_f16(
725        &self,
726        _output: &mut Self::Tensor,
727        _input: &Self::Tensor,
728        _batch: usize,
729        _seq: usize,
730        _num_heads: usize,
731        _head_dim: usize,
732    ) -> crate::Result<()> {
733        Err(crate::Error::Metal(
734            "FP16 not supported by this driver".into(),
735        ))
736    }
737
738    /// FP16 scatter flat `[total_tokens, dim]` to padded `[batch, max_seq, dim]`.
739    fn pad_to_batch_f16(
740        &self,
741        _flat: &Self::Tensor,
742        _padded: &mut Self::Tensor,
743        _seq_lengths: &[usize],
744        _max_seq: usize,
745        _dim: usize,
746    ) -> crate::Result<()> {
747        Err(crate::Error::Metal(
748            "FP16 not supported by this driver".into(),
749        ))
750    }
751
752    /// FP16 gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
753    fn unpad_from_batch_f16(
754        &self,
755        _padded: &Self::Tensor,
756        _flat: &mut Self::Tensor,
757        _seq_lengths: &[usize],
758        _max_seq: usize,
759        _dim: usize,
760    ) -> crate::Result<()> {
761        Err(crate::Error::Metal(
762            "FP16 not supported by this driver".into(),
763        ))
764    }
765
766    /// FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
767    fn rope_encode_f16(
768        &self,
769        _qk: &mut Self::Tensor,
770        _cos: &Self::Tensor,
771        _sin: &Self::Tensor,
772        _num_rows: usize,
773        _seq_len: usize,
774        _head_dim: usize,
775        _num_heads: usize,
776    ) -> crate::Result<()> {
777        Err(crate::Error::Metal(
778            "FP16 not supported by this driver".into(),
779        ))
780    }
781
782    /// FP16 `GeGLU` gated activation: `output = gelu(value) * gate`. Half I/O.
783    fn geglu_f16(
784        &self,
785        _value: &Self::Tensor,
786        _gate: &Self::Tensor,
787        _output: &mut Self::Tensor,
788        _n: usize,
789    ) -> crate::Result<()> {
790        Err(crate::Error::Metal(
791            "FP16 not supported by this driver".into(),
792        ))
793    }
794
795    /// FP16 fused residual add + layer normalization.
796    fn fused_residual_layernorm_f16(
797        &self,
798        _output: &mut Self::Tensor,
799        _hidden: &Self::Tensor,
800        _residual: &Self::Tensor,
801        _weight: &Self::Tensor,
802        _bias: &Self::Tensor,
803        _rows: usize,
804        _cols: usize,
805        _eps: f32,
806    ) -> crate::Result<()> {
807        Err(crate::Error::Metal(
808            "FP16 not supported by this driver".into(),
809        ))
810    }
811
812    /// FP16 residual add (no bias): `output = hidden + residual`.
813    fn residual_add_f16(
814        &self,
815        _output: &mut Self::Tensor,
816        _hidden: &Self::Tensor,
817        _residual: &Self::Tensor,
818        _n: usize,
819    ) -> crate::Result<()> {
820        Err(crate::Error::Metal(
821            "FP16 not supported by this driver".into(),
822        ))
823    }
824
825    /// FP16 split `[rows, 2*cols]` into two `[rows, cols]` halves.
826    fn split_gate_value_f16(
827        &self,
828        _first: &mut Self::Tensor,
829        _second: &mut Self::Tensor,
830        _input: &Self::Tensor,
831        _rows: usize,
832        _cols: usize,
833    ) -> crate::Result<()> {
834        Err(crate::Error::Metal(
835            "FP16 not supported by this driver".into(),
836        ))
837    }
838
839    /// Fused split + `GeGLU`: read `[rows, 2*cols]`, write `[rows, cols]`.
840    ///
841    /// Combines [`split_gate_value_f16`](Driver::split_gate_value_f16) and
842    /// [`geglu_f16`](Driver::geglu_f16) into a single kernel, eliminating
843    /// two intermediate `[rows, cols]` buffers and halving HBM round-trips.
844    ///
845    /// Default falls back to separate split + geglu calls.
846    fn fused_split_geglu_f16(
847        &self,
848        output: &mut Self::Tensor,
849        input: &Self::Tensor,
850        rows: usize,
851        cols: usize,
852    ) -> crate::Result<()> {
853        // Default: allocate intermediates and call separately.
854        let n = rows * cols;
855        let mut value = self.alloc_zeros_f16(n)?;
856        let mut gate = self.alloc_zeros_f16(n)?;
857        self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
858        self.geglu_f16(&value, &gate, output, n)
859    }
860
861    /// Fused pad + QKV split: flat `[total_tokens, 3*hidden]` → Q, K, V
862    /// each `[batch*heads, max_seq, head_dim]`.
863    ///
864    /// Eliminates the padded intermediate buffer. Default calls pad then split.
865    #[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
866    fn fused_pad_qkv_split_f16(
867        &self,
868        q: &mut Self::Tensor,
869        k: &mut Self::Tensor,
870        v: &mut Self::Tensor,
871        qkv_flat: &Self::Tensor,
872        seq_lengths: &[usize],
873        max_seq: usize,
874        batch: usize,
875        hidden: usize,
876        num_heads: usize,
877        head_dim: usize,
878    ) -> crate::Result<()> {
879        // Default: pad then split.
880        let padded_tokens = batch * max_seq;
881        let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
882        self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
883        self.qkv_split_f16(
884            q,
885            k,
886            v,
887            &qkv_padded,
888            batch,
889            max_seq,
890            hidden,
891            num_heads,
892            head_dim,
893        )
894    }
895
896    /// Fused attn_reshape + unpad: `[batch*heads, max_seq, head_dim]` →
897    /// `[total_tokens, hidden]`.
898    ///
899    /// Eliminates the padded context intermediate. Default calls reshape then unpad.
900    fn fused_reshape_unpad_f16(
901        &self,
902        flat: &mut Self::Tensor,
903        heads: &Self::Tensor,
904        seq_lengths: &[usize],
905        max_seq: usize,
906        batch: usize,
907        num_heads: usize,
908        head_dim: usize,
909    ) -> crate::Result<()> {
910        // Default: reshape then unpad.
911        let hidden = num_heads * head_dim;
912        let padded_tokens = batch * max_seq;
913        let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
914        self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
915        self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
916    }
917}
918
919/// Batch input tensors on device, produced by [`Driver::prepare_batch`].
920///
921/// Supports both padded and unpadded modes:
922/// - **Padded**: all sequences padded to `max_seq`. `cu_seqlens` is `None`.
923/// - **Unpadded**: sequences concatenated without padding. `cu_seqlens`
924///   contains cumulative lengths `[0, len0, len0+len1, ...]` so attention
925///   knows where each sequence starts. Eliminates ALL padding compute.
926pub struct BatchInputs<T> {
927    /// Token IDs — `[batch * max_seq]` (padded) or `[total_tokens]` (unpadded).
928    pub input_ids: T,
929    /// Attention mask `[batch * max_seq]` as int32 (0 or 1). Unused in unpadded mode.
930    pub attention_mask: T,
931    /// Token type IDs — same layout as `input_ids`.
932    pub token_type_ids: T,
933    /// Position IDs — same layout as `input_ids`.
934    pub position_ids: T,
935    /// Float attention bias mask `[batch * max_seq]` (0.0 or -1e9) for softmax.
936    pub float_mask: T,
937    /// Float pooling mask `[batch * max_seq]` (1.0 or 0.0) for mean pooling.
938    pub pooling_mask: T,
939    /// Number of sequences in this batch.
940    pub batch: usize,
941    /// Maximum sequence length (all sequences padded to this). In unpadded mode,
942    /// this is the longest sequence (used for workspace sizing, not padding).
943    pub max_seq: usize,
944    /// Total actual tokens across all sequences (no padding).
945    pub total_tokens: usize,
946    /// Per-sequence lengths: `[batch]` — each element is the actual token count.
947    pub seq_lengths: Vec<usize>,
948    /// Cumulative sequence lengths for unpadded attention: `[batch + 1]`.
949    /// `cu_seqlens[i]..cu_seqlens[i+1]` is the token range for sequence `i`.
950    /// `None` in padded mode (all sequences padded to max_seq).
951    pub cu_seqlens: Option<Vec<usize>>,
952}