Skip to main content

ripvec_core/backend/driver/
mod.rs

1//! Hardware-agnostic compute driver trait.
2//!
3//! The [`Driver`] trait exposes low-level compute primitives (GEMM, layer-norm,
4//! activations, etc.) that each hardware backend implements. Model architectures
5//! are generic over `D: Driver` and compose these primitives into a forward pass.
6//!
7//! # Design
8//!
9//! - **Associated type `Tensor`**: each driver defines its own opaque tensor
10//!   handle (Metal: buffer+offset, CUDA: device pointer, CPU: ndarray).
11//! - **Not object-safe**: architectures use `D: Driver` generics so the compiler
12//!   can monomorphize and inline driver calls.
13//! - **Send + Sync**: drivers are shared across the pipeline.
14
15#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
16pub mod cpu;
17#[cfg(feature = "cuda")]
18pub mod cuda;
19#[cfg(feature = "metal")]
20pub mod metal;
21#[cfg(feature = "mlx")]
22pub mod mlx;
23
24use super::Encoding;
25
26/// Hardware-agnostic compute primitives for BERT inference.
27///
28/// Each method corresponds to one operation in the forward pass. Drivers handle
29/// memory allocation, kernel dispatch, and synchronization. Architectures
30/// compose these primitives via the [`super::arch::ModelArch`] trait.
31pub trait Driver: Send + Sync {
32    /// Opaque tensor handle.
33    ///
34    /// Metal: `MTLBuffer` + byte offset. CUDA: `CUdeviceptr`. CPU: `Array2<f32>`.
35    type Tensor;
36
37    /// Short human-readable label for diagnostics (e.g. "Metal", "CUDA", "CPU").
38    /// Surfaced via [`super::EmbedBackend::name`].
39    fn name(&self) -> &'static str;
40
41    /// Create a new driver instance for a cloned worker thread.
42    ///
43    /// CPU drivers are zero-size and always succeed. GPU drivers typically
44    /// cannot be cloned this way (they share device state) and should leave
45    /// the default panic implementation.
46    fn new_for_clone() -> crate::Result<Self>
47    where
48        Self: Sized,
49    {
50        Err(crate::Error::Other(anyhow::anyhow!(
51            "this driver does not support cloning"
52        )))
53    }
54
55    // --- Batching ---
56
57    /// Begin batched mode: all subsequent operations encode into one dispatch.
58    ///
59    /// GPU drivers accumulate into a single command buffer; CPU is a no-op.
60    /// Call [`Self::end_batch`] to commit. This eliminates per-call overhead.
61    fn begin_batch(&self) -> crate::Result<()> {
62        Ok(())
63    }
64
65    /// End batched mode: commit all accumulated operations and wait.
66    fn end_batch(&self) -> crate::Result<()> {
67        Ok(())
68    }
69
70    /// Flush the current command buffer and start a new one, preserving pool
71    /// state. Use mid-forward-pass to prevent GPU timeouts on deep models.
72    fn flush_batch(&self) -> crate::Result<()> {
73        Ok(())
74    }
75
76    /// Close and reopen the compute encoder within the same command buffer.
77    ///
78    /// This segments a long sequence of compute dispatches into multiple
79    /// encoders without committing or waiting. Metal processes encoders
80    /// back-to-back from the same CB — zero sync overhead.
81    ///
82    /// Use every few layers to prevent encoder state overflow (>~60 dispatches
83    /// per encoder can cause hangs on some Apple Silicon GPUs).
84    fn segment_encoder(&self) {
85        // No-op for non-Metal backends
86    }
87
88    /// Save the current pool cursor position. Call BEFORE a layer's work.
89    fn save_pool_cursor(&self) -> usize {
90        0
91    }
92
93    /// Restore the pool cursor to a previously saved position. Call AFTER
94    /// a layer's transient tensors have been dropped (out of scope).
95    ///
96    /// The architecture must ensure only the output tensor (`hidden_states`)
97    /// survives — all layer-internal tensors (qkv, scores, context, etc.)
98    /// must be dropped before this call so their pool slots can be recycled.
99    fn restore_pool_cursor(&self, _saved: usize) {}
100
101    // --- Allocation ---
102
103    /// Allocate a zero-initialized tensor with `n` float elements on device.
104    ///
105    /// Used by architectures to create workspace buffers (QKV projections,
106    /// attention scores, intermediate activations, etc.).
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if device memory allocation fails.
111    fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
112
113    /// Clone a tensor, producing an independent copy of the data.
114    ///
115    /// Used when an operation needs both the original and a mutable output
116    /// referencing the same logical data (e.g., in-place layer normalization
117    /// where input == output).
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if device memory allocation or the copy fails.
122    fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
123
124    // --- Batch preparation ---
125
126    /// Prepare a batch of encodings for inference, returning input tensors on device.
127    ///
128    /// Pads all sequences to `max_seq` and uploads `input_ids`, `attention_mask`,
129    /// `token_type_ids`, `position_ids`, and a float attention mask to device memory.
130    fn prepare_batch(
131        &self,
132        encodings: &[Encoding],
133        max_seq: usize,
134    ) -> crate::Result<BatchInputs<Self::Tensor>>;
135
136    /// Prepare a batch WITHOUT padding — concatenate all tokens flat.
137    ///
138    /// Returns `BatchInputs` with `total_tokens` actual tokens (no padding),
139    /// `cu_seqlens` for attention boundaries, and per-token position IDs.
140    /// Linear layers (GEMM, LN, GELU) process `total_tokens` rows.
141    /// Attention must pad/unpad around the per-head operations.
142    fn prepare_batch_unpadded(
143        &self,
144        encodings: &[Encoding],
145    ) -> crate::Result<BatchInputs<Self::Tensor>> {
146        // Default: fall back to padded (backends override for unpadded support)
147        let max_seq = encodings
148            .iter()
149            .map(|e| e.input_ids.len())
150            .max()
151            .unwrap_or(0)
152            .next_multiple_of(8);
153        self.prepare_batch(encodings, max_seq)
154    }
155
156    /// Scatter flat `[total_tokens, dim]` tensor into padded `[batch, max_seq, dim]`.
157    ///
158    /// Used before attention: linear layers produce unpadded output, but the
159    /// QKV split + batched attention GEMM need aligned `[batch*heads, seq, head_dim]`.
160    /// Padding positions are zeroed.
161    fn pad_to_batch(
162        &self,
163        flat: &Self::Tensor,
164        padded: &mut Self::Tensor,
165        seq_lengths: &[usize],
166        max_seq: usize,
167        dim: usize,
168    ) -> crate::Result<()>;
169
170    /// Gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
171    ///
172    /// Used after attention: extracts only the real tokens, discarding padding.
173    fn unpad_from_batch(
174        &self,
175        padded: &Self::Tensor,
176        flat: &mut Self::Tensor,
177        seq_lengths: &[usize],
178        max_seq: usize,
179        dim: usize,
180    ) -> crate::Result<()>;
181
182    // --- Embedding operations ---
183
184    /// Word/position/token-type embedding lookup via gather.
185    ///
186    /// Reads `seq_len` token IDs from `word_ids`, gathers rows from
187    /// `embedding_table`, and writes `[seq_len, hidden]` floats to the result.
188    fn embedding_lookup(
189        &self,
190        word_ids: &Self::Tensor,
191        embedding_table: &Self::Tensor,
192        seq_len: usize,
193        hidden: usize,
194    ) -> crate::Result<Self::Tensor>;
195
196    /// Element-wise add an embedding table lookup into `hidden`.
197    ///
198    /// Used for position and token-type embeddings:
199    /// `hidden[i] += table[ids[i]]` for each token position.
200    fn add_embeddings(
201        &self,
202        hidden: &mut Self::Tensor,
203        table: &Self::Tensor,
204        ids: &Self::Tensor,
205        seq_len: usize,
206        hidden_dim: usize,
207    ) -> crate::Result<()>;
208
209    // --- Normalization ---
210
211    /// Layer normalization: `output = (input - mean) / sqrt(var + eps) * weight + bias`.
212    fn layer_norm(
213        &self,
214        output: &mut Self::Tensor,
215        input: &Self::Tensor,
216        weight: &Self::Tensor,
217        bias: &Self::Tensor,
218        rows: usize,
219        cols: usize,
220        eps: f32,
221    ) -> crate::Result<()>;
222
223    // --- Linear algebra ---
224
225    /// General matrix multiply: `output = A * B` (or `A * B^T` if `transpose_b`).
226    ///
227    /// Dimensions: A is `[m, k]`, B is `[k, n]` (or `[n, k]` if transposed),
228    /// output is `[m, n]`.
229    fn gemm(
230        &self,
231        a: &Self::Tensor,
232        b: &Self::Tensor,
233        output: &mut Self::Tensor,
234        m: usize,
235        n: usize,
236        k: usize,
237        transpose_b: bool,
238    ) -> crate::Result<()>;
239
240    /// Batched GEMM for multi-head attention.
241    ///
242    /// Performs `batch_count` independent GEMMs with strided access into
243    /// contiguous buffers. Used for per-head Q*K^T and attn*V.
244    fn gemm_batched(
245        &self,
246        a: &Self::Tensor,
247        b: &Self::Tensor,
248        output: &mut Self::Tensor,
249        m: usize,
250        n: usize,
251        k: usize,
252        transpose_b: bool,
253        stride_a: usize,
254        stride_b: usize,
255        stride_c: usize,
256        batch_count: usize,
257    ) -> crate::Result<()>;
258
259    // --- Attention ---
260
261    /// Fused scale + mask + softmax for attention scores.
262    ///
263    /// `scores = softmax(scores * scale + mask)` computed per-head.
264    fn fused_scale_mask_softmax(
265        &self,
266        scores: &mut Self::Tensor,
267        mask: &Self::Tensor,
268        batch: usize,
269        num_heads: usize,
270        seq_len: usize,
271        scale: f32,
272    ) -> crate::Result<()>;
273
274    /// Fused scale + mask + sliding window + softmax for attention scores.
275    ///
276    /// Like [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax) but
277    /// additionally masks out positions where `|query_pos - key_pos| > window_size / 2`.
278    /// Used by `ModernBERT`'s local attention layers.
279    fn fused_scale_mask_softmax_windowed(
280        &self,
281        scores: &mut Self::Tensor,
282        mask: &Self::Tensor,
283        batch: usize,
284        num_heads: usize,
285        seq_len: usize,
286        scale: f32,
287        window_size: usize,
288    ) -> crate::Result<()>;
289
290    /// Build a float attention mask from an integer mask.
291    ///
292    /// Converts `[batch * seq]` int mask (0/1) to `[batch * seq]` float mask
293    /// (0.0 / -10000.0) for use with [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax).
294    fn build_attn_mask(
295        &self,
296        output: &mut Self::Tensor,
297        int_mask: &Self::Tensor,
298        n: usize,
299    ) -> crate::Result<()>;
300
301    /// Split a fused QKV projection into separate Q, K, V tensors.
302    fn qkv_split(
303        &self,
304        q: &mut Self::Tensor,
305        k: &mut Self::Tensor,
306        v: &mut Self::Tensor,
307        qkv: &Self::Tensor,
308        batch: usize,
309        seq: usize,
310        hidden: usize,
311        num_heads: usize,
312        head_dim: usize,
313    ) -> crate::Result<()>;
314
315    // --- Banded (local/sliding-window) attention ---
316
317    /// Banded Q@K^T: compute attention scores only within a sliding window.
318    ///
319    /// Output shape: `[batch * num_heads, seq, window]` (NOT `[seq, seq]`).
320    /// `scores[h, i, w]` = dot(Q[h, i, :], K[h, i - window/2 + w, :])
321    /// where out-of-bounds positions are set to `-inf` (masked in softmax).
322    ///
323    /// Reduces attention compute from O(seq²) to O(seq × window).
324    /// For `seq=512, window=128`: **4× less compute** per local layer.
325    fn banded_qk(
326        &self,
327        q: &Self::Tensor,
328        k: &Self::Tensor,
329        scores: &mut Self::Tensor,
330        batch_heads: usize,
331        seq: usize,
332        head_dim: usize,
333        window: usize,
334        stride_qk: usize,
335        stride_scores: usize,
336    ) -> crate::Result<()>;
337
338    /// Banded scores@V: weighted sum using banded attention scores.
339    ///
340    /// Input scores: `[batch * num_heads, seq, window]` (from `banded_qk`).
341    /// Output: `[batch * num_heads, seq, head_dim]`.
342    /// `output[h, i, d]` = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
343    fn banded_sv(
344        &self,
345        scores: &Self::Tensor,
346        v: &Self::Tensor,
347        output: &mut Self::Tensor,
348        batch_heads: usize,
349        seq: usize,
350        head_dim: usize,
351        window: usize,
352        stride_scores: usize,
353        stride_v: usize,
354        stride_out: usize,
355    ) -> crate::Result<()>;
356
357    /// Fused scale + softmax over the window dimension (no padding mask needed).
358    ///
359    /// Operates on `[batch * num_heads * seq, window]` rows.
360    fn banded_softmax(
361        &self,
362        scores: &mut Self::Tensor,
363        total_rows: usize,
364        window: usize,
365        scale: f32,
366    ) -> crate::Result<()>;
367
368    /// Reshape attention output from `[batch, num_heads, seq, head_dim]` to
369    /// `[batch * seq, hidden]`.
370    fn attn_reshape(
371        &self,
372        output: &mut Self::Tensor,
373        input: &Self::Tensor,
374        batch: usize,
375        seq: usize,
376        num_heads: usize,
377        head_dim: usize,
378    ) -> crate::Result<()>;
379
380    /// Apply Rotary Position Embedding (RoPE) to Q/K tensors.
381    ///
382    /// Used by ModernBERT (not ClassicBert which uses learned position embeddings).
383    fn apply_rope(
384        &self,
385        qk: &mut Self::Tensor,
386        cos: &Self::Tensor,
387        sin: &Self::Tensor,
388        num_rows: usize,
389        seq_len: usize,
390        head_dim: usize,
391        num_heads: usize,
392    ) -> crate::Result<()>;
393
394    // --- Tensor manipulation ---
395
396    /// Split a `[rows, 2*cols]` matrix into two `[rows, cols]` halves.
397    ///
398    /// Each row of `input` is `[first_half | second_half]`. The first `cols`
399    /// elements go to `first`, the remaining `cols` to `second`.
400    /// Used by `ModernBERT` for gated MLP splits.
401    fn split_gate_value(
402        &self,
403        first: &mut Self::Tensor,
404        second: &mut Self::Tensor,
405        input: &Self::Tensor,
406        rows: usize,
407        cols: usize,
408    ) -> crate::Result<()>;
409
410    // --- Activations ---
411
412    /// GELU activation (Gaussian Error Linear Unit), applied in-place.
413    fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
414
415    /// SwiGLU gated activation: `output = value * silu(gate)`.
416    ///
417    /// The gate and value come from splitting the intermediate projection.
418    fn swiglu(
419        &self,
420        value: &Self::Tensor,
421        gate: &Self::Tensor,
422        output: &mut Self::Tensor,
423        n: usize,
424    ) -> crate::Result<()>;
425
426    /// `GeGLU` gated activation: `output = gelu(value) * gate`.
427    ///
428    /// Used by `ModernBERT`. The value and gate come from splitting the
429    /// MLP `Wi` projection output in half.
430    fn geglu(
431        &self,
432        value: &Self::Tensor,
433        gate: &Self::Tensor,
434        output: &mut Self::Tensor,
435        n: usize,
436    ) -> crate::Result<()>;
437
438    /// Fused bias + GELU: `x = gelu(x + bias)` row-wise.
439    fn fused_bias_gelu(
440        &self,
441        x: &mut Self::Tensor,
442        bias: &Self::Tensor,
443        rows: usize,
444        cols: usize,
445    ) -> crate::Result<()>;
446
447    // --- Fused residual operations ---
448
449    /// Fused bias + residual add: `output = input + bias + residual`.
450    ///
451    /// Bias is broadcast row-wise (`cols`-wide) across `n / cols` rows.
452    fn fused_bias_residual(
453        &self,
454        output: &mut Self::Tensor,
455        input: &Self::Tensor,
456        bias: &Self::Tensor,
457        residual: &Self::Tensor,
458        n: usize,
459        cols: usize,
460    ) -> crate::Result<()>;
461
462    /// Fused residual add + layer normalization.
463    ///
464    /// `output = layer_norm(hidden + residual, weight, bias, eps)`.
465    fn fused_residual_layernorm(
466        &self,
467        output: &mut Self::Tensor,
468        hidden: &Self::Tensor,
469        residual: &Self::Tensor,
470        weight: &Self::Tensor,
471        bias: &Self::Tensor,
472        rows: usize,
473        cols: usize,
474        eps: f32,
475    ) -> crate::Result<()>;
476
477    /// Residual add without bias: `output = hidden + residual`.
478    ///
479    /// Used by `ModernBERT` which has no bias terms.
480    fn residual_add(
481        &self,
482        output: &mut Self::Tensor,
483        hidden: &Self::Tensor,
484        residual: &Self::Tensor,
485        n: usize,
486    ) -> crate::Result<()>;
487
488    /// Add bias to a matrix row-wise: `x[row] += bias` for each row.
489    fn add_bias(
490        &self,
491        x: &mut Self::Tensor,
492        bias: &Self::Tensor,
493        rows: usize,
494        cols: usize,
495    ) -> crate::Result<()>;
496
497    // --- Pooling ---
498
499    /// CLS pooling: extract the first token's hidden state per batch element.
500    fn cls_pool(
501        &self,
502        output: &mut Self::Tensor,
503        hidden: &Self::Tensor,
504        batch: usize,
505        seq: usize,
506        hidden_dim: usize,
507    ) -> crate::Result<()>;
508
509    /// Mean pooling: attention-mask-weighted average of hidden states.
510    fn mean_pool(
511        &self,
512        output: &mut Self::Tensor,
513        hidden: &Self::Tensor,
514        mask: &Self::Tensor,
515        batch: usize,
516        seq: usize,
517        hidden_dim: usize,
518    ) -> crate::Result<()>;
519
520    // --- Post-processing ---
521
522    /// L2-normalize each row vector in-place.
523    fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
524
525    /// Copy tensor data back to host memory as `Vec<Vec<f32>>`.
526    ///
527    /// Returns one `Vec<f32>` of length `dim` per batch element.
528    fn to_host(
529        &self,
530        tensor: &Self::Tensor,
531        batch: usize,
532        dim: usize,
533    ) -> crate::Result<Vec<Vec<f32>>>;
534
535    // =======================================================================
536    // FP16 operations for full half-precision pipeline
537    //
538    // These methods mirror the FP32 counterparts but operate on FP16 tensors.
539    // Internal reductions (softmax, layer-norm) use FP32 accumulators but
540    // all tensor I/O is half precision. Default implementations return an
541    // error — only backends with FP16 support override them.
542    // =======================================================================
543
544    /// Allocate a zero-initialized FP16 tensor with `n` half-precision elements.
545    ///
546    /// # Errors
547    ///
548    /// Returns an error if device memory allocation fails or FP16 is unsupported.
549    fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
550        Err(crate::Error::Metal(
551            "FP16 not supported by this driver".into(),
552        ))
553    }
554
555    /// Convert FP32 tensor to FP16 (element-wise narrowing).
556    fn f32_to_f16(
557        &self,
558        _output: &mut Self::Tensor,
559        _input: &Self::Tensor,
560        _n: usize,
561    ) -> crate::Result<()> {
562        Err(crate::Error::Metal(
563            "FP16 not supported by this driver".into(),
564        ))
565    }
566
567    /// Convert FP16 tensor back to FP32 (element-wise widening).
568    fn f16_to_f32(
569        &self,
570        _output: &mut Self::Tensor,
571        _input: &Self::Tensor,
572        _n: usize,
573    ) -> crate::Result<()> {
574        Err(crate::Error::Metal(
575            "FP16 not supported by this driver".into(),
576        ))
577    }
578
579    /// Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
580    fn gemm_mixed(
581        &self,
582        _a_f16: &Self::Tensor,
583        _b_f16: &Self::Tensor,
584        _output_f32: &mut Self::Tensor,
585        _m: usize,
586        _n: usize,
587        _k: usize,
588        _transpose_b: bool,
589    ) -> crate::Result<()> {
590        Err(crate::Error::Metal(
591            "gemm_mixed not supported by this driver".into(),
592        ))
593    }
594
595    /// FP16 GEMM: `output = A * B` (or `A * B^T`). All tensors are half.
596    fn gemm_f16(
597        &self,
598        _a: &Self::Tensor,
599        _b: &Self::Tensor,
600        _output: &mut Self::Tensor,
601        _m: usize,
602        _n: usize,
603        _k: usize,
604        _transpose_b: bool,
605    ) -> crate::Result<()> {
606        Err(crate::Error::Metal(
607            "FP16 not supported by this driver".into(),
608        ))
609    }
610
611    /// FP16 batched GEMM for multi-head attention. All tensors are half.
612    #[expect(
613        clippy::too_many_arguments,
614        reason = "matches FP32 gemm_batched signature"
615    )]
616    fn gemm_batched_f16(
617        &self,
618        _a: &Self::Tensor,
619        _b: &Self::Tensor,
620        _output: &mut Self::Tensor,
621        _m: usize,
622        _n: usize,
623        _k: usize,
624        _transpose_b: bool,
625        _stride_a: usize,
626        _stride_b: usize,
627        _stride_c: usize,
628        _batch_count: usize,
629    ) -> crate::Result<()> {
630        Err(crate::Error::Metal(
631            "FP16 not supported by this driver".into(),
632        ))
633    }
634
635    /// FP16 layer normalization. Half I/O, FP32 reductions.
636    fn layer_norm_f16(
637        &self,
638        _output: &mut Self::Tensor,
639        _input: &Self::Tensor,
640        _weight: &Self::Tensor,
641        _bias: &Self::Tensor,
642        _rows: usize,
643        _cols: usize,
644        _eps: f32,
645    ) -> crate::Result<()> {
646        Err(crate::Error::Metal(
647            "FP16 not supported by this driver".into(),
648        ))
649    }
650
651    /// FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
652    fn fused_scale_mask_softmax_f16(
653        &self,
654        _scores: &mut Self::Tensor,
655        _mask: &Self::Tensor,
656        _batch: usize,
657        _num_heads: usize,
658        _seq_len: usize,
659        _scale: f32,
660    ) -> crate::Result<()> {
661        Err(crate::Error::Metal(
662            "FP16 not supported by this driver".into(),
663        ))
664    }
665
666    /// FP16 fused scale + mask + sliding window + softmax.
667    fn fused_scale_mask_softmax_windowed_f16(
668        &self,
669        _scores: &mut Self::Tensor,
670        _mask: &Self::Tensor,
671        _batch: usize,
672        _num_heads: usize,
673        _seq_len: usize,
674        _scale: f32,
675        _window_size: usize,
676    ) -> crate::Result<()> {
677        Err(crate::Error::Metal(
678            "FP16 not supported by this driver".into(),
679        ))
680    }
681
682    /// FP16 QKV split: `[batch*seq, 3*hidden]` into Q, K, V per-head layout.
683    fn qkv_split_f16(
684        &self,
685        _q: &mut Self::Tensor,
686        _k: &mut Self::Tensor,
687        _v: &mut Self::Tensor,
688        _qkv: &Self::Tensor,
689        _batch: usize,
690        _seq: usize,
691        _hidden: usize,
692        _num_heads: usize,
693        _head_dim: usize,
694    ) -> crate::Result<()> {
695        Err(crate::Error::Metal(
696            "FP16 not supported by this driver".into(),
697        ))
698    }
699
700    /// FP16 attention output reshape: `[batch*num_heads, seq, head_dim]` to
701    /// `[batch*seq, hidden]`.
702    fn attn_reshape_f16(
703        &self,
704        _output: &mut Self::Tensor,
705        _input: &Self::Tensor,
706        _batch: usize,
707        _seq: usize,
708        _num_heads: usize,
709        _head_dim: usize,
710    ) -> crate::Result<()> {
711        Err(crate::Error::Metal(
712            "FP16 not supported by this driver".into(),
713        ))
714    }
715
716    /// FP16 scatter flat `[total_tokens, dim]` to padded `[batch, max_seq, dim]`.
717    fn pad_to_batch_f16(
718        &self,
719        _flat: &Self::Tensor,
720        _padded: &mut Self::Tensor,
721        _seq_lengths: &[usize],
722        _max_seq: usize,
723        _dim: usize,
724    ) -> crate::Result<()> {
725        Err(crate::Error::Metal(
726            "FP16 not supported by this driver".into(),
727        ))
728    }
729
730    /// FP16 gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
731    fn unpad_from_batch_f16(
732        &self,
733        _padded: &Self::Tensor,
734        _flat: &mut Self::Tensor,
735        _seq_lengths: &[usize],
736        _max_seq: usize,
737        _dim: usize,
738    ) -> crate::Result<()> {
739        Err(crate::Error::Metal(
740            "FP16 not supported by this driver".into(),
741        ))
742    }
743
744    /// FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
745    fn rope_encode_f16(
746        &self,
747        _qk: &mut Self::Tensor,
748        _cos: &Self::Tensor,
749        _sin: &Self::Tensor,
750        _num_rows: usize,
751        _seq_len: usize,
752        _head_dim: usize,
753        _num_heads: usize,
754    ) -> crate::Result<()> {
755        Err(crate::Error::Metal(
756            "FP16 not supported by this driver".into(),
757        ))
758    }
759
760    /// FP16 `GeGLU` gated activation: `output = gelu(value) * gate`. Half I/O.
761    fn geglu_f16(
762        &self,
763        _value: &Self::Tensor,
764        _gate: &Self::Tensor,
765        _output: &mut Self::Tensor,
766        _n: usize,
767    ) -> crate::Result<()> {
768        Err(crate::Error::Metal(
769            "FP16 not supported by this driver".into(),
770        ))
771    }
772
773    /// FP16 fused residual add + layer normalization.
774    fn fused_residual_layernorm_f16(
775        &self,
776        _output: &mut Self::Tensor,
777        _hidden: &Self::Tensor,
778        _residual: &Self::Tensor,
779        _weight: &Self::Tensor,
780        _bias: &Self::Tensor,
781        _rows: usize,
782        _cols: usize,
783        _eps: f32,
784    ) -> crate::Result<()> {
785        Err(crate::Error::Metal(
786            "FP16 not supported by this driver".into(),
787        ))
788    }
789
790    /// FP16 residual add (no bias): `output = hidden + residual`.
791    fn residual_add_f16(
792        &self,
793        _output: &mut Self::Tensor,
794        _hidden: &Self::Tensor,
795        _residual: &Self::Tensor,
796        _n: usize,
797    ) -> crate::Result<()> {
798        Err(crate::Error::Metal(
799            "FP16 not supported by this driver".into(),
800        ))
801    }
802
803    /// FP16 split `[rows, 2*cols]` into two `[rows, cols]` halves.
804    fn split_gate_value_f16(
805        &self,
806        _first: &mut Self::Tensor,
807        _second: &mut Self::Tensor,
808        _input: &Self::Tensor,
809        _rows: usize,
810        _cols: usize,
811    ) -> crate::Result<()> {
812        Err(crate::Error::Metal(
813            "FP16 not supported by this driver".into(),
814        ))
815    }
816
817    /// Fused split + `GeGLU`: read `[rows, 2*cols]`, write `[rows, cols]`.
818    ///
819    /// Combines [`split_gate_value_f16`](Driver::split_gate_value_f16) and
820    /// [`geglu_f16`](Driver::geglu_f16) into a single kernel, eliminating
821    /// two intermediate `[rows, cols]` buffers and halving HBM round-trips.
822    ///
823    /// Default falls back to separate split + geglu calls.
824    fn fused_split_geglu_f16(
825        &self,
826        output: &mut Self::Tensor,
827        input: &Self::Tensor,
828        rows: usize,
829        cols: usize,
830    ) -> crate::Result<()> {
831        // Default: allocate intermediates and call separately.
832        let n = rows * cols;
833        let mut value = self.alloc_zeros_f16(n)?;
834        let mut gate = self.alloc_zeros_f16(n)?;
835        self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
836        self.geglu_f16(&value, &gate, output, n)
837    }
838
839    /// Fused pad + QKV split: flat `[total_tokens, 3*hidden]` → Q, K, V
840    /// each `[batch*heads, max_seq, head_dim]`.
841    ///
842    /// Eliminates the padded intermediate buffer. Default calls pad then split.
843    #[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
844    fn fused_pad_qkv_split_f16(
845        &self,
846        q: &mut Self::Tensor,
847        k: &mut Self::Tensor,
848        v: &mut Self::Tensor,
849        qkv_flat: &Self::Tensor,
850        seq_lengths: &[usize],
851        max_seq: usize,
852        batch: usize,
853        hidden: usize,
854        num_heads: usize,
855        head_dim: usize,
856    ) -> crate::Result<()> {
857        // Default: pad then split.
858        let padded_tokens = batch * max_seq;
859        let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
860        self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
861        self.qkv_split_f16(
862            q,
863            k,
864            v,
865            &qkv_padded,
866            batch,
867            max_seq,
868            hidden,
869            num_heads,
870            head_dim,
871        )
872    }
873
874    /// Fused attn_reshape + unpad: `[batch*heads, max_seq, head_dim]` →
875    /// `[total_tokens, hidden]`.
876    ///
877    /// Eliminates the padded context intermediate. Default calls reshape then unpad.
878    fn fused_reshape_unpad_f16(
879        &self,
880        flat: &mut Self::Tensor,
881        heads: &Self::Tensor,
882        seq_lengths: &[usize],
883        max_seq: usize,
884        batch: usize,
885        num_heads: usize,
886        head_dim: usize,
887    ) -> crate::Result<()> {
888        // Default: reshape then unpad.
889        let hidden = num_heads * head_dim;
890        let padded_tokens = batch * max_seq;
891        let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
892        self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
893        self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
894    }
895}
896
897/// Batch input tensors on device, produced by [`Driver::prepare_batch`].
898///
899/// Supports both padded and unpadded modes:
900/// - **Padded**: all sequences padded to `max_seq`. `cu_seqlens` is `None`.
901/// - **Unpadded**: sequences concatenated without padding. `cu_seqlens`
902///   contains cumulative lengths `[0, len0, len0+len1, ...]` so attention
903///   knows where each sequence starts. Eliminates ALL padding compute.
904pub struct BatchInputs<T> {
905    /// Token IDs — `[batch * max_seq]` (padded) or `[total_tokens]` (unpadded).
906    pub input_ids: T,
907    /// Attention mask `[batch * max_seq]` as int32 (0 or 1). Unused in unpadded mode.
908    pub attention_mask: T,
909    /// Token type IDs — same layout as `input_ids`.
910    pub token_type_ids: T,
911    /// Position IDs — same layout as `input_ids`.
912    pub position_ids: T,
913    /// Float attention bias mask `[batch * max_seq]` (0.0 or -1e9) for softmax.
914    pub float_mask: T,
915    /// Float pooling mask `[batch * max_seq]` (1.0 or 0.0) for mean pooling.
916    pub pooling_mask: T,
917    /// Number of sequences in this batch.
918    pub batch: usize,
919    /// Maximum sequence length (all sequences padded to this). In unpadded mode,
920    /// this is the longest sequence (used for workspace sizing, not padding).
921    pub max_seq: usize,
922    /// Total actual tokens across all sequences (no padding).
923    pub total_tokens: usize,
924    /// Per-sequence lengths: `[batch]` — each element is the actual token count.
925    pub seq_lengths: Vec<usize>,
926    /// Cumulative sequence lengths for unpadded attention: `[batch + 1]`.
927    /// `cu_seqlens[i]..cu_seqlens[i+1]` is the token range for sequence `i`.
928    /// `None` in padded mode (all sequences padded to max_seq).
929    pub cu_seqlens: Option<Vec<usize>>,
930}