erniebot_rs/embedding/
model.rs

1use serde::{Deserialize, Serialize};
2use strum_macros::{Display, EnumString};
3
4#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, Display, PartialEq, Eq)]
5#[non_exhaustive]
6pub enum EmbeddingModel {
7    #[default]
8    #[strum(serialize = "embedding-v1")]
9    #[serde(rename = "embedding-v1")]
10    EmbeddingV1,
11    #[strum(serialize = "bge_large_zh")]
12    #[serde(rename = "bge_large_zh")]
13    BgeLargeZh,
14    #[strum(serialize = "bge_large_en")]
15    #[serde(rename = "bge_large_en")]
16    BgeLagreEn,
17    #[strum(serialize = "tao_8k")]
18    #[serde(rename = "tao_8k")]
19    Tao8k,
20}
21
22#[cfg(test)]
23mod tests {
24    use super::EmbeddingModel;
25    use std::str::FromStr;
26    #[test]
27    fn test_embedding_model_to_string() {
28        assert_eq!(EmbeddingModel::EmbeddingV1.to_string(), "embedding-v1");
29        assert_eq!(EmbeddingModel::BgeLargeZh.to_string(), "bge_large_zh");
30        assert_eq!(EmbeddingModel::BgeLagreEn.to_string(), "bge_large_en");
31        assert_eq!(EmbeddingModel::Tao8k.to_string(), "tao_8k");
32    }
33
34    #[test]
35    fn test_embedding_model_from_str() {
36        assert_eq!(
37            EmbeddingModel::from_str("embedding-v1").unwrap(),
38            EmbeddingModel::EmbeddingV1
39        );
40        assert_eq!(
41            EmbeddingModel::from_str("bge_large_zh").unwrap(),
42            EmbeddingModel::BgeLargeZh
43        );
44        assert_eq!(
45            EmbeddingModel::from_str("bge_large_en").unwrap(),
46            EmbeddingModel::BgeLagreEn
47        );
48        assert_eq!(
49            EmbeddingModel::from_str("tao_8k").unwrap(),
50            EmbeddingModel::Tao8k
51        );
52    }
53}