1use super::pooling::Pooling;
19use std::collections::HashMap;
20use std::sync::OnceLock;
21
22#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
28pub enum EmbeddingModel {
29 #[default]
32 AllMiniLML6V2,
33 AllMiniLML12V2,
35 AllMpnetBaseV2,
38
39 BGESmallENV15,
42 BGEBaseENV15,
44 BGELargeENV15,
46 BGESmallZHV15,
48 BGELargeZHV15,
50
51 MultilingualE5Small,
54 MultilingualE5Base,
56 MultilingualE5Large,
58
59 ParaphraseMLMiniLML12V2,
62 ParaphraseMLMpnetBaseV2,
64
65 SnowflakeArcticEmbedXS,
68 SnowflakeArcticEmbedS,
70 SnowflakeArcticEmbedM,
72 SnowflakeArcticEmbedL,
74
75 MxbaiEmbedLargeV1,
78
79 NomicEmbedTextV15,
82}
83
84#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
86pub enum ImageEmbeddingModel {
87 #[default]
89 NomicEmbedVisionV15,
90}
91
92#[derive(Debug, Clone)]
94pub struct ImageModelInfo {
95 pub model: ImageEmbeddingModel,
96 pub dim: usize,
97 pub description: &'static str,
98 pub hf_repo: &'static str,
99 pub model_file: &'static str,
100 pub img_size: usize,
101}
102
103static IMAGE_MODEL_MAP: OnceLock<HashMap<ImageEmbeddingModel, ImageModelInfo>> = OnceLock::new();
104
105fn init_image_models_map() -> HashMap<ImageEmbeddingModel, ImageModelInfo> {
106 vec![ImageModelInfo {
107 model: ImageEmbeddingModel::NomicEmbedVisionV15,
108 dim: 768,
109 description: "Nomic embed vision v1.5, 12 layers, 224px",
110 hf_repo: "nomic-ai/nomic-embed-vision-v1.5",
111 model_file: "model.safetensors",
112 img_size: 224,
113 }]
114 .into_iter()
115 .map(|info| (info.model.clone(), info))
116 .collect()
117}
118
119impl ImageEmbeddingModel {
120 pub fn get_info(&self) -> Option<&'static ImageModelInfo> {
121 IMAGE_MODEL_MAP.get_or_init(init_image_models_map).get(self)
122 }
123}
124
125#[derive(Debug, Clone, PartialEq, Eq)]
127pub enum ModelArch {
128 Bert,
130 NomicBert,
132}
133
134#[derive(Debug, Clone)]
136pub struct ModelInfo {
137 pub model: EmbeddingModel,
138 pub dim: usize,
140 pub description: &'static str,
142 pub hf_repo: &'static str,
144 pub model_file: &'static str,
146 pub pooling: Pooling,
148 pub max_length: usize,
150 pub arch: ModelArch,
152}
153
154static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo>> = OnceLock::new();
155
156fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo> {
157 vec![
158 ModelInfo {
159 model: EmbeddingModel::AllMiniLML6V2,
160 dim: 384,
161 description: "MiniLM-L6-v2, 6 layers, fast and lightweight",
162 hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
163 model_file: "model.safetensors",
164 pooling: Pooling::Mean,
165 max_length: 256,
166 arch: ModelArch::Bert,
167 },
168 ModelInfo {
169 model: EmbeddingModel::AllMiniLML12V2,
170 dim: 384,
171 description: "MiniLM-L12-v2, 12 layers, higher quality",
172 hf_repo: "sentence-transformers/all-MiniLM-L12-v2",
173 model_file: "model.safetensors",
174 pooling: Pooling::Mean,
175 max_length: 256,
176 arch: ModelArch::Bert,
177 },
178 ModelInfo {
179 model: EmbeddingModel::BGESmallENV15,
180 dim: 384,
181 description: "BGE small English v1.5, compact and fast",
182 hf_repo: "BAAI/bge-small-en-v1.5",
183 model_file: "model.safetensors",
184 pooling: Pooling::Cls,
185 max_length: 512,
186 arch: ModelArch::Bert,
187 },
188 ModelInfo {
189 model: EmbeddingModel::BGEBaseENV15,
190 dim: 768,
191 description: "BGE base English v1.5, balanced quality and speed",
192 hf_repo: "BAAI/bge-base-en-v1.5",
193 model_file: "model.safetensors",
194 pooling: Pooling::Cls,
195 max_length: 512,
196 arch: ModelArch::Bert,
197 },
198 ModelInfo {
199 model: EmbeddingModel::BGELargeENV15,
200 dim: 1024,
201 description: "BGE large English v1.5, highest quality",
202 hf_repo: "BAAI/bge-large-en-v1.5",
203 model_file: "model.safetensors",
204 pooling: Pooling::Cls,
205 max_length: 512,
206 arch: ModelArch::Bert,
207 },
208 ModelInfo {
209 model: EmbeddingModel::BGESmallZHV15,
210 dim: 512,
211 description: "BGE small Chinese v1.5, CLS pooling",
212 hf_repo: "BAAI/bge-small-zh-v1.5",
213 model_file: "model.safetensors",
214 pooling: Pooling::Cls,
215 max_length: 512,
216 arch: ModelArch::Bert,
217 },
218 ModelInfo {
219 model: EmbeddingModel::ParaphraseMLMiniLML12V2,
220 dim: 384,
221 description: "Paraphrase multilingual MiniLM L12 v2, mean pooling",
222 hf_repo: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
223 model_file: "model.safetensors",
224 pooling: Pooling::Mean,
225 max_length: 128,
226 arch: ModelArch::Bert,
227 },
228 ModelInfo {
230 model: EmbeddingModel::AllMpnetBaseV2,
231 dim: 768,
232 description: "mpnet-base-v2, strong general-purpose embeddings",
233 hf_repo: "sentence-transformers/all-mpnet-base-v2",
234 model_file: "model.safetensors",
235 pooling: Pooling::Mean,
236 max_length: 384,
237 arch: ModelArch::Bert,
238 },
239 ModelInfo {
241 model: EmbeddingModel::BGELargeZHV15,
242 dim: 1024,
243 description: "BGE large Chinese v1.5",
244 hf_repo: "BAAI/bge-large-zh-v1.5",
245 model_file: "model.safetensors",
246 pooling: Pooling::Cls,
247 max_length: 512,
248 arch: ModelArch::Bert,
249 },
250 ModelInfo {
252 model: EmbeddingModel::MultilingualE5Small,
253 dim: 384,
254 description: "Multilingual E5 small, 100+ languages",
255 hf_repo: "intfloat/multilingual-e5-small",
256 model_file: "model.safetensors",
257 pooling: Pooling::Mean,
258 max_length: 512,
259 arch: ModelArch::Bert,
260 },
261 ModelInfo {
262 model: EmbeddingModel::MultilingualE5Base,
263 dim: 768,
264 description: "Multilingual E5 base, 100+ languages",
265 hf_repo: "intfloat/multilingual-e5-base",
266 model_file: "model.safetensors",
267 pooling: Pooling::Mean,
268 max_length: 512,
269 arch: ModelArch::Bert,
270 },
271 ModelInfo {
272 model: EmbeddingModel::MultilingualE5Large,
273 dim: 1024,
274 description: "Multilingual E5 large, 100+ languages",
275 hf_repo: "intfloat/multilingual-e5-large",
276 model_file: "model.safetensors",
277 pooling: Pooling::Mean,
278 max_length: 512,
279 arch: ModelArch::Bert,
280 },
281 ModelInfo {
283 model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
284 dim: 768,
285 description: "Paraphrase multilingual mpnet base v2",
286 hf_repo: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
287 model_file: "model.safetensors",
288 pooling: Pooling::Mean,
289 max_length: 384,
290 arch: ModelArch::Bert,
291 },
292 ModelInfo {
294 model: EmbeddingModel::SnowflakeArcticEmbedXS,
295 dim: 384,
296 description: "Snowflake Arctic Embed XS",
297 hf_repo: "snowflake/snowflake-arctic-embed-xs",
298 model_file: "model.safetensors",
299 pooling: Pooling::Cls,
300 max_length: 512,
301 arch: ModelArch::Bert,
302 },
303 ModelInfo {
304 model: EmbeddingModel::SnowflakeArcticEmbedS,
305 dim: 384,
306 description: "Snowflake Arctic Embed S",
307 hf_repo: "snowflake/snowflake-arctic-embed-s",
308 model_file: "model.safetensors",
309 pooling: Pooling::Cls,
310 max_length: 512,
311 arch: ModelArch::Bert,
312 },
313 ModelInfo {
314 model: EmbeddingModel::SnowflakeArcticEmbedM,
315 dim: 768,
316 description: "Snowflake Arctic Embed M",
317 hf_repo: "Snowflake/snowflake-arctic-embed-m",
318 model_file: "model.safetensors",
319 pooling: Pooling::Cls,
320 max_length: 512,
321 arch: ModelArch::Bert,
322 },
323 ModelInfo {
324 model: EmbeddingModel::SnowflakeArcticEmbedL,
325 dim: 1024,
326 description: "Snowflake Arctic Embed L",
327 hf_repo: "snowflake/snowflake-arctic-embed-l",
328 model_file: "model.safetensors",
329 pooling: Pooling::Cls,
330 max_length: 512,
331 arch: ModelArch::Bert,
332 },
333 ModelInfo {
335 model: EmbeddingModel::MxbaiEmbedLargeV1,
336 dim: 1024,
337 description: "MxBai embed large v1",
338 hf_repo: "mixedbread-ai/mxbai-embed-large-v1",
339 model_file: "model.safetensors",
340 pooling: Pooling::Cls,
341 max_length: 512,
342 arch: ModelArch::Bert,
343 },
344 ModelInfo {
346 model: EmbeddingModel::NomicEmbedTextV15,
347 dim: 768,
348 description: "Nomic embed text v1.5, RoPE, SwiGLU, 8192 context",
349 hf_repo: "nomic-ai/nomic-embed-text-v1.5",
350 model_file: "model.safetensors",
351 pooling: Pooling::Mean,
352 max_length: 8192,
353 arch: ModelArch::NomicBert,
354 },
355 ]
356 .into_iter()
357 .map(|info| (info.model.clone(), info))
358 .collect()
359}
360
361pub fn models_map() -> &'static HashMap<EmbeddingModel, ModelInfo> {
363 MODEL_MAP.get_or_init(init_models_map)
364}
365
366impl EmbeddingModel {
367 pub fn get_info(&self) -> Option<&'static ModelInfo> {
369 models_map().get(self)
370 }
371
372 pub fn list_supported() -> Vec<&'static ModelInfo> {
374 models_map().values().collect()
375 }
376}