ripvec_core/backend/driver/mod.rs
1//! Hardware-agnostic compute driver trait.
2//!
3//! The [`Driver`] trait exposes low-level compute primitives (GEMM, layer-norm,
4//! activations, etc.) that each hardware backend implements. Model architectures
5//! are generic over `D: Driver` and compose these primitives into a forward pass.
6//!
7//! # Design
8//!
9//! - **Associated type `Tensor`**: each driver defines its own opaque tensor
10//! handle (Metal: buffer+offset, CUDA: device pointer, CPU: ndarray).
11//! - **Not object-safe**: architectures use `D: Driver` generics so the compiler
12//! can monomorphize and inline driver calls.
13//! - **Send + Sync**: drivers are shared across the pipeline.
14
15#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
16pub mod cpu;
17#[cfg(feature = "cuda")]
18pub mod cuda;
19#[cfg(feature = "metal")]
20pub mod metal;
21#[cfg(feature = "mlx")]
22pub mod mlx;
23
24use super::Encoding;
25
26/// Hardware-agnostic compute primitives for BERT inference.
27///
28/// Each method corresponds to one operation in the forward pass. Drivers handle
29/// memory allocation, kernel dispatch, and synchronization. Architectures
30/// compose these primitives via the [`super::arch::ModelArch`] trait.
31pub trait Driver: Send + Sync {
32 /// Opaque tensor handle.
33 ///
34 /// Metal: `MTLBuffer` + byte offset. CUDA: `CUdeviceptr`. CPU: `Array2<f32>`.
35 type Tensor;
36
37 /// Short human-readable label for diagnostics (e.g. "Metal", "CUDA", "CPU").
38 /// Surfaced via [`super::EmbedBackend::name`].
39 fn name(&self) -> &'static str;
40
41 /// Create a new driver instance for a cloned worker thread.
42 ///
43 /// CPU drivers are zero-size and always succeed. GPU drivers typically
44 /// cannot be cloned this way (they share device state) and should leave
45 /// the default panic implementation.
46 fn new_for_clone() -> crate::Result<Self>
47 where
48 Self: Sized,
49 {
50 Err(crate::Error::Other(anyhow::anyhow!(
51 "this driver does not support cloning"
52 )))
53 }
54
55 // --- Batching ---
56
57 /// Begin batched mode: all subsequent operations encode into one dispatch.
58 ///
59 /// GPU drivers accumulate into a single command buffer; CPU is a no-op.
60 /// Call [`Self::end_batch`] to commit. This eliminates per-call overhead.
61 fn begin_batch(&self) -> crate::Result<()> {
62 Ok(())
63 }
64
65 /// End batched mode: commit all accumulated operations and wait.
66 fn end_batch(&self) -> crate::Result<()> {
67 Ok(())
68 }
69
70 /// Flush the current command buffer and start a new one, preserving pool
71 /// state. Use mid-forward-pass to prevent GPU timeouts on deep models.
72 fn flush_batch(&self) -> crate::Result<()> {
73 Ok(())
74 }
75
76 /// Close and reopen the compute encoder within the same command buffer.
77 ///
78 /// This segments a long sequence of compute dispatches into multiple
79 /// encoders without committing or waiting. Metal processes encoders
80 /// back-to-back from the same CB — zero sync overhead.
81 ///
82 /// Use every few layers to prevent encoder state overflow (>~60 dispatches
83 /// per encoder can cause hangs on some Apple Silicon GPUs).
84 fn segment_encoder(&self) {
85 // No-op for non-Metal backends
86 }
87
88 /// Save the current pool cursor position. Call BEFORE a layer's work.
89 fn save_pool_cursor(&self) -> usize {
90 0
91 }
92
93 /// Restore the pool cursor to a previously saved position. Call AFTER
94 /// a layer's transient tensors have been dropped (out of scope).
95 ///
96 /// The architecture must ensure only the output tensor (`hidden_states`)
97 /// survives — all layer-internal tensors (qkv, scores, context, etc.)
98 /// must be dropped before this call so their pool slots can be recycled.
99 fn restore_pool_cursor(&self, _saved: usize) {}
100
101 // --- Allocation ---
102
103 /// Allocate a zero-initialized tensor with `n` float elements on device.
104 ///
105 /// Used by architectures to create workspace buffers (QKV projections,
106 /// attention scores, intermediate activations, etc.).
107 ///
108 /// # Errors
109 ///
110 /// Returns an error if device memory allocation fails.
111 fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
112
113 /// Clone a tensor, producing an independent copy of the data.
114 ///
115 /// Used when an operation needs both the original and a mutable output
116 /// referencing the same logical data (e.g., in-place layer normalization
117 /// where input == output).
118 ///
119 /// # Errors
120 ///
121 /// Returns an error if device memory allocation or the copy fails.
122 fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
123
124 // --- Batch preparation ---
125
126 /// Prepare a batch of encodings for inference, returning input tensors on device.
127 ///
128 /// Pads all sequences to `max_seq` and uploads `input_ids`, `attention_mask`,
129 /// `token_type_ids`, `position_ids`, and a float attention mask to device memory.
130 fn prepare_batch(
131 &self,
132 encodings: &[Encoding],
133 max_seq: usize,
134 ) -> crate::Result<BatchInputs<Self::Tensor>>;
135
136 /// Prepare a batch WITHOUT padding — concatenate all tokens flat.
137 ///
138 /// Returns `BatchInputs` with `total_tokens` actual tokens (no padding),
139 /// `cu_seqlens` for attention boundaries, and per-token position IDs.
140 /// Linear layers (GEMM, LN, GELU) process `total_tokens` rows.
141 /// Attention must pad/unpad around the per-head operations.
142 fn prepare_batch_unpadded(
143 &self,
144 encodings: &[Encoding],
145 ) -> crate::Result<BatchInputs<Self::Tensor>> {
146 // Default: fall back to padded (backends override for unpadded support)
147 let max_seq = encodings
148 .iter()
149 .map(|e| e.input_ids.len())
150 .max()
151 .unwrap_or(0)
152 .next_multiple_of(8);
153 self.prepare_batch(encodings, max_seq)
154 }
155
156 /// Scatter flat `[total_tokens, dim]` tensor into padded `[batch, max_seq, dim]`.
157 ///
158 /// Used before attention: linear layers produce unpadded output, but the
159 /// QKV split + batched attention GEMM need aligned `[batch*heads, seq, head_dim]`.
160 /// Padding positions are zeroed.
161 fn pad_to_batch(
162 &self,
163 flat: &Self::Tensor,
164 padded: &mut Self::Tensor,
165 seq_lengths: &[usize],
166 max_seq: usize,
167 dim: usize,
168 ) -> crate::Result<()>;
169
170 /// Gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
171 ///
172 /// Used after attention: extracts only the real tokens, discarding padding.
173 fn unpad_from_batch(
174 &self,
175 padded: &Self::Tensor,
176 flat: &mut Self::Tensor,
177 seq_lengths: &[usize],
178 max_seq: usize,
179 dim: usize,
180 ) -> crate::Result<()>;
181
182 // --- Embedding operations ---
183
184 /// Word/position/token-type embedding lookup via gather.
185 ///
186 /// Reads `seq_len` token IDs from `word_ids`, gathers rows from
187 /// `embedding_table`, and writes `[seq_len, hidden]` floats to the result.
188 fn embedding_lookup(
189 &self,
190 word_ids: &Self::Tensor,
191 embedding_table: &Self::Tensor,
192 seq_len: usize,
193 hidden: usize,
194 ) -> crate::Result<Self::Tensor>;
195
196 /// Element-wise add an embedding table lookup into `hidden`.
197 ///
198 /// Used for position and token-type embeddings:
199 /// `hidden[i] += table[ids[i]]` for each token position.
200 fn add_embeddings(
201 &self,
202 hidden: &mut Self::Tensor,
203 table: &Self::Tensor,
204 ids: &Self::Tensor,
205 seq_len: usize,
206 hidden_dim: usize,
207 ) -> crate::Result<()>;
208
209 // --- Normalization ---
210
211 /// Layer normalization: `output = (input - mean) / sqrt(var + eps) * weight + bias`.
212 fn layer_norm(
213 &self,
214 output: &mut Self::Tensor,
215 input: &Self::Tensor,
216 weight: &Self::Tensor,
217 bias: &Self::Tensor,
218 rows: usize,
219 cols: usize,
220 eps: f32,
221 ) -> crate::Result<()>;
222
223 // --- Linear algebra ---
224
225 /// General matrix multiply: `output = A * B` (or `A * B^T` if `transpose_b`).
226 ///
227 /// Dimensions: A is `[m, k]`, B is `[k, n]` (or `[n, k]` if transposed),
228 /// output is `[m, n]`.
229 fn gemm(
230 &self,
231 a: &Self::Tensor,
232 b: &Self::Tensor,
233 output: &mut Self::Tensor,
234 m: usize,
235 n: usize,
236 k: usize,
237 transpose_b: bool,
238 ) -> crate::Result<()>;
239
240 /// Batched GEMM for multi-head attention.
241 ///
242 /// Performs `batch_count` independent GEMMs with strided access into
243 /// contiguous buffers. Used for per-head Q*K^T and attn*V.
244 fn gemm_batched(
245 &self,
246 a: &Self::Tensor,
247 b: &Self::Tensor,
248 output: &mut Self::Tensor,
249 m: usize,
250 n: usize,
251 k: usize,
252 transpose_b: bool,
253 stride_a: usize,
254 stride_b: usize,
255 stride_c: usize,
256 batch_count: usize,
257 ) -> crate::Result<()>;
258
259 // --- Attention ---
260
261 /// Fused scale + mask + softmax for attention scores.
262 ///
263 /// `scores = softmax(scores * scale + mask)` computed per-head.
264 fn fused_scale_mask_softmax(
265 &self,
266 scores: &mut Self::Tensor,
267 mask: &Self::Tensor,
268 batch: usize,
269 num_heads: usize,
270 seq_len: usize,
271 scale: f32,
272 ) -> crate::Result<()>;
273
274 /// Fused scale + mask + sliding window + softmax for attention scores.
275 ///
276 /// Like [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax) but
277 /// additionally masks out positions where `|query_pos - key_pos| > window_size / 2`.
278 /// Used by `ModernBERT`'s local attention layers.
279 fn fused_scale_mask_softmax_windowed(
280 &self,
281 scores: &mut Self::Tensor,
282 mask: &Self::Tensor,
283 batch: usize,
284 num_heads: usize,
285 seq_len: usize,
286 scale: f32,
287 window_size: usize,
288 ) -> crate::Result<()>;
289
290 /// Build a float attention mask from an integer mask.
291 ///
292 /// Converts `[batch * seq]` int mask (0/1) to `[batch * seq]` float mask
293 /// (0.0 / -10000.0) for use with [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax).
294 fn build_attn_mask(
295 &self,
296 output: &mut Self::Tensor,
297 int_mask: &Self::Tensor,
298 n: usize,
299 ) -> crate::Result<()>;
300
301 /// Split a fused QKV projection into separate Q, K, V tensors.
302 fn qkv_split(
303 &self,
304 q: &mut Self::Tensor,
305 k: &mut Self::Tensor,
306 v: &mut Self::Tensor,
307 qkv: &Self::Tensor,
308 batch: usize,
309 seq: usize,
310 hidden: usize,
311 num_heads: usize,
312 head_dim: usize,
313 ) -> crate::Result<()>;
314
315 // --- Banded (local/sliding-window) attention ---
316
317 /// Banded Q@K^T: compute attention scores only within a sliding window.
318 ///
319 /// Output shape: `[batch * num_heads, seq, window]` (NOT `[seq, seq]`).
320 /// `scores[h, i, w]` = dot(Q[h, i, :], K[h, i - window/2 + w, :])
321 /// where out-of-bounds positions are set to `-inf` (masked in softmax).
322 ///
323 /// Reduces attention compute from O(seq²) to O(seq × window).
324 /// For `seq=512, window=128`: **4× less compute** per local layer.
325 fn banded_qk(
326 &self,
327 q: &Self::Tensor,
328 k: &Self::Tensor,
329 scores: &mut Self::Tensor,
330 batch_heads: usize,
331 seq: usize,
332 head_dim: usize,
333 window: usize,
334 stride_qk: usize,
335 stride_scores: usize,
336 ) -> crate::Result<()>;
337
338 /// Banded scores@V: weighted sum using banded attention scores.
339 ///
340 /// Input scores: `[batch * num_heads, seq, window]` (from `banded_qk`).
341 /// Output: `[batch * num_heads, seq, head_dim]`.
342 /// `output[h, i, d]` = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
343 fn banded_sv(
344 &self,
345 scores: &Self::Tensor,
346 v: &Self::Tensor,
347 output: &mut Self::Tensor,
348 batch_heads: usize,
349 seq: usize,
350 head_dim: usize,
351 window: usize,
352 stride_scores: usize,
353 stride_v: usize,
354 stride_out: usize,
355 ) -> crate::Result<()>;
356
357 /// Fused scale + softmax over the window dimension (no padding mask needed).
358 ///
359 /// Operates on `[batch * num_heads * seq, window]` rows.
360 fn banded_softmax(
361 &self,
362 scores: &mut Self::Tensor,
363 total_rows: usize,
364 window: usize,
365 scale: f32,
366 ) -> crate::Result<()>;
367
368 /// Reshape attention output from `[batch, num_heads, seq, head_dim]` to
369 /// `[batch * seq, hidden]`.
370 fn attn_reshape(
371 &self,
372 output: &mut Self::Tensor,
373 input: &Self::Tensor,
374 batch: usize,
375 seq: usize,
376 num_heads: usize,
377 head_dim: usize,
378 ) -> crate::Result<()>;
379
380 /// Apply Rotary Position Embedding (RoPE) to Q/K tensors.
381 ///
382 /// Used by ModernBERT (not ClassicBert which uses learned position embeddings).
383 fn apply_rope(
384 &self,
385 qk: &mut Self::Tensor,
386 cos: &Self::Tensor,
387 sin: &Self::Tensor,
388 num_rows: usize,
389 seq_len: usize,
390 head_dim: usize,
391 num_heads: usize,
392 ) -> crate::Result<()>;
393
394 // --- Tensor manipulation ---
395
396 /// Split a `[rows, 2*cols]` matrix into two `[rows, cols]` halves.
397 ///
398 /// Each row of `input` is `[first_half | second_half]`. The first `cols`
399 /// elements go to `first`, the remaining `cols` to `second`.
400 /// Used by `ModernBERT` for gated MLP splits.
401 fn split_gate_value(
402 &self,
403 first: &mut Self::Tensor,
404 second: &mut Self::Tensor,
405 input: &Self::Tensor,
406 rows: usize,
407 cols: usize,
408 ) -> crate::Result<()>;
409
410 // --- Activations ---
411
412 /// GELU activation (Gaussian Error Linear Unit), applied in-place.
413 fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
414
415 /// SwiGLU gated activation: `output = value * silu(gate)`.
416 ///
417 /// The gate and value come from splitting the intermediate projection.
418 fn swiglu(
419 &self,
420 value: &Self::Tensor,
421 gate: &Self::Tensor,
422 output: &mut Self::Tensor,
423 n: usize,
424 ) -> crate::Result<()>;
425
426 /// `GeGLU` gated activation: `output = gelu(value) * gate`.
427 ///
428 /// Used by `ModernBERT`. The value and gate come from splitting the
429 /// MLP `Wi` projection output in half.
430 fn geglu(
431 &self,
432 value: &Self::Tensor,
433 gate: &Self::Tensor,
434 output: &mut Self::Tensor,
435 n: usize,
436 ) -> crate::Result<()>;
437
438 /// Fused bias + GELU: `x = gelu(x + bias)` row-wise.
439 fn fused_bias_gelu(
440 &self,
441 x: &mut Self::Tensor,
442 bias: &Self::Tensor,
443 rows: usize,
444 cols: usize,
445 ) -> crate::Result<()>;
446
447 // --- Fused residual operations ---
448
449 /// Fused bias + residual add: `output = input + bias + residual`.
450 ///
451 /// Bias is broadcast row-wise (`cols`-wide) across `n / cols` rows.
452 fn fused_bias_residual(
453 &self,
454 output: &mut Self::Tensor,
455 input: &Self::Tensor,
456 bias: &Self::Tensor,
457 residual: &Self::Tensor,
458 n: usize,
459 cols: usize,
460 ) -> crate::Result<()>;
461
462 /// Fused residual add + layer normalization.
463 ///
464 /// `output = layer_norm(hidden + residual, weight, bias, eps)`.
465 fn fused_residual_layernorm(
466 &self,
467 output: &mut Self::Tensor,
468 hidden: &Self::Tensor,
469 residual: &Self::Tensor,
470 weight: &Self::Tensor,
471 bias: &Self::Tensor,
472 rows: usize,
473 cols: usize,
474 eps: f32,
475 ) -> crate::Result<()>;
476
477 /// Residual add without bias: `output = hidden + residual`.
478 ///
479 /// Used by `ModernBERT` which has no bias terms.
480 fn residual_add(
481 &self,
482 output: &mut Self::Tensor,
483 hidden: &Self::Tensor,
484 residual: &Self::Tensor,
485 n: usize,
486 ) -> crate::Result<()>;
487
488 /// Add bias to a matrix row-wise: `x[row] += bias` for each row.
489 fn add_bias(
490 &self,
491 x: &mut Self::Tensor,
492 bias: &Self::Tensor,
493 rows: usize,
494 cols: usize,
495 ) -> crate::Result<()>;
496
497 // --- Pooling ---
498
499 /// CLS pooling: extract the first token's hidden state per batch element.
500 fn cls_pool(
501 &self,
502 output: &mut Self::Tensor,
503 hidden: &Self::Tensor,
504 batch: usize,
505 seq: usize,
506 hidden_dim: usize,
507 ) -> crate::Result<()>;
508
509 /// Mean pooling: attention-mask-weighted average of hidden states.
510 fn mean_pool(
511 &self,
512 output: &mut Self::Tensor,
513 hidden: &Self::Tensor,
514 mask: &Self::Tensor,
515 batch: usize,
516 seq: usize,
517 hidden_dim: usize,
518 ) -> crate::Result<()>;
519
520 // --- Post-processing ---
521
522 /// L2-normalize each row vector in-place.
523 fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
524
525 /// Copy tensor data back to host memory as `Vec<Vec<f32>>`.
526 ///
527 /// Returns one `Vec<f32>` of length `dim` per batch element.
528 fn to_host(
529 &self,
530 tensor: &Self::Tensor,
531 batch: usize,
532 dim: usize,
533 ) -> crate::Result<Vec<Vec<f32>>>;
534
535 // =======================================================================
536 // FP16 operations for full half-precision pipeline
537 //
538 // These methods mirror the FP32 counterparts but operate on FP16 tensors.
539 // Internal reductions (softmax, layer-norm) use FP32 accumulators but
540 // all tensor I/O is half precision. Default implementations return an
541 // error — only backends with FP16 support override them.
542 // =======================================================================
543
544 /// Allocate a zero-initialized FP16 tensor with `n` half-precision elements.
545 ///
546 /// # Errors
547 ///
548 /// Returns an error if device memory allocation fails or FP16 is unsupported.
549 fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
550 Err(crate::Error::Metal(
551 "FP16 not supported by this driver".into(),
552 ))
553 }
554
555 /// Convert FP32 tensor to FP16 (element-wise narrowing).
556 fn f32_to_f16(
557 &self,
558 _output: &mut Self::Tensor,
559 _input: &Self::Tensor,
560 _n: usize,
561 ) -> crate::Result<()> {
562 Err(crate::Error::Metal(
563 "FP16 not supported by this driver".into(),
564 ))
565 }
566
567 /// Convert FP16 tensor back to FP32 (element-wise widening).
568 fn f16_to_f32(
569 &self,
570 _output: &mut Self::Tensor,
571 _input: &Self::Tensor,
572 _n: usize,
573 ) -> crate::Result<()> {
574 Err(crate::Error::Metal(
575 "FP16 not supported by this driver".into(),
576 ))
577 }
578
579 /// Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
580 fn gemm_mixed(
581 &self,
582 _a_f16: &Self::Tensor,
583 _b_f16: &Self::Tensor,
584 _output_f32: &mut Self::Tensor,
585 _m: usize,
586 _n: usize,
587 _k: usize,
588 _transpose_b: bool,
589 ) -> crate::Result<()> {
590 Err(crate::Error::Metal(
591 "gemm_mixed not supported by this driver".into(),
592 ))
593 }
594
595 /// FP16 GEMM: `output = A * B` (or `A * B^T`). All tensors are half.
596 fn gemm_f16(
597 &self,
598 _a: &Self::Tensor,
599 _b: &Self::Tensor,
600 _output: &mut Self::Tensor,
601 _m: usize,
602 _n: usize,
603 _k: usize,
604 _transpose_b: bool,
605 ) -> crate::Result<()> {
606 Err(crate::Error::Metal(
607 "FP16 not supported by this driver".into(),
608 ))
609 }
610
611 /// FP16 batched GEMM for multi-head attention. All tensors are half.
612 #[expect(
613 clippy::too_many_arguments,
614 reason = "matches FP32 gemm_batched signature"
615 )]
616 fn gemm_batched_f16(
617 &self,
618 _a: &Self::Tensor,
619 _b: &Self::Tensor,
620 _output: &mut Self::Tensor,
621 _m: usize,
622 _n: usize,
623 _k: usize,
624 _transpose_b: bool,
625 _stride_a: usize,
626 _stride_b: usize,
627 _stride_c: usize,
628 _batch_count: usize,
629 ) -> crate::Result<()> {
630 Err(crate::Error::Metal(
631 "FP16 not supported by this driver".into(),
632 ))
633 }
634
635 /// FP16 layer normalization. Half I/O, FP32 reductions.
636 fn layer_norm_f16(
637 &self,
638 _output: &mut Self::Tensor,
639 _input: &Self::Tensor,
640 _weight: &Self::Tensor,
641 _bias: &Self::Tensor,
642 _rows: usize,
643 _cols: usize,
644 _eps: f32,
645 ) -> crate::Result<()> {
646 Err(crate::Error::Metal(
647 "FP16 not supported by this driver".into(),
648 ))
649 }
650
651 /// FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
652 fn fused_scale_mask_softmax_f16(
653 &self,
654 _scores: &mut Self::Tensor,
655 _mask: &Self::Tensor,
656 _batch: usize,
657 _num_heads: usize,
658 _seq_len: usize,
659 _scale: f32,
660 ) -> crate::Result<()> {
661 Err(crate::Error::Metal(
662 "FP16 not supported by this driver".into(),
663 ))
664 }
665
666 /// FP16 fused scale + mask + sliding window + softmax.
667 fn fused_scale_mask_softmax_windowed_f16(
668 &self,
669 _scores: &mut Self::Tensor,
670 _mask: &Self::Tensor,
671 _batch: usize,
672 _num_heads: usize,
673 _seq_len: usize,
674 _scale: f32,
675 _window_size: usize,
676 ) -> crate::Result<()> {
677 Err(crate::Error::Metal(
678 "FP16 not supported by this driver".into(),
679 ))
680 }
681
682 /// FP16 QKV split: `[batch*seq, 3*hidden]` into Q, K, V per-head layout.
683 fn qkv_split_f16(
684 &self,
685 _q: &mut Self::Tensor,
686 _k: &mut Self::Tensor,
687 _v: &mut Self::Tensor,
688 _qkv: &Self::Tensor,
689 _batch: usize,
690 _seq: usize,
691 _hidden: usize,
692 _num_heads: usize,
693 _head_dim: usize,
694 ) -> crate::Result<()> {
695 Err(crate::Error::Metal(
696 "FP16 not supported by this driver".into(),
697 ))
698 }
699
700 /// FP16 attention output reshape: `[batch*num_heads, seq, head_dim]` to
701 /// `[batch*seq, hidden]`.
702 fn attn_reshape_f16(
703 &self,
704 _output: &mut Self::Tensor,
705 _input: &Self::Tensor,
706 _batch: usize,
707 _seq: usize,
708 _num_heads: usize,
709 _head_dim: usize,
710 ) -> crate::Result<()> {
711 Err(crate::Error::Metal(
712 "FP16 not supported by this driver".into(),
713 ))
714 }
715
716 /// FP16 scatter flat `[total_tokens, dim]` to padded `[batch, max_seq, dim]`.
717 fn pad_to_batch_f16(
718 &self,
719 _flat: &Self::Tensor,
720 _padded: &mut Self::Tensor,
721 _seq_lengths: &[usize],
722 _max_seq: usize,
723 _dim: usize,
724 ) -> crate::Result<()> {
725 Err(crate::Error::Metal(
726 "FP16 not supported by this driver".into(),
727 ))
728 }
729
730 /// FP16 gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
731 fn unpad_from_batch_f16(
732 &self,
733 _padded: &Self::Tensor,
734 _flat: &mut Self::Tensor,
735 _seq_lengths: &[usize],
736 _max_seq: usize,
737 _dim: usize,
738 ) -> crate::Result<()> {
739 Err(crate::Error::Metal(
740 "FP16 not supported by this driver".into(),
741 ))
742 }
743
744 /// FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
745 fn rope_encode_f16(
746 &self,
747 _qk: &mut Self::Tensor,
748 _cos: &Self::Tensor,
749 _sin: &Self::Tensor,
750 _num_rows: usize,
751 _seq_len: usize,
752 _head_dim: usize,
753 _num_heads: usize,
754 ) -> crate::Result<()> {
755 Err(crate::Error::Metal(
756 "FP16 not supported by this driver".into(),
757 ))
758 }
759
760 /// FP16 `GeGLU` gated activation: `output = gelu(value) * gate`. Half I/O.
761 fn geglu_f16(
762 &self,
763 _value: &Self::Tensor,
764 _gate: &Self::Tensor,
765 _output: &mut Self::Tensor,
766 _n: usize,
767 ) -> crate::Result<()> {
768 Err(crate::Error::Metal(
769 "FP16 not supported by this driver".into(),
770 ))
771 }
772
773 /// FP16 fused residual add + layer normalization.
774 fn fused_residual_layernorm_f16(
775 &self,
776 _output: &mut Self::Tensor,
777 _hidden: &Self::Tensor,
778 _residual: &Self::Tensor,
779 _weight: &Self::Tensor,
780 _bias: &Self::Tensor,
781 _rows: usize,
782 _cols: usize,
783 _eps: f32,
784 ) -> crate::Result<()> {
785 Err(crate::Error::Metal(
786 "FP16 not supported by this driver".into(),
787 ))
788 }
789
790 /// FP16 residual add (no bias): `output = hidden + residual`.
791 fn residual_add_f16(
792 &self,
793 _output: &mut Self::Tensor,
794 _hidden: &Self::Tensor,
795 _residual: &Self::Tensor,
796 _n: usize,
797 ) -> crate::Result<()> {
798 Err(crate::Error::Metal(
799 "FP16 not supported by this driver".into(),
800 ))
801 }
802
803 /// FP16 split `[rows, 2*cols]` into two `[rows, cols]` halves.
804 fn split_gate_value_f16(
805 &self,
806 _first: &mut Self::Tensor,
807 _second: &mut Self::Tensor,
808 _input: &Self::Tensor,
809 _rows: usize,
810 _cols: usize,
811 ) -> crate::Result<()> {
812 Err(crate::Error::Metal(
813 "FP16 not supported by this driver".into(),
814 ))
815 }
816
817 /// Fused split + `GeGLU`: read `[rows, 2*cols]`, write `[rows, cols]`.
818 ///
819 /// Combines [`split_gate_value_f16`](Driver::split_gate_value_f16) and
820 /// [`geglu_f16`](Driver::geglu_f16) into a single kernel, eliminating
821 /// two intermediate `[rows, cols]` buffers and halving HBM round-trips.
822 ///
823 /// Default falls back to separate split + geglu calls.
824 fn fused_split_geglu_f16(
825 &self,
826 output: &mut Self::Tensor,
827 input: &Self::Tensor,
828 rows: usize,
829 cols: usize,
830 ) -> crate::Result<()> {
831 // Default: allocate intermediates and call separately.
832 let n = rows * cols;
833 let mut value = self.alloc_zeros_f16(n)?;
834 let mut gate = self.alloc_zeros_f16(n)?;
835 self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
836 self.geglu_f16(&value, &gate, output, n)
837 }
838
839 /// Fused pad + QKV split: flat `[total_tokens, 3*hidden]` → Q, K, V
840 /// each `[batch*heads, max_seq, head_dim]`.
841 ///
842 /// Eliminates the padded intermediate buffer. Default calls pad then split.
843 #[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
844 fn fused_pad_qkv_split_f16(
845 &self,
846 q: &mut Self::Tensor,
847 k: &mut Self::Tensor,
848 v: &mut Self::Tensor,
849 qkv_flat: &Self::Tensor,
850 seq_lengths: &[usize],
851 max_seq: usize,
852 batch: usize,
853 hidden: usize,
854 num_heads: usize,
855 head_dim: usize,
856 ) -> crate::Result<()> {
857 // Default: pad then split.
858 let padded_tokens = batch * max_seq;
859 let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
860 self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
861 self.qkv_split_f16(
862 q,
863 k,
864 v,
865 &qkv_padded,
866 batch,
867 max_seq,
868 hidden,
869 num_heads,
870 head_dim,
871 )
872 }
873
874 /// Fused attn_reshape + unpad: `[batch*heads, max_seq, head_dim]` →
875 /// `[total_tokens, hidden]`.
876 ///
877 /// Eliminates the padded context intermediate. Default calls reshape then unpad.
878 fn fused_reshape_unpad_f16(
879 &self,
880 flat: &mut Self::Tensor,
881 heads: &Self::Tensor,
882 seq_lengths: &[usize],
883 max_seq: usize,
884 batch: usize,
885 num_heads: usize,
886 head_dim: usize,
887 ) -> crate::Result<()> {
888 // Default: reshape then unpad.
889 let hidden = num_heads * head_dim;
890 let padded_tokens = batch * max_seq;
891 let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
892 self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
893 self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
894 }
895}
896
897/// Batch input tensors on device, produced by [`Driver::prepare_batch`].
898///
899/// Supports both padded and unpadded modes:
900/// - **Padded**: all sequences padded to `max_seq`. `cu_seqlens` is `None`.
901/// - **Unpadded**: sequences concatenated without padding. `cu_seqlens`
902/// contains cumulative lengths `[0, len0, len0+len1, ...]` so attention
903/// knows where each sequence starts. Eliminates ALL padding compute.
904pub struct BatchInputs<T> {
905 /// Token IDs — `[batch * max_seq]` (padded) or `[total_tokens]` (unpadded).
906 pub input_ids: T,
907 /// Attention mask `[batch * max_seq]` as int32 (0 or 1). Unused in unpadded mode.
908 pub attention_mask: T,
909 /// Token type IDs — same layout as `input_ids`.
910 pub token_type_ids: T,
911 /// Position IDs — same layout as `input_ids`.
912 pub position_ids: T,
913 /// Float attention bias mask `[batch * max_seq]` (0.0 or -1e9) for softmax.
914 pub float_mask: T,
915 /// Float pooling mask `[batch * max_seq]` (1.0 or 0.0) for mean pooling.
916 pub pooling_mask: T,
917 /// Number of sequences in this batch.
918 pub batch: usize,
919 /// Maximum sequence length (all sequences padded to this). In unpadded mode,
920 /// this is the longest sequence (used for workspace sizing, not padding).
921 pub max_seq: usize,
922 /// Total actual tokens across all sequences (no padding).
923 pub total_tokens: usize,
924 /// Per-sequence lengths: `[batch]` — each element is the actual token count.
925 pub seq_lengths: Vec<usize>,
926 /// Cumulative sequence lengths for unpadded attention: `[batch + 1]`.
927 /// `cu_seqlens[i]..cu_seqlens[i+1]` is the token range for sequence `i`.
928 /// `None` in padded mode (all sequences padded to max_seq).
929 pub cu_seqlens: Option<Vec<usize>>,
930}