1use std::str::FromStr;
2
3use crate::{HuggingFaceModelSpec, ModelRuntimeError, ModelTask, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum ModelPreset {
8 DetrResnet50,
10 YolosTiny,
12 DistilbertSst2,
14 BertBaseNer,
16 MiniLmL6V2,
18 XenovaDistilbertSst2Onnx,
20 XenovaMiniLmL6V2Onnx,
22 XenovaAllMpnetBaseV2Onnx,
24 XenovaTwitterRobertaSentimentOnnx,
26 XenovaBartLargeMnliOnnx,
28 XenovaBartLargeCnnOnnx,
30 XenovaMsMarcoMiniLmL6V2Onnx,
32 XenovaRobertaBaseSquad2Onnx,
34 OnnxCommunityRobertaBaseSquad2,
36 XenovaVitBasePatch16_224Onnx,
38 XenovaVitGpt2ImageCaptioningOnnx,
40 XenovaTrocrBasePrintedOnnx,
42 XenovaTrocrBaseHandwrittenOnnx,
44 XenovaDetrResnet50Onnx,
46 XenovaYolov8nPoseOnnx,
48 AstAudioset,
50 XenovaAstAudiosetOnnx,
52 ClapHtsatUnfused,
54 WhisperTinyEn,
56 WhisperLargeV3,
58 WhisperLargeV3Turbo,
60 Wav2Vec2Base960h,
62 PyannoteSpeakerDiarization31,
64 DemucsMusicSeparation,
66 MusicgenSmall,
68 F5TtsV1Base,
70 F5TtsBase,
72 E2TtsBase,
74 VocosMel24Khz,
76}
77
78impl ModelPreset {
79 pub const ALL: &'static [Self] = &[
81 Self::DetrResnet50,
82 Self::YolosTiny,
83 Self::DistilbertSst2,
84 Self::BertBaseNer,
85 Self::MiniLmL6V2,
86 Self::XenovaDistilbertSst2Onnx,
87 Self::XenovaMiniLmL6V2Onnx,
88 Self::XenovaAllMpnetBaseV2Onnx,
89 Self::XenovaTwitterRobertaSentimentOnnx,
90 Self::XenovaBartLargeMnliOnnx,
91 Self::XenovaBartLargeCnnOnnx,
92 Self::XenovaMsMarcoMiniLmL6V2Onnx,
93 Self::XenovaRobertaBaseSquad2Onnx,
94 Self::OnnxCommunityRobertaBaseSquad2,
95 Self::XenovaVitBasePatch16_224Onnx,
96 Self::XenovaVitGpt2ImageCaptioningOnnx,
97 Self::XenovaTrocrBasePrintedOnnx,
98 Self::XenovaTrocrBaseHandwrittenOnnx,
99 Self::XenovaDetrResnet50Onnx,
100 Self::XenovaYolov8nPoseOnnx,
101 Self::AstAudioset,
102 Self::XenovaAstAudiosetOnnx,
103 Self::ClapHtsatUnfused,
104 Self::WhisperTinyEn,
105 Self::WhisperLargeV3,
106 Self::WhisperLargeV3Turbo,
107 Self::Wav2Vec2Base960h,
108 Self::PyannoteSpeakerDiarization31,
109 Self::DemucsMusicSeparation,
110 Self::MusicgenSmall,
111 Self::F5TtsV1Base,
112 Self::F5TtsBase,
113 Self::E2TtsBase,
114 Self::VocosMel24Khz,
115 ];
116
117 pub fn as_str(self) -> &'static str {
119 match self {
120 Self::DetrResnet50 => "detr-resnet-50",
121 Self::YolosTiny => "yolos-tiny",
122 Self::DistilbertSst2 => "distilbert-sst2",
123 Self::BertBaseNer => "bert-base-ner",
124 Self::MiniLmL6V2 => "minilm-l6-v2",
125 Self::XenovaDistilbertSst2Onnx => "xenova-distilbert-sst2-onnx",
126 Self::XenovaMiniLmL6V2Onnx => "xenova-minilm-l6-v2-onnx",
127 Self::XenovaAllMpnetBaseV2Onnx => "xenova-all-mpnet-base-v2-onnx",
128 Self::XenovaTwitterRobertaSentimentOnnx => {
129 "xenova-twitter-roberta-sentiment-latest-onnx"
130 }
131 Self::XenovaBartLargeMnliOnnx => "xenova-bart-large-mnli-onnx",
132 Self::XenovaBartLargeCnnOnnx => "xenova-bart-large-cnn-onnx",
133 Self::XenovaMsMarcoMiniLmL6V2Onnx => "xenova-ms-marco-minilm-l6-v2-onnx",
134 Self::XenovaRobertaBaseSquad2Onnx => "xenova-roberta-base-squad2-onnx",
135 Self::OnnxCommunityRobertaBaseSquad2 => "roberta-base-squad2-onnx",
136 Self::XenovaVitBasePatch16_224Onnx => "vit-base-patch16-224-onnx",
137 Self::XenovaVitGpt2ImageCaptioningOnnx => "vit-gpt2-image-captioning-onnx",
138 Self::XenovaTrocrBasePrintedOnnx => "trocr-base-printed-onnx",
139 Self::XenovaTrocrBaseHandwrittenOnnx => "trocr-base-handwritten-onnx",
140 Self::XenovaDetrResnet50Onnx => "xenova-detr-resnet-50-onnx",
141 Self::XenovaYolov8nPoseOnnx => "xenova-yolov8n-pose-onnx",
142 Self::AstAudioset => "ast-audioset",
143 Self::XenovaAstAudiosetOnnx => "xenova-ast-audioset-onnx",
144 Self::ClapHtsatUnfused => "clap-htsat-unfused",
145 Self::WhisperTinyEn => "whisper-tiny-en",
146 Self::WhisperLargeV3 => "whisper-large-v3",
147 Self::WhisperLargeV3Turbo => "whisper-large-v3-turbo",
148 Self::Wav2Vec2Base960h => "wav2vec2-base-960h",
149 Self::PyannoteSpeakerDiarization31 => "pyannote-speaker-diarization-3-1",
150 Self::DemucsMusicSeparation => "demucs-music-separation",
151 Self::MusicgenSmall => "musicgen-small",
152 Self::F5TtsV1Base => "f5-tts-v1-base",
153 Self::F5TtsBase => "f5-tts-base",
154 Self::E2TtsBase => "e2-tts-base",
155 Self::VocosMel24Khz => "vocos-mel-24khz",
156 }
157 }
158
159 pub fn spec(self) -> HuggingFaceModelSpec {
161 match self {
162 Self::DetrResnet50 => {
163 HuggingFaceModelSpec::new("facebook/detr-resnet-50", ModelTask::ObjectDetection)
164 .name(self.as_str())
165 .file("config.json")
166 .file("preprocessor_config.json")
167 .first_available_file(["model.safetensors", "pytorch_model.bin"])
168 }
169 Self::YolosTiny => {
170 HuggingFaceModelSpec::new("hustvl/yolos-tiny", ModelTask::ObjectDetection)
171 .name(self.as_str())
172 .file("config.json")
173 .file("preprocessor_config.json")
174 .first_available_file(["model.safetensors", "pytorch_model.bin"])
175 }
176 Self::DistilbertSst2 => HuggingFaceModelSpec::new(
177 "distilbert-base-uncased-finetuned-sst-2-english",
178 ModelTask::TextClassification,
179 )
180 .name(self.as_str())
181 .file("config.json")
182 .file("tokenizer_config.json")
183 .file("vocab.txt")
184 .first_available_file(["model.safetensors", "pytorch_model.bin"]),
185 Self::BertBaseNer => {
186 HuggingFaceModelSpec::new("dslim/bert-base-NER", ModelTask::TokenClassification)
187 .name(self.as_str())
188 .file("config.json")
189 .file("tokenizer_config.json")
190 .file("vocab.txt")
191 .optional_file("tokenizer.json")
192 .first_available_file(["model.safetensors", "pytorch_model.bin"])
193 }
194 Self::MiniLmL6V2 => HuggingFaceModelSpec::new(
195 "sentence-transformers/all-MiniLM-L6-v2",
196 ModelTask::TextEmbedding,
197 )
198 .name(self.as_str())
199 .file("config.json")
200 .file("tokenizer.json")
201 .file("tokenizer_config.json")
202 .file("vocab.txt")
203 .file("modules.json")
204 .optional_file("sentence_bert_config.json")
205 .first_available_file(["model.safetensors", "pytorch_model.bin"]),
206 Self::XenovaDistilbertSst2Onnx => HuggingFaceModelSpec::new(
207 "Xenova/distilbert-base-uncased-finetuned-sst-2-english",
208 ModelTask::TextClassification,
209 )
210 .name(self.as_str())
211 .file("config.json")
212 .file("tokenizer.json")
213 .file("tokenizer_config.json")
214 .first_available_file([
215 "onnx/model.onnx",
216 "onnx/model_quantized.onnx",
217 "onnx/model_int8.onnx",
218 ]),
219 Self::XenovaMiniLmL6V2Onnx => {
220 HuggingFaceModelSpec::new("Xenova/all-MiniLM-L6-v2", ModelTask::TextEmbedding)
221 .name(self.as_str())
222 .file("config.json")
223 .file("tokenizer.json")
224 .file("tokenizer_config.json")
225 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"])
226 }
227 Self::XenovaAllMpnetBaseV2Onnx => {
228 HuggingFaceModelSpec::new("Xenova/all-mpnet-base-v2", ModelTask::TextEmbedding)
229 .name(self.as_str())
230 .file("config.json")
231 .file("tokenizer.json")
232 .file("tokenizer_config.json")
233 .optional_file("modules.json")
234 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"])
235 }
236 Self::XenovaTwitterRobertaSentimentOnnx => HuggingFaceModelSpec::new(
237 "Xenova/twitter-roberta-base-sentiment-latest",
238 ModelTask::TextClassification,
239 )
240 .name(self.as_str())
241 .file("config.json")
242 .file("tokenizer.json")
243 .file("tokenizer_config.json")
244 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"]),
245 Self::XenovaBartLargeMnliOnnx => HuggingFaceModelSpec::new(
246 "Xenova/bart-large-mnli",
247 ModelTask::ZeroShotClassification,
248 )
249 .name(self.as_str())
250 .file("config.json")
251 .file("tokenizer.json")
252 .file("tokenizer_config.json")
253 .first_available_file([
254 "onnx/model_quantized.onnx",
255 "onnx/encoder_model.onnx",
256 "onnx/model.onnx",
257 ]),
258 Self::XenovaBartLargeCnnOnnx => {
259 HuggingFaceModelSpec::new("Xenova/bart-large-cnn", ModelTask::Summarization)
260 .name(self.as_str())
261 .file("config.json")
262 .file("tokenizer.json")
263 .file("tokenizer_config.json")
264 .first_available_file([
265 "onnx/encoder_model.onnx",
266 "onnx/model.onnx",
267 "onnx/model_quantized.onnx",
268 ])
269 }
270 Self::XenovaMsMarcoMiniLmL6V2Onnx => {
271 HuggingFaceModelSpec::new("Xenova/ms-marco-MiniLM-L-6-v2", ModelTask::Reranking)
272 .name(self.as_str())
273 .file("config.json")
274 .file("tokenizer.json")
275 .file("tokenizer_config.json")
276 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"])
277 }
278 Self::XenovaRobertaBaseSquad2Onnx => HuggingFaceModelSpec::new(
279 "Xenova/roberta-base-squad2",
280 ModelTask::QuestionAnswering,
281 )
282 .name(self.as_str())
283 .file("config.json")
284 .file("tokenizer.json")
285 .file("tokenizer_config.json")
286 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"]),
287 Self::OnnxCommunityRobertaBaseSquad2 => HuggingFaceModelSpec::new(
288 "onnx-community/roberta-base-squad2-ONNX",
289 ModelTask::QuestionAnswering,
290 )
291 .name(self.as_str())
292 .file("config.json")
293 .file("tokenizer.json")
294 .file("tokenizer_config.json")
295 .optional_file("vocab.json")
296 .optional_file("merges.txt")
297 .optional_file("special_tokens_map.json")
298 .first_available_file([
299 "onnx/model_quantized.onnx",
300 "onnx/model.onnx",
301 "onnx/model_uint8.onnx",
302 ]),
303 Self::XenovaVitBasePatch16_224Onnx => HuggingFaceModelSpec::new(
304 "Xenova/vit-base-patch16-224",
305 ModelTask::ImageClassification,
306 )
307 .name(self.as_str())
308 .file("config.json")
309 .file("preprocessor_config.json")
310 .first_available_file(["onnx/model_quantized.onnx", "onnx/model.onnx"]),
311 Self::XenovaVitGpt2ImageCaptioningOnnx => HuggingFaceModelSpec::new(
312 "Xenova/vit-gpt2-image-captioning",
313 ModelTask::Custom("image_captioning".to_string()),
314 )
315 .name(self.as_str())
316 .file("config.json")
317 .file("generation_config.json")
318 .file("preprocessor_config.json")
319 .file("tokenizer.json")
320 .file("tokenizer_config.json")
321 .file("vocab.json")
322 .file("merges.txt")
323 .first_available_file([
324 "onnx/encoder_model_quantized.onnx",
325 "onnx/encoder_model.onnx",
326 ])
327 .first_available_file([
328 "onnx/decoder_model_quantized.onnx",
329 "onnx/decoder_model.onnx",
330 ]),
331 Self::XenovaTrocrBasePrintedOnnx | Self::XenovaTrocrBaseHandwrittenOnnx => {
332 let repo_id = match self {
333 Self::XenovaTrocrBasePrintedOnnx => "Xenova/trocr-base-printed",
334 Self::XenovaTrocrBaseHandwrittenOnnx => "Xenova/trocr-base-handwritten",
335 _ => unreachable!(),
336 };
337 HuggingFaceModelSpec::new(
338 repo_id,
339 ModelTask::Custom("optical_character_recognition".to_string()),
340 )
341 .name(self.as_str())
342 .file("config.json")
343 .file("preprocessor_config.json")
344 .file("tokenizer.json")
345 .optional_file("generation_config.json")
346 .optional_file("tokenizer_config.json")
347 .optional_file("vocab.json")
348 .optional_file("merges.txt")
349 .first_available_file([
350 "onnx/encoder_model_quantized.onnx",
351 "onnx/encoder_model.onnx",
352 "onnx/encoder_model_fp16.onnx",
353 ])
354 .first_available_file([
355 "onnx/decoder_model_quantized.onnx",
356 "onnx/decoder_model.onnx",
357 "onnx/decoder_model_fp16.onnx",
358 ])
359 }
360 Self::XenovaDetrResnet50Onnx => {
361 HuggingFaceModelSpec::new("Xenova/detr-resnet-50", ModelTask::ObjectDetection)
362 .name(self.as_str())
363 .file("config.json")
364 .file("preprocessor_config.json")
365 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"])
366 }
367 Self::XenovaYolov8nPoseOnnx => {
368 HuggingFaceModelSpec::new("Xenova/yolov8n-pose", ModelTask::PoseEstimation2d)
369 .name(self.as_str())
370 .file("config.json")
371 .file("preprocessor_config.json")
372 .first_available_file([
373 "onnx/model_quantized.onnx",
374 "onnx/model_int8.onnx",
375 "onnx/model.onnx",
376 ])
377 }
378 Self::AstAudioset => HuggingFaceModelSpec::new(
379 "MIT/ast-finetuned-audioset-10-10-0.4593",
380 ModelTask::AudioClassification,
381 )
382 .name(self.as_str())
383 .file("config.json")
384 .file("preprocessor_config.json")
385 .first_available_file(["model.safetensors", "pytorch_model.bin"]),
386 Self::XenovaAstAudiosetOnnx => HuggingFaceModelSpec::new(
387 "Xenova/ast-finetuned-audioset-10-10-0.4593",
388 ModelTask::AudioClassification,
389 )
390 .name(self.as_str())
391 .file("config.json")
392 .file("preprocessor_config.json")
393 .first_available_file(["onnx/model.onnx", "onnx/model_quantized.onnx"]),
394 Self::ClapHtsatUnfused => {
395 HuggingFaceModelSpec::new("laion/clap-htsat-unfused", ModelTask::AudioEmbedding)
396 .name(self.as_str())
397 .file("config.json")
398 .file("preprocessor_config.json")
399 .optional_file("tokenizer.json")
400 .first_available_file(["model.safetensors", "pytorch_model.bin"])
401 }
402 Self::WhisperTinyEn => {
403 HuggingFaceModelSpec::new("openai/whisper-tiny.en", ModelTask::SpeechRecognition)
404 .name(self.as_str())
405 .file("config.json")
406 .file("preprocessor_config.json")
407 .file("tokenizer.json")
408 .first_available_file(["model.safetensors", "pytorch_model.bin"])
409 }
410 Self::WhisperLargeV3 => {
411 let mut spec = HuggingFaceModelSpec::new(
412 "openai/whisper-large-v3",
413 ModelTask::SpeechRecognition,
414 )
415 .name(self.as_str())
416 .file("config.json")
417 .file("generation_config.json")
418 .file("tokenizer.json")
419 .file("preprocessor_config.json")
420 .file("model.safetensors");
421 spec.metadata
422 .insert("backend".to_string(), "candle".to_string());
423 spec
424 }
425 Self::WhisperLargeV3Turbo => {
426 let mut spec = HuggingFaceModelSpec::new(
427 "openai/whisper-large-v3-turbo",
428 ModelTask::SpeechRecognition,
429 )
430 .name(self.as_str())
431 .file("config.json")
432 .file("generation_config.json")
433 .file("tokenizer.json")
434 .file("preprocessor_config.json")
435 .file("model.safetensors");
436 spec.metadata
437 .insert("backend".to_string(), "candle".to_string());
438 spec
439 }
440 Self::Wav2Vec2Base960h => {
441 let mut spec = HuggingFaceModelSpec::new(
442 "facebook/wav2vec2-base-960h",
443 ModelTask::Custom("forced_alignment".to_string()),
444 )
445 .name(self.as_str())
446 .file("config.json")
447 .file("preprocessor_config.json")
448 .first_available_file(["tokenizer.json", "vocab.json"])
449 .file("model.safetensors");
450 spec.metadata
451 .insert("backend".to_string(), "candle".to_string());
452 spec
453 }
454 Self::PyannoteSpeakerDiarization31 => {
455 let mut spec = HuggingFaceModelSpec::new(
456 "pyannote/speaker-diarization-3.1",
457 ModelTask::SpeakerDiarization,
458 )
459 .name(self.as_str())
460 .file("config.yaml")
461 .optional_file("pytorch_model.bin");
462 spec.metadata
463 .insert("backend".to_string(), "plan-only".to_string());
464 spec
465 }
466 Self::DemucsMusicSeparation => {
467 HuggingFaceModelSpec::new("facebook/demucs", ModelTask::SourceSeparation)
468 .name(self.as_str())
469 .file("config.json")
470 .first_available_file(["pytorch_model.bin", "model.safetensors"])
471 }
472 Self::MusicgenSmall => {
473 HuggingFaceModelSpec::new("facebook/musicgen-small", ModelTask::AudioGeneration)
474 .name(self.as_str())
475 .file("config.json")
476 .file("preprocessor_config.json")
477 .file("tokenizer.json")
478 .first_available_file(["model.safetensors", "pytorch_model.bin"])
479 }
480 Self::F5TtsV1Base => tts_preset(
481 HuggingFaceModelSpec::new("SWivid/F5-TTS", ModelTask::SpeakerConditionedTts)
482 .name(self.as_str())
483 .file("F5TTS_v1_Base/model_1250000.safetensors")
484 .file("F5TTS_v1_Base/vocab.txt"),
485 "f5-tts",
486 "F5-TTS v1 base",
487 "cc-by-nc-4.0",
488 "Creative Commons Attribution Non Commercial 4.0",
489 "https://creativecommons.org/licenses/by-nc/4.0/",
490 ),
491 Self::F5TtsBase => tts_preset(
492 HuggingFaceModelSpec::new("SWivid/F5-TTS", ModelTask::SpeakerConditionedTts)
493 .name(self.as_str())
494 .file("F5TTS_Base/model_1200000.safetensors")
495 .file("F5TTS_Base/vocab.txt"),
496 "f5-tts",
497 "F5-TTS base",
498 "cc-by-nc-4.0",
499 "Creative Commons Attribution Non Commercial 4.0",
500 "https://creativecommons.org/licenses/by-nc/4.0/",
501 ),
502 Self::E2TtsBase => tts_preset(
503 HuggingFaceModelSpec::new("SWivid/E2-TTS", ModelTask::SpeakerConditionedTts)
504 .name(self.as_str())
505 .file("E2TTS_Base/model_1200000.safetensors"),
506 "e2-tts",
507 "E2-TTS base",
508 "cc-by-nc-4.0",
509 "Creative Commons Attribution Non Commercial 4.0",
510 "https://creativecommons.org/licenses/by-nc/4.0/",
511 ),
512 Self::VocosMel24Khz => tts_preset(
513 HuggingFaceModelSpec::new("charactr/vocos-mel-24khz", ModelTask::AudioGeneration)
514 .name(self.as_str())
515 .file("config.yaml")
516 .file("pytorch_model.bin"),
517 "vocos",
518 "Vocos mel 24 kHz",
519 "mit",
520 "MIT",
521 "https://opensource.org/license/mit/",
522 ),
523 }
524 }
525}
526
527fn tts_preset(
528 mut spec: HuggingFaceModelSpec,
529 family: &str,
530 display_name: &str,
531 license: &str,
532 license_name: &str,
533 license_url: &str,
534) -> HuggingFaceModelSpec {
535 spec.metadata
536 .insert("modelFamily".to_string(), family.to_string());
537 spec.metadata
538 .insert("displayName".to_string(), display_name.to_string());
539 spec.metadata
540 .insert("license".to_string(), license.to_string());
541 spec.metadata
542 .insert("licenseName".to_string(), license_name.to_string());
543 spec.metadata
544 .insert("licenseUrl".to_string(), license_url.to_string());
545 spec.metadata
546 .insert("licenseScope".to_string(), "model".to_string());
547 spec.metadata
548 .insert("explicitOptIn".to_string(), "true".to_string());
549 spec
550}
551
552impl FromStr for ModelPreset {
553 type Err = ModelRuntimeError;
554
555 fn from_str(input: &str) -> Result<Self> {
556 Self::ALL
557 .iter()
558 .copied()
559 .find(|preset| preset.as_str() == input)
560 .ok_or_else(|| {
561 ModelRuntimeError::InvalidArgument(format!(
562 "unknown model preset `{input}`; expected one of {}",
563 Self::ALL
564 .iter()
565 .map(|preset| preset.as_str())
566 .collect::<Vec<_>>()
567 .join(", ")
568 ))
569 })
570 }
571}