use std::path::PathBuf;
use hf_hub::api::sync::{Api, ApiBuilder};
use flodl::{Device, Graph, Result, TensorError};
use crate::models::auto::{
AutoConfig, AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification,
AutoModelForTokenClassification,
};
use crate::models::bert::{
BertConfig, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertModel,
};
use crate::models::distilbert::{
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertForSequenceClassification,
DistilBertForTokenClassification, DistilBertModel,
};
use crate::models::roberta::{
RobertaConfig, RobertaForQuestionAnswering, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaModel,
};
use crate::safetensors_io::{
bert_legacy_key_rename, load_safetensors_into_graph_with_rename_allow_unused,
};
#[cfg(feature = "tokenizer")]
use crate::tokenizer::HfTokenizer;
impl BertModel {
pub fn from_pretrained(repo_id: &str) -> Result<Graph> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Graph> {
let (config, weights) = fetch_bert_config_and_weights(repo_id)?;
let graph = BertModel::on_device(&config, device)?;
load_weights_with_logging(repo_id, &graph, &weights)?;
Ok(graph)
}
}
const HF_HOME_ENV: &str = "HF_HOME";
fn default_hf_home() -> PathBuf {
if let Some(home) = std::env::var_os("HOME") {
PathBuf::from(home).join(".cache").join("huggingface")
} else {
PathBuf::from("/tmp/huggingface")
}
}
fn flodl_converted_path(repo_id: &str) -> PathBuf {
let hf_home = std::env::var_os(HF_HOME_ENV)
.map(PathBuf::from)
.unwrap_or_else(default_hf_home);
hf_home
.join("flodl-converted")
.join(repo_id)
.join("model.safetensors")
}
fn fetch_safetensors(api: &Api, repo_id: &str) -> Result<PathBuf> {
let converted = flodl_converted_path(repo_id);
if converted.exists() {
eprintln!(
"from_pretrained({repo_id}): using flodl-converted safetensors at {}",
converted.display(),
);
return Ok(converted);
}
api.model(repo_id.to_string())
.get("model.safetensors")
.map_err(|e| {
TensorError::new(&format!(
"hf-hub fetch {repo_id}/model.safetensors: {e}\n\
If this repo ships only `pytorch_model.bin`, convert it first:\n \
fdl flodl-hf convert {repo_id}",
))
})
}
fn fetch_config_str_and_weights(repo_id: &str) -> Result<(String, Vec<u8>)> {
let api = ApiBuilder::from_env()
.build()
.map_err(|e| TensorError::new(&format!("hf-hub init: {e}")))?;
let repo = api.model(repo_id.to_string());
let config_path = repo.get("config.json").map_err(|e| {
TensorError::new(&format!("hf-hub fetch {repo_id}/config.json: {e}"))
})?;
let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
TensorError::new(&format!("read {}: {e}", config_path.display()))
})?;
let weights_path = fetch_safetensors(&api, repo_id)?;
let weights = std::fs::read(&weights_path).map_err(|e| {
TensorError::new(&format!("read {}: {e}", weights_path.display()))
})?;
Ok((config_str, weights))
}
fn fetch_bert_config_and_weights(repo_id: &str) -> Result<(BertConfig, Vec<u8>)> {
let (config_str, weights) = fetch_config_str_and_weights(repo_id)?;
let config = BertConfig::from_json_str(&config_str)?;
Ok((config, weights))
}
fn fetch_roberta_config_and_weights(repo_id: &str) -> Result<(RobertaConfig, Vec<u8>)> {
let (config_str, weights) = fetch_config_str_and_weights(repo_id)?;
let config = RobertaConfig::from_json_str(&config_str)?;
Ok((config, weights))
}
fn fetch_distilbert_config_and_weights(repo_id: &str) -> Result<(DistilBertConfig, Vec<u8>)> {
let (config_str, weights) = fetch_config_str_and_weights(repo_id)?;
let config = DistilBertConfig::from_json_str(&config_str)?;
Ok((config, weights))
}
#[cfg(feature = "tokenizer")]
fn try_load_tokenizer(repo_id: &str) -> Option<HfTokenizer> {
match HfTokenizer::from_pretrained(repo_id) {
Ok(tok) => Some(tok),
Err(e) => {
let terse = if e.to_string().contains("404") {
"no tokenizer.json on Hub".to_string()
} else {
e.to_string()
};
eprintln!(
"from_pretrained({repo_id}): tokenizer not attached ({terse}) \
— predict()/answer() need .with_tokenizer()",
);
None
}
}
}
fn load_weights_with_logging(
repo_id: &str,
graph: &Graph,
bytes: &[u8],
) -> Result<()> {
let unused = load_safetensors_into_graph_with_rename_allow_unused(
graph, bytes, bert_legacy_key_rename,
)?;
if !unused.is_empty() {
eprintln!(
"from_pretrained({repo_id}): ignored {} checkpoint key(s) not used by the model:",
unused.len(),
);
for k in unused.iter().take(20) {
eprintln!(" - {k}");
}
if unused.len() > 20 {
eprintln!(" ... and {} more", unused.len() - 20);
}
}
Ok(())
}
impl BertForSequenceClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_bert_config_and_weights(repo_id)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl BertForTokenClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_bert_config_and_weights(repo_id)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl BertForQuestionAnswering {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_bert_config_and_weights(repo_id)?;
let head = Self::on_device(&config, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl RobertaModel {
pub fn from_pretrained(repo_id: &str) -> Result<Graph> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Graph> {
let (config, weights) = fetch_roberta_config_and_weights(repo_id)?;
let graph = RobertaModel::on_device_without_pooler(&config, device)?;
load_weights_with_logging(repo_id, &graph, &weights)?;
Ok(graph)
}
}
impl RobertaForSequenceClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_roberta_config_and_weights(repo_id)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl RobertaForTokenClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_roberta_config_and_weights(repo_id)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl RobertaForQuestionAnswering {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_roberta_config_and_weights(repo_id)?;
let head = Self::on_device(&config, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl DistilBertModel {
pub fn from_pretrained(repo_id: &str) -> Result<Graph> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Graph> {
let (config, weights) = fetch_distilbert_config_and_weights(repo_id)?;
let graph = DistilBertModel::on_device(&config, device)?;
load_weights_with_logging(repo_id, &graph, &weights)?;
Ok(graph)
}
}
impl DistilBertForSequenceClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_distilbert_config_and_weights(repo_id)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl DistilBertForTokenClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_distilbert_config_and_weights(repo_id)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl DistilBertForQuestionAnswering {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_distilbert_config_and_weights(repo_id)?;
let head = Self::on_device(&config, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
fn fetch_auto_config_and_weights(repo_id: &str) -> Result<(AutoConfig, Vec<u8>)> {
let (config_str, weights) = fetch_config_str_and_weights(repo_id)?;
let config = AutoConfig::from_json_str(&config_str)?;
Ok((config, weights))
}
impl AutoModel {
pub fn from_pretrained(repo_id: &str) -> Result<Graph> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Graph> {
let (config, weights) = fetch_auto_config_and_weights(repo_id)?;
let graph = match config {
AutoConfig::Bert(c) => BertModel::on_device_without_pooler(&c, device)?,
AutoConfig::Roberta(c) => RobertaModel::on_device_without_pooler(&c, device)?,
AutoConfig::DistilBert(c) => DistilBertModel::on_device(&c, device)?,
};
load_weights_with_logging(repo_id, &graph, &weights)?;
Ok(graph)
}
}
impl AutoModelForSequenceClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_auto_config_and_weights(repo_id)?;
let head = match config {
AutoConfig::Bert(c) => {
let num_labels = BertForSequenceClassification::num_labels_from_config(&c)?;
let h = BertForSequenceClassification::on_device(&c, num_labels, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::Bert(h)
}
AutoConfig::Roberta(c) => {
let num_labels = RobertaForSequenceClassification::num_labels_from_config(&c)?;
let h = RobertaForSequenceClassification::on_device(&c, num_labels, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::Roberta(h)
}
AutoConfig::DistilBert(c) => {
let num_labels = DistilBertForSequenceClassification::num_labels_from_config(&c)?;
let h = DistilBertForSequenceClassification::on_device(&c, num_labels, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::DistilBert(h)
}
};
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl AutoModelForTokenClassification {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_auto_config_and_weights(repo_id)?;
let head = match config {
AutoConfig::Bert(c) => {
let num_labels = BertForTokenClassification::num_labels_from_config(&c)?;
let h = BertForTokenClassification::on_device(&c, num_labels, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::Bert(h)
}
AutoConfig::Roberta(c) => {
let num_labels = RobertaForTokenClassification::num_labels_from_config(&c)?;
let h = RobertaForTokenClassification::on_device(&c, num_labels, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::Roberta(h)
}
AutoConfig::DistilBert(c) => {
let num_labels = DistilBertForTokenClassification::num_labels_from_config(&c)?;
let h = DistilBertForTokenClassification::on_device(&c, num_labels, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::DistilBert(h)
}
};
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl AutoModelForQuestionAnswering {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_auto_config_and_weights(repo_id)?;
let head = match config {
AutoConfig::Bert(c) => {
let h = BertForQuestionAnswering::on_device(&c, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::Bert(h)
}
AutoConfig::Roberta(c) => {
let h = RobertaForQuestionAnswering::on_device(&c, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::Roberta(h)
}
AutoConfig::DistilBert(c) => {
let h = DistilBertForQuestionAnswering::on_device(&c, device)?;
load_weights_with_logging(repo_id, h.graph(), &weights)?;
Self::DistilBert(h)
}
};
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
#[cfg(feature = "tokenizer")]
impl HfTokenizer {
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
let api = ApiBuilder::from_env()
.build()
.map_err(|e| TensorError::new(&format!("hf-hub init: {e}")))?;
let repo = api.model(repo_id.to_string());
let path = repo.get("tokenizer.json").map_err(|e| {
TensorError::new(&format!("hf-hub fetch {repo_id}/tokenizer.json: {e}"))
})?;
Self::from_file(&path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "network + ~440MB cache write"]
fn bert_from_pretrained_live() {
use flodl::nn::Module;
use flodl::{DType, Tensor, TensorOptions, Variable};
use crate::models::bert::build_extended_attention_mask;
let graph = BertModel::from_pretrained("bert-base-uncased").unwrap();
graph.eval();
let dev = Device::CPU;
let batch = 1;
let seq = 4;
let input_ids = Variable::new(
Tensor::from_i64(&[101, 7592, 2088, 102], &[batch, seq], dev).unwrap(),
false,
);
let position_ids = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3], &[batch, seq], dev).unwrap(),
false,
);
let token_type_ids = Variable::new(
Tensor::from_i64(&[0, 0, 0, 0], &[batch, seq], dev).unwrap(),
false,
);
let mask_flat = Tensor::ones(&[batch, seq], TensorOptions {
dtype: DType::Float32, device: dev,
}).unwrap();
let attention_mask = Variable::new(
build_extended_attention_mask(&mask_flat).unwrap(),
false,
);
let out = graph
.forward_multi(&[input_ids, position_ids, token_type_ids, attention_mask])
.unwrap();
assert_eq!(out.shape(), vec![batch, 768]);
}
}