1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
13#[serde(rename_all = "kebab-case")]
14pub enum EmbeddingModel {
15 #[default]
20 BgeLarge,
21
22 MiniLM,
27
28 BgeSmall,
33
34 E5Small,
39}
40
41impl EmbeddingModel {
42 pub fn model_id(&self) -> &'static str {
44 match self {
45 EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
46 EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
47 EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
48 EmbeddingModel::E5Small => "intfloat/e5-small-v2",
49 }
50 }
51
52 pub fn dimension(&self) -> usize {
54 match self {
55 EmbeddingModel::BgeLarge => 1024,
56 EmbeddingModel::MiniLM => 384,
57 EmbeddingModel::BgeSmall => 384,
58 EmbeddingModel::E5Small => 384,
59 }
60 }
61
62 pub fn max_seq_length(&self) -> usize {
64 match self {
65 EmbeddingModel::BgeLarge => 512,
66 EmbeddingModel::MiniLM => 256,
67 EmbeddingModel::BgeSmall => 512,
68 EmbeddingModel::E5Small => 512,
69 }
70 }
71
72 pub fn query_prefix(&self) -> Option<&'static str> {
75 match self {
76 EmbeddingModel::BgeLarge => None,
77 EmbeddingModel::MiniLM => None,
78 EmbeddingModel::BgeSmall => None,
79 EmbeddingModel::E5Small => Some("query: "),
80 }
81 }
82
83 pub fn document_prefix(&self) -> Option<&'static str> {
85 match self {
86 EmbeddingModel::BgeLarge => None,
87 EmbeddingModel::MiniLM => None,
88 EmbeddingModel::BgeSmall => None,
89 EmbeddingModel::E5Small => Some("passage: "),
90 }
91 }
92
93 pub fn use_mean_pooling(&self) -> bool {
95 match self {
96 EmbeddingModel::BgeLarge => true,
97 EmbeddingModel::MiniLM => true,
98 EmbeddingModel::BgeSmall => true,
99 EmbeddingModel::E5Small => true,
100 }
101 }
102
103 pub fn normalize_embeddings(&self) -> bool {
105 true }
107
108 pub fn tokens_per_second_cpu(&self) -> usize {
110 match self {
111 EmbeddingModel::BgeLarge => 1000,
112 EmbeddingModel::MiniLM => 5000,
113 EmbeddingModel::BgeSmall => 3000,
114 EmbeddingModel::E5Small => 3000,
115 }
116 }
117
118 pub fn onnx_repo_id(&self) -> &'static str {
123 match self {
124 EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
125 EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
126 EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
127 EmbeddingModel::E5Small => "Xenova/e5-small-v2",
128 }
129 }
130
131 pub fn onnx_filename(&self) -> &'static str {
133 "onnx/model_quantized.onnx"
134 }
135
136 pub fn all() -> &'static [EmbeddingModel] {
138 &[
139 EmbeddingModel::BgeLarge,
140 EmbeddingModel::MiniLM,
141 EmbeddingModel::BgeSmall,
142 EmbeddingModel::E5Small,
143 ]
144 }
145
146 pub fn parse(s: &str) -> Option<Self> {
148 match s.to_lowercase().as_str() {
149 "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
150 "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
151 "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
152 "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
153 _ => None,
154 }
155 }
156}
157
158impl std::fmt::Display for EmbeddingModel {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self {
161 EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
162 EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
163 EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
164 EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ModelConfig {
172 pub model: EmbeddingModel,
174
175 pub cache_dir: Option<String>,
178
179 pub max_batch_size: usize,
181
182 pub use_gpu: bool,
184
185 pub num_threads: Option<usize>,
187}
188
189impl Default for ModelConfig {
190 fn default() -> Self {
191 Self {
192 model: EmbeddingModel::default(),
193 cache_dir: None,
194 max_batch_size: 32,
195 use_gpu: false,
196 num_threads: None,
197 }
198 }
199}
200
201impl ModelConfig {
202 pub fn new(model: EmbeddingModel) -> Self {
204 Self {
205 model,
206 ..Default::default()
207 }
208 }
209
210 pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
212 self.cache_dir = Some(dir.into());
213 self
214 }
215
216 pub fn with_max_batch_size(mut self, size: usize) -> Self {
218 self.max_batch_size = size;
219 self
220 }
221
222 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
224 self.use_gpu = use_gpu;
225 self
226 }
227
228 pub fn with_num_threads(mut self, threads: usize) -> Self {
230 self.num_threads = Some(threads);
231 self
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_model_ids() {
241 assert_eq!(
242 EmbeddingModel::BgeLarge.model_id(),
243 "BAAI/bge-large-en-v1.5"
244 );
245 assert_eq!(
246 EmbeddingModel::MiniLM.model_id(),
247 "sentence-transformers/all-MiniLM-L6-v2"
248 );
249 assert_eq!(
250 EmbeddingModel::BgeSmall.model_id(),
251 "BAAI/bge-small-en-v1.5"
252 );
253 assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
254 }
255
256 #[test]
257 fn test_dimensions() {
258 assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
259 assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
260 assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
261 assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
262 for model in EmbeddingModel::all() {
264 assert!(model.dimension() > 0);
265 }
266 }
267
268 #[test]
269 fn test_from_str() {
270 assert_eq!(
271 EmbeddingModel::parse("bge-large"),
272 Some(EmbeddingModel::BgeLarge)
273 );
274 assert_eq!(
275 EmbeddingModel::parse("minilm"),
276 Some(EmbeddingModel::MiniLM)
277 );
278 assert_eq!(
279 EmbeddingModel::parse("BGE-SMALL"),
280 Some(EmbeddingModel::BgeSmall)
281 );
282 assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
283 assert_eq!(EmbeddingModel::parse("unknown"), None);
284 }
285
286 #[test]
287 fn test_e5_prefixes() {
288 assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
289 assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
290 assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
291 }
292}