use minijinja::context;
mod from_pretrained;
pub use from_pretrained::FromPretrainedParameters;
#[derive(Debug, serde::Deserialize, serde::Serialize, Clone)]
pub struct DefaultPromptMessage {
pub role: String,
pub content: String,
}
impl DefaultPromptMessage {
pub fn new(role: &str, content: &str) -> Self {
Self {
role: role.to_string(),
content: content.to_string(),
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct TokenObj {
#[serde(rename = "__type")]
pub token_type: String,
pub content: String,
pub lstrip: bool,
pub normalized: bool,
pub rstrip: bool,
pub single_word: bool,
}
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum Token {
String(String),
TokenObj(TokenObj),
}
impl Token {
pub fn content(&self) -> &str {
match self {
Token::String(s) => s.as_str(),
Token::TokenObj(obj) => obj.content.as_str(),
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct ChatTemplateEntry {
pub name: String,
pub template: String,
}
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum ChatTemplate {
Single(String),
Multiple(Vec<ChatTemplateEntry>),
}
impl ChatTemplate {
pub fn resolve(&self, name: Option<&str>) -> Option<&str> {
match self {
ChatTemplate::Single(s) => Some(s.as_str()),
ChatTemplate::Multiple(entries) => {
if let Some(name) = name {
entries.iter().find(|e| e.name == name).map(|e| e.template.as_str())
} else {
entries
.iter()
.find(|e| e.name == "default")
.or_else(|| entries.first())
.map(|e| e.template.as_str())
}
}
}
}
}
fn deserialize_model_max_length<'de, D: serde::Deserializer<'de>>(
d: D,
) -> Result<Option<u64>, D::Error> {
let v: Option<serde_json::Value> = serde::Deserialize::deserialize(d)?;
Ok(v.and_then(|v| v.as_u64()))
}
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct AutoTokenizer {
pub tokenizer_class: Option<String>,
#[serde(default, deserialize_with = "deserialize_model_max_length")]
pub model_max_length: Option<u64>,
pub padding_side: Option<String>,
pub truncation_side: Option<String>,
pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>,
pub clean_up_tokenization_spaces: Option<bool>,
pub legacy: Option<bool>,
pub bos_token: Option<Token>,
pub eos_token: Option<Token>,
pub pad_token: Option<Token>,
pub unk_token: Option<Token>,
pub sep_token: Option<Token>,
pub cls_token: Option<Token>,
pub mask_token: Option<Token>,
pub chat_template: Option<ChatTemplate>,
}
impl AutoTokenizer {
pub fn from_file<P: AsRef<std::path::Path>>(
file: P,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let content = std::fs::read_to_string(file)?;
Ok(serde_json::from_str(&content)?)
}
pub fn from_pretrained(
identifier: impl AsRef<str>,
params: Option<FromPretrainedParameters>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let tokenizer_file = from_pretrained::from_pretrained(identifier, params)?;
AutoTokenizer::from_file(tokenizer_file)
}
pub fn apply_chat_template<S: serde::Serialize>(
&self,
messages: S,
add_generation_prompt: bool,
template_name: Option<&str>,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let template_str = self
.chat_template
.as_ref()
.and_then(|t| t.resolve(template_name))
.ok_or("tokenizer_config.json does not contain a chat_template")?;
let mut env = minijinja::Environment::new();
env.add_template("chat", template_str)?;
env.add_function(
"raise_exception",
|msg: String| -> Result<(), minijinja::Error> {
Err(minijinja::Error::new(
minijinja::ErrorKind::UndefinedError,
msg,
))
},
);
let eos = self.eos_token.as_ref().map_or("", Token::content);
let bos = self.bos_token.as_ref().map_or("", Token::content);
let pad = self.pad_token.as_ref().map_or("", Token::content);
let unk = self.unk_token.as_ref().map_or("", Token::content);
let tmpl = env.get_template("chat")?;
Ok(tmpl.render(context! {
messages => messages,
eos_token => eos,
bos_token => bos,
pad_token => pad,
unk_token => unk,
add_generation_prompt => add_generation_prompt,
})?)
}
}