1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//! BERT-family `from_pretrained` impls (backbone + 4 task heads).
use flodl::{Device, Graph, Result};
use crate::models::bert::{
BertConfig, BertForMaskedLM, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertModel,
};
use crate::safetensors_io::weights_have_pooler;
use super::{fetch_config_and_weights, load_weights_with_logging};
#[cfg(feature = "tokenizer")]
use super::try_load_tokenizer;
impl BertModel {
/// Download a pretrained BERT checkpoint from the HuggingFace Hub and
/// return a fully-initialised [`Graph`] on CPU.
///
/// Convenience wrapper over [`BertModel::from_pretrained_on_device`]
/// with `Device::CPU`.
pub fn from_pretrained(repo_id: &str) -> Result<Graph> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
/// Download a pretrained BERT checkpoint from the HuggingFace Hub and
/// return a fully-initialised [`Graph`] on `device`.
///
/// Pulls `config.json` and `model.safetensors` from `repo_id` via
/// `hf_hub` (using its on-disk cache), parses the config, builds a
/// matching graph, and loads the safetensors weights.
///
/// Picks `on_device` (with pooler) when the checkpoint ships pooler
/// weights and `on_device_without_pooler` when it doesn't, so an
/// encoder-only BERT checkpoint loads strict without missing-key
/// errors and a pooler-bearing one keeps its `pooler_output` slot.
/// Reach for [`BertModel::on_device`] /
/// [`BertModel::on_device_without_pooler`] directly when the call
/// site needs a guaranteed shape regardless of what the Hub repo
/// happens to ship.
///
/// `repo_id` is the HF-style identifier, e.g. `"bert-base-uncased"`
/// or `"google-bert/bert-base-multilingual-cased"`.
///
/// Errors on: hub API init failure, network / HTTP failure,
/// config parse failure, shape/key mismatch against the built graph,
/// and any I/O error reading the cached safetensors file. Nothing
/// partial is returned — the graph is either fully loaded or the
/// call errors out.
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Graph> {
let (config, weights) = fetch_config_and_weights(repo_id, BertConfig::from_json_str)?;
let graph = if weights_have_pooler(&weights)? {
BertModel::on_device(&config, device)?
} else {
BertModel::on_device_without_pooler(&config, device)?
};
// HF base checkpoints (e.g. `bert-base-uncased`) ship as
// `BertForPreTraining`, which carries MLM + NSP heads that a
// bare `BertModel` has no slot for. `load_weights_with_logging`
// tolerates those and names them on stderr.
load_weights_with_logging(repo_id, &graph, &weights)?;
graph.set_source_config(config.with_architectures("BertModel").to_json_str());
Ok(graph)
}
}
impl BertForSequenceClassification {
/// Download a fine-tuned `BertForSequenceClassification` checkpoint
/// from the Hub and return a ready-to-use predictor on CPU.
///
/// The config must carry `num_labels` (or `id2label`) so the head's
/// output width is known. Popular checkpoints: `nateraw/bert-base-uncased-emotion`
/// (6 emotions, `.bin`-only — needs `fdl flodl-hf convert` first),
/// `nlptown/bert-base-multilingual-uncased-sentiment` (5-star rating,
/// safetensors on main), `unitary/toxic-bert` (6-label toxicity).
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
/// Device-aware variant of [`from_pretrained`](Self::from_pretrained).
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_config_and_weights(repo_id, BertConfig::from_json_str)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
head.graph().set_source_config(
config.with_architectures("BertForSequenceClassification").to_json_str(),
);
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl BertForTokenClassification {
/// Download a fine-tuned `BertForTokenClassification` checkpoint
/// (NER, POS tagging, …) from the Hub. Popular checkpoints:
/// `dslim/bert-base-NER`,
/// `dbmdz/bert-large-cased-finetuned-conll03-english`.
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_config_and_weights(repo_id, BertConfig::from_json_str)?;
let num_labels = Self::num_labels_from_config(&config)?;
let head = Self::on_device(&config, num_labels, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
head.graph().set_source_config(
config.with_architectures("BertForTokenClassification").to_json_str(),
);
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl BertForQuestionAnswering {
/// Download a fine-tuned `BertForQuestionAnswering` checkpoint
/// (SQuAD, etc.) from the Hub. Popular checkpoints:
/// `csarron/bert-base-uncased-squad-v1`,
/// `bert-large-uncased-whole-word-masking-finetuned-squad`.
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_config_and_weights(repo_id, BertConfig::from_json_str)?;
let head = Self::on_device(&config, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
head.graph().set_source_config(
config.with_architectures("BertForQuestionAnswering").to_json_str(),
);
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}
impl BertForMaskedLM {
/// Download a BERT MLM checkpoint (`bert-base-uncased`,
/// `bert-base-cased`, any `*-mlm` fine-tune) from the Hub.
///
/// The decoder weight is tied to the word-embedding table, so
/// checkpoints that redundantly save `cls.predictions.decoder.weight`
/// alongside `bert.embeddings.word_embeddings.weight` (HF's
/// historical save format) load cleanly — the decoder key is
/// ignored by `load_safetensors_into_graph_with_rename_allow_unused`,
/// and the embedding key populates the single tied Parameter.
/// Checkpoints that skip the redundant decoder key (modern HF saves)
/// also load cleanly.
pub fn from_pretrained(repo_id: &str) -> Result<Self> {
Self::from_pretrained_on_device(repo_id, Device::CPU)
}
pub fn from_pretrained_on_device(repo_id: &str, device: Device) -> Result<Self> {
let (config, weights) = fetch_config_and_weights(repo_id, BertConfig::from_json_str)?;
let head = Self::on_device(&config, device)?;
load_weights_with_logging(repo_id, head.graph(), &weights)?;
head.graph().set_source_config(
config.with_architectures("BertForMaskedLM").to_json_str(),
);
#[cfg(feature = "tokenizer")]
let head = match try_load_tokenizer(repo_id) {
Some(tok) => head.with_tokenizer(tok),
None => head,
};
Ok(head)
}
}