1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
//! Model architecture trait and variant enum.
//!
//! The [`ModelArch`] trait defines how a model architecture composes
//! [`Driver`] primitives into a complete forward pass
//! (embeddings -> encoder layers -> pooling -> normalization).
//!
//! Each architecture (ClassicBert, ModernBert) is implemented once
//! and works with any driver backend via generics.
use Encoding;
use Driver;
/// Model architecture that composes [`Driver`] primitives into a forward pass.
///
/// Implementations store their weights (on device) and model config, then
/// orchestrate the driver to execute embedding lookup, encoder layers, pooling,
/// and L2 normalization.
///
/// # Type parameter
///
/// `D: Driver` — the hardware backend. Architectures are generic over the
/// driver so they can be monomorphized for each backend (Metal, CUDA, CPU).
/// Supported model architectures.
///
/// Each variant corresponds to a distinct BERT family with different attention
/// mechanisms, activations, position encodings, and pooling strategies.