Skip to main content

ripvec_core/backend/
generic.rs

1//! Generic backend that pairs a [`Driver`] with a [`ModelArch`].
2//!
3//! [`GenericBackend`] implements [`EmbedBackend`] by delegating to the
4//! architecture's `forward()` method, which composes driver primitives into
5//! the full inference pipeline. This decouples weight loading from the
6//! backend interface — any `(Driver, ModelArch)` pair can serve as an
7//! embedding backend.
8//!
9//! The `_mmap` field keeps the memory-mapped safetensors file alive as long
10//! as the backend exists, since Metal zero-copy buffers reference its pages.
11
12use super::arch::ModelArch;
13use super::driver::Driver;
14use super::{EmbedBackend, Encoding};
15
16/// Generic backend that pairs a [`Driver`] with a [`ModelArch`].
17///
18/// Implements [`EmbedBackend`] by calling `arch.forward(driver, encodings)`.
19/// The driver provides hardware-specific compute primitives; the architecture
20/// orchestrates them into a full forward pass.
21///
22/// # Lifetime invariant
23///
24/// `_mmap` **must** be declared after `arch` so it is dropped last. The
25/// architecture's weight tensors reference pages in the memory-mapped file
26/// via zero-copy Metal buffers; dropping the mmap first would invalidate them.
27pub struct GenericBackend<D: Driver, A: ModelArch<D>> {
28    /// Hardware compute driver (Metal, CUDA, CPU).
29    driver: D,
30    /// Model architecture with loaded weights.
31    arch: A,
32    /// Maximum token count the model supports.
33    max_tokens: usize,
34    /// Whether this backend runs on a GPU.
35    is_gpu: bool,
36    /// Maximum encodings per forward pass. Larger batches saturate GPU SMs better
37    /// but use more memory. Default: 32 (Metal-tuned). CUDA can handle 128+.
38    max_batch: usize,
39    /// Keeps the memory-mapped safetensors file alive.
40    ///
41    /// Must outlive the weight tensors in `arch` — declared last for correct
42    /// drop order.
43    _mmap: memmap2::Mmap,
44}
45
46impl<D: Driver, A: ModelArch<D>> GenericBackend<D, A> {
47    /// Create a new generic backend from a driver, architecture, and mmap.
48    ///
49    /// The `mmap` must be the memory-mapped safetensors file whose pages back
50    /// the weight tensors stored in `arch`.
51    /// Create a new generic backend.
52    ///
53    /// For GPU backends, runs a warm-up forward pass to prime the buffer pool.
54    /// This is skipped for large models (max_tokens > 1024) where the warm-up
55    /// cost exceeds the benefit.
56    /// Create a new generic backend.
57    ///
58    /// `max_batch` controls how many encodings are sent in each forward pass.
59    /// Metal: 32 (optimal for M2 Max AMX). CUDA: 128+ (needs more work to
60    /// saturate 128 SMs on RTX 4090).
61    pub fn new(driver: D, arch: A, max_tokens: usize, is_gpu: bool, mmap: memmap2::Mmap) -> Self {
62        Self::with_max_batch(driver, arch, max_tokens, is_gpu, mmap, 32)
63    }
64
65    /// Create with explicit max batch size.
66    #[expect(clippy::cast_possible_wrap, reason = "warmup seq length is small")]
67    pub fn with_max_batch(
68        driver: D,
69        arch: A,
70        max_tokens: usize,
71        is_gpu: bool,
72        mmap: memmap2::Mmap,
73        max_batch: usize,
74    ) -> Self {
75        let backend = Self {
76            driver,
77            arch,
78            max_tokens,
79            is_gpu,
80            max_batch,
81            _mmap: mmap,
82        };
83        // Warm up buffer pool: run a dummy forward to pre-allocate Metal buffers.
84        // Without this, the first real batch pays 160-330 fresh newBufferWithLength
85        // calls. The warm-up fills the pool; subsequent batches with similar
86        // dimensions get exact-match hits (within 8× tolerance).
87        //
88        // Small models (BGE-small, 12L): batch=32 × seq=512, ~80ms.
89        // Large models (ModernBERT, 22L): batch=32 × seq=64, ~300ms.
90        //   (Smaller seq keeps cost down; 8× pool tolerance covers seq up to 512.)
91        if is_gpu && max_tokens <= 1024 {
92            let seq = if max_tokens <= 1024 {
93                512.min(max_tokens)
94            } else {
95                64
96            };
97            let mut dummy = Vec::with_capacity(32);
98            for _ in 0..32 {
99                let ids: Vec<i64> = (0..seq as i64).collect();
100                dummy.push(Encoding {
101                    input_ids: ids,
102                    attention_mask: vec![1; seq],
103                    token_type_ids: vec![0; seq],
104                });
105            }
106            let _ = backend.arch.forward(&backend.driver, &dummy);
107        }
108        backend
109    }
110}
111
112impl<D, A> EmbedBackend for GenericBackend<D, A>
113where
114    D: Driver + Send + Sync + 'static,
115    A: ModelArch<D> + Send + Sync + 'static,
116{
117    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
118        let max_batch = self.max_batch;
119        if encodings.len() <= max_batch {
120            return self.arch.forward(&self.driver, encodings);
121        }
122        let mut all = Vec::with_capacity(encodings.len());
123        for chunk in encodings.chunks(max_batch) {
124            let mut results = self.arch.forward(&self.driver, chunk)?;
125            all.append(&mut results);
126        }
127        Ok(all)
128    }
129
130    fn supports_clone(&self) -> bool {
131        false
132    }
133
134    fn clone_backend(&self) -> Box<dyn EmbedBackend> {
135        panic!("GenericBackend does not support cloning")
136    }
137
138    fn is_gpu(&self) -> bool {
139        self.is_gpu
140    }
141
142    fn max_tokens(&self) -> usize {
143        self.max_tokens
144    }
145}