Skip to main content

fastembed/models/
sparse.rs

1use std::{fmt::Display, str::FromStr};
2
3use crate::ModelInfo;
4
5#[derive(Default, Debug, Clone, PartialEq, Eq)]
6pub enum SparseModel {
7    /// prithivida/Splade_PP_en_v1
8    #[default]
9    SPLADEPPV1,
10    /// BAAI/bge-m3
11    BGEM3,
12}
13
14pub fn models_list() -> Vec<ModelInfo<SparseModel>> {
15    vec![
16        ModelInfo {
17            model: SparseModel::SPLADEPPV1,
18            dim: 0,
19            description: String::from("Splade sparse vector model for commercial use, v1"),
20            model_code: String::from("Qdrant/Splade_PP_en_v1"),
21            model_file: String::from("model.onnx"),
22            additional_files: Vec::new(),
23            output_key: None,
24        },
25        ModelInfo {
26            model: SparseModel::BGEM3,
27            dim: 0,
28            description: String::from(
29                "BGE-M3 sparse embedding model with 8192 context, supports 100+ languages",
30            ),
31            model_code: String::from("BAAI/bge-m3"),
32            model_file: String::from("onnx/model.onnx"),
33            additional_files: vec![
34                "onnx/model.onnx_data".to_string(),
35                "onnx/Constant_7_attr__value".to_string(),
36            ],
37            output_key: None,
38        },
39    ]
40}
41
42impl Display for SparseModel {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let model_info = models_list()
45            .into_iter()
46            .find(|model| model.model == *self)
47            .ok_or(std::fmt::Error)?;
48        write!(f, "{}", model_info.model_code)
49    }
50}
51
52impl FromStr for SparseModel {
53    type Err = String;
54
55    fn from_str(s: &str) -> Result<Self, Self::Err> {
56        models_list()
57            .into_iter()
58            .find(|m| m.model_code.eq_ignore_ascii_case(s))
59            .map(|m| m.model)
60            .ok_or_else(|| format!("Unknown sparse model: {s}"))
61    }
62}
63
64impl TryFrom<String> for SparseModel {
65    type Error = String;
66
67    fn try_from(value: String) -> Result<Self, Self::Error> {
68        value.parse()
69    }
70}