Skip to main content

rlx_embed/
registry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Model registry with metadata for all supported text and image embedding models.
17
18use super::pooling::Pooling;
19use std::collections::HashMap;
20use std::sync::OnceLock;
21
22/// Supported text embedding models.
23///
24/// Each variant maps to a specific HuggingFace model repository with
25/// pre-trained safetensors weights. Use [`EmbeddingModel::get_info`] to
26/// access metadata (dimension, pooling strategy, max sequence length).
27#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
28pub enum EmbeddingModel {
29    // ── MiniLM / MPNet ──────────────────────────────────────────────────────
30    /// sentence-transformers/all-MiniLM-L6-v2 — 384-dim, 6 layers, mean pooling.
31    #[default]
32    AllMiniLML6V2,
33    /// sentence-transformers/all-MiniLM-L12-v2 — 384-dim, 12 layers, mean pooling.
34    AllMiniLML12V2,
35    /// sentence-transformers/all-mpnet-base-v2 — 768-dim, 12 layers, mean pooling.
36    /// Note: mpnet architecture — uses different attention key naming.
37    AllMpnetBaseV2,
38
39    // ── BGE ─────────────────────────────────────────────────────────────────
40    /// BAAI/bge-small-en-v1.5 — 384-dim, 12 layers, CLS pooling.
41    BGESmallENV15,
42    /// BAAI/bge-base-en-v1.5 — 768-dim, 12 layers, CLS pooling.
43    BGEBaseENV15,
44    /// BAAI/bge-large-en-v1.5 — 1024-dim, 24 layers, CLS pooling.
45    BGELargeENV15,
46    /// BAAI/bge-small-zh-v1.5 — 512-dim, 12 layers, CLS pooling.
47    BGESmallZHV15,
48    /// BAAI/bge-large-zh-v1.5 — 1024-dim, 24 layers, CLS pooling.
49    BGELargeZHV15,
50
51    // ── Multilingual E5 ─────────────────────────────────────────────────────
52    /// intfloat/multilingual-e5-small — 384-dim, 12 layers, mean pooling.
53    MultilingualE5Small,
54    /// intfloat/multilingual-e5-base — 768-dim, 12 layers, mean pooling.
55    MultilingualE5Base,
56    /// intfloat/multilingual-e5-large — 1024-dim, 24 layers, mean pooling.
57    MultilingualE5Large,
58
59    // ── Paraphrase ──────────────────────────────────────────────────────────
60    /// sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 — 384-dim, 12 layers, mean pooling.
61    ParaphraseMLMiniLML12V2,
62    /// sentence-transformers/paraphrase-multilingual-mpnet-base-v2 — 768-dim, 12 layers, mean pooling.
63    ParaphraseMLMpnetBaseV2,
64
65    // ── Snowflake Arctic ────────────────────────────────────────────────────
66    /// snowflake/snowflake-arctic-embed-xs — 384-dim, CLS pooling.
67    SnowflakeArcticEmbedXS,
68    /// snowflake/snowflake-arctic-embed-s — 384-dim, CLS pooling.
69    SnowflakeArcticEmbedS,
70    /// Snowflake/snowflake-arctic-embed-m — 768-dim, CLS pooling.
71    SnowflakeArcticEmbedM,
72    /// snowflake/snowflake-arctic-embed-l — 1024-dim, CLS pooling.
73    SnowflakeArcticEmbedL,
74
75    // ── MxBai ───────────────────────────────────────────────────────────────
76    /// mixedbread-ai/mxbai-embed-large-v1 — 1024-dim, CLS pooling.
77    MxbaiEmbedLargeV1,
78
79    // ── Nomic ───────────────────────────────────────────────────────────────
80    /// nomic-ai/nomic-embed-text-v1.5 — 768-dim, 12 layers, mean pooling, RoPE, SwiGLU.
81    NomicEmbedTextV15,
82}
83
84/// Supported image embedding models.
85#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
86pub enum ImageEmbeddingModel {
87    /// nomic-ai/nomic-embed-vision-v1.5 — 768-dim, 12 layers, CLS pooling, 224px.
88    #[default]
89    NomicEmbedVisionV15,
90}
91
92/// Metadata for an image embedding model.
93#[derive(Debug, Clone)]
94pub struct ImageModelInfo {
95    pub model: ImageEmbeddingModel,
96    pub dim: usize,
97    pub description: &'static str,
98    pub hf_repo: &'static str,
99    pub model_file: &'static str,
100    pub img_size: usize,
101}
102
103static IMAGE_MODEL_MAP: OnceLock<HashMap<ImageEmbeddingModel, ImageModelInfo>> = OnceLock::new();
104
105fn init_image_models_map() -> HashMap<ImageEmbeddingModel, ImageModelInfo> {
106    vec![ImageModelInfo {
107        model: ImageEmbeddingModel::NomicEmbedVisionV15,
108        dim: 768,
109        description: "Nomic embed vision v1.5, 12 layers, 224px",
110        hf_repo: "nomic-ai/nomic-embed-vision-v1.5",
111        model_file: "model.safetensors",
112        img_size: 224,
113    }]
114    .into_iter()
115    .map(|info| (info.model.clone(), info))
116    .collect()
117}
118
119impl ImageEmbeddingModel {
120    pub fn get_info(&self) -> Option<&'static ImageModelInfo> {
121        IMAGE_MODEL_MAP.get_or_init(init_image_models_map).get(self)
122    }
123}
124
125/// Model architecture type.
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub enum ModelArch {
128    /// Standard BERT architecture (absolute position embeddings, GELU FFN).
129    Bert,
130    /// NomicBERT architecture (RoPE, fused QKV, SwiGLU FFN).
131    NomicBert,
132}
133
134/// Metadata for an embedding model.
135#[derive(Debug, Clone)]
136pub struct ModelInfo {
137    pub model: EmbeddingModel,
138    /// Output embedding dimension.
139    pub dim: usize,
140    /// Human-readable description.
141    pub description: &'static str,
142    /// HuggingFace repository ID.
143    pub hf_repo: &'static str,
144    /// Weight file name within the repository.
145    pub model_file: &'static str,
146    /// Pooling strategy for this model.
147    pub pooling: Pooling,
148    /// Maximum input sequence length.
149    pub max_length: usize,
150    /// Model architecture.
151    pub arch: ModelArch,
152}
153
154static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo>> = OnceLock::new();
155
156fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo> {
157    vec![
158        ModelInfo {
159            model: EmbeddingModel::AllMiniLML6V2,
160            dim: 384,
161            description: "MiniLM-L6-v2, 6 layers, fast and lightweight",
162            hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
163            model_file: "model.safetensors",
164            pooling: Pooling::Mean,
165            max_length: 256,
166            arch: ModelArch::Bert,
167        },
168        ModelInfo {
169            model: EmbeddingModel::AllMiniLML12V2,
170            dim: 384,
171            description: "MiniLM-L12-v2, 12 layers, higher quality",
172            hf_repo: "sentence-transformers/all-MiniLM-L12-v2",
173            model_file: "model.safetensors",
174            pooling: Pooling::Mean,
175            max_length: 256,
176            arch: ModelArch::Bert,
177        },
178        ModelInfo {
179            model: EmbeddingModel::BGESmallENV15,
180            dim: 384,
181            description: "BGE small English v1.5, compact and fast",
182            hf_repo: "BAAI/bge-small-en-v1.5",
183            model_file: "model.safetensors",
184            pooling: Pooling::Cls,
185            max_length: 512,
186            arch: ModelArch::Bert,
187        },
188        ModelInfo {
189            model: EmbeddingModel::BGEBaseENV15,
190            dim: 768,
191            description: "BGE base English v1.5, balanced quality and speed",
192            hf_repo: "BAAI/bge-base-en-v1.5",
193            model_file: "model.safetensors",
194            pooling: Pooling::Cls,
195            max_length: 512,
196            arch: ModelArch::Bert,
197        },
198        ModelInfo {
199            model: EmbeddingModel::BGELargeENV15,
200            dim: 1024,
201            description: "BGE large English v1.5, highest quality",
202            hf_repo: "BAAI/bge-large-en-v1.5",
203            model_file: "model.safetensors",
204            pooling: Pooling::Cls,
205            max_length: 512,
206            arch: ModelArch::Bert,
207        },
208        ModelInfo {
209            model: EmbeddingModel::BGESmallZHV15,
210            dim: 512,
211            description: "BGE small Chinese v1.5, CLS pooling",
212            hf_repo: "BAAI/bge-small-zh-v1.5",
213            model_file: "model.safetensors",
214            pooling: Pooling::Cls,
215            max_length: 512,
216            arch: ModelArch::Bert,
217        },
218        ModelInfo {
219            model: EmbeddingModel::ParaphraseMLMiniLML12V2,
220            dim: 384,
221            description: "Paraphrase multilingual MiniLM L12 v2, mean pooling",
222            hf_repo: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
223            model_file: "model.safetensors",
224            pooling: Pooling::Mean,
225            max_length: 128,
226            arch: ModelArch::Bert,
227        },
228        // mpnet
229        ModelInfo {
230            model: EmbeddingModel::AllMpnetBaseV2,
231            dim: 768,
232            description: "mpnet-base-v2, strong general-purpose embeddings",
233            hf_repo: "sentence-transformers/all-mpnet-base-v2",
234            model_file: "model.safetensors",
235            pooling: Pooling::Mean,
236            max_length: 384,
237            arch: ModelArch::Bert,
238        },
239        // BGE Chinese large
240        ModelInfo {
241            model: EmbeddingModel::BGELargeZHV15,
242            dim: 1024,
243            description: "BGE large Chinese v1.5",
244            hf_repo: "BAAI/bge-large-zh-v1.5",
245            model_file: "model.safetensors",
246            pooling: Pooling::Cls,
247            max_length: 512,
248            arch: ModelArch::Bert,
249        },
250        // Multilingual E5
251        ModelInfo {
252            model: EmbeddingModel::MultilingualE5Small,
253            dim: 384,
254            description: "Multilingual E5 small, 100+ languages",
255            hf_repo: "intfloat/multilingual-e5-small",
256            model_file: "model.safetensors",
257            pooling: Pooling::Mean,
258            max_length: 512,
259            arch: ModelArch::Bert,
260        },
261        ModelInfo {
262            model: EmbeddingModel::MultilingualE5Base,
263            dim: 768,
264            description: "Multilingual E5 base, 100+ languages",
265            hf_repo: "intfloat/multilingual-e5-base",
266            model_file: "model.safetensors",
267            pooling: Pooling::Mean,
268            max_length: 512,
269            arch: ModelArch::Bert,
270        },
271        ModelInfo {
272            model: EmbeddingModel::MultilingualE5Large,
273            dim: 1024,
274            description: "Multilingual E5 large, 100+ languages",
275            hf_repo: "intfloat/multilingual-e5-large",
276            model_file: "model.safetensors",
277            pooling: Pooling::Mean,
278            max_length: 512,
279            arch: ModelArch::Bert,
280        },
281        // Paraphrase mpnet
282        ModelInfo {
283            model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
284            dim: 768,
285            description: "Paraphrase multilingual mpnet base v2",
286            hf_repo: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
287            model_file: "model.safetensors",
288            pooling: Pooling::Mean,
289            max_length: 384,
290            arch: ModelArch::Bert,
291        },
292        // Snowflake Arctic
293        ModelInfo {
294            model: EmbeddingModel::SnowflakeArcticEmbedXS,
295            dim: 384,
296            description: "Snowflake Arctic Embed XS",
297            hf_repo: "snowflake/snowflake-arctic-embed-xs",
298            model_file: "model.safetensors",
299            pooling: Pooling::Cls,
300            max_length: 512,
301            arch: ModelArch::Bert,
302        },
303        ModelInfo {
304            model: EmbeddingModel::SnowflakeArcticEmbedS,
305            dim: 384,
306            description: "Snowflake Arctic Embed S",
307            hf_repo: "snowflake/snowflake-arctic-embed-s",
308            model_file: "model.safetensors",
309            pooling: Pooling::Cls,
310            max_length: 512,
311            arch: ModelArch::Bert,
312        },
313        ModelInfo {
314            model: EmbeddingModel::SnowflakeArcticEmbedM,
315            dim: 768,
316            description: "Snowflake Arctic Embed M",
317            hf_repo: "Snowflake/snowflake-arctic-embed-m",
318            model_file: "model.safetensors",
319            pooling: Pooling::Cls,
320            max_length: 512,
321            arch: ModelArch::Bert,
322        },
323        ModelInfo {
324            model: EmbeddingModel::SnowflakeArcticEmbedL,
325            dim: 1024,
326            description: "Snowflake Arctic Embed L",
327            hf_repo: "snowflake/snowflake-arctic-embed-l",
328            model_file: "model.safetensors",
329            pooling: Pooling::Cls,
330            max_length: 512,
331            arch: ModelArch::Bert,
332        },
333        // MxBai
334        ModelInfo {
335            model: EmbeddingModel::MxbaiEmbedLargeV1,
336            dim: 1024,
337            description: "MxBai embed large v1",
338            hf_repo: "mixedbread-ai/mxbai-embed-large-v1",
339            model_file: "model.safetensors",
340            pooling: Pooling::Cls,
341            max_length: 512,
342            arch: ModelArch::Bert,
343        },
344        // Nomic
345        ModelInfo {
346            model: EmbeddingModel::NomicEmbedTextV15,
347            dim: 768,
348            description: "Nomic embed text v1.5, RoPE, SwiGLU, 8192 context",
349            hf_repo: "nomic-ai/nomic-embed-text-v1.5",
350            model_file: "model.safetensors",
351            pooling: Pooling::Mean,
352            max_length: 8192,
353            arch: ModelArch::NomicBert,
354        },
355    ]
356    .into_iter()
357    .map(|info| (info.model.clone(), info))
358    .collect()
359}
360
361/// Get the global model registry.
362pub fn models_map() -> &'static HashMap<EmbeddingModel, ModelInfo> {
363    MODEL_MAP.get_or_init(init_models_map)
364}
365
366impl EmbeddingModel {
367    /// Look up metadata for this model.
368    pub fn get_info(&self) -> Option<&'static ModelInfo> {
369        models_map().get(self)
370    }
371
372    /// List all supported models.
373    pub fn list_supported() -> Vec<&'static ModelInfo> {
374        models_map().values().collect()
375    }
376}