1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
13#[serde(rename_all = "kebab-case")]
14pub enum EmbeddingModel {
15 #[default]
20 BgeLarge,
21
22 MiniLM,
27
28 BgeSmall,
33
34 E5Small,
39}
40
41impl EmbeddingModel {
42 pub fn model_id(&self) -> &'static str {
44 match self {
45 EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
46 EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
47 EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
48 EmbeddingModel::E5Small => "intfloat/e5-small-v2",
49 }
50 }
51
52 pub fn dimension(&self) -> usize {
54 match self {
55 EmbeddingModel::BgeLarge => 1024,
56 EmbeddingModel::MiniLM => 384,
57 EmbeddingModel::BgeSmall => 384,
58 EmbeddingModel::E5Small => 384,
59 }
60 }
61
62 pub fn max_seq_length(&self) -> usize {
64 match self {
65 EmbeddingModel::BgeLarge => 512,
66 EmbeddingModel::MiniLM => 256,
67 EmbeddingModel::BgeSmall => 512,
68 EmbeddingModel::E5Small => 512,
69 }
70 }
71
72 pub fn query_prefix(&self) -> Option<&'static str> {
75 match self {
76 EmbeddingModel::BgeLarge => None,
77 EmbeddingModel::MiniLM => None,
78 EmbeddingModel::BgeSmall => None,
79 EmbeddingModel::E5Small => Some("query: "),
80 }
81 }
82
83 pub fn document_prefix(&self) -> Option<&'static str> {
85 match self {
86 EmbeddingModel::BgeLarge => None,
87 EmbeddingModel::MiniLM => None,
88 EmbeddingModel::BgeSmall => None,
89 EmbeddingModel::E5Small => Some("passage: "),
90 }
91 }
92
93 pub fn use_mean_pooling(&self) -> bool {
95 match self {
96 EmbeddingModel::BgeLarge => true,
97 EmbeddingModel::MiniLM => true,
98 EmbeddingModel::BgeSmall => true,
99 EmbeddingModel::E5Small => true,
100 }
101 }
102
103 pub fn normalize_embeddings(&self) -> bool {
105 true }
107
108 pub fn tokens_per_second_cpu(&self) -> usize {
110 match self {
111 EmbeddingModel::BgeLarge => 1000,
112 EmbeddingModel::MiniLM => 5000,
113 EmbeddingModel::BgeSmall => 3000,
114 EmbeddingModel::E5Small => 3000,
115 }
116 }
117
118 pub fn onnx_repo_id(&self) -> &'static str {
123 match self {
124 EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
125 EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
126 EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
127 EmbeddingModel::E5Small => "Xenova/e5-small-v2",
128 }
129 }
130
131 pub fn onnx_filename(&self) -> &'static str {
133 "onnx/model_quantized.onnx"
134 }
135
136 pub fn all() -> &'static [EmbeddingModel] {
138 &[
139 EmbeddingModel::BgeLarge,
140 EmbeddingModel::MiniLM,
141 EmbeddingModel::BgeSmall,
142 EmbeddingModel::E5Small,
143 ]
144 }
145
146 pub fn parse(s: &str) -> Option<Self> {
148 match s.to_lowercase().as_str() {
149 "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
150 "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
151 "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
152 "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
153 _ => None,
154 }
155 }
156}
157
158impl std::fmt::Display for EmbeddingModel {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self {
161 EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
162 EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
163 EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
164 EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ModelConfig {
172 pub model: EmbeddingModel,
174
175 pub cache_dir: Option<String>,
178
179 pub max_batch_size: usize,
181
182 pub use_gpu: bool,
184
185 pub num_threads: Option<usize>,
187
188 pub session_pool_size: usize,
195}
196
197impl Default for ModelConfig {
198 fn default() -> Self {
199 let pool_size = std::env::var("DAKERA_ONNX_POOL_SIZE")
206 .ok()
207 .and_then(|v| v.parse::<usize>().ok())
208 .filter(|&n| n >= 1)
209 .unwrap_or(4);
210 let max_batch_size = std::env::var("DAKERA_ONNX_BATCH_SIZE")
214 .ok()
215 .and_then(|v| v.parse::<usize>().ok())
216 .filter(|&n| n >= 1)
217 .unwrap_or(8);
218 Self {
219 model: EmbeddingModel::default(),
220 cache_dir: None,
221 max_batch_size,
222 use_gpu: false,
223 num_threads: None,
224 session_pool_size: pool_size,
225 }
226 }
227}
228
229impl ModelConfig {
230 pub fn new(model: EmbeddingModel) -> Self {
232 Self {
233 model,
234 ..Default::default()
235 }
236 }
237
238 pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
240 self.cache_dir = Some(dir.into());
241 self
242 }
243
244 pub fn with_max_batch_size(mut self, size: usize) -> Self {
246 self.max_batch_size = size;
247 self
248 }
249
250 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
252 self.use_gpu = use_gpu;
253 self
254 }
255
256 pub fn with_num_threads(mut self, threads: usize) -> Self {
258 self.num_threads = Some(threads);
259 self
260 }
261
262 pub fn with_session_pool_size(mut self, size: usize) -> Self {
264 self.session_pool_size = size.max(1);
265 self
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_model_ids() {
275 assert_eq!(
276 EmbeddingModel::BgeLarge.model_id(),
277 "BAAI/bge-large-en-v1.5"
278 );
279 assert_eq!(
280 EmbeddingModel::MiniLM.model_id(),
281 "sentence-transformers/all-MiniLM-L6-v2"
282 );
283 assert_eq!(
284 EmbeddingModel::BgeSmall.model_id(),
285 "BAAI/bge-small-en-v1.5"
286 );
287 assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
288 }
289
290 #[test]
291 fn test_dimensions() {
292 assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
293 assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
294 assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
295 assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
296 for model in EmbeddingModel::all() {
298 assert!(model.dimension() > 0);
299 }
300 }
301
302 #[test]
303 fn test_from_str() {
304 assert_eq!(
305 EmbeddingModel::parse("bge-large"),
306 Some(EmbeddingModel::BgeLarge)
307 );
308 assert_eq!(
309 EmbeddingModel::parse("minilm"),
310 Some(EmbeddingModel::MiniLM)
311 );
312 assert_eq!(
313 EmbeddingModel::parse("BGE-SMALL"),
314 Some(EmbeddingModel::BgeSmall)
315 );
316 assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
317 assert_eq!(EmbeddingModel::parse("unknown"), None);
318 }
319
320 #[test]
321 fn test_e5_prefixes() {
322 assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
323 assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
324 assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
325 }
326}