1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
12#[serde(rename_all = "kebab-case")]
13pub enum EmbeddingModel {
14 #[default]
19 MiniLM,
20
21 BgeSmall,
26
27 E5Small,
32}
33
34impl EmbeddingModel {
35 pub fn model_id(&self) -> &'static str {
37 match self {
38 EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
39 EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
40 EmbeddingModel::E5Small => "intfloat/e5-small-v2",
41 }
42 }
43
44 pub fn dimension(&self) -> usize {
46 match self {
47 EmbeddingModel::MiniLM => 384,
48 EmbeddingModel::BgeSmall => 384,
49 EmbeddingModel::E5Small => 384,
50 }
51 }
52
53 pub fn max_seq_length(&self) -> usize {
55 match self {
56 EmbeddingModel::MiniLM => 256,
57 EmbeddingModel::BgeSmall => 512,
58 EmbeddingModel::E5Small => 512,
59 }
60 }
61
62 pub fn query_prefix(&self) -> Option<&'static str> {
65 match self {
66 EmbeddingModel::MiniLM => None,
67 EmbeddingModel::BgeSmall => None,
68 EmbeddingModel::E5Small => Some("query: "),
69 }
70 }
71
72 pub fn document_prefix(&self) -> Option<&'static str> {
74 match self {
75 EmbeddingModel::MiniLM => None,
76 EmbeddingModel::BgeSmall => None,
77 EmbeddingModel::E5Small => Some("passage: "),
78 }
79 }
80
81 pub fn use_mean_pooling(&self) -> bool {
83 match self {
84 EmbeddingModel::MiniLM => true,
85 EmbeddingModel::BgeSmall => true,
86 EmbeddingModel::E5Small => true,
87 }
88 }
89
90 pub fn normalize_embeddings(&self) -> bool {
92 true }
94
95 pub fn tokens_per_second_cpu(&self) -> usize {
97 match self {
98 EmbeddingModel::MiniLM => 5000,
99 EmbeddingModel::BgeSmall => 3000,
100 EmbeddingModel::E5Small => 3000,
101 }
102 }
103
104 pub fn onnx_repo_id(&self) -> &'static str {
109 match self {
110 EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
111 EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
112 EmbeddingModel::E5Small => "Xenova/e5-small-v2",
113 }
114 }
115
116 pub fn onnx_filename(&self) -> &'static str {
118 "onnx/model_quantized.onnx"
119 }
120
121 pub fn all() -> &'static [EmbeddingModel] {
123 &[
124 EmbeddingModel::MiniLM,
125 EmbeddingModel::BgeSmall,
126 EmbeddingModel::E5Small,
127 ]
128 }
129
130 pub fn parse(s: &str) -> Option<Self> {
132 match s.to_lowercase().as_str() {
133 "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
134 "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
135 "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
136 _ => None,
137 }
138 }
139}
140
141impl std::fmt::Display for EmbeddingModel {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
145 EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
146 EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ModelConfig {
154 pub model: EmbeddingModel,
156
157 pub cache_dir: Option<String>,
160
161 pub max_batch_size: usize,
163
164 pub use_gpu: bool,
166
167 pub num_threads: Option<usize>,
169}
170
171impl Default for ModelConfig {
172 fn default() -> Self {
173 Self {
174 model: EmbeddingModel::default(),
175 cache_dir: None,
176 max_batch_size: 32,
177 use_gpu: false,
178 num_threads: None,
179 }
180 }
181}
182
183impl ModelConfig {
184 pub fn new(model: EmbeddingModel) -> Self {
186 Self {
187 model,
188 ..Default::default()
189 }
190 }
191
192 pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
194 self.cache_dir = Some(dir.into());
195 self
196 }
197
198 pub fn with_max_batch_size(mut self, size: usize) -> Self {
200 self.max_batch_size = size;
201 self
202 }
203
204 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
206 self.use_gpu = use_gpu;
207 self
208 }
209
210 pub fn with_num_threads(mut self, threads: usize) -> Self {
212 self.num_threads = Some(threads);
213 self
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_model_ids() {
223 assert_eq!(
224 EmbeddingModel::MiniLM.model_id(),
225 "sentence-transformers/all-MiniLM-L6-v2"
226 );
227 assert_eq!(
228 EmbeddingModel::BgeSmall.model_id(),
229 "BAAI/bge-small-en-v1.5"
230 );
231 assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
232 }
233
234 #[test]
235 fn test_dimensions() {
236 for model in EmbeddingModel::all() {
237 assert_eq!(model.dimension(), 384);
238 }
239 }
240
241 #[test]
242 fn test_from_str() {
243 assert_eq!(
244 EmbeddingModel::parse("minilm"),
245 Some(EmbeddingModel::MiniLM)
246 );
247 assert_eq!(
248 EmbeddingModel::parse("BGE-SMALL"),
249 Some(EmbeddingModel::BgeSmall)
250 );
251 assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
252 assert_eq!(EmbeddingModel::parse("unknown"), None);
253 }
254
255 #[test]
256 fn test_e5_prefixes() {
257 assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
258 assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
259 assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
260 }
261}