Skip to main content

encoderfile/common/
model_type.rs

1macro_rules! model_type {
2    [ $( $x:ident ),* $(,)? ] => {
3        // create enum
4        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, utoipa::ToSchema, schemars::JsonSchema)]
5        #[serde(rename_all = "snake_case")]
6        pub enum ModelType {
7            $(
8            $x,
9            )*
10        }
11
12        impl std::str::FromStr for ModelType {
13            type Err = String;
14
15            fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
16                serde_json::from_value::<ModelType>(serde_json::Value::String(s.to_string()))
17                    .map_err(|_| format!("Invalid model type: {}", s))
18            }
19        }
20
21        impl std::fmt::Display for ModelType {
22            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
23                serde_json::to_value(self)
24                    .map_err(|_| std::fmt::Error)?
25                    .as_str()
26                    .ok_or(std::fmt::Error)?
27                    .fmt(f)
28            }
29        }
30
31        $(
32            #[derive(Debug, Clone)]
33            pub struct $x;
34
35            impl ModelTypeSpec for $x {
36                fn enum_val() -> ModelType {
37                    ModelType::$x
38                }
39            }
40        )*
41    }
42}
43
44model_type![
45    Embedding,
46    SequenceClassification,
47    TokenClassification,
48    SentenceEmbedding,
49];
50
51pub trait ModelTypeSpec: Send + Sync + Clone + std::fmt::Debug + 'static {
52    fn enum_val() -> ModelType;
53}