Skip to main content

ripvec_core/backend/
generic.rs

1//! Generic backend that pairs a [`Driver`] with a [`ModelArch`].
2//!
3//! [`GenericBackend`] implements [`EmbedBackend`] by delegating to the
4//! architecture's `forward()` method, which composes driver primitives into
5//! the full inference pipeline. This decouples weight loading from the
6//! backend interface — any `(Driver, ModelArch)` pair can serve as an
7//! embedding backend.
8//!
9//! The `_mmap` field keeps the memory-mapped safetensors file alive as long
10//! as the backend exists, since Metal zero-copy buffers and CPU `MmapTensor`
11//! slices reference its pages.
12
13use super::arch::ModelArch;
14use super::driver::Driver;
15use super::{EmbedBackend, Encoding};
16
17/// Generic backend that pairs a [`Driver`] with a [`ModelArch`].
18///
19/// Implements [`EmbedBackend`] by calling `arch.forward(driver, encodings)`.
20/// The driver provides hardware-specific compute primitives; the architecture
21/// orchestrates them into a full forward pass.
22///
23/// # Lifetime invariant
24///
25/// `_mmap` **must** be declared after `arch` so it is dropped last. The
26/// architecture's weight tensors reference pages in the memory-mapped file
27/// via zero-copy Metal buffers or CPU `MmapTensor::Mapped` slices; dropping
28/// the mmap first would invalidate them.
29pub struct GenericBackend<D: Driver, A: ModelArch<D>> {
30    /// Hardware compute driver (Metal, CUDA, CPU).
31    driver: D,
32    /// Model architecture with loaded weights.
33    arch: A,
34    /// Maximum token count the model supports.
35    max_tokens: usize,
36    /// Whether this backend runs on a GPU.
37    is_gpu: bool,
38    /// Maximum encodings per forward pass. Larger batches saturate GPU SMs better
39    /// but use more memory. Default: 32 (Metal-tuned). CUDA can handle 128+.
40    max_batch: usize,
41    /// Keeps the memory-mapped safetensors file alive.
42    ///
43    /// Must outlive the weight tensors in `arch` — declared last for correct
44    /// drop order.
45    _mmap: MmapHolder,
46}
47
48/// Holds a memory-mapped file, accepting either an owned `Mmap` or an
49/// `Arc<Mmap>`. CPU backends share the `Arc` with `MmapTensor::Mapped`
50/// variants; GPU backends pass an owned `Mmap`.
51pub enum MmapHolder {
52    /// Owned mmap (GPU backends — tensors are device copies, not mmap refs).
53    Owned(memmap2::Mmap),
54    /// Shared mmap (CPU backend — `MmapTensor::Mapped` variants hold `Arc` clones).
55    Shared(std::sync::Arc<memmap2::Mmap>),
56}
57
58impl<D: Driver, A: ModelArch<D>> GenericBackend<D, A> {
59    /// Create a new generic backend from a driver, architecture, and mmap.
60    ///
61    /// The `mmap` must be the memory-mapped safetensors file whose pages back
62    /// the weight tensors stored in `arch`.
63    ///
64    /// For GPU backends, runs a warm-up forward pass to prime the buffer pool.
65    /// This is skipped for large models (max_tokens > 1024) where the warm-up
66    /// cost exceeds the benefit.
67    ///
68    /// `max_batch` controls how many encodings are sent in each forward pass.
69    /// Metal: 32 (optimal for M2 Max AMX). CUDA: 128+ (needs more work to
70    /// saturate 128 SMs on RTX 4090).
71    pub fn new(driver: D, arch: A, max_tokens: usize, is_gpu: bool, mmap: memmap2::Mmap) -> Self {
72        Self::with_max_batch(
73            driver,
74            arch,
75            max_tokens,
76            is_gpu,
77            MmapHolder::Owned(mmap),
78            32,
79        )
80    }
81
82    /// Create with an `Arc<Mmap>` shared with zero-copy tensors.
83    pub fn new_shared(
84        driver: D,
85        arch: A,
86        max_tokens: usize,
87        is_gpu: bool,
88        mmap: std::sync::Arc<memmap2::Mmap>,
89    ) -> Self {
90        Self::with_max_batch(
91            driver,
92            arch,
93            max_tokens,
94            is_gpu,
95            MmapHolder::Shared(mmap),
96            32,
97        )
98    }
99
100    /// Create with explicit max batch size.
101    #[expect(clippy::cast_possible_wrap, reason = "warmup seq length is small")]
102    pub fn with_max_batch(
103        driver: D,
104        arch: A,
105        max_tokens: usize,
106        is_gpu: bool,
107        mmap: MmapHolder,
108        max_batch: usize,
109    ) -> Self {
110        let backend = Self {
111            driver,
112            arch,
113            max_tokens,
114            is_gpu,
115            max_batch,
116            _mmap: mmap,
117        };
118        // Warm up buffer pool: run a dummy forward to pre-allocate Metal buffers.
119        // Without this, the first real batch pays 160-330 fresh newBufferWithLength
120        // calls. The warm-up fills the pool; subsequent batches with similar
121        // dimensions get exact-match hits (within 8× tolerance).
122        //
123        // Small models (BGE-small, 12L): batch=32 × seq=512, ~80ms.
124        // Large models (ModernBERT, 22L): batch=32 × seq=64, ~300ms.
125        //   (Smaller seq keeps cost down; 8× pool tolerance covers seq up to 512.)
126        if is_gpu && max_tokens <= 1024 {
127            let seq = if max_tokens <= 1024 {
128                512.min(max_tokens)
129            } else {
130                64
131            };
132            let mut dummy = Vec::with_capacity(32);
133            for _ in 0..32 {
134                let ids: Vec<i64> = (0..seq as i64).collect();
135                dummy.push(Encoding {
136                    input_ids: ids,
137                    attention_mask: vec![1; seq],
138                    token_type_ids: vec![0; seq],
139                });
140            }
141            let _ = backend.arch.forward(&backend.driver, &dummy);
142        }
143        backend
144    }
145}
146
147impl<D, A> EmbedBackend for GenericBackend<D, A>
148where
149    D: Driver + Send + Sync + 'static,
150    A: ModelArch<D> + Send + Sync + 'static,
151{
152    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
153        let max_batch = self.max_batch;
154        if encodings.len() <= max_batch {
155            return self.arch.forward(&self.driver, encodings);
156        }
157        let mut all = Vec::with_capacity(encodings.len());
158        for chunk in encodings.chunks(max_batch) {
159            let mut results = self.arch.forward(&self.driver, chunk)?;
160            all.append(&mut results);
161        }
162        Ok(all)
163    }
164
165    fn supports_clone(&self) -> bool {
166        false
167    }
168
169    fn clone_backend(&self) -> Box<dyn EmbedBackend> {
170        panic!("GenericBackend does not support cloning")
171    }
172
173    fn is_gpu(&self) -> bool {
174        self.is_gpu
175    }
176
177    fn max_tokens(&self) -> usize {
178        self.max_tokens
179    }
180}