autotokenizer 0.1.4

我就只是想要rust能有一個簡單的,從hg上拉下config並製作chat prompt的,也這麼難!要我發明輪子,天啊!
Documentation
use minijinja::context;

mod from_pretrained;

pub use from_pretrained::FromPretrainedParameters;

/// A single conversation turn for use with [`AutoTokenizer::apply_chat_template`].
#[derive(Debug, serde::Deserialize, serde::Serialize, Clone)]
pub struct DefaultPromptMessage {
    pub role: String,
    pub content: String,
}

impl DefaultPromptMessage {
    pub fn new(role: &str, content: &str) -> Self {
        Self {
            role: role.to_string(),
            content: content.to_string(),
        }
    }
}

/// The expanded form of a special token as stored in `tokenizer_config.json`.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct TokenObj {
    #[serde(rename = "__type")]
    pub token_type: String,
    pub content: String,
    pub lstrip: bool,
    pub normalized: bool,
    pub rstrip: bool,
    pub single_word: bool,
}

/// A special token that can appear either as a plain string or as an object.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum Token {
    String(String),
    TokenObj(TokenObj),
}

impl Token {
    /// Return the token string regardless of which variant is held.
    pub fn content(&self) -> &str {
        match self {
            Token::String(s) => s.as_str(),
            Token::TokenObj(obj) => obj.content.as_str(),
        }
    }
}

/// One entry in a multi-template `chat_template` list.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct ChatTemplateEntry {
    pub name: String,
    pub template: String,
}

/// `chat_template` in `tokenizer_config.json` is either a plain Jinja string
/// or a list of `{"name": …, "template": …}` objects (e.g. for tool-use models).
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum ChatTemplate {
    Single(String),
    Multiple(Vec<ChatTemplateEntry>),
}

impl ChatTemplate {
    /// Resolve the Jinja template string by name.
    ///
    /// For `Single`, the name is ignored and the only template is returned.
    /// For `Multiple`, the entry whose `name` matches is returned; when `name`
    /// is `None` the `"default"` entry is tried first, then the first entry.
    pub fn resolve(&self, name: Option<&str>) -> Option<&str> {
        match self {
            ChatTemplate::Single(s) => Some(s.as_str()),
            ChatTemplate::Multiple(entries) => {
                if let Some(name) = name {
                    entries.iter().find(|e| e.name == name).map(|e| e.template.as_str())
                } else {
                    entries
                        .iter()
                        .find(|e| e.name == "default")
                        .or_else(|| entries.first())
                        .map(|e| e.template.as_str())
                }
            }
        }
    }
}

// Custom deserializer that accepts any JSON number for model_max_length and
// converts it to u64 where possible (e.g. Llama uses 1e30-scale sentinels).
fn deserialize_model_max_length<'de, D: serde::Deserializer<'de>>(
    d: D,
) -> Result<Option<u64>, D::Error> {
    let v: Option<serde_json::Value> = serde::Deserialize::deserialize(d)?;
    Ok(v.and_then(|v| v.as_u64()))
}

/// Mirrors the fields of `tokenizer_config.json` required for chat-template
/// rendering and basic token identity.  Unknown JSON keys are silently ignored.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct AutoTokenizer {
    pub tokenizer_class: Option<String>,
    #[serde(default, deserialize_with = "deserialize_model_max_length")]
    pub model_max_length: Option<u64>,
    pub padding_side: Option<String>,
    pub truncation_side: Option<String>,
    pub add_bos_token: Option<bool>,
    pub add_eos_token: Option<bool>,
    pub clean_up_tokenization_spaces: Option<bool>,
    pub legacy: Option<bool>,
    pub bos_token: Option<Token>,
    pub eos_token: Option<Token>,
    pub pad_token: Option<Token>,
    pub unk_token: Option<Token>,
    pub sep_token: Option<Token>,
    pub cls_token: Option<Token>,
    pub mask_token: Option<Token>,
    pub chat_template: Option<ChatTemplate>,
}

impl AutoTokenizer {
    /// Load a tokenizer config from a local `tokenizer_config.json` file.
    pub fn from_file<P: AsRef<std::path::Path>>(
        file: P,
    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
        let content = std::fs::read_to_string(file)?;
        Ok(serde_json::from_str(&content)?)
    }

    /// Download and cache `tokenizer_config.json` from the Hugging Face Hub,
    /// then parse it into an [`AutoTokenizer`].
    pub fn from_pretrained(
        identifier: impl AsRef<str>,
        params: Option<FromPretrainedParameters>,
    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
        let tokenizer_file = from_pretrained::from_pretrained(identifier, params)?;
        AutoTokenizer::from_file(tokenizer_file)
    }

    /// Render `messages` through the Jinja chat template.
    ///
    /// `template_name` selects a specific entry when the model provides
    /// multiple named templates (e.g. `Some("tool_use")`).  Pass `None` to
    /// use the default (or only) template.
    pub fn apply_chat_template<S: serde::Serialize>(
        &self,
        messages: S,
        add_generation_prompt: bool,
        template_name: Option<&str>,
    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
        let template_str = self
            .chat_template
            .as_ref()
            .and_then(|t| t.resolve(template_name))
            .ok_or("tokenizer_config.json does not contain a chat_template")?;

        let mut env = minijinja::Environment::new();
        env.add_template("chat", template_str)?;
        env.add_function(
            "raise_exception",
            |msg: String| -> Result<(), minijinja::Error> {
                Err(minijinja::Error::new(
                    minijinja::ErrorKind::UndefinedError,
                    msg,
                ))
            },
        );

        let eos = self.eos_token.as_ref().map_or("", Token::content);
        let bos = self.bos_token.as_ref().map_or("", Token::content);
        let pad = self.pad_token.as_ref().map_or("", Token::content);
        let unk = self.unk_token.as_ref().map_or("", Token::content);

        let tmpl = env.get_template("chat")?;
        Ok(tmpl.render(context! {
            messages => messages,
            eos_token => eos,
            bos_token => bos,
            pad_token => pad,
            unk_token => unk,
            add_generation_prompt => add_generation_prompt,
        })?)
    }
}