ripvec_core/backend/generic.rs
1//! Generic backend that pairs a [`Driver`] with a [`ModelArch`].
2//!
3//! [`GenericBackend`] implements [`EmbedBackend`] by delegating to the
4//! architecture's `forward()` method, which composes driver primitives into
5//! the full inference pipeline. This decouples weight loading from the
6//! backend interface — any `(Driver, ModelArch)` pair can serve as an
7//! embedding backend.
8//!
9//! The `_mmap` field keeps the memory-mapped safetensors file alive as long
10//! as the backend exists, since Metal zero-copy buffers and CPU `MmapTensor`
11//! slices reference its pages.
12
13use super::arch::ModelArch;
14use super::driver::Driver;
15use super::{EmbedBackend, Encoding};
16
17/// Generic backend that pairs a [`Driver`] with a [`ModelArch`].
18///
19/// Implements [`EmbedBackend`] by calling `arch.forward(driver, encodings)`.
20/// The driver provides hardware-specific compute primitives; the architecture
21/// orchestrates them into a full forward pass.
22///
23/// # Lifetime invariant
24///
25/// `_mmap` **must** be declared after `arch` so it is dropped last. The
26/// architecture's weight tensors reference pages in the memory-mapped file
27/// via zero-copy Metal buffers or CPU `MmapTensor::Mapped` slices; dropping
28/// the mmap first would invalidate them.
29pub struct GenericBackend<D: Driver, A: ModelArch<D>> {
30 /// Hardware compute driver (Metal, CUDA, CPU).
31 driver: D,
32 /// Model architecture with loaded weights.
33 arch: A,
34 /// Maximum token count the model supports.
35 max_tokens: usize,
36 /// Whether this backend runs on a GPU.
37 is_gpu: bool,
38 /// Maximum encodings per forward pass. Larger batches saturate GPU SMs better
39 /// but use more memory. Default: 32 (Metal-tuned). CUDA can handle 128+.
40 max_batch: usize,
41 /// Keeps the memory-mapped safetensors file alive.
42 ///
43 /// Must outlive the weight tensors in `arch` — declared last for correct
44 /// drop order.
45 _mmap: MmapHolder,
46}
47
48/// Holds a memory-mapped file, accepting either an owned `Mmap` or an
49/// `Arc<Mmap>`. CPU backends share the `Arc` with `MmapTensor::Mapped`
50/// variants; GPU backends pass an owned `Mmap`.
51pub enum MmapHolder {
52 /// Owned mmap (GPU backends — tensors are device copies, not mmap refs).
53 Owned(memmap2::Mmap),
54 /// Shared mmap (CPU backend — `MmapTensor::Mapped` variants hold `Arc` clones).
55 Shared(std::sync::Arc<memmap2::Mmap>),
56}
57
58impl<D: Driver, A: ModelArch<D>> GenericBackend<D, A> {
59 /// Create a new generic backend from a driver, architecture, and mmap.
60 ///
61 /// The `mmap` must be the memory-mapped safetensors file whose pages back
62 /// the weight tensors stored in `arch`.
63 ///
64 /// For GPU backends, runs a warm-up forward pass to prime the buffer pool.
65 /// This is skipped for large models (max_tokens > 1024) where the warm-up
66 /// cost exceeds the benefit.
67 ///
68 /// `max_batch` controls how many encodings are sent in each forward pass.
69 /// Metal: 32 (optimal for M2 Max AMX). CUDA: 128+ (needs more work to
70 /// saturate 128 SMs on RTX 4090).
71 pub fn new(driver: D, arch: A, max_tokens: usize, is_gpu: bool, mmap: memmap2::Mmap) -> Self {
72 Self::with_max_batch(
73 driver,
74 arch,
75 max_tokens,
76 is_gpu,
77 MmapHolder::Owned(mmap),
78 32,
79 )
80 }
81
82 /// Create with an `Arc<Mmap>` shared with zero-copy tensors.
83 pub fn new_shared(
84 driver: D,
85 arch: A,
86 max_tokens: usize,
87 is_gpu: bool,
88 mmap: std::sync::Arc<memmap2::Mmap>,
89 ) -> Self {
90 Self::with_max_batch(
91 driver,
92 arch,
93 max_tokens,
94 is_gpu,
95 MmapHolder::Shared(mmap),
96 32,
97 )
98 }
99
100 /// Create with explicit max batch size.
101 #[expect(clippy::cast_possible_wrap, reason = "warmup seq length is small")]
102 pub fn with_max_batch(
103 driver: D,
104 arch: A,
105 max_tokens: usize,
106 is_gpu: bool,
107 mmap: MmapHolder,
108 max_batch: usize,
109 ) -> Self {
110 let backend = Self {
111 driver,
112 arch,
113 max_tokens,
114 is_gpu,
115 max_batch,
116 _mmap: mmap,
117 };
118 // Warm up buffer pool: run a dummy forward to pre-allocate Metal buffers.
119 // Without this, the first real batch pays 160-330 fresh newBufferWithLength
120 // calls. The warm-up fills the pool; subsequent batches with similar
121 // dimensions get exact-match hits (within 8× tolerance).
122 //
123 // Small models (BGE-small, 12L): batch=32 × seq=512, ~80ms.
124 // Large models (ModernBERT, 22L): batch=32 × seq=64, ~300ms.
125 // (Smaller seq keeps cost down; 8× pool tolerance covers seq up to 512.)
126 if is_gpu && max_tokens <= 1024 {
127 let seq = if max_tokens <= 1024 {
128 512.min(max_tokens)
129 } else {
130 64
131 };
132 let mut dummy = Vec::with_capacity(32);
133 for _ in 0..32 {
134 let ids: Vec<i64> = (0..seq as i64).collect();
135 dummy.push(Encoding {
136 input_ids: ids,
137 attention_mask: vec![1; seq],
138 token_type_ids: vec![0; seq],
139 });
140 }
141 let _ = backend.arch.forward(&backend.driver, &dummy);
142 }
143 backend
144 }
145}
146
147impl<D, A> EmbedBackend for GenericBackend<D, A>
148where
149 D: Driver + Send + Sync + 'static,
150 A: ModelArch<D> + Send + Sync + 'static,
151{
152 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
153 let max_batch = self.max_batch;
154 if encodings.len() <= max_batch {
155 return self.arch.forward(&self.driver, encodings);
156 }
157 let mut all = Vec::with_capacity(encodings.len());
158 for chunk in encodings.chunks(max_batch) {
159 let mut results = self.arch.forward(&self.driver, chunk)?;
160 all.append(&mut results);
161 }
162 Ok(all)
163 }
164
165 fn supports_clone(&self) -> bool {
166 false
167 }
168
169 fn clone_backend(&self) -> Box<dyn EmbedBackend> {
170 panic!("GenericBackend does not support cloning")
171 }
172
173 fn is_gpu(&self) -> bool {
174 self.is_gpu
175 }
176
177 fn max_tokens(&self) -> usize {
178 self.max_tokens
179 }
180}