Skip to main content

ripvec_core/backend/arch/
mod.rs

1//! Model architecture trait and variant enum.
2//!
3//! The [`ModelArch`] trait defines how a model architecture composes
4//! [`Driver`] primitives into a complete forward pass
5//! (embeddings -> encoder layers -> pooling -> normalization).
6//!
7//! Each architecture (ClassicBert, ModernBert) is implemented once
8//! and works with any driver backend via generics.
9
10pub mod classic_bert;
11pub mod modern_bert;
12
13use super::Encoding;
14use super::driver::Driver;
15
16/// Model architecture that composes [`Driver`] primitives into a forward pass.
17///
18/// Implementations store their weights (on device) and model config, then
19/// orchestrate the driver to execute embedding lookup, encoder layers, pooling,
20/// and L2 normalization.
21///
22/// # Type parameter
23///
24/// `D: Driver` — the hardware backend. Architectures are generic over the
25/// driver so they can be monomorphized for each backend (Metal, CUDA, CPU).
26pub trait ModelArch<D: Driver> {
27    /// Run the full forward pass: embeddings -> encoder layers -> pool -> L2 normalize.
28    ///
29    /// Returns one L2-normalized embedding vector per input encoding.
30    ///
31    /// # Errors
32    ///
33    /// Returns an error if any driver operation fails (buffer allocation,
34    /// kernel dispatch, synchronization, etc.).
35    fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
36}
37
38/// Supported model architectures.
39///
40/// Each variant corresponds to a distinct BERT family with different attention
41/// mechanisms, activations, position encodings, and pooling strategies.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ArchVariant {
44    /// BGE-small: learned position embeddings, GELU, CLS pooling, bias.
45    ClassicBert,
46    /// ModernBERT: alternating local/global attention, GeGLU, unpadding.
47    ModernBert,
48}