use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Repo, RepoType};
use serde::{Deserialize, Serialize};
use tokenizers::{FromPretrainedParameters, Tokenizer};
use crate::primitives::*;
const COMMON_LOCATIONS: &[EosTokenLocation] = &[
EosTokenLocation {
file: "generation_config.json",
location: EosTokenField::Id,
},
EosTokenLocation {
file: "tokenizer_config.json",
location: EosTokenField::Value,
},
EosTokenLocation {
file: "tokenizer_config.json",
location: EosTokenField::Object,
},
];
#[derive(Debug, Serialize, Deserialize)]
struct Id {
eos_token_id: u64,
}
#[derive(Debug, Serialize, Deserialize)]
struct Value {
eos_token: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct Object {
eos_token: Content,
}
#[derive(Debug, Serialize, Deserialize)]
struct Content {
content: String,
}
enum EosTokenField {
Id,
Value,
Object,
}
struct EosTokenLocation {
file: &'static str,
location: EosTokenField,
}
pub(crate) trait Locator {
fn locate_eos_token_id(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId>;
}
pub(crate) struct HFLocator;
impl Locator for HFLocator {
fn locate_eos_token_id(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
COMMON_LOCATIONS
.iter()
.find_map(|location| location.lookup(model, tokenizer, parameters))
}
}
impl EosTokenLocation {
fn lookup(
&self,
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
let file_path = Self::download_config(model, self.file, parameters).ok()?;
let file = std::fs::File::open(file_path).ok()?;
match self.location {
EosTokenField::Id => {
let config: Id = serde_json::from_reader(file).ok()?;
u32::try_from(config.eos_token_id).ok()
}
EosTokenField::Value => {
let config: Value = serde_json::from_reader(file).ok()?;
tokenizer.token_to_id(&config.eos_token)
}
EosTokenField::Object => {
let config: Object = serde_json::from_reader(file).ok()?;
tokenizer.token_to_id(&config.eos_token.content)
}
}
}
fn download_config(
project: &str,
file: &str,
parameters: &Option<FromPretrainedParameters>,
) -> tokenizers::Result<std::path::PathBuf> {
let params = parameters.clone().unwrap_or_default();
Self::validate(project)?;
Self::validate(¶ms.revision)?;
let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision);
let api = ApiBuilder::new()
.with_token(params.token)
.build()?
.repo(repo);
Ok(api.get(file)?)
}
fn validate(input: &str) -> tokenizers::Result<()> {
let valid_chars = ['-', '_', '.', '/'];
if !input
.chars()
.all(|c: char| c.is_alphanumeric() || valid_chars.contains(&c))
{
return Err(format!(
"Input {input} contains invalid characters, expected only alphanumeric or {}",
valid_chars
.iter()
.map(|x| format!("'{}'", x))
.collect::<Vec<_>>()
.join(", ")
)
.into());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn common_locations() {
for (model, expected_token_id, expected_token) in &[
("openai-community/gpt2", 50256, "<|endoftext|>"),
("microsoft/phi-2", 50256, "<|endoftext|>"),
("hf-internal-testing/llama-tokenizer", 2, "</s>"),
] {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let located = HFLocator::locate_eos_token_id(model, &tokenizer, &None)
.expect("Token id is not located");
assert_eq!(located, *expected_token_id);
assert_eq!(
tokenizer.id_to_token(located).expect("Token is not found"),
expected_token.to_string()
);
}
}
#[test]
fn bad_location() {
let bad_location = EosTokenLocation {
file: "tokenizer_config.json",
location: EosTokenField::Id,
};
let model = "microsoft/phi-2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let token_id = bad_location.lookup(model, &tokenizer, &None);
assert!(token_id.is_none());
let bad_file = EosTokenLocation {
file: "generation_config.json",
location: EosTokenField::Value,
};
let token_id = bad_file.lookup(model, &tokenizer, &None);
assert!(token_id.is_none());
}
#[test]
fn validate_config_input() {
let input = "bad_model_name*";
assert!(EosTokenLocation::validate(input).is_err());
}
}