1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
12#[serde(rename_all = "kebab-case")]
13pub enum EmbeddingModel {
14 #[default]
19 MiniLM,
20
21 BgeSmall,
26
27 E5Small,
32}
33
34impl EmbeddingModel {
35 pub fn model_id(&self) -> &'static str {
37 match self {
38 EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
39 EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
40 EmbeddingModel::E5Small => "intfloat/e5-small-v2",
41 }
42 }
43
44 pub fn dimension(&self) -> usize {
46 match self {
47 EmbeddingModel::MiniLM => 384,
48 EmbeddingModel::BgeSmall => 384,
49 EmbeddingModel::E5Small => 384,
50 }
51 }
52
53 pub fn max_seq_length(&self) -> usize {
55 match self {
56 EmbeddingModel::MiniLM => 256,
57 EmbeddingModel::BgeSmall => 512,
58 EmbeddingModel::E5Small => 512,
59 }
60 }
61
62 pub fn query_prefix(&self) -> Option<&'static str> {
65 match self {
66 EmbeddingModel::MiniLM => None,
67 EmbeddingModel::BgeSmall => None,
68 EmbeddingModel::E5Small => Some("query: "),
69 }
70 }
71
72 pub fn document_prefix(&self) -> Option<&'static str> {
74 match self {
75 EmbeddingModel::MiniLM => None,
76 EmbeddingModel::BgeSmall => None,
77 EmbeddingModel::E5Small => Some("passage: "),
78 }
79 }
80
81 pub fn use_mean_pooling(&self) -> bool {
83 match self {
84 EmbeddingModel::MiniLM => true,
85 EmbeddingModel::BgeSmall => true,
86 EmbeddingModel::E5Small => true,
87 }
88 }
89
90 pub fn normalize_embeddings(&self) -> bool {
92 true }
94
95 pub fn tokens_per_second_cpu(&self) -> usize {
97 match self {
98 EmbeddingModel::MiniLM => 5000,
99 EmbeddingModel::BgeSmall => 3000,
100 EmbeddingModel::E5Small => 3000,
101 }
102 }
103
104 pub fn all() -> &'static [EmbeddingModel] {
106 &[
107 EmbeddingModel::MiniLM,
108 EmbeddingModel::BgeSmall,
109 EmbeddingModel::E5Small,
110 ]
111 }
112
113 pub fn parse(s: &str) -> Option<Self> {
115 match s.to_lowercase().as_str() {
116 "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
117 "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
118 "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
119 _ => None,
120 }
121 }
122}
123
124impl std::fmt::Display for EmbeddingModel {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
128 EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
129 EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
130 }
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ModelConfig {
137 pub model: EmbeddingModel,
139
140 pub cache_dir: Option<String>,
143
144 pub max_batch_size: usize,
146
147 pub use_gpu: bool,
149
150 pub num_threads: Option<usize>,
152}
153
154impl Default for ModelConfig {
155 fn default() -> Self {
156 Self {
157 model: EmbeddingModel::default(),
158 cache_dir: None,
159 max_batch_size: 32,
160 use_gpu: false,
161 num_threads: None,
162 }
163 }
164}
165
166impl ModelConfig {
167 pub fn new(model: EmbeddingModel) -> Self {
169 Self {
170 model,
171 ..Default::default()
172 }
173 }
174
175 pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
177 self.cache_dir = Some(dir.into());
178 self
179 }
180
181 pub fn with_max_batch_size(mut self, size: usize) -> Self {
183 self.max_batch_size = size;
184 self
185 }
186
187 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
189 self.use_gpu = use_gpu;
190 self
191 }
192
193 pub fn with_num_threads(mut self, threads: usize) -> Self {
195 self.num_threads = Some(threads);
196 self
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_model_ids() {
206 assert_eq!(
207 EmbeddingModel::MiniLM.model_id(),
208 "sentence-transformers/all-MiniLM-L6-v2"
209 );
210 assert_eq!(
211 EmbeddingModel::BgeSmall.model_id(),
212 "BAAI/bge-small-en-v1.5"
213 );
214 assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
215 }
216
217 #[test]
218 fn test_dimensions() {
219 for model in EmbeddingModel::all() {
220 assert_eq!(model.dimension(), 384);
221 }
222 }
223
224 #[test]
225 fn test_from_str() {
226 assert_eq!(
227 EmbeddingModel::parse("minilm"),
228 Some(EmbeddingModel::MiniLM)
229 );
230 assert_eq!(
231 EmbeddingModel::parse("BGE-SMALL"),
232 Some(EmbeddingModel::BgeSmall)
233 );
234 assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
235 assert_eq!(EmbeddingModel::parse("unknown"), None);
236 }
237
238 #[test]
239 fn test_e5_prefixes() {
240 assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
241 assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
242 assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
243 }
244}