ripvec_core/backend/driver/mod.rs
1//! Hardware-agnostic compute driver trait.
2//!
3//! The [`Driver`] trait exposes low-level compute primitives (GEMM, layer-norm,
4//! activations, etc.) that each hardware backend implements. Model architectures
5//! are generic over `D: Driver` and compose these primitives into a forward pass.
6//!
7//! # Design
8//!
9//! - **Associated type `Tensor`**: each driver defines its own opaque tensor
10//! handle (Metal: buffer+offset, CUDA: device pointer, CPU: ndarray).
11//! - **Not object-safe**: architectures use `D: Driver` generics so the compiler
12//! can monomorphize and inline driver calls.
13//! - **Send + Sync**: drivers are shared across the pipeline.
14
15#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
16pub mod cpu;
17#[cfg(feature = "cuda")]
18pub mod cuda;
19#[cfg(feature = "metal")]
20pub mod metal;
21#[cfg(feature = "mlx")]
22pub mod mlx;
23
24use super::Encoding;
25
26/// Hardware-agnostic compute primitives for BERT inference.
27///
28/// Each method corresponds to one operation in the forward pass. Drivers handle
29/// memory allocation, kernel dispatch, and synchronization. Architectures
30/// compose these primitives via the [`super::arch::ModelArch`] trait.
31pub trait Driver: Send + Sync {
32 /// Opaque tensor handle.
33 ///
34 /// Metal: `MTLBuffer` + byte offset. CUDA: `CUdeviceptr`. CPU: `Array2<f32>`.
35 type Tensor;
36
37 /// Create a new driver instance for a cloned worker thread.
38 ///
39 /// CPU drivers are zero-size and always succeed. GPU drivers typically
40 /// cannot be cloned this way (they share device state) and should leave
41 /// the default panic implementation.
42 fn new_for_clone() -> crate::Result<Self>
43 where
44 Self: Sized,
45 {
46 Err(crate::Error::Other(anyhow::anyhow!(
47 "this driver does not support cloning"
48 )))
49 }
50
51 // --- Batching ---
52
53 /// Begin batched mode: all subsequent operations encode into one dispatch.
54 ///
55 /// GPU drivers accumulate into a single command buffer; CPU is a no-op.
56 /// Call [`end_batch`] to commit. This eliminates per-call overhead.
57 fn begin_batch(&self) -> crate::Result<()> {
58 Ok(())
59 }
60
61 /// End batched mode: commit all accumulated operations and wait.
62 fn end_batch(&self) -> crate::Result<()> {
63 Ok(())
64 }
65
66 /// Flush the current command buffer and start a new one, preserving pool
67 /// state. Use mid-forward-pass to prevent GPU timeouts on deep models.
68 fn flush_batch(&self) -> crate::Result<()> {
69 Ok(())
70 }
71
72 /// Close and reopen the compute encoder within the same command buffer.
73 ///
74 /// This segments a long sequence of compute dispatches into multiple
75 /// encoders without committing or waiting. Metal processes encoders
76 /// back-to-back from the same CB — zero sync overhead.
77 ///
78 /// Use every few layers to prevent encoder state overflow (>~60 dispatches
79 /// per encoder can cause hangs on some Apple Silicon GPUs).
80 fn segment_encoder(&self) {
81 // No-op for non-Metal backends
82 }
83
84 /// Save the current pool cursor position. Call BEFORE a layer's work.
85 fn save_pool_cursor(&self) -> usize {
86 0
87 }
88
89 /// Restore the pool cursor to a previously saved position. Call AFTER
90 /// a layer's transient tensors have been dropped (out of scope).
91 ///
92 /// The architecture must ensure only the output tensor (`hidden_states`)
93 /// survives — all layer-internal tensors (qkv, scores, context, etc.)
94 /// must be dropped before this call so their pool slots can be recycled.
95 fn restore_pool_cursor(&self, _saved: usize) {}
96
97 // --- Allocation ---
98
99 /// Allocate a zero-initialized tensor with `n` float elements on device.
100 ///
101 /// Used by architectures to create workspace buffers (QKV projections,
102 /// attention scores, intermediate activations, etc.).
103 ///
104 /// # Errors
105 ///
106 /// Returns an error if device memory allocation fails.
107 fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
108
109 /// Clone a tensor, producing an independent copy of the data.
110 ///
111 /// Used when an operation needs both the original and a mutable output
112 /// referencing the same logical data (e.g., in-place layer normalization
113 /// where input == output).
114 ///
115 /// # Errors
116 ///
117 /// Returns an error if device memory allocation or the copy fails.
118 fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
119
120 // --- Batch preparation ---
121
122 /// Prepare a batch of encodings for inference, returning input tensors on device.
123 ///
124 /// Pads all sequences to `max_seq` and uploads `input_ids`, `attention_mask`,
125 /// `token_type_ids`, `position_ids`, and a float attention mask to device memory.
126 fn prepare_batch(
127 &self,
128 encodings: &[Encoding],
129 max_seq: usize,
130 ) -> crate::Result<BatchInputs<Self::Tensor>>;
131
132 /// Prepare a batch WITHOUT padding — concatenate all tokens flat.
133 ///
134 /// Returns `BatchInputs` with `total_tokens` actual tokens (no padding),
135 /// `cu_seqlens` for attention boundaries, and per-token position IDs.
136 /// Linear layers (GEMM, LN, GELU) process `total_tokens` rows.
137 /// Attention must pad/unpad around the per-head operations.
138 fn prepare_batch_unpadded(
139 &self,
140 encodings: &[Encoding],
141 ) -> crate::Result<BatchInputs<Self::Tensor>> {
142 // Default: fall back to padded (backends override for unpadded support)
143 let max_seq = encodings
144 .iter()
145 .map(|e| e.input_ids.len())
146 .max()
147 .unwrap_or(0)
148 .next_multiple_of(8);
149 self.prepare_batch(encodings, max_seq)
150 }
151
152 /// Scatter flat `[total_tokens, dim]` tensor into padded `[batch, max_seq, dim]`.
153 ///
154 /// Used before attention: linear layers produce unpadded output, but the
155 /// QKV split + batched attention GEMM need aligned `[batch*heads, seq, head_dim]`.
156 /// Padding positions are zeroed.
157 fn pad_to_batch(
158 &self,
159 flat: &Self::Tensor,
160 padded: &mut Self::Tensor,
161 seq_lengths: &[usize],
162 max_seq: usize,
163 dim: usize,
164 ) -> crate::Result<()>;
165
166 /// Gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
167 ///
168 /// Used after attention: extracts only the real tokens, discarding padding.
169 fn unpad_from_batch(
170 &self,
171 padded: &Self::Tensor,
172 flat: &mut Self::Tensor,
173 seq_lengths: &[usize],
174 max_seq: usize,
175 dim: usize,
176 ) -> crate::Result<()>;
177
178 // --- Embedding operations ---
179
180 /// Word/position/token-type embedding lookup via gather.
181 ///
182 /// Reads `seq_len` token IDs from `word_ids`, gathers rows from
183 /// `embedding_table`, and writes `[seq_len, hidden]` floats to the result.
184 fn embedding_lookup(
185 &self,
186 word_ids: &Self::Tensor,
187 embedding_table: &Self::Tensor,
188 seq_len: usize,
189 hidden: usize,
190 ) -> crate::Result<Self::Tensor>;
191
192 /// Element-wise add an embedding table lookup into `hidden`.
193 ///
194 /// Used for position and token-type embeddings:
195 /// `hidden[i] += table[ids[i]]` for each token position.
196 fn add_embeddings(
197 &self,
198 hidden: &mut Self::Tensor,
199 table: &Self::Tensor,
200 ids: &Self::Tensor,
201 seq_len: usize,
202 hidden_dim: usize,
203 ) -> crate::Result<()>;
204
205 // --- Normalization ---
206
207 /// Layer normalization: `output = (input - mean) / sqrt(var + eps) * weight + bias`.
208 fn layer_norm(
209 &self,
210 output: &mut Self::Tensor,
211 input: &Self::Tensor,
212 weight: &Self::Tensor,
213 bias: &Self::Tensor,
214 rows: usize,
215 cols: usize,
216 eps: f32,
217 ) -> crate::Result<()>;
218
219 // --- Linear algebra ---
220
221 /// General matrix multiply: `output = A * B` (or `A * B^T` if `transpose_b`).
222 ///
223 /// Dimensions: A is `[m, k]`, B is `[k, n]` (or `[n, k]` if transposed),
224 /// output is `[m, n]`.
225 fn gemm(
226 &self,
227 a: &Self::Tensor,
228 b: &Self::Tensor,
229 output: &mut Self::Tensor,
230 m: usize,
231 n: usize,
232 k: usize,
233 transpose_b: bool,
234 ) -> crate::Result<()>;
235
236 /// Batched GEMM for multi-head attention.
237 ///
238 /// Performs `batch_count` independent GEMMs with strided access into
239 /// contiguous buffers. Used for per-head Q*K^T and attn*V.
240 fn gemm_batched(
241 &self,
242 a: &Self::Tensor,
243 b: &Self::Tensor,
244 output: &mut Self::Tensor,
245 m: usize,
246 n: usize,
247 k: usize,
248 transpose_b: bool,
249 stride_a: usize,
250 stride_b: usize,
251 stride_c: usize,
252 batch_count: usize,
253 ) -> crate::Result<()>;
254
255 // --- Attention ---
256
257 /// Fused scale + mask + softmax for attention scores.
258 ///
259 /// `scores = softmax(scores * scale + mask)` computed per-head.
260 fn fused_scale_mask_softmax(
261 &self,
262 scores: &mut Self::Tensor,
263 mask: &Self::Tensor,
264 batch: usize,
265 num_heads: usize,
266 seq_len: usize,
267 scale: f32,
268 ) -> crate::Result<()>;
269
270 /// Fused scale + mask + sliding window + softmax for attention scores.
271 ///
272 /// Like [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax) but
273 /// additionally masks out positions where `|query_pos - key_pos| > window_size / 2`.
274 /// Used by `ModernBERT`'s local attention layers.
275 fn fused_scale_mask_softmax_windowed(
276 &self,
277 scores: &mut Self::Tensor,
278 mask: &Self::Tensor,
279 batch: usize,
280 num_heads: usize,
281 seq_len: usize,
282 scale: f32,
283 window_size: usize,
284 ) -> crate::Result<()>;
285
286 /// Build a float attention mask from an integer mask.
287 ///
288 /// Converts `[batch * seq]` int mask (0/1) to `[batch * seq]` float mask
289 /// (0.0 / -10000.0) for use with [`fused_scale_mask_softmax`](Driver::fused_scale_mask_softmax).
290 fn build_attn_mask(
291 &self,
292 output: &mut Self::Tensor,
293 int_mask: &Self::Tensor,
294 n: usize,
295 ) -> crate::Result<()>;
296
297 /// Split a fused QKV projection into separate Q, K, V tensors.
298 fn qkv_split(
299 &self,
300 q: &mut Self::Tensor,
301 k: &mut Self::Tensor,
302 v: &mut Self::Tensor,
303 qkv: &Self::Tensor,
304 batch: usize,
305 seq: usize,
306 hidden: usize,
307 num_heads: usize,
308 head_dim: usize,
309 ) -> crate::Result<()>;
310
311 // --- Banded (local/sliding-window) attention ---
312
313 /// Banded Q@K^T: compute attention scores only within a sliding window.
314 ///
315 /// Output shape: `[batch * num_heads, seq, window]` (NOT `[seq, seq]`).
316 /// `scores[h, i, w]` = dot(Q[h, i, :], K[h, i - window/2 + w, :])
317 /// where out-of-bounds positions are set to `-inf` (masked in softmax).
318 ///
319 /// Reduces attention compute from O(seq²) to O(seq × window).
320 /// For `seq=512, window=128`: **4× less compute** per local layer.
321 fn banded_qk(
322 &self,
323 q: &Self::Tensor,
324 k: &Self::Tensor,
325 scores: &mut Self::Tensor,
326 batch_heads: usize,
327 seq: usize,
328 head_dim: usize,
329 window: usize,
330 stride_qk: usize,
331 stride_scores: usize,
332 ) -> crate::Result<()>;
333
334 /// Banded scores@V: weighted sum using banded attention scores.
335 ///
336 /// Input scores: `[batch * num_heads, seq, window]` (from `banded_qk`).
337 /// Output: `[batch * num_heads, seq, head_dim]`.
338 /// `output[h, i, d]` = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
339 fn banded_sv(
340 &self,
341 scores: &Self::Tensor,
342 v: &Self::Tensor,
343 output: &mut Self::Tensor,
344 batch_heads: usize,
345 seq: usize,
346 head_dim: usize,
347 window: usize,
348 stride_scores: usize,
349 stride_v: usize,
350 stride_out: usize,
351 ) -> crate::Result<()>;
352
353 /// Fused scale + softmax over the window dimension (no padding mask needed).
354 ///
355 /// Operates on `[batch * num_heads * seq, window]` rows.
356 fn banded_softmax(
357 &self,
358 scores: &mut Self::Tensor,
359 total_rows: usize,
360 window: usize,
361 scale: f32,
362 ) -> crate::Result<()>;
363
364 /// Reshape attention output from `[batch, num_heads, seq, head_dim]` to
365 /// `[batch * seq, hidden]`.
366 fn attn_reshape(
367 &self,
368 output: &mut Self::Tensor,
369 input: &Self::Tensor,
370 batch: usize,
371 seq: usize,
372 num_heads: usize,
373 head_dim: usize,
374 ) -> crate::Result<()>;
375
376 /// Apply Rotary Position Embedding (RoPE) to Q/K tensors.
377 ///
378 /// Used by ModernBERT (not ClassicBert which uses learned position embeddings).
379 fn apply_rope(
380 &self,
381 qk: &mut Self::Tensor,
382 cos: &Self::Tensor,
383 sin: &Self::Tensor,
384 num_rows: usize,
385 seq_len: usize,
386 head_dim: usize,
387 num_heads: usize,
388 ) -> crate::Result<()>;
389
390 // --- Tensor manipulation ---
391
392 /// Split a `[rows, 2*cols]` matrix into two `[rows, cols]` halves.
393 ///
394 /// Each row of `input` is `[first_half | second_half]`. The first `cols`
395 /// elements go to `first`, the remaining `cols` to `second`.
396 /// Used by `ModernBERT` for gated MLP splits.
397 fn split_gate_value(
398 &self,
399 first: &mut Self::Tensor,
400 second: &mut Self::Tensor,
401 input: &Self::Tensor,
402 rows: usize,
403 cols: usize,
404 ) -> crate::Result<()>;
405
406 // --- Activations ---
407
408 /// GELU activation (Gaussian Error Linear Unit), applied in-place.
409 fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
410
411 /// SwiGLU gated activation: `output = value * silu(gate)`.
412 ///
413 /// The gate and value come from splitting the intermediate projection.
414 fn swiglu(
415 &self,
416 value: &Self::Tensor,
417 gate: &Self::Tensor,
418 output: &mut Self::Tensor,
419 n: usize,
420 ) -> crate::Result<()>;
421
422 /// `GeGLU` gated activation: `output = gelu(value) * gate`.
423 ///
424 /// Used by `ModernBERT`. The value and gate come from splitting the
425 /// MLP `Wi` projection output in half.
426 fn geglu(
427 &self,
428 value: &Self::Tensor,
429 gate: &Self::Tensor,
430 output: &mut Self::Tensor,
431 n: usize,
432 ) -> crate::Result<()>;
433
434 /// Fused bias + GELU: `x = gelu(x + bias)` row-wise.
435 fn fused_bias_gelu(
436 &self,
437 x: &mut Self::Tensor,
438 bias: &Self::Tensor,
439 rows: usize,
440 cols: usize,
441 ) -> crate::Result<()>;
442
443 // --- Fused residual operations ---
444
445 /// Fused bias + residual add: `output = input + bias + residual`.
446 ///
447 /// Bias is broadcast row-wise (`cols`-wide) across `n / cols` rows.
448 fn fused_bias_residual(
449 &self,
450 output: &mut Self::Tensor,
451 input: &Self::Tensor,
452 bias: &Self::Tensor,
453 residual: &Self::Tensor,
454 n: usize,
455 cols: usize,
456 ) -> crate::Result<()>;
457
458 /// Fused residual add + layer normalization.
459 ///
460 /// `output = layer_norm(hidden + residual, weight, bias, eps)`.
461 fn fused_residual_layernorm(
462 &self,
463 output: &mut Self::Tensor,
464 hidden: &Self::Tensor,
465 residual: &Self::Tensor,
466 weight: &Self::Tensor,
467 bias: &Self::Tensor,
468 rows: usize,
469 cols: usize,
470 eps: f32,
471 ) -> crate::Result<()>;
472
473 /// Residual add without bias: `output = hidden + residual`.
474 ///
475 /// Used by `ModernBERT` which has no bias terms.
476 fn residual_add(
477 &self,
478 output: &mut Self::Tensor,
479 hidden: &Self::Tensor,
480 residual: &Self::Tensor,
481 n: usize,
482 ) -> crate::Result<()>;
483
484 /// Add bias to a matrix row-wise: `x[row] += bias` for each row.
485 fn add_bias(
486 &self,
487 x: &mut Self::Tensor,
488 bias: &Self::Tensor,
489 rows: usize,
490 cols: usize,
491 ) -> crate::Result<()>;
492
493 // --- Pooling ---
494
495 /// CLS pooling: extract the first token's hidden state per batch element.
496 fn cls_pool(
497 &self,
498 output: &mut Self::Tensor,
499 hidden: &Self::Tensor,
500 batch: usize,
501 seq: usize,
502 hidden_dim: usize,
503 ) -> crate::Result<()>;
504
505 /// Mean pooling: attention-mask-weighted average of hidden states.
506 fn mean_pool(
507 &self,
508 output: &mut Self::Tensor,
509 hidden: &Self::Tensor,
510 mask: &Self::Tensor,
511 batch: usize,
512 seq: usize,
513 hidden_dim: usize,
514 ) -> crate::Result<()>;
515
516 // --- Post-processing ---
517
518 /// L2-normalize each row vector in-place.
519 fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
520
521 /// Copy tensor data back to host memory as `Vec<Vec<f32>>`.
522 ///
523 /// Returns one `Vec<f32>` of length `dim` per batch element.
524 fn to_host(
525 &self,
526 tensor: &Self::Tensor,
527 batch: usize,
528 dim: usize,
529 ) -> crate::Result<Vec<Vec<f32>>>;
530
531 // =======================================================================
532 // FP16 operations for full half-precision pipeline
533 //
534 // These methods mirror the FP32 counterparts but operate on FP16 tensors.
535 // Internal reductions (softmax, layer-norm) use FP32 accumulators but
536 // all tensor I/O is half precision. Default implementations return an
537 // error — only backends with FP16 support override them.
538 // =======================================================================
539
540 /// Allocate a zero-initialized FP16 tensor with `n` half-precision elements.
541 ///
542 /// # Errors
543 ///
544 /// Returns an error if device memory allocation fails or FP16 is unsupported.
545 fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
546 Err(crate::Error::Metal(
547 "FP16 not supported by this driver".into(),
548 ))
549 }
550
551 /// Convert FP32 tensor to FP16 (element-wise narrowing).
552 fn f32_to_f16(
553 &self,
554 _output: &mut Self::Tensor,
555 _input: &Self::Tensor,
556 _n: usize,
557 ) -> crate::Result<()> {
558 Err(crate::Error::Metal(
559 "FP16 not supported by this driver".into(),
560 ))
561 }
562
563 /// Convert FP16 tensor back to FP32 (element-wise widening).
564 fn f16_to_f32(
565 &self,
566 _output: &mut Self::Tensor,
567 _input: &Self::Tensor,
568 _n: usize,
569 ) -> crate::Result<()> {
570 Err(crate::Error::Metal(
571 "FP16 not supported by this driver".into(),
572 ))
573 }
574
575 /// Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
576 fn gemm_mixed(
577 &self,
578 _a_f16: &Self::Tensor,
579 _b_f16: &Self::Tensor,
580 _output_f32: &mut Self::Tensor,
581 _m: usize,
582 _n: usize,
583 _k: usize,
584 _transpose_b: bool,
585 ) -> crate::Result<()> {
586 Err(crate::Error::Metal(
587 "gemm_mixed not supported by this driver".into(),
588 ))
589 }
590
591 /// FP16 GEMM: `output = A * B` (or `A * B^T`). All tensors are half.
592 fn gemm_f16(
593 &self,
594 _a: &Self::Tensor,
595 _b: &Self::Tensor,
596 _output: &mut Self::Tensor,
597 _m: usize,
598 _n: usize,
599 _k: usize,
600 _transpose_b: bool,
601 ) -> crate::Result<()> {
602 Err(crate::Error::Metal(
603 "FP16 not supported by this driver".into(),
604 ))
605 }
606
607 /// FP16 batched GEMM for multi-head attention. All tensors are half.
608 #[expect(
609 clippy::too_many_arguments,
610 reason = "matches FP32 gemm_batched signature"
611 )]
612 fn gemm_batched_f16(
613 &self,
614 _a: &Self::Tensor,
615 _b: &Self::Tensor,
616 _output: &mut Self::Tensor,
617 _m: usize,
618 _n: usize,
619 _k: usize,
620 _transpose_b: bool,
621 _stride_a: usize,
622 _stride_b: usize,
623 _stride_c: usize,
624 _batch_count: usize,
625 ) -> crate::Result<()> {
626 Err(crate::Error::Metal(
627 "FP16 not supported by this driver".into(),
628 ))
629 }
630
631 /// FP16 layer normalization. Half I/O, FP32 reductions.
632 fn layer_norm_f16(
633 &self,
634 _output: &mut Self::Tensor,
635 _input: &Self::Tensor,
636 _weight: &Self::Tensor,
637 _bias: &Self::Tensor,
638 _rows: usize,
639 _cols: usize,
640 _eps: f32,
641 ) -> crate::Result<()> {
642 Err(crate::Error::Metal(
643 "FP16 not supported by this driver".into(),
644 ))
645 }
646
647 /// FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
648 fn fused_scale_mask_softmax_f16(
649 &self,
650 _scores: &mut Self::Tensor,
651 _mask: &Self::Tensor,
652 _batch: usize,
653 _num_heads: usize,
654 _seq_len: usize,
655 _scale: f32,
656 ) -> crate::Result<()> {
657 Err(crate::Error::Metal(
658 "FP16 not supported by this driver".into(),
659 ))
660 }
661
662 /// FP16 fused scale + mask + sliding window + softmax.
663 fn fused_scale_mask_softmax_windowed_f16(
664 &self,
665 _scores: &mut Self::Tensor,
666 _mask: &Self::Tensor,
667 _batch: usize,
668 _num_heads: usize,
669 _seq_len: usize,
670 _scale: f32,
671 _window_size: usize,
672 ) -> crate::Result<()> {
673 Err(crate::Error::Metal(
674 "FP16 not supported by this driver".into(),
675 ))
676 }
677
678 /// FP16 QKV split: `[batch*seq, 3*hidden]` into Q, K, V per-head layout.
679 fn qkv_split_f16(
680 &self,
681 _q: &mut Self::Tensor,
682 _k: &mut Self::Tensor,
683 _v: &mut Self::Tensor,
684 _qkv: &Self::Tensor,
685 _batch: usize,
686 _seq: usize,
687 _hidden: usize,
688 _num_heads: usize,
689 _head_dim: usize,
690 ) -> crate::Result<()> {
691 Err(crate::Error::Metal(
692 "FP16 not supported by this driver".into(),
693 ))
694 }
695
696 /// FP16 attention output reshape: `[batch*num_heads, seq, head_dim]` to
697 /// `[batch*seq, hidden]`.
698 fn attn_reshape_f16(
699 &self,
700 _output: &mut Self::Tensor,
701 _input: &Self::Tensor,
702 _batch: usize,
703 _seq: usize,
704 _num_heads: usize,
705 _head_dim: usize,
706 ) -> crate::Result<()> {
707 Err(crate::Error::Metal(
708 "FP16 not supported by this driver".into(),
709 ))
710 }
711
712 /// FP16 scatter flat `[total_tokens, dim]` to padded `[batch, max_seq, dim]`.
713 fn pad_to_batch_f16(
714 &self,
715 _flat: &Self::Tensor,
716 _padded: &mut Self::Tensor,
717 _seq_lengths: &[usize],
718 _max_seq: usize,
719 _dim: usize,
720 ) -> crate::Result<()> {
721 Err(crate::Error::Metal(
722 "FP16 not supported by this driver".into(),
723 ))
724 }
725
726 /// FP16 gather padded `[batch, max_seq, dim]` back to flat `[total_tokens, dim]`.
727 fn unpad_from_batch_f16(
728 &self,
729 _padded: &Self::Tensor,
730 _flat: &mut Self::Tensor,
731 _seq_lengths: &[usize],
732 _max_seq: usize,
733 _dim: usize,
734 ) -> crate::Result<()> {
735 Err(crate::Error::Metal(
736 "FP16 not supported by this driver".into(),
737 ))
738 }
739
740 /// FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
741 fn rope_encode_f16(
742 &self,
743 _qk: &mut Self::Tensor,
744 _cos: &Self::Tensor,
745 _sin: &Self::Tensor,
746 _num_rows: usize,
747 _seq_len: usize,
748 _head_dim: usize,
749 _num_heads: usize,
750 ) -> crate::Result<()> {
751 Err(crate::Error::Metal(
752 "FP16 not supported by this driver".into(),
753 ))
754 }
755
756 /// FP16 `GeGLU` gated activation: `output = gelu(value) * gate`. Half I/O.
757 fn geglu_f16(
758 &self,
759 _value: &Self::Tensor,
760 _gate: &Self::Tensor,
761 _output: &mut Self::Tensor,
762 _n: usize,
763 ) -> crate::Result<()> {
764 Err(crate::Error::Metal(
765 "FP16 not supported by this driver".into(),
766 ))
767 }
768
769 /// FP16 fused residual add + layer normalization.
770 fn fused_residual_layernorm_f16(
771 &self,
772 _output: &mut Self::Tensor,
773 _hidden: &Self::Tensor,
774 _residual: &Self::Tensor,
775 _weight: &Self::Tensor,
776 _bias: &Self::Tensor,
777 _rows: usize,
778 _cols: usize,
779 _eps: f32,
780 ) -> crate::Result<()> {
781 Err(crate::Error::Metal(
782 "FP16 not supported by this driver".into(),
783 ))
784 }
785
786 /// FP16 residual add (no bias): `output = hidden + residual`.
787 fn residual_add_f16(
788 &self,
789 _output: &mut Self::Tensor,
790 _hidden: &Self::Tensor,
791 _residual: &Self::Tensor,
792 _n: usize,
793 ) -> crate::Result<()> {
794 Err(crate::Error::Metal(
795 "FP16 not supported by this driver".into(),
796 ))
797 }
798
799 /// FP16 split `[rows, 2*cols]` into two `[rows, cols]` halves.
800 fn split_gate_value_f16(
801 &self,
802 _first: &mut Self::Tensor,
803 _second: &mut Self::Tensor,
804 _input: &Self::Tensor,
805 _rows: usize,
806 _cols: usize,
807 ) -> crate::Result<()> {
808 Err(crate::Error::Metal(
809 "FP16 not supported by this driver".into(),
810 ))
811 }
812
813 /// Fused split + `GeGLU`: read `[rows, 2*cols]`, write `[rows, cols]`.
814 ///
815 /// Combines [`split_gate_value_f16`](Driver::split_gate_value_f16) and
816 /// [`geglu_f16`](Driver::geglu_f16) into a single kernel, eliminating
817 /// two intermediate `[rows, cols]` buffers and halving HBM round-trips.
818 ///
819 /// Default falls back to separate split + geglu calls.
820 fn fused_split_geglu_f16(
821 &self,
822 output: &mut Self::Tensor,
823 input: &Self::Tensor,
824 rows: usize,
825 cols: usize,
826 ) -> crate::Result<()> {
827 // Default: allocate intermediates and call separately.
828 let n = rows * cols;
829 let mut value = self.alloc_zeros_f16(n)?;
830 let mut gate = self.alloc_zeros_f16(n)?;
831 self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
832 self.geglu_f16(&value, &gate, output, n)
833 }
834
835 /// Fused pad + QKV split: flat `[total_tokens, 3*hidden]` → Q, K, V
836 /// each `[batch*heads, max_seq, head_dim]`.
837 ///
838 /// Eliminates the padded intermediate buffer. Default calls pad then split.
839 #[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
840 fn fused_pad_qkv_split_f16(
841 &self,
842 q: &mut Self::Tensor,
843 k: &mut Self::Tensor,
844 v: &mut Self::Tensor,
845 qkv_flat: &Self::Tensor,
846 seq_lengths: &[usize],
847 max_seq: usize,
848 batch: usize,
849 hidden: usize,
850 num_heads: usize,
851 head_dim: usize,
852 ) -> crate::Result<()> {
853 // Default: pad then split.
854 let padded_tokens = batch * max_seq;
855 let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
856 self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
857 self.qkv_split_f16(
858 q,
859 k,
860 v,
861 &qkv_padded,
862 batch,
863 max_seq,
864 hidden,
865 num_heads,
866 head_dim,
867 )
868 }
869
870 /// Fused attn_reshape + unpad: `[batch*heads, max_seq, head_dim]` →
871 /// `[total_tokens, hidden]`.
872 ///
873 /// Eliminates the padded context intermediate. Default calls reshape then unpad.
874 fn fused_reshape_unpad_f16(
875 &self,
876 flat: &mut Self::Tensor,
877 heads: &Self::Tensor,
878 seq_lengths: &[usize],
879 max_seq: usize,
880 batch: usize,
881 num_heads: usize,
882 head_dim: usize,
883 ) -> crate::Result<()> {
884 // Default: reshape then unpad.
885 let hidden = num_heads * head_dim;
886 let padded_tokens = batch * max_seq;
887 let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
888 self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
889 self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
890 }
891}
892
893/// Batch input tensors on device, produced by [`Driver::prepare_batch`].
894///
895/// Supports both padded and unpadded modes:
896/// - **Padded**: all sequences padded to `max_seq`. `cu_seqlens` is `None`.
897/// - **Unpadded**: sequences concatenated without padding. `cu_seqlens`
898/// contains cumulative lengths `[0, len0, len0+len1, ...]` so attention
899/// knows where each sequence starts. Eliminates ALL padding compute.
900pub struct BatchInputs<T> {
901 /// Token IDs — `[batch * max_seq]` (padded) or `[total_tokens]` (unpadded).
902 pub input_ids: T,
903 /// Attention mask `[batch * max_seq]` as int32 (0 or 1). Unused in unpadded mode.
904 pub attention_mask: T,
905 /// Token type IDs — same layout as `input_ids`.
906 pub token_type_ids: T,
907 /// Position IDs — same layout as `input_ids`.
908 pub position_ids: T,
909 /// Float attention bias mask `[batch * max_seq]` (0.0 or -1e9) for softmax.
910 pub float_mask: T,
911 /// Float pooling mask `[batch * max_seq]` (1.0 or 0.0) for mean pooling.
912 pub pooling_mask: T,
913 /// Number of sequences in this batch.
914 pub batch: usize,
915 /// Maximum sequence length (all sequences padded to this). In unpadded mode,
916 /// this is the longest sequence (used for workspace sizing, not padding).
917 pub max_seq: usize,
918 /// Total actual tokens across all sequences (no padding).
919 pub total_tokens: usize,
920 /// Per-sequence lengths: `[batch]` — each element is the actual token count.
921 pub seq_lengths: Vec<usize>,
922 /// Cumulative sequence lengths for unpadded attention: `[batch + 1]`.
923 /// `cu_seqlens[i]..cu_seqlens[i+1]` is the token range for sequence `i`.
924 /// `None` in padded mode (all sequences padded to max_seq).
925 pub cu_seqlens: Option<Vec<usize>>,
926}