ripvec_core/backend/driver/mod.rs
1//! Hardware-agnostic compute driver trait.
2//!
3//! The [`Driver`] trait exposes low-level compute primitives (GEMM, layer-norm,
4//! activations, etc.) that each hardware backend implements. Model architectures
5//! are generic over `D: Driver` and compose these primitives into a forward pass.
6//!
7//! # Design
8//!
9//! - **Associated type `Tensor`**: each driver defines its own opaque tensor
10//! handle (Metal: buffer+offset, CUDA: device pointer, CPU: ndarray).
11//! - **Not object-safe**: architectures use `D: Driver` generics so the compiler
12//! can monomorphize and inline driver calls.
13//! - **Send + Sync**: drivers are shared across the pipeline.
14
15#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
16pub mod cpu;
17#[cfg(feature = "cuda")]
18pub mod cuda;
19#[cfg(feature = "metal")]
20pub mod metal;
21#[cfg(feature = "mlx")]
22pub mod mlx;
23
24use super::Encoding;
25
26/// Hardware-agnostic compute primitives for BERT inference.
27///
28/// Each method corresponds to one operation in the forward pass. Drivers handle
29/// memory allocation, kernel dispatch, and synchronization. Architectures
30/// compose these primitives via the [`super::arch::ModelArch`] trait.
31pub trait Driver: Send + Sync {
32 /// Opaque tensor handle.
33 ///
34 /// Metal: `MTLBuffer` + byte offset. CUDA: `CUdeviceptr`. CPU: `Array2<f32>`.
35 type Tensor;
36
37 /// Short human-readable label for diagnostics (e.g. "Metal", "CUDA", "CPU").
38 /// Surfaced via [`super::EmbedBackend::name`].
39 fn name(&self) -> &'static str;
40
41 /// Create a new driver instance for a cloned worker thread.
42 ///
43 /// CPU drivers are zero-size and always succeed. GPU drivers typically
44 /// cannot be cloned this way (they share device state) and should leave
45 /// the default panic implementation.
46 fn new_for_clone() -> crate::Result<Self>
47 where
48 Self: Sized,
49 {
50 Err(crate::Error::Other(anyhow::anyhow!(
51 "this driver does not support cloning"
52 )))
53 }
54
55 // --- Batching ---
56
57 /// Begin batched mode: all subsequent operations encode into one dispatch.
58 ///
59 /// GPU drivers accumulate into a single command buffer; CPU is a no-op.
60 /// Call [`Self::end_batch`] to commit. This eliminates per-call overhead.
61 fn begin_batch(&self) -> crate::Result<()> {
62 Ok(())
63 }
64
65 /// End batched mode: commit all accumulated operations and wait.
66 fn end_batch(&self) -> crate::Result<()> {
67 Ok(())
68 }
69
70 /// Flush the current command buffer and start a new one, preserving pool
71 /// state. Use mid-forward-pass to prevent GPU timeouts on deep models.
72 fn flush_batch(&self) -> crate::Result<()> {
73 Ok(())
74 }
75
76 /// Close and reopen the compute encoder within the same command buffer.
77 ///
78 /// This segments a long sequence of compute dispatches into multiple
79 /// encoders without committing or waiting. Metal processes encoders
80 /// back-to-back from the same CB — zero sync overhead.
81 ///
82 /// Use every few layers to prevent encoder state overflow (>~60 dispatches
83 /// per encoder can cause hangs on some Apple Silicon GPUs).
84 fn segment_encoder(&self) {
85 // No-op for non-Metal backends
86 }
87
88 /// Save the current pool cursor position. Call BEFORE a layer's work.
89 fn save_pool_cursor(&self) -> usize {
90 0
91 }
92
93 /// Restore the pool cursor to a previously saved position. Call AFTER
94 /// a layer's transient tensors have been dropped (out of scope).
95 ///
96 /// The architecture must ensure only the output tensor (`hidden_states`)
97 /// survives — all layer-internal tensors (qkv, scores, context, etc.)
98 /// must be dropped before this call so their pool slots can be recycled.
99 fn restore_pool_cursor(&self, _saved: usize) {}
100
101 // --- Allocation ---
102
103 /// Allocate a zero-initialized tensor with `n` float elements on device.
104 ///
105 /// Used by architectures to create workspace buffers (QKV projections,
106 /// attention scores, intermediate activations, etc.).
107 ///
108 /// # Errors
109 ///
110 /// Returns an error if device memory allocation fails.
111 fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
112
113 /// Clone a tensor, producing an independent copy of the data.
114 ///
115 /// Used when an operation needs both the original and a mutable output
116 /// referencing the same logical data (e.g., in-place layer normalization
117 /// where input == output).
118 ///
119 /// # Errors
120 ///
121 /// Returns an error if device memory allocation or the copy fails.
122 fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
123
124 // --- Batch preparation ---
125
126 /// Prepare a batch of encodings for inference, returning input tensors on device.
127 ///
128 /// Pads all sequences to `max_seq` and uploads `input_ids`, `attention_mask`,
129 /// `token_type_ids`, `position_ids`, and a float attention mask to device memory.
130 fn prepare_batch(
131 &self,
132 encodings: &[Encoding],
133 max_seq: usize,
134 ) -> crate::Result<BatchInputs<Self::Tensor>>;
135
136 /// Prepare a batch WITHOUT padding — concatenate all tokens flat.
137 ///
138 /// Returns `BatchInputs` with `total_tokens` actual tokens (no padding),
139 /// `cu_seqlens` for attention boundaries, and per-token position IDs.
140 /// Linear layers (GEMM, LN, GELU) process `total_tokens` rows.
141 /// Attention must pad/unpad around the per-head operations.
142 fn prepare_batch_unpadded(
143 &self,
144 encodings: &[Encoding],
145 ) -> crate::Result<BatchInputs<Self::Tensor>> {
146 // Default: fall back to padded (backends override for unpadded support)
147 let max_seq = encodings
148 .iter()
149 .map(|e| e.input_ids.len())
150 .max()
151 .unwrap_or(0)
152 .next_multiple_of(8);
153 self.prepare_batch(encodings, max_seq)
154 }
155
156 /// Scatter flat `[total_tokens, dim]` tensor into padded `[batch, max_seq, dim]`.
157 ///
158 /// Used before attention: linear layers produce unpadded output, but the
159 /// QKV split + batched attention GEMM need aligned `[batch*heads, seq, head_dim]`.
160 /// Padding positions are zeroed.
161 fn pad_to_batch(
162 &self,
163 flat: &Self::Tensor,
164 padded: &mut Self::Tensor,
165 seq_lengths: &[usize],
166 max_seq: usize,
167 dim: usize,
168 ) -> crate::Result<()>;
169
170 /// Gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
171 ///
172 /// Used after attention: extracts only the real tokens, discarding padding.
173 fn unpad_from_batch(
174 &self,
175 padded: &Self::Tensor,
176 flat: &mut Self::Tensor,
177 seq_lengths: &[usize],
178 max_seq: usize,
179 dim: usize,
180 ) -> crate::Result<()>;
181
182 // --- Embedding operations ---
183
184 /// Word/position/token-type embedding lookup via gather.
185 ///
186 /// Reads `seq_len` token IDs from `word_ids`, gathers rows from
187 /// `embedding_table`, and writes `[seq_len, hidden]` floats to the result.
188 fn embedding_lookup(
189 &self,
190 word_ids: &Self::Tensor,
191 embedding_table: &Self::Tensor,
192 seq_len: usize,
193 hidden: usize,
194 ) -> crate::Result<Self::Tensor>;
195
196 /// Element-wise add an embedding table lookup into `hidden`.
197 ///
198 /// Used for position and token-type embeddings:
199 /// `hidden[i] += table[ids[i]]` for each token position.
200 fn add_embeddings(
201 &self,
202 hidden: &mut Self::Tensor,
203 table: &Self::Tensor,
204 ids: &Self::Tensor,
205 seq_len: usize,
206 hidden_dim: usize,
207 ) -> crate::Result<()>;
208
209 // --- Normalization ---
210
211 /// Layer normalization: `output = (input - mean) / sqrt(var + eps) * weight + bias`.
212 fn layer_norm(
213 &self,
214 output: &mut Self::Tensor,
215 input: &Self::Tensor,
216 weight: &Self::Tensor,
217 bias: &Self::Tensor,
218 rows: usize,
219 cols: usize,
220 eps: f32,
221 ) -> crate::Result<()>;
222
223 // --- Linear algebra ---
224
225 /// General matrix multiply: `output = A * B` (or `A * B^T` if `transpose_b`).
226 ///
227 /// Dimensions: A is `[m, k]`, B is `[k, n]` (or `[n, k]` if transposed),
228 /// output is `[m, n]`.
229 fn gemm(
230 &self,
231 a: &Self::Tensor,
232 b: &Self::Tensor,
233 output: &mut Self::Tensor,
234 m: usize,
235 n: usize,
236 k: usize,
237 transpose_b: bool,
238 ) -> crate::Result<()>;
239
240 /// Batched GEMM for multi-head attention.
241 ///
242 /// Performs `batch_count` independent GEMMs with strided access into
243 /// contiguous buffers. Used for per-head Q*K^T and attn*V.
244 fn gemm_batched(
245 &self,
246 a: &Self::Tensor,
247 b: &Self::Tensor,
248 output: &mut Self::Tensor,
249 m: usize,
250 n: usize,
251 k: usize,
252 transpose_b: bool,
253 stride_a: usize,
254 stride_b: usize,
255 stride_c: usize,
256 batch_count: usize,
257 ) -> crate::Result<()>;
258
259 // --- Attention ---
260
261 /// Fused scale + mask + softmax for attention scores.
262 ///
263 /// `scores = softmax(scores * scale + mask)` computed per-head.
264 fn fused_scale_mask_softmax(
265 &self,
266 scores: &mut Self::Tensor,
267 mask: &Self::Tensor,
268 batch: usize,
269 num_heads: usize,
270 seq_len: usize,
271 scale: f32,
272 ) -> crate::Result<()>;
273
274 /// Fused scale + mask + sliding window + softmax for attention scores.
275 ///
276 /// Like [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax) but
277 /// additionally masks out positions where `|query_pos - key_pos| > window_size / 2`.
278 /// Used by `ModernBERT`'s local attention layers.
279 fn fused_scale_mask_softmax_windowed(
280 &self,
281 scores: &mut Self::Tensor,
282 mask: &Self::Tensor,
283 batch: usize,
284 num_heads: usize,
285 seq_len: usize,
286 scale: f32,
287 window_size: usize,
288 ) -> crate::Result<()>;
289
290 /// Build a float attention mask from an integer mask.
291 ///
292 /// Converts `[batch * seq]` int mask (0/1) to `[batch * seq]` float mask
293 /// (0.0 / -10000.0) for use with [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax).
294 fn build_attn_mask(
295 &self,
296 output: &mut Self::Tensor,
297 int_mask: &Self::Tensor,
298 n: usize,
299 ) -> crate::Result<()>;
300
301 /// Split a fused QKV projection into separate Q, K, V tensors.
302 fn qkv_split(
303 &self,
304 q: &mut Self::Tensor,
305 k: &mut Self::Tensor,
306 v: &mut Self::Tensor,
307 qkv: &Self::Tensor,
308 batch: usize,
309 seq: usize,
310 hidden: usize,
311 num_heads: usize,
312 head_dim: usize,
313 ) -> crate::Result<()>;
314
315 // --- Banded (local/sliding-window) attention ---
316
317 /// Banded Q@K^T: compute attention scores only within a sliding window.
318 ///
319 /// Output shape: `[batch * num_heads, seq, window]` (NOT `[seq, seq]`).
320 /// `scores[h, i, w]` = dot(Q[h, i, :], K[h, i - window/2 + w, :])
321 /// where out-of-bounds positions are set to `-inf` (masked in softmax).
322 ///
323 /// Reduces attention compute from O(seq²) to O(seq × window).
324 /// For `seq=512, window=128`: **4× less compute** per local layer.
325 fn banded_qk(
326 &self,
327 q: &Self::Tensor,
328 k: &Self::Tensor,
329 scores: &mut Self::Tensor,
330 batch_heads: usize,
331 seq: usize,
332 head_dim: usize,
333 window: usize,
334 stride_qk: usize,
335 stride_scores: usize,
336 ) -> crate::Result<()>;
337
338 /// Banded scores@V: weighted sum using banded attention scores.
339 ///
340 /// Input scores: `[batch * num_heads, seq, window]` (from `banded_qk`).
341 /// Output: `[batch * num_heads, seq, head_dim]`.
342 /// `output[h, i, d]` = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
343 fn banded_sv(
344 &self,
345 scores: &Self::Tensor,
346 v: &Self::Tensor,
347 output: &mut Self::Tensor,
348 batch_heads: usize,
349 seq: usize,
350 head_dim: usize,
351 window: usize,
352 stride_scores: usize,
353 stride_v: usize,
354 stride_out: usize,
355 ) -> crate::Result<()>;
356
357 /// Fused scale + softmax over the window dimension (no padding mask needed).
358 ///
359 /// Operates on `[batch * num_heads * seq, window]` rows.
360 fn banded_softmax(
361 &self,
362 scores: &mut Self::Tensor,
363 total_rows: usize,
364 window: usize,
365 scale: f32,
366 ) -> crate::Result<()>;
367
368 /// Reshape attention output from `[batch, num_heads, seq, head_dim]` to
369 /// `[batch * seq, hidden]`.
370 fn attn_reshape(
371 &self,
372 output: &mut Self::Tensor,
373 input: &Self::Tensor,
374 batch: usize,
375 seq: usize,
376 num_heads: usize,
377 head_dim: usize,
378 ) -> crate::Result<()>;
379
380 /// Apply Rotary Position Embedding (RoPE) to Q/K tensors.
381 ///
382 /// Used by ModernBERT (not ClassicBert which uses learned position embeddings).
383 fn apply_rope(
384 &self,
385 qk: &mut Self::Tensor,
386 cos: &Self::Tensor,
387 sin: &Self::Tensor,
388 num_rows: usize,
389 seq_len: usize,
390 head_dim: usize,
391 num_heads: usize,
392 ) -> crate::Result<()>;
393
394 // --- Tensor manipulation ---
395
396 /// Split a `[rows, 2*cols]` matrix into two `[rows, cols]` halves.
397 ///
398 /// Each row of `input` is `[first_half | second_half]`. The first `cols`
399 /// elements go to `first`, the remaining `cols` to `second`.
400 /// Used by `ModernBERT` for gated MLP splits.
401 fn split_gate_value(
402 &self,
403 first: &mut Self::Tensor,
404 second: &mut Self::Tensor,
405 input: &Self::Tensor,
406 rows: usize,
407 cols: usize,
408 ) -> crate::Result<()>;
409
410 // --- Activations ---
411
412 /// GELU activation (Gaussian Error Linear Unit), applied in-place.
413 fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
414
415 /// SwiGLU gated activation: `output = value * silu(gate)`.
416 ///
417 /// The gate and value come from splitting the intermediate projection.
418 fn swiglu(
419 &self,
420 value: &Self::Tensor,
421 gate: &Self::Tensor,
422 output: &mut Self::Tensor,
423 n: usize,
424 ) -> crate::Result<()>;
425
426 /// `GeGLU` gated activation: `output = gelu(value) * gate`.
427 ///
428 /// Used by `ModernBERT`. The value and gate come from splitting the
429 /// MLP `Wi` projection output in half.
430 fn geglu(
431 &self,
432 value: &Self::Tensor,
433 gate: &Self::Tensor,
434 output: &mut Self::Tensor,
435 n: usize,
436 ) -> crate::Result<()>;
437
438 /// Fused bias + GELU: `x = gelu(x + bias)` row-wise.
439 fn fused_bias_gelu(
440 &self,
441 x: &mut Self::Tensor,
442 bias: &Self::Tensor,
443 rows: usize,
444 cols: usize,
445 ) -> crate::Result<()>;
446
447 // --- Fused residual operations ---
448
449 /// Fused bias + residual add: `output = input + bias + residual`.
450 ///
451 /// Bias is broadcast row-wise (`cols`-wide) across `n / cols` rows.
452 fn fused_bias_residual(
453 &self,
454 output: &mut Self::Tensor,
455 input: &Self::Tensor,
456 bias: &Self::Tensor,
457 residual: &Self::Tensor,
458 n: usize,
459 cols: usize,
460 ) -> crate::Result<()>;
461
462 /// Fused residual add + layer normalization.
463 ///
464 /// `output = layer_norm(hidden + residual, weight, bias, eps)`.
465 fn fused_residual_layernorm(
466 &self,
467 output: &mut Self::Tensor,
468 hidden: &Self::Tensor,
469 residual: &Self::Tensor,
470 weight: &Self::Tensor,
471 bias: &Self::Tensor,
472 rows: usize,
473 cols: usize,
474 eps: f32,
475 ) -> crate::Result<()>;
476
477 /// Residual add without bias: `output = hidden + residual`.
478 ///
479 /// Used by `ModernBERT` which has no bias terms.
480 fn residual_add(
481 &self,
482 output: &mut Self::Tensor,
483 hidden: &Self::Tensor,
484 residual: &Self::Tensor,
485 n: usize,
486 ) -> crate::Result<()>;
487
488 /// Add bias to a matrix row-wise: `x[row] += bias` for each row.
489 fn add_bias(
490 &self,
491 x: &mut Self::Tensor,
492 bias: &Self::Tensor,
493 rows: usize,
494 cols: usize,
495 ) -> crate::Result<()>;
496
497 // --- Pooling ---
498
499 /// CLS pooling: extract the first token's hidden state per batch element.
500 fn cls_pool(
501 &self,
502 output: &mut Self::Tensor,
503 hidden: &Self::Tensor,
504 batch: usize,
505 seq: usize,
506 hidden_dim: usize,
507 ) -> crate::Result<()>;
508
509 /// Mean pooling: attention-mask-weighted average of hidden states.
510 fn mean_pool(
511 &self,
512 output: &mut Self::Tensor,
513 hidden: &Self::Tensor,
514 mask: &Self::Tensor,
515 batch: usize,
516 seq: usize,
517 hidden_dim: usize,
518 ) -> crate::Result<()>;
519
520 // --- Post-processing ---
521
522 /// L2-normalize each row vector in-place.
523 fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
524
525 /// Copy tensor data back to host memory as `Vec<Vec<f32>>`.
526 ///
527 /// Returns one `Vec<f32>` of length `dim` per batch element.
528 fn to_host(
529 &self,
530 tensor: &Self::Tensor,
531 batch: usize,
532 dim: usize,
533 ) -> crate::Result<Vec<Vec<f32>>>;
534
535 /// Optional finite-value diagnostic hook for backend tensors.
536 ///
537 /// Backends should keep this cheap or disabled by default. The CUDA driver
538 /// enables full tensor readback only with `RIPVEC_CUDA_DEBUG_TENSORS=1`.
539 fn debug_tensor(
540 &self,
541 _label: &str,
542 _tensor: &Self::Tensor,
543 _rows: usize,
544 _cols: usize,
545 ) -> crate::Result<()> {
546 Ok(())
547 }
548
549 /// Whether calls to [`Driver::debug_tensor`] will inspect tensor contents.
550 ///
551 /// Architecture code uses this to avoid allocating and converting probe
552 /// tensors when diagnostics are disabled.
553 fn debug_tensors_enabled(&self) -> bool {
554 false
555 }
556
557 // =======================================================================
558 // FP16 operations for full half-precision pipeline
559 //
560 // These methods mirror the FP32 counterparts but operate on FP16 tensors.
561 // Internal reductions (softmax, layer-norm) use FP32 accumulators but
562 // all tensor I/O is half precision. Default implementations return an
563 // error — only backends with FP16 support override them.
564 // =======================================================================
565
566 /// Allocate a zero-initialized FP16 tensor with `n` half-precision elements.
567 ///
568 /// # Errors
569 ///
570 /// Returns an error if device memory allocation fails or FP16 is unsupported.
571 fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
572 Err(crate::Error::Metal(
573 "FP16 not supported by this driver".into(),
574 ))
575 }
576
577 /// Convert FP32 tensor to FP16 (element-wise narrowing).
578 fn f32_to_f16(
579 &self,
580 _output: &mut Self::Tensor,
581 _input: &Self::Tensor,
582 _n: usize,
583 ) -> crate::Result<()> {
584 Err(crate::Error::Metal(
585 "FP16 not supported by this driver".into(),
586 ))
587 }
588
589 /// Convert FP16 tensor back to FP32 (element-wise widening).
590 fn f16_to_f32(
591 &self,
592 _output: &mut Self::Tensor,
593 _input: &Self::Tensor,
594 _n: usize,
595 ) -> crate::Result<()> {
596 Err(crate::Error::Metal(
597 "FP16 not supported by this driver".into(),
598 ))
599 }
600
601 /// Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
602 fn gemm_mixed(
603 &self,
604 _a_f16: &Self::Tensor,
605 _b_f16: &Self::Tensor,
606 _output_f32: &mut Self::Tensor,
607 _m: usize,
608 _n: usize,
609 _k: usize,
610 _transpose_b: bool,
611 ) -> crate::Result<()> {
612 Err(crate::Error::Metal(
613 "gemm_mixed not supported by this driver".into(),
614 ))
615 }
616
617 /// FP16 GEMM: `output = A * B` (or `A * B^T`). All tensors are half.
618 fn gemm_f16(
619 &self,
620 _a: &Self::Tensor,
621 _b: &Self::Tensor,
622 _output: &mut Self::Tensor,
623 _m: usize,
624 _n: usize,
625 _k: usize,
626 _transpose_b: bool,
627 ) -> crate::Result<()> {
628 Err(crate::Error::Metal(
629 "FP16 not supported by this driver".into(),
630 ))
631 }
632
633 /// FP16 batched GEMM for multi-head attention. All tensors are half.
634 #[expect(
635 clippy::too_many_arguments,
636 reason = "matches FP32 gemm_batched signature"
637 )]
638 fn gemm_batched_f16(
639 &self,
640 _a: &Self::Tensor,
641 _b: &Self::Tensor,
642 _output: &mut Self::Tensor,
643 _m: usize,
644 _n: usize,
645 _k: usize,
646 _transpose_b: bool,
647 _stride_a: usize,
648 _stride_b: usize,
649 _stride_c: usize,
650 _batch_count: usize,
651 ) -> crate::Result<()> {
652 Err(crate::Error::Metal(
653 "FP16 not supported by this driver".into(),
654 ))
655 }
656
657 /// FP16 layer normalization. Half I/O, FP32 reductions.
658 fn layer_norm_f16(
659 &self,
660 _output: &mut Self::Tensor,
661 _input: &Self::Tensor,
662 _weight: &Self::Tensor,
663 _bias: &Self::Tensor,
664 _rows: usize,
665 _cols: usize,
666 _eps: f32,
667 ) -> crate::Result<()> {
668 Err(crate::Error::Metal(
669 "FP16 not supported by this driver".into(),
670 ))
671 }
672
673 /// FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
674 fn fused_scale_mask_softmax_f16(
675 &self,
676 _scores: &mut Self::Tensor,
677 _mask: &Self::Tensor,
678 _batch: usize,
679 _num_heads: usize,
680 _seq_len: usize,
681 _scale: f32,
682 ) -> crate::Result<()> {
683 Err(crate::Error::Metal(
684 "FP16 not supported by this driver".into(),
685 ))
686 }
687
688 /// FP16 fused scale + mask + sliding window + softmax.
689 fn fused_scale_mask_softmax_windowed_f16(
690 &self,
691 _scores: &mut Self::Tensor,
692 _mask: &Self::Tensor,
693 _batch: usize,
694 _num_heads: usize,
695 _seq_len: usize,
696 _scale: f32,
697 _window_size: usize,
698 ) -> crate::Result<()> {
699 Err(crate::Error::Metal(
700 "FP16 not supported by this driver".into(),
701 ))
702 }
703
704 /// FP16 QKV split: `[batch*seq, 3*hidden]` into Q, K, V per-head layout.
705 fn qkv_split_f16(
706 &self,
707 _q: &mut Self::Tensor,
708 _k: &mut Self::Tensor,
709 _v: &mut Self::Tensor,
710 _qkv: &Self::Tensor,
711 _batch: usize,
712 _seq: usize,
713 _hidden: usize,
714 _num_heads: usize,
715 _head_dim: usize,
716 ) -> crate::Result<()> {
717 Err(crate::Error::Metal(
718 "FP16 not supported by this driver".into(),
719 ))
720 }
721
722 /// FP16 attention output reshape: `[batch*num_heads, seq, head_dim]` to
723 /// `[batch*seq, hidden]`.
724 fn attn_reshape_f16(
725 &self,
726 _output: &mut Self::Tensor,
727 _input: &Self::Tensor,
728 _batch: usize,
729 _seq: usize,
730 _num_heads: usize,
731 _head_dim: usize,
732 ) -> crate::Result<()> {
733 Err(crate::Error::Metal(
734 "FP16 not supported by this driver".into(),
735 ))
736 }
737
738 /// FP16 scatter flat `[total_tokens, dim]` to padded `[batch, max_seq, dim]`.
739 fn pad_to_batch_f16(
740 &self,
741 _flat: &Self::Tensor,
742 _padded: &mut Self::Tensor,
743 _seq_lengths: &[usize],
744 _max_seq: usize,
745 _dim: usize,
746 ) -> crate::Result<()> {
747 Err(crate::Error::Metal(
748 "FP16 not supported by this driver".into(),
749 ))
750 }
751
752 /// FP16 gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
753 fn unpad_from_batch_f16(
754 &self,
755 _padded: &Self::Tensor,
756 _flat: &mut Self::Tensor,
757 _seq_lengths: &[usize],
758 _max_seq: usize,
759 _dim: usize,
760 ) -> crate::Result<()> {
761 Err(crate::Error::Metal(
762 "FP16 not supported by this driver".into(),
763 ))
764 }
765
766 /// FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
767 fn rope_encode_f16(
768 &self,
769 _qk: &mut Self::Tensor,
770 _cos: &Self::Tensor,
771 _sin: &Self::Tensor,
772 _num_rows: usize,
773 _seq_len: usize,
774 _head_dim: usize,
775 _num_heads: usize,
776 ) -> crate::Result<()> {
777 Err(crate::Error::Metal(
778 "FP16 not supported by this driver".into(),
779 ))
780 }
781
782 /// FP16 `GeGLU` gated activation: `output = gelu(value) * gate`. Half I/O.
783 fn geglu_f16(
784 &self,
785 _value: &Self::Tensor,
786 _gate: &Self::Tensor,
787 _output: &mut Self::Tensor,
788 _n: usize,
789 ) -> crate::Result<()> {
790 Err(crate::Error::Metal(
791 "FP16 not supported by this driver".into(),
792 ))
793 }
794
795 /// FP16 fused residual add + layer normalization.
796 fn fused_residual_layernorm_f16(
797 &self,
798 _output: &mut Self::Tensor,
799 _hidden: &Self::Tensor,
800 _residual: &Self::Tensor,
801 _weight: &Self::Tensor,
802 _bias: &Self::Tensor,
803 _rows: usize,
804 _cols: usize,
805 _eps: f32,
806 ) -> crate::Result<()> {
807 Err(crate::Error::Metal(
808 "FP16 not supported by this driver".into(),
809 ))
810 }
811
812 /// FP16 residual add (no bias): `output = hidden + residual`.
813 fn residual_add_f16(
814 &self,
815 _output: &mut Self::Tensor,
816 _hidden: &Self::Tensor,
817 _residual: &Self::Tensor,
818 _n: usize,
819 ) -> crate::Result<()> {
820 Err(crate::Error::Metal(
821 "FP16 not supported by this driver".into(),
822 ))
823 }
824
825 /// FP16 split `[rows, 2*cols]` into two `[rows, cols]` halves.
826 fn split_gate_value_f16(
827 &self,
828 _first: &mut Self::Tensor,
829 _second: &mut Self::Tensor,
830 _input: &Self::Tensor,
831 _rows: usize,
832 _cols: usize,
833 ) -> crate::Result<()> {
834 Err(crate::Error::Metal(
835 "FP16 not supported by this driver".into(),
836 ))
837 }
838
839 /// Fused split + `GeGLU`: read `[rows, 2*cols]`, write `[rows, cols]`.
840 ///
841 /// Combines [`split_gate_value_f16`](Driver::split_gate_value_f16) and
842 /// [`geglu_f16`](Driver::geglu_f16) into a single kernel, eliminating
843 /// two intermediate `[rows, cols]` buffers and halving HBM round-trips.
844 ///
845 /// Default falls back to separate split + geglu calls.
846 fn fused_split_geglu_f16(
847 &self,
848 output: &mut Self::Tensor,
849 input: &Self::Tensor,
850 rows: usize,
851 cols: usize,
852 ) -> crate::Result<()> {
853 // Default: allocate intermediates and call separately.
854 let n = rows * cols;
855 let mut value = self.alloc_zeros_f16(n)?;
856 let mut gate = self.alloc_zeros_f16(n)?;
857 self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
858 self.geglu_f16(&value, &gate, output, n)
859 }
860
861 /// Fused pad + QKV split: flat `[total_tokens, 3*hidden]` → Q, K, V
862 /// each `[batch*heads, max_seq, head_dim]`.
863 ///
864 /// Eliminates the padded intermediate buffer. Default calls pad then split.
865 #[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
866 fn fused_pad_qkv_split_f16(
867 &self,
868 q: &mut Self::Tensor,
869 k: &mut Self::Tensor,
870 v: &mut Self::Tensor,
871 qkv_flat: &Self::Tensor,
872 seq_lengths: &[usize],
873 max_seq: usize,
874 batch: usize,
875 hidden: usize,
876 num_heads: usize,
877 head_dim: usize,
878 ) -> crate::Result<()> {
879 // Default: pad then split.
880 let padded_tokens = batch * max_seq;
881 let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
882 self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
883 self.qkv_split_f16(
884 q,
885 k,
886 v,
887 &qkv_padded,
888 batch,
889 max_seq,
890 hidden,
891 num_heads,
892 head_dim,
893 )
894 }
895
896 /// Fused attn_reshape + unpad: `[batch*heads, max_seq, head_dim]` →
897 /// `[total_tokens, hidden]`.
898 ///
899 /// Eliminates the padded context intermediate. Default calls reshape then unpad.
900 fn fused_reshape_unpad_f16(
901 &self,
902 flat: &mut Self::Tensor,
903 heads: &Self::Tensor,
904 seq_lengths: &[usize],
905 max_seq: usize,
906 batch: usize,
907 num_heads: usize,
908 head_dim: usize,
909 ) -> crate::Result<()> {
910 // Default: reshape then unpad.
911 let hidden = num_heads * head_dim;
912 let padded_tokens = batch * max_seq;
913 let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
914 self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
915 self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
916 }
917}
918
919/// Batch input tensors on device, produced by [`Driver::prepare_batch`].
920///
921/// Supports both padded and unpadded modes:
922/// - **Padded**: all sequences padded to `max_seq`. `cu_seqlens` is `None`.
923/// - **Unpadded**: sequences concatenated without padding. `cu_seqlens`
924/// contains cumulative lengths `[0, len0, len0+len1, ...]` so attention
925/// knows where each sequence starts. Eliminates ALL padding compute.
926pub struct BatchInputs<T> {
927 /// Token IDs — `[batch * max_seq]` (padded) or `[total_tokens]` (unpadded).
928 pub input_ids: T,
929 /// Attention mask `[batch * max_seq]` as int32 (0 or 1). Unused in unpadded mode.
930 pub attention_mask: T,
931 /// Token type IDs — same layout as `input_ids`.
932 pub token_type_ids: T,
933 /// Position IDs — same layout as `input_ids`.
934 pub position_ids: T,
935 /// Float attention bias mask `[batch * max_seq]` (0.0 or -1e9) for softmax.
936 pub float_mask: T,
937 /// Float pooling mask `[batch * max_seq]` (1.0 or 0.0) for mean pooling.
938 pub pooling_mask: T,
939 /// Number of sequences in this batch.
940 pub batch: usize,
941 /// Maximum sequence length (all sequences padded to this). In unpadded mode,
942 /// this is the longest sequence (used for workspace sizing, not padding).
943 pub max_seq: usize,
944 /// Total actual tokens across all sequences (no padding).
945 pub total_tokens: usize,
946 /// Per-sequence lengths: `[batch]` — each element is the actual token count.
947 pub seq_lengths: Vec<usize>,
948 /// Cumulative sequence lengths for unpadded attention: `[batch + 1]`.
949 /// `cu_seqlens[i]..cu_seqlens[i+1]` is the token range for sequence `i`.
950 /// `None` in padded mode (all sequences padded to max_seq).
951 pub cu_seqlens: Option<Vec<usize>>,
952}