Skip to main content

ripvec_core/backend/driver/
mod.rs

1//! Hardware-agnostic compute driver trait.
2//!
3//! The [`Driver`] trait exposes low-level compute primitives (GEMM, layer-norm,
4//! activations, etc.) that each hardware backend implements. Model architectures
5//! are generic over `D: Driver` and compose these primitives into a forward pass.
6//!
7//! # Design
8//!
9//! - **Associated type `Tensor`**: each driver defines its own opaque tensor
10//!   handle (Metal: buffer+offset, CUDA: device pointer, CPU: ndarray).
11//! - **Not object-safe**: architectures use `D: Driver` generics so the compiler
12//!   can monomorphize and inline driver calls.
13//! - **Send + Sync**: drivers are shared across the pipeline.
14
15#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
16pub mod cpu;
17#[cfg(feature = "cuda")]
18pub mod cuda;
19#[cfg(feature = "metal")]
20pub mod metal;
21#[cfg(feature = "mlx")]
22pub mod mlx;
23
24use super::Encoding;
25
26/// Hardware-agnostic compute primitives for BERT inference.
27///
28/// Each method corresponds to one operation in the forward pass. Drivers handle
29/// memory allocation, kernel dispatch, and synchronization. Architectures
30/// compose these primitives via the [`super::arch::ModelArch`] trait.
31pub trait Driver: Send + Sync {
32    /// Opaque tensor handle.
33    ///
34    /// Metal: `MTLBuffer` + byte offset. CUDA: `CUdeviceptr`. CPU: `Array2<f32>`.
35    type Tensor;
36
37    /// Create a new driver instance for a cloned worker thread.
38    ///
39    /// CPU drivers are zero-size and always succeed. GPU drivers typically
40    /// cannot be cloned this way (they share device state) and should leave
41    /// the default panic implementation.
42    fn new_for_clone() -> crate::Result<Self>
43    where
44        Self: Sized,
45    {
46        Err(crate::Error::Other(anyhow::anyhow!(
47            "this driver does not support cloning"
48        )))
49    }
50
51    // --- Batching ---
52
53    /// Begin batched mode: all subsequent operations encode into one dispatch.
54    ///
55    /// GPU drivers accumulate into a single command buffer; CPU is a no-op.
56    /// Call [`Self::end_batch`] to commit. This eliminates per-call overhead.
57    fn begin_batch(&self) -> crate::Result<()> {
58        Ok(())
59    }
60
61    /// End batched mode: commit all accumulated operations and wait.
62    fn end_batch(&self) -> crate::Result<()> {
63        Ok(())
64    }
65
66    /// Flush the current command buffer and start a new one, preserving pool
67    /// state. Use mid-forward-pass to prevent GPU timeouts on deep models.
68    fn flush_batch(&self) -> crate::Result<()> {
69        Ok(())
70    }
71
72    /// Close and reopen the compute encoder within the same command buffer.
73    ///
74    /// This segments a long sequence of compute dispatches into multiple
75    /// encoders without committing or waiting. Metal processes encoders
76    /// back-to-back from the same CB — zero sync overhead.
77    ///
78    /// Use every few layers to prevent encoder state overflow (>~60 dispatches
79    /// per encoder can cause hangs on some Apple Silicon GPUs).
80    fn segment_encoder(&self) {
81        // No-op for non-Metal backends
82    }
83
84    /// Save the current pool cursor position. Call BEFORE a layer's work.
85    fn save_pool_cursor(&self) -> usize {
86        0
87    }
88
89    /// Restore the pool cursor to a previously saved position. Call AFTER
90    /// a layer's transient tensors have been dropped (out of scope).
91    ///
92    /// The architecture must ensure only the output tensor (`hidden_states`)
93    /// survives — all layer-internal tensors (qkv, scores, context, etc.)
94    /// must be dropped before this call so their pool slots can be recycled.
95    fn restore_pool_cursor(&self, _saved: usize) {}
96
97    // --- Allocation ---
98
99    /// Allocate a zero-initialized tensor with `n` float elements on device.
100    ///
101    /// Used by architectures to create workspace buffers (QKV projections,
102    /// attention scores, intermediate activations, etc.).
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if device memory allocation fails.
107    fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
108
109    /// Clone a tensor, producing an independent copy of the data.
110    ///
111    /// Used when an operation needs both the original and a mutable output
112    /// referencing the same logical data (e.g., in-place layer normalization
113    /// where input == output).
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if device memory allocation or the copy fails.
118    fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
119
120    // --- Batch preparation ---
121
122    /// Prepare a batch of encodings for inference, returning input tensors on device.
123    ///
124    /// Pads all sequences to `max_seq` and uploads `input_ids`, `attention_mask`,
125    /// `token_type_ids`, `position_ids`, and a float attention mask to device memory.
126    fn prepare_batch(
127        &self,
128        encodings: &[Encoding],
129        max_seq: usize,
130    ) -> crate::Result<BatchInputs<Self::Tensor>>;
131
132    /// Prepare a batch WITHOUT padding — concatenate all tokens flat.
133    ///
134    /// Returns `BatchInputs` with `total_tokens` actual tokens (no padding),
135    /// `cu_seqlens` for attention boundaries, and per-token position IDs.
136    /// Linear layers (GEMM, LN, GELU) process `total_tokens` rows.
137    /// Attention must pad/unpad around the per-head operations.
138    fn prepare_batch_unpadded(
139        &self,
140        encodings: &[Encoding],
141    ) -> crate::Result<BatchInputs<Self::Tensor>> {
142        // Default: fall back to padded (backends override for unpadded support)
143        let max_seq = encodings
144            .iter()
145            .map(|e| e.input_ids.len())
146            .max()
147            .unwrap_or(0)
148            .next_multiple_of(8);
149        self.prepare_batch(encodings, max_seq)
150    }
151
152    /// Scatter flat `[total_tokens, dim]` tensor into padded `[batch, max_seq, dim]`.
153    ///
154    /// Used before attention: linear layers produce unpadded output, but the
155    /// QKV split + batched attention GEMM need aligned `[batch*heads, seq, head_dim]`.
156    /// Padding positions are zeroed.
157    fn pad_to_batch(
158        &self,
159        flat: &Self::Tensor,
160        padded: &mut Self::Tensor,
161        seq_lengths: &[usize],
162        max_seq: usize,
163        dim: usize,
164    ) -> crate::Result<()>;
165
166    /// Gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
167    ///
168    /// Used after attention: extracts only the real tokens, discarding padding.
169    fn unpad_from_batch(
170        &self,
171        padded: &Self::Tensor,
172        flat: &mut Self::Tensor,
173        seq_lengths: &[usize],
174        max_seq: usize,
175        dim: usize,
176    ) -> crate::Result<()>;
177
178    // --- Embedding operations ---
179
180    /// Word/position/token-type embedding lookup via gather.
181    ///
182    /// Reads `seq_len` token IDs from `word_ids`, gathers rows from
183    /// `embedding_table`, and writes `[seq_len, hidden]` floats to the result.
184    fn embedding_lookup(
185        &self,
186        word_ids: &Self::Tensor,
187        embedding_table: &Self::Tensor,
188        seq_len: usize,
189        hidden: usize,
190    ) -> crate::Result<Self::Tensor>;
191
192    /// Element-wise add an embedding table lookup into `hidden`.
193    ///
194    /// Used for position and token-type embeddings:
195    /// `hidden[i] += table[ids[i]]` for each token position.
196    fn add_embeddings(
197        &self,
198        hidden: &mut Self::Tensor,
199        table: &Self::Tensor,
200        ids: &Self::Tensor,
201        seq_len: usize,
202        hidden_dim: usize,
203    ) -> crate::Result<()>;
204
205    // --- Normalization ---
206
207    /// Layer normalization: `output = (input - mean) / sqrt(var + eps) * weight + bias`.
208    fn layer_norm(
209        &self,
210        output: &mut Self::Tensor,
211        input: &Self::Tensor,
212        weight: &Self::Tensor,
213        bias: &Self::Tensor,
214        rows: usize,
215        cols: usize,
216        eps: f32,
217    ) -> crate::Result<()>;
218
219    // --- Linear algebra ---
220
221    /// General matrix multiply: `output = A * B` (or `A * B^T` if `transpose_b`).
222    ///
223    /// Dimensions: A is `[m, k]`, B is `[k, n]` (or `[n, k]` if transposed),
224    /// output is `[m, n]`.
225    fn gemm(
226        &self,
227        a: &Self::Tensor,
228        b: &Self::Tensor,
229        output: &mut Self::Tensor,
230        m: usize,
231        n: usize,
232        k: usize,
233        transpose_b: bool,
234    ) -> crate::Result<()>;
235
236    /// Batched GEMM for multi-head attention.
237    ///
238    /// Performs `batch_count` independent GEMMs with strided access into
239    /// contiguous buffers. Used for per-head Q*K^T and attn*V.
240    fn gemm_batched(
241        &self,
242        a: &Self::Tensor,
243        b: &Self::Tensor,
244        output: &mut Self::Tensor,
245        m: usize,
246        n: usize,
247        k: usize,
248        transpose_b: bool,
249        stride_a: usize,
250        stride_b: usize,
251        stride_c: usize,
252        batch_count: usize,
253    ) -> crate::Result<()>;
254
255    // --- Attention ---
256
257    /// Fused scale + mask + softmax for attention scores.
258    ///
259    /// `scores = softmax(scores * scale + mask)` computed per-head.
260    fn fused_scale_mask_softmax(
261        &self,
262        scores: &mut Self::Tensor,
263        mask: &Self::Tensor,
264        batch: usize,
265        num_heads: usize,
266        seq_len: usize,
267        scale: f32,
268    ) -> crate::Result<()>;
269
270    /// Fused scale + mask + sliding window + softmax for attention scores.
271    ///
272    /// Like [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax) but
273    /// additionally masks out positions where `|query_pos - key_pos| > window_size / 2`.
274    /// Used by `ModernBERT`'s local attention layers.
275    fn fused_scale_mask_softmax_windowed(
276        &self,
277        scores: &mut Self::Tensor,
278        mask: &Self::Tensor,
279        batch: usize,
280        num_heads: usize,
281        seq_len: usize,
282        scale: f32,
283        window_size: usize,
284    ) -> crate::Result<()>;
285
286    /// Build a float attention mask from an integer mask.
287    ///
288    /// Converts `[batch * seq]` int mask (0/1) to `[batch * seq]` float mask
289    /// (0.0 / -10000.0) for use with [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax).
290    fn build_attn_mask(
291        &self,
292        output: &mut Self::Tensor,
293        int_mask: &Self::Tensor,
294        n: usize,
295    ) -> crate::Result<()>;
296
297    /// Split a fused QKV projection into separate Q, K, V tensors.
298    fn qkv_split(
299        &self,
300        q: &mut Self::Tensor,
301        k: &mut Self::Tensor,
302        v: &mut Self::Tensor,
303        qkv: &Self::Tensor,
304        batch: usize,
305        seq: usize,
306        hidden: usize,
307        num_heads: usize,
308        head_dim: usize,
309    ) -> crate::Result<()>;
310
311    // --- Banded (local/sliding-window) attention ---
312
313    /// Banded Q@K^T: compute attention scores only within a sliding window.
314    ///
315    /// Output shape: `[batch * num_heads, seq, window]` (NOT `[seq, seq]`).
316    /// `scores[h, i, w]` = dot(Q[h, i, :], K[h, i - window/2 + w, :])
317    /// where out-of-bounds positions are set to `-inf` (masked in softmax).
318    ///
319    /// Reduces attention compute from O(seq²) to O(seq × window).
320    /// For `seq=512, window=128`: **4× less compute** per local layer.
321    fn banded_qk(
322        &self,
323        q: &Self::Tensor,
324        k: &Self::Tensor,
325        scores: &mut Self::Tensor,
326        batch_heads: usize,
327        seq: usize,
328        head_dim: usize,
329        window: usize,
330        stride_qk: usize,
331        stride_scores: usize,
332    ) -> crate::Result<()>;
333
334    /// Banded scores@V: weighted sum using banded attention scores.
335    ///
336    /// Input scores: `[batch * num_heads, seq, window]` (from `banded_qk`).
337    /// Output: `[batch * num_heads, seq, head_dim]`.
338    /// `output[h, i, d]` = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
339    fn banded_sv(
340        &self,
341        scores: &Self::Tensor,
342        v: &Self::Tensor,
343        output: &mut Self::Tensor,
344        batch_heads: usize,
345        seq: usize,
346        head_dim: usize,
347        window: usize,
348        stride_scores: usize,
349        stride_v: usize,
350        stride_out: usize,
351    ) -> crate::Result<()>;
352
353    /// Fused scale + softmax over the window dimension (no padding mask needed).
354    ///
355    /// Operates on `[batch * num_heads * seq, window]` rows.
356    fn banded_softmax(
357        &self,
358        scores: &mut Self::Tensor,
359        total_rows: usize,
360        window: usize,
361        scale: f32,
362    ) -> crate::Result<()>;
363
364    /// Reshape attention output from `[batch, num_heads, seq, head_dim]` to
365    /// `[batch * seq, hidden]`.
366    fn attn_reshape(
367        &self,
368        output: &mut Self::Tensor,
369        input: &Self::Tensor,
370        batch: usize,
371        seq: usize,
372        num_heads: usize,
373        head_dim: usize,
374    ) -> crate::Result<()>;
375
376    /// Apply Rotary Position Embedding (RoPE) to Q/K tensors.
377    ///
378    /// Used by ModernBERT (not ClassicBert which uses learned position embeddings).
379    fn apply_rope(
380        &self,
381        qk: &mut Self::Tensor,
382        cos: &Self::Tensor,
383        sin: &Self::Tensor,
384        num_rows: usize,
385        seq_len: usize,
386        head_dim: usize,
387        num_heads: usize,
388    ) -> crate::Result<()>;
389
390    // --- Tensor manipulation ---
391
392    /// Split a `[rows, 2*cols]` matrix into two `[rows, cols]` halves.
393    ///
394    /// Each row of `input` is `[first_half | second_half]`. The first `cols`
395    /// elements go to `first`, the remaining `cols` to `second`.
396    /// Used by `ModernBERT` for gated MLP splits.
397    fn split_gate_value(
398        &self,
399        first: &mut Self::Tensor,
400        second: &mut Self::Tensor,
401        input: &Self::Tensor,
402        rows: usize,
403        cols: usize,
404    ) -> crate::Result<()>;
405
406    // --- Activations ---
407
408    /// GELU activation (Gaussian Error Linear Unit), applied in-place.
409    fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
410
411    /// SwiGLU gated activation: `output = value * silu(gate)`.
412    ///
413    /// The gate and value come from splitting the intermediate projection.
414    fn swiglu(
415        &self,
416        value: &Self::Tensor,
417        gate: &Self::Tensor,
418        output: &mut Self::Tensor,
419        n: usize,
420    ) -> crate::Result<()>;
421
422    /// `GeGLU` gated activation: `output = gelu(value) * gate`.
423    ///
424    /// Used by `ModernBERT`. The value and gate come from splitting the
425    /// MLP `Wi` projection output in half.
426    fn geglu(
427        &self,
428        value: &Self::Tensor,
429        gate: &Self::Tensor,
430        output: &mut Self::Tensor,
431        n: usize,
432    ) -> crate::Result<()>;
433
434    /// Fused bias + GELU: `x = gelu(x + bias)` row-wise.
435    fn fused_bias_gelu(
436        &self,
437        x: &mut Self::Tensor,
438        bias: &Self::Tensor,
439        rows: usize,
440        cols: usize,
441    ) -> crate::Result<()>;
442
443    // --- Fused residual operations ---
444
445    /// Fused bias + residual add: `output = input + bias + residual`.
446    ///
447    /// Bias is broadcast row-wise (`cols`-wide) across `n / cols` rows.
448    fn fused_bias_residual(
449        &self,
450        output: &mut Self::Tensor,
451        input: &Self::Tensor,
452        bias: &Self::Tensor,
453        residual: &Self::Tensor,
454        n: usize,
455        cols: usize,
456    ) -> crate::Result<()>;
457
458    /// Fused residual add + layer normalization.
459    ///
460    /// `output = layer_norm(hidden + residual, weight, bias, eps)`.
461    fn fused_residual_layernorm(
462        &self,
463        output: &mut Self::Tensor,
464        hidden: &Self::Tensor,
465        residual: &Self::Tensor,
466        weight: &Self::Tensor,
467        bias: &Self::Tensor,
468        rows: usize,
469        cols: usize,
470        eps: f32,
471    ) -> crate::Result<()>;
472
473    /// Residual add without bias: `output = hidden + residual`.
474    ///
475    /// Used by `ModernBERT` which has no bias terms.
476    fn residual_add(
477        &self,
478        output: &mut Self::Tensor,
479        hidden: &Self::Tensor,
480        residual: &Self::Tensor,
481        n: usize,
482    ) -> crate::Result<()>;
483
484    /// Add bias to a matrix row-wise: `x[row] += bias` for each row.
485    fn add_bias(
486        &self,
487        x: &mut Self::Tensor,
488        bias: &Self::Tensor,
489        rows: usize,
490        cols: usize,
491    ) -> crate::Result<()>;
492
493    // --- Pooling ---
494
495    /// CLS pooling: extract the first token's hidden state per batch element.
496    fn cls_pool(
497        &self,
498        output: &mut Self::Tensor,
499        hidden: &Self::Tensor,
500        batch: usize,
501        seq: usize,
502        hidden_dim: usize,
503    ) -> crate::Result<()>;
504
505    /// Mean pooling: attention-mask-weighted average of hidden states.
506    fn mean_pool(
507        &self,
508        output: &mut Self::Tensor,
509        hidden: &Self::Tensor,
510        mask: &Self::Tensor,
511        batch: usize,
512        seq: usize,
513        hidden_dim: usize,
514    ) -> crate::Result<()>;
515
516    // --- Post-processing ---
517
518    /// L2-normalize each row vector in-place.
519    fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
520
521    /// Copy tensor data back to host memory as `Vec<Vec<f32>>`.
522    ///
523    /// Returns one `Vec<f32>` of length `dim` per batch element.
524    fn to_host(
525        &self,
526        tensor: &Self::Tensor,
527        batch: usize,
528        dim: usize,
529    ) -> crate::Result<Vec<Vec<f32>>>;
530
531    // =======================================================================
532    // FP16 operations for full half-precision pipeline
533    //
534    // These methods mirror the FP32 counterparts but operate on FP16 tensors.
535    // Internal reductions (softmax, layer-norm) use FP32 accumulators but
536    // all tensor I/O is half precision. Default implementations return an
537    // error — only backends with FP16 support override them.
538    // =======================================================================
539
540    /// Allocate a zero-initialized FP16 tensor with `n` half-precision elements.
541    ///
542    /// # Errors
543    ///
544    /// Returns an error if device memory allocation fails or FP16 is unsupported.
545    fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
546        Err(crate::Error::Metal(
547            "FP16 not supported by this driver".into(),
548        ))
549    }
550
551    /// Convert FP32 tensor to FP16 (element-wise narrowing).
552    fn f32_to_f16(
553        &self,
554        _output: &mut Self::Tensor,
555        _input: &Self::Tensor,
556        _n: usize,
557    ) -> crate::Result<()> {
558        Err(crate::Error::Metal(
559            "FP16 not supported by this driver".into(),
560        ))
561    }
562
563    /// Convert FP16 tensor back to FP32 (element-wise widening).
564    fn f16_to_f32(
565        &self,
566        _output: &mut Self::Tensor,
567        _input: &Self::Tensor,
568        _n: usize,
569    ) -> crate::Result<()> {
570        Err(crate::Error::Metal(
571            "FP16 not supported by this driver".into(),
572        ))
573    }
574
575    /// Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
576    fn gemm_mixed(
577        &self,
578        _a_f16: &Self::Tensor,
579        _b_f16: &Self::Tensor,
580        _output_f32: &mut Self::Tensor,
581        _m: usize,
582        _n: usize,
583        _k: usize,
584        _transpose_b: bool,
585    ) -> crate::Result<()> {
586        Err(crate::Error::Metal(
587            "gemm_mixed not supported by this driver".into(),
588        ))
589    }
590
591    /// FP16 GEMM: `output = A * B` (or `A * B^T`). All tensors are half.
592    fn gemm_f16(
593        &self,
594        _a: &Self::Tensor,
595        _b: &Self::Tensor,
596        _output: &mut Self::Tensor,
597        _m: usize,
598        _n: usize,
599        _k: usize,
600        _transpose_b: bool,
601    ) -> crate::Result<()> {
602        Err(crate::Error::Metal(
603            "FP16 not supported by this driver".into(),
604        ))
605    }
606
607    /// FP16 batched GEMM for multi-head attention. All tensors are half.
608    #[expect(
609        clippy::too_many_arguments,
610        reason = "matches FP32 gemm_batched signature"
611    )]
612    fn gemm_batched_f16(
613        &self,
614        _a: &Self::Tensor,
615        _b: &Self::Tensor,
616        _output: &mut Self::Tensor,
617        _m: usize,
618        _n: usize,
619        _k: usize,
620        _transpose_b: bool,
621        _stride_a: usize,
622        _stride_b: usize,
623        _stride_c: usize,
624        _batch_count: usize,
625    ) -> crate::Result<()> {
626        Err(crate::Error::Metal(
627            "FP16 not supported by this driver".into(),
628        ))
629    }
630
631    /// FP16 layer normalization. Half I/O, FP32 reductions.
632    fn layer_norm_f16(
633        &self,
634        _output: &mut Self::Tensor,
635        _input: &Self::Tensor,
636        _weight: &Self::Tensor,
637        _bias: &Self::Tensor,
638        _rows: usize,
639        _cols: usize,
640        _eps: f32,
641    ) -> crate::Result<()> {
642        Err(crate::Error::Metal(
643            "FP16 not supported by this driver".into(),
644        ))
645    }
646
647    /// FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
648    fn fused_scale_mask_softmax_f16(
649        &self,
650        _scores: &mut Self::Tensor,
651        _mask: &Self::Tensor,
652        _batch: usize,
653        _num_heads: usize,
654        _seq_len: usize,
655        _scale: f32,
656    ) -> crate::Result<()> {
657        Err(crate::Error::Metal(
658            "FP16 not supported by this driver".into(),
659        ))
660    }
661
662    /// FP16 fused scale + mask + sliding window + softmax.
663    fn fused_scale_mask_softmax_windowed_f16(
664        &self,
665        _scores: &mut Self::Tensor,
666        _mask: &Self::Tensor,
667        _batch: usize,
668        _num_heads: usize,
669        _seq_len: usize,
670        _scale: f32,
671        _window_size: usize,
672    ) -> crate::Result<()> {
673        Err(crate::Error::Metal(
674            "FP16 not supported by this driver".into(),
675        ))
676    }
677
678    /// FP16 QKV split: `[batch*seq, 3*hidden]` into Q, K, V per-head layout.
679    fn qkv_split_f16(
680        &self,
681        _q: &mut Self::Tensor,
682        _k: &mut Self::Tensor,
683        _v: &mut Self::Tensor,
684        _qkv: &Self::Tensor,
685        _batch: usize,
686        _seq: usize,
687        _hidden: usize,
688        _num_heads: usize,
689        _head_dim: usize,
690    ) -> crate::Result<()> {
691        Err(crate::Error::Metal(
692            "FP16 not supported by this driver".into(),
693        ))
694    }
695
696    /// FP16 attention output reshape: `[batch*num_heads, seq, head_dim]` to
697    /// `[batch*seq, hidden]`.
698    fn attn_reshape_f16(
699        &self,
700        _output: &mut Self::Tensor,
701        _input: &Self::Tensor,
702        _batch: usize,
703        _seq: usize,
704        _num_heads: usize,
705        _head_dim: usize,
706    ) -> crate::Result<()> {
707        Err(crate::Error::Metal(
708            "FP16 not supported by this driver".into(),
709        ))
710    }
711
712    /// FP16 scatter flat `[total_tokens, dim]` to padded `[batch, max_seq, dim]`.
713    fn pad_to_batch_f16(
714        &self,
715        _flat: &Self::Tensor,
716        _padded: &mut Self::Tensor,
717        _seq_lengths: &[usize],
718        _max_seq: usize,
719        _dim: usize,
720    ) -> crate::Result<()> {
721        Err(crate::Error::Metal(
722            "FP16 not supported by this driver".into(),
723        ))
724    }
725
726    /// FP16 gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
727    fn unpad_from_batch_f16(
728        &self,
729        _padded: &Self::Tensor,
730        _flat: &mut Self::Tensor,
731        _seq_lengths: &[usize],
732        _max_seq: usize,
733        _dim: usize,
734    ) -> crate::Result<()> {
735        Err(crate::Error::Metal(
736            "FP16 not supported by this driver".into(),
737        ))
738    }
739
740    /// FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
741    fn rope_encode_f16(
742        &self,
743        _qk: &mut Self::Tensor,
744        _cos: &Self::Tensor,
745        _sin: &Self::Tensor,
746        _num_rows: usize,
747        _seq_len: usize,
748        _head_dim: usize,
749        _num_heads: usize,
750    ) -> crate::Result<()> {
751        Err(crate::Error::Metal(
752            "FP16 not supported by this driver".into(),
753        ))
754    }
755
756    /// FP16 `GeGLU` gated activation: `output = gelu(value) * gate`. Half I/O.
757    fn geglu_f16(
758        &self,
759        _value: &Self::Tensor,
760        _gate: &Self::Tensor,
761        _output: &mut Self::Tensor,
762        _n: usize,
763    ) -> crate::Result<()> {
764        Err(crate::Error::Metal(
765            "FP16 not supported by this driver".into(),
766        ))
767    }
768
769    /// FP16 fused residual add + layer normalization.
770    fn fused_residual_layernorm_f16(
771        &self,
772        _output: &mut Self::Tensor,
773        _hidden: &Self::Tensor,
774        _residual: &Self::Tensor,
775        _weight: &Self::Tensor,
776        _bias: &Self::Tensor,
777        _rows: usize,
778        _cols: usize,
779        _eps: f32,
780    ) -> crate::Result<()> {
781        Err(crate::Error::Metal(
782            "FP16 not supported by this driver".into(),
783        ))
784    }
785
786    /// FP16 residual add (no bias): `output = hidden + residual`.
787    fn residual_add_f16(
788        &self,
789        _output: &mut Self::Tensor,
790        _hidden: &Self::Tensor,
791        _residual: &Self::Tensor,
792        _n: usize,
793    ) -> crate::Result<()> {
794        Err(crate::Error::Metal(
795            "FP16 not supported by this driver".into(),
796        ))
797    }
798
799    /// FP16 split `[rows, 2*cols]` into two `[rows, cols]` halves.
800    fn split_gate_value_f16(
801        &self,
802        _first: &mut Self::Tensor,
803        _second: &mut Self::Tensor,
804        _input: &Self::Tensor,
805        _rows: usize,
806        _cols: usize,
807    ) -> crate::Result<()> {
808        Err(crate::Error::Metal(
809            "FP16 not supported by this driver".into(),
810        ))
811    }
812
813    /// Fused split + `GeGLU`: read `[rows, 2*cols]`, write `[rows, cols]`.
814    ///
815    /// Combines [`split_gate_value_f16`](Driver::split_gate_value_f16) and
816    /// [`geglu_f16`](Driver::geglu_f16) into a single kernel, eliminating
817    /// two intermediate `[rows, cols]` buffers and halving HBM round-trips.
818    ///
819    /// Default falls back to separate split + geglu calls.
820    fn fused_split_geglu_f16(
821        &self,
822        output: &mut Self::Tensor,
823        input: &Self::Tensor,
824        rows: usize,
825        cols: usize,
826    ) -> crate::Result<()> {
827        // Default: allocate intermediates and call separately.
828        let n = rows * cols;
829        let mut value = self.alloc_zeros_f16(n)?;
830        let mut gate = self.alloc_zeros_f16(n)?;
831        self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
832        self.geglu_f16(&value, &gate, output, n)
833    }
834
835    /// Fused pad + QKV split: flat `[total_tokens, 3*hidden]` → Q, K, V
836    /// each `[batch*heads, max_seq, head_dim]`.
837    ///
838    /// Eliminates the padded intermediate buffer. Default calls pad then split.
839    #[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
840    fn fused_pad_qkv_split_f16(
841        &self,
842        q: &mut Self::Tensor,
843        k: &mut Self::Tensor,
844        v: &mut Self::Tensor,
845        qkv_flat: &Self::Tensor,
846        seq_lengths: &[usize],
847        max_seq: usize,
848        batch: usize,
849        hidden: usize,
850        num_heads: usize,
851        head_dim: usize,
852    ) -> crate::Result<()> {
853        // Default: pad then split.
854        let padded_tokens = batch * max_seq;
855        let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
856        self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
857        self.qkv_split_f16(
858            q,
859            k,
860            v,
861            &qkv_padded,
862            batch,
863            max_seq,
864            hidden,
865            num_heads,
866            head_dim,
867        )
868    }
869
870    /// Fused attn_reshape + unpad: `[batch*heads, max_seq, head_dim]` →
871    /// `[total_tokens, hidden]`.
872    ///
873    /// Eliminates the padded context intermediate. Default calls reshape then unpad.
874    fn fused_reshape_unpad_f16(
875        &self,
876        flat: &mut Self::Tensor,
877        heads: &Self::Tensor,
878        seq_lengths: &[usize],
879        max_seq: usize,
880        batch: usize,
881        num_heads: usize,
882        head_dim: usize,
883    ) -> crate::Result<()> {
884        // Default: reshape then unpad.
885        let hidden = num_heads * head_dim;
886        let padded_tokens = batch * max_seq;
887        let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
888        self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
889        self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
890    }
891}
892
893/// Batch input tensors on device, produced by [`Driver::prepare_batch`].
894///
895/// Supports both padded and unpadded modes:
896/// - **Padded**: all sequences padded to `max_seq`. `cu_seqlens` is `None`.
897/// - **Unpadded**: sequences concatenated without padding. `cu_seqlens`
898///   contains cumulative lengths `[0, len0, len0+len1, ...]` so attention
899///   knows where each sequence starts. Eliminates ALL padding compute.
900pub struct BatchInputs<T> {
901    /// Token IDs — `[batch * max_seq]` (padded) or `[total_tokens]` (unpadded).
902    pub input_ids: T,
903    /// Attention mask `[batch * max_seq]` as int32 (0 or 1). Unused in unpadded mode.
904    pub attention_mask: T,
905    /// Token type IDs — same layout as `input_ids`.
906    pub token_type_ids: T,
907    /// Position IDs — same layout as `input_ids`.
908    pub position_ids: T,
909    /// Float attention bias mask `[batch * max_seq]` (0.0 or -1e9) for softmax.
910    pub float_mask: T,
911    /// Float pooling mask `[batch * max_seq]` (1.0 or 0.0) for mean pooling.
912    pub pooling_mask: T,
913    /// Number of sequences in this batch.
914    pub batch: usize,
915    /// Maximum sequence length (all sequences padded to this). In unpadded mode,
916    /// this is the longest sequence (used for workspace sizing, not padding).
917    pub max_seq: usize,
918    /// Total actual tokens across all sequences (no padding).
919    pub total_tokens: usize,
920    /// Per-sequence lengths: `[batch]` — each element is the actual token count.
921    pub seq_lengths: Vec<usize>,
922    /// Cumulative sequence lengths for unpadded attention: `[batch + 1]`.
923    /// `cu_seqlens[i]..cu_seqlens[i+1]` is the token range for sequence `i`.
924    /// `None` in padded mode (all sequences padded to max_seq).
925    pub cu_seqlens: Option<Vec<usize>>,
926}