use std::path::PathBuf;
use hf_hub::api::sync::{Api, ApiBuilder};
use flodl::{Graph, Result, TensorError};
use crate::safetensors_io::{
bert_legacy_key_rename, load_safetensors_into_graph_with_rename_allow_unused,
};
#[cfg(feature = "tokenizer")]
use crate::tokenizer::HfTokenizer;
mod albert;
mod auto;
mod bert;
mod deberta_v2;
mod distilbert;
mod roberta;
mod xlm_roberta;
pub use auto::HubExportHead;
const HF_HOME_ENV: &str = "HF_HOME";
fn default_hf_home() -> PathBuf {
if let Some(home) = std::env::var_os("HOME") {
PathBuf::from(home).join(".cache").join("huggingface")
} else {
PathBuf::from("/tmp/huggingface")
}
}
fn flodl_converted_path(repo_id: &str) -> PathBuf {
let hf_home = std::env::var_os(HF_HOME_ENV)
.map(PathBuf::from)
.unwrap_or_else(default_hf_home);
hf_home
.join("flodl-converted")
.join(repo_id)
.join("model.safetensors")
}
fn fetch_safetensors(api: &Api, repo_id: &str) -> Result<PathBuf> {
let converted = flodl_converted_path(repo_id);
if converted.exists() {
eprintln!(
"from_pretrained({repo_id}): using flodl-converted safetensors at {}",
converted.display(),
);
return Ok(converted);
}
api.model(repo_id.to_string())
.get("model.safetensors")
.map_err(|e| {
TensorError::new(&format!(
"hf-hub fetch {repo_id}/model.safetensors: {e}\n\
If this repo ships only `pytorch_model.bin`, convert it first:\n \
fdl flodl-hf convert {repo_id}",
))
})
}
fn fetch_config_str(repo_id: &str) -> Result<String> {
let api = ApiBuilder::from_env()
.build()
.map_err(|e| TensorError::new(&format!("hf-hub init: {e}")))?;
let repo = api.model(repo_id.to_string());
let config_path = repo.get("config.json").map_err(|e| {
TensorError::new(&format!("hf-hub fetch {repo_id}/config.json: {e}"))
})?;
std::fs::read_to_string(&config_path).map_err(|e| {
TensorError::new(&format!("read {}: {e}", config_path.display()))
})
}
fn fetch_config_str_and_weights(repo_id: &str) -> Result<(String, Vec<u8>)> {
let config_str = fetch_config_str(repo_id)?;
let api = ApiBuilder::from_env()
.build()
.map_err(|e| TensorError::new(&format!("hf-hub init: {e}")))?;
let weights_path = fetch_safetensors(&api, repo_id)?;
let weights = std::fs::read(&weights_path).map_err(|e| {
TensorError::new(&format!("read {}: {e}", weights_path.display()))
})?;
Ok((config_str, weights))
}
fn fetch_config_and_weights<C, F>(repo_id: &str, parse: F) -> Result<(C, Vec<u8>)>
where
F: FnOnce(&str) -> Result<C>,
{
let (config_str, weights) = fetch_config_str_and_weights(repo_id)?;
let config = parse(&config_str)?;
Ok((config, weights))
}
#[cfg(feature = "tokenizer")]
fn try_load_tokenizer(repo_id: &str) -> Option<HfTokenizer> {
match HfTokenizer::from_pretrained(repo_id) {
Ok(tok) => Some(tok),
Err(e) => {
let terse = if e.to_string().contains("404") {
"no tokenizer.json on Hub".to_string()
} else {
e.to_string()
};
eprintln!(
"from_pretrained({repo_id}): tokenizer not attached ({terse}) \
— predict()/answer() need .with_tokenizer()",
);
None
}
}
}
fn load_weights_with_logging(
repo_id: &str,
graph: &Graph,
bytes: &[u8],
) -> Result<()> {
let unused = load_safetensors_into_graph_with_rename_allow_unused(
graph, bytes, bert_legacy_key_rename,
)?;
if !unused.is_empty() {
eprintln!(
"from_pretrained({repo_id}): ignored {} checkpoint key(s) not used by the model:",
unused.len(),
);
for k in unused.iter().take(20) {
eprintln!(" - {k}");
}
if unused.len() > 20 {
eprintln!(" ... and {} more", unused.len() - 20);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::models::bert::BertModel;
use flodl::Device;
#[test]
#[ignore = "network + ~440MB cache write"]
fn bert_from_pretrained_live() {
use flodl::nn::Module;
use flodl::{DType, Tensor, TensorOptions, Variable};
use crate::models::bert::build_extended_attention_mask;
let graph = BertModel::from_pretrained("bert-base-uncased").unwrap();
graph.eval();
let dev = Device::CPU;
let batch = 1;
let seq = 4;
let input_ids = Variable::new(
Tensor::from_i64(&[101, 7592, 2088, 102], &[batch, seq], dev).unwrap(),
false,
);
let position_ids = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3], &[batch, seq], dev).unwrap(),
false,
);
let token_type_ids = Variable::new(
Tensor::from_i64(&[0, 0, 0, 0], &[batch, seq], dev).unwrap(),
false,
);
let mask_flat = Tensor::ones(&[batch, seq], TensorOptions {
dtype: DType::Float32, device: dev,
}).unwrap();
let attention_mask = Variable::new(
build_extended_attention_mask(&mask_flat).unwrap(),
false,
);
let out = graph
.forward_multi(&[input_ids, position_ids, token_type_ids, attention_mask])
.unwrap();
assert_eq!(out.shape(), vec![batch, 768]);
}
fn tiny_bert_config() -> crate::models::bert::BertConfig {
crate::models::bert::BertConfig {
vocab_size: 32,
hidden_size: 16,
num_hidden_layers: 1,
num_attention_heads: 4,
intermediate_size: 32,
max_position_embeddings: 8,
type_vocab_size: 2,
pad_token_id: Some(0),
layer_norm_eps: 1e-12,
hidden_dropout_prob: 0.0,
attention_probs_dropout_prob: 0.0,
hidden_act: flodl::nn::GeluApprox::Exact,
num_labels: None,
id2label: None,
architectures: None,
}
}
fn sidecar_for_checkpoint(checkpoint: &str) -> std::path::PathBuf {
let mut p = std::path::PathBuf::from(checkpoint);
if p.extension().and_then(|e| e.to_str()) == Some("gz") {
p.set_extension("");
}
p.set_extension("config.json");
p
}
#[test]
fn head_save_checkpoint_emits_normalised_architectures_sidecar() {
use crate::export::build_for_export;
use crate::models::auto::AutoConfig;
use crate::models::bert::BertForMaskedLM;
let upstream = tiny_bert_config().with_architectures("BertForPreTraining");
let head = BertForMaskedLM::on_device(&upstream, Device::CPU).unwrap();
head.graph().set_source_config(
upstream.with_architectures("BertForMaskedLM").to_json_str(),
);
let pid = std::process::id();
let ckpt = std::env::temp_dir().join(format!("flodl_hf_mlm_norm_{pid}.fdl"));
let ckpt_str = ckpt.to_string_lossy().to_string();
head.graph().save_checkpoint(&ckpt_str).unwrap();
let sidecar = sidecar_for_checkpoint(&ckpt_str);
let sidecar_str = std::fs::read_to_string(&sidecar).unwrap();
let parsed = AutoConfig::from_json_str(&sidecar_str).unwrap();
let arch = parsed.architectures().unwrap();
assert_eq!(
arch,
["BertForMaskedLM"],
"save_checkpoint sidecar must reflect the head class actually built; \
without the with_architectures call upstream's BertForPreTraining \
would leak through and fail classify_architecture on re-export",
);
let rebuilt = build_for_export(&parsed, false, Device::CPU).unwrap();
assert_eq!(
rebuilt.structural_hash(),
head.graph().structural_hash(),
"build_for_export from sidecar must rebuild the same MLM topology",
);
let _ = std::fs::remove_file(&ckpt);
let _ = std::fs::remove_file(&sidecar);
}
}