alith_models/local_model/metadata/
tokenizer.rs

1use std::str::FromStr;
2
3#[derive(Clone)]
4pub struct TokenizerMetadata {
5    pub ggml: Option<GgmlTokenizerMetadata>,
6    pub huggingface_json: Option<String>,
7    pub rwkv_world: Option<String>,
8    pub chat_template: Option<String>,
9}
10
11#[derive(Debug, Clone)]
12pub struct GgmlTokenizerMetadata {
13    pub model: GgmlTokenizerModel,
14    pub tokens: Vec<String>,
15    pub scores: Option<Vec<f32>>,
16    pub merges: Option<Vec<String>>,
17    pub added_tokens: Option<Vec<String>>,
18    pub bos_token_id: u32,
19    pub eos_token_id: u32,
20    pub unknown_token_id: Option<u32>,
21    pub separator_token_id: Option<u32>,
22    pub padding_token_id: Option<u32>,
23}
24
25#[derive(Debug, Clone, PartialEq)]
26pub enum GgmlTokenizerModel {
27    Llama,
28    Replit,
29    Gpt2,
30    Rwkv,
31}
32
33impl FromStr for GgmlTokenizerModel {
34    type Err = anyhow::Error;
35
36    fn from_str(s: &str) -> Result<Self, Self::Err> {
37        Ok(match s {
38            "llama" => Self::Llama,
39            "replit" => Self::Replit,
40            "gpt2" => Self::Gpt2,
41            "rwkv" => Self::Rwkv,
42            _ => crate::bail!("Unknown GGML tokenizer model: {}", s),
43        })
44    }
45}
46
47impl GgmlTokenizerModel {
48    pub fn to_str(&self) -> &str {
49        match self {
50            Self::Llama => "llama",
51            Self::Replit => "replit",
52            Self::Gpt2 => "gpt2",
53            Self::Rwkv => "rwkv",
54        }
55    }
56}
57
58impl TokenizerMetadata {
59    pub fn from_gguf(
60        gguf: &crate::local_model::gguf::tools::gguf_file::GgufFile,
61    ) -> crate::Result<Self> {
62        if gguf
63            .get_value::<Option<String>>("tokenizer.ggml.model")?
64            .is_some()
65        {
66            return Ok(Self {
67                ggml: Some(GgmlTokenizerMetadata::from_gguf(gguf)?),
68                huggingface_json: gguf.get_value("tokenizer.huggingface.json")?,
69                rwkv_world: gguf.get_value("tokenizer.rwkv_world")?,
70                chat_template: gguf.get_value("tokenizer.chat_template")?,
71            });
72        }
73        Ok(Self {
74            ggml: None,
75            huggingface_json: gguf.get_value("tokenizer.huggingface.json")?,
76            rwkv_world: gguf.get_value("tokenizer.rwkv_world")?,
77            chat_template: gguf.get_value("tokenizer.chat_template")?,
78        })
79    }
80}
81
82impl GgmlTokenizerMetadata {
83    pub fn from_gguf(
84        gguf: &crate::local_model::gguf::tools::gguf_file::GgufFile,
85    ) -> crate::Result<Self> {
86        let model_string: String = gguf.get_value("tokenizer.ggml.model")?;
87
88        Ok(Self {
89            model: GgmlTokenizerModel::from_str(&model_string)?,
90            tokens: gguf.get_value("tokenizer.ggml.tokens")?,
91            scores: gguf.get_value("tokenizer.ggml.scores")?,
92            merges: gguf.get_value("tokenizer.ggml.merges")?,
93            added_tokens: gguf.get_value("tokenizer.ggml.added_tokens")?,
94            bos_token_id: gguf.get_value("tokenizer.ggml.bos_token_id")?,
95            eos_token_id: gguf.get_value("tokenizer.ggml.eos_token_id")?,
96            unknown_token_id: gguf.get_value("tokenizer.ggml.unknown_token_id")?,
97            separator_token_id: gguf.get_value("tokenizer.ggml.separator_token_id")?,
98            padding_token_id: gguf.get_value("tokenizer.ggml.padding_token_id")?,
99        })
100    }
101}
102
103impl std::fmt::Debug for TokenizerMetadata {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        let mut debug_struct = f.debug_struct("TokenizerMetadata");
106        if let Some(ggml) = &self.ggml {
107            debug_struct.field("GgmlTokenizerModel", &ggml.model.to_str());
108            debug_struct.field("bos_token_id", &ggml.bos_token_id);
109            debug_struct.field("eos_token_id", &ggml.eos_token_id);
110            if let Some(unknown_token_id) = ggml.unknown_token_id {
111                debug_struct.field("unknown_token_id", &unknown_token_id);
112            }
113            if let Some(separator_token_id) = ggml.separator_token_id {
114                debug_struct.field("separator_token_id", &separator_token_id);
115            }
116        }
117        debug_struct.finish()
118    }
119}