use flodl::{Result, TensorError};
use crate::models::albert::{
AlbertConfig, AlbertForMaskedLM, AlbertForQuestionAnswering,
AlbertForSequenceClassification, AlbertForTokenClassification,
};
use crate::models::bert::{
BertConfig, BertForMaskedLM, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification,
};
use crate::models::deberta_v2::{
DebertaV2Config, DebertaV2ForMaskedLM, DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification, DebertaV2ForTokenClassification,
};
use crate::models::distilbert::{
DistilBertConfig, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
DistilBertForSequenceClassification, DistilBertForTokenClassification,
};
use crate::models::roberta::{
RobertaConfig, RobertaForMaskedLM, RobertaForQuestionAnswering,
RobertaForSequenceClassification, RobertaForTokenClassification,
};
use crate::models::xlm_roberta::{
XlmRobertaConfig, XlmRobertaForMaskedLM, XlmRobertaForQuestionAnswering,
XlmRobertaForSequenceClassification, XlmRobertaForTokenClassification,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AutoConfig {
Bert(BertConfig),
Roberta(RobertaConfig),
DistilBert(DistilBertConfig),
XlmRoberta(XlmRobertaConfig),
Albert(AlbertConfig),
DebertaV2(DebertaV2Config),
}
impl AutoConfig {
pub fn from_json_str(s: &str) -> Result<Self> {
use crate::config_json::required_string;
let v: serde_json::Value = serde_json::from_str(s)
.map_err(|e| TensorError::new(&format!("config.json parse error: {e}")))?;
let model_type = required_string(&v, "model_type")?;
match model_type {
"bert" => Ok(AutoConfig::Bert(BertConfig::from_json_str(s)?)),
"roberta" => Ok(AutoConfig::Roberta(RobertaConfig::from_json_str(s)?)),
"distilbert" => Ok(AutoConfig::DistilBert(DistilBertConfig::from_json_str(s)?)),
"xlm-roberta" => Ok(AutoConfig::XlmRoberta(XlmRobertaConfig::from_json_str(s)?)),
"albert" => Ok(AutoConfig::Albert(AlbertConfig::from_json_str(s)?)),
"deberta-v2" => Ok(AutoConfig::DebertaV2(DebertaV2Config::from_json_str(s)?)),
other => Err(TensorError::new(&format!(
"AutoConfig: unsupported model_type {other:?}. \
Supported families: \"bert\", \"roberta\", \"distilbert\", \"xlm-roberta\", \
\"albert\", \"deberta-v2\". ModernBERT and other architectures are planned \
for a future release.",
))),
}
}
pub fn model_type(&self) -> &'static str {
match self {
AutoConfig::Bert(_) => "bert",
AutoConfig::Roberta(_) => "roberta",
AutoConfig::DistilBert(_) => "distilbert",
AutoConfig::XlmRoberta(_) => "xlm-roberta",
AutoConfig::Albert(_) => "albert",
AutoConfig::DebertaV2(_) => "deberta-v2",
}
}
pub fn to_json_str(&self) -> String {
match self {
AutoConfig::Bert(c) => c.to_json_str(),
AutoConfig::Roberta(c) => c.to_json_str(),
AutoConfig::DistilBert(c) => c.to_json_str(),
AutoConfig::XlmRoberta(c) => c.to_json_str(),
AutoConfig::Albert(c) => c.to_json_str(),
AutoConfig::DebertaV2(c) => c.to_json_str(),
}
}
pub fn architectures(&self) -> Option<&[String]> {
match self {
AutoConfig::Bert(c) => c.architectures.as_deref(),
AutoConfig::Roberta(c) => c.architectures.as_deref(),
AutoConfig::DistilBert(c) => c.architectures.as_deref(),
AutoConfig::XlmRoberta(c) => c.architectures.as_deref(),
AutoConfig::Albert(c) => c.architectures.as_deref(),
AutoConfig::DebertaV2(c) => c.architectures.as_deref(),
}
}
pub fn base_class_name(&self) -> &'static str {
match self {
AutoConfig::Bert(_) => "BertModel",
AutoConfig::Roberta(_) => "RobertaModel",
AutoConfig::DistilBert(_) => "DistilBertModel",
AutoConfig::XlmRoberta(_) => "XLMRobertaModel",
AutoConfig::Albert(_) => "AlbertModel",
AutoConfig::DebertaV2(_) => "DebertaV2Model",
}
}
pub fn into_normalized_config_json(self, arch_class: &str) -> String {
match self {
AutoConfig::Bert(c) => c.with_architectures(arch_class).to_json_str(),
AutoConfig::Roberta(c) => c.with_architectures(arch_class).to_json_str(),
AutoConfig::DistilBert(c) => c.with_architectures(arch_class).to_json_str(),
AutoConfig::XlmRoberta(c) => c.with_architectures(arch_class).to_json_str(),
AutoConfig::Albert(c) => c.with_architectures(arch_class).to_json_str(),
AutoConfig::DebertaV2(c) => c.with_architectures(arch_class).to_json_str(),
}
}
}
pub struct AutoModel;
#[non_exhaustive]
pub enum AutoModelForSequenceClassification {
Bert(BertForSequenceClassification),
Roberta(RobertaForSequenceClassification),
DistilBert(DistilBertForSequenceClassification),
XlmRoberta(XlmRobertaForSequenceClassification),
Albert(AlbertForSequenceClassification),
DebertaV2(DebertaV2ForSequenceClassification),
}
#[non_exhaustive]
pub enum AutoModelForTokenClassification {
Bert(BertForTokenClassification),
Roberta(RobertaForTokenClassification),
DistilBert(DistilBertForTokenClassification),
XlmRoberta(XlmRobertaForTokenClassification),
Albert(AlbertForTokenClassification),
DebertaV2(DebertaV2ForTokenClassification),
}
#[non_exhaustive]
pub enum AutoModelForQuestionAnswering {
Bert(BertForQuestionAnswering),
Roberta(RobertaForQuestionAnswering),
DistilBert(DistilBertForQuestionAnswering),
XlmRoberta(XlmRobertaForQuestionAnswering),
Albert(AlbertForQuestionAnswering),
DebertaV2(DebertaV2ForQuestionAnswering),
}
impl AutoModelForSequenceClassification {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
Self::XlmRoberta(h) => h.graph(),
Self::Albert(h) => h.graph(),
Self::DebertaV2(h) => h.graph(),
}
}
pub fn into_graph(self) -> flodl::Graph {
match self {
Self::Bert(h) => h.into_graph(),
Self::Roberta(h) => h.into_graph(),
Self::DistilBert(h) => h.into_graph(),
Self::XlmRoberta(h) => h.into_graph(),
Self::Albert(h) => h.into_graph(),
Self::DebertaV2(h) => h.into_graph(),
}
}
pub fn labels(&self) -> &[String] {
match self {
Self::Bert(h) => h.labels(),
Self::Roberta(h) => h.labels(),
Self::DistilBert(h) => h.labels(),
Self::XlmRoberta(h) => h.labels(),
Self::Albert(h) => h.labels(),
Self::DebertaV2(h) => h.labels(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
Self::XlmRoberta(h) => Self::XlmRoberta(h.with_tokenizer(tok)),
Self::Albert(h) => Self::Albert(h.with_tokenizer(tok)),
Self::DebertaV2(h) => Self::DebertaV2(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<(String, f32)>>> {
match self {
Self::Bert(h) => h.predict(texts),
Self::Roberta(h) => h.predict(texts),
Self::DistilBert(h) => h.predict(texts),
Self::XlmRoberta(h) => h.predict(texts),
Self::Albert(h) => h.predict(texts),
Self::DebertaV2(h) => h.predict(texts),
}
}
}
impl AutoModelForTokenClassification {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
Self::XlmRoberta(h) => h.graph(),
Self::Albert(h) => h.graph(),
Self::DebertaV2(h) => h.graph(),
}
}
pub fn into_graph(self) -> flodl::Graph {
match self {
Self::Bert(h) => h.into_graph(),
Self::Roberta(h) => h.into_graph(),
Self::DistilBert(h) => h.into_graph(),
Self::XlmRoberta(h) => h.into_graph(),
Self::Albert(h) => h.into_graph(),
Self::DebertaV2(h) => h.into_graph(),
}
}
pub fn labels(&self) -> &[String] {
match self {
Self::Bert(h) => h.labels(),
Self::Roberta(h) => h.labels(),
Self::DistilBert(h) => h.labels(),
Self::XlmRoberta(h) => h.labels(),
Self::Albert(h) => h.labels(),
Self::DebertaV2(h) => h.labels(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
Self::XlmRoberta(h) => Self::XlmRoberta(h.with_tokenizer(tok)),
Self::Albert(h) => Self::Albert(h.with_tokenizer(tok)),
Self::DebertaV2(h) => Self::DebertaV2(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn predict(
&self,
texts: &[&str],
) -> Result<Vec<Vec<crate::task_heads::TokenPrediction>>> {
match self {
Self::Bert(h) => h.predict(texts),
Self::Roberta(h) => h.predict(texts),
Self::DistilBert(h) => h.predict(texts),
Self::XlmRoberta(h) => h.predict(texts),
Self::Albert(h) => h.predict(texts),
Self::DebertaV2(h) => h.predict(texts),
}
}
}
impl AutoModelForQuestionAnswering {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
Self::XlmRoberta(h) => h.graph(),
Self::Albert(h) => h.graph(),
Self::DebertaV2(h) => h.graph(),
}
}
pub fn into_graph(self) -> flodl::Graph {
match self {
Self::Bert(h) => h.into_graph(),
Self::Roberta(h) => h.into_graph(),
Self::DistilBert(h) => h.into_graph(),
Self::XlmRoberta(h) => h.into_graph(),
Self::Albert(h) => h.into_graph(),
Self::DebertaV2(h) => h.into_graph(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
Self::XlmRoberta(h) => Self::XlmRoberta(h.with_tokenizer(tok)),
Self::Albert(h) => Self::Albert(h.with_tokenizer(tok)),
Self::DebertaV2(h) => Self::DebertaV2(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn answer(
&self,
question: &str,
context: &str,
) -> Result<crate::task_heads::Answer> {
match self {
Self::Bert(h) => h.answer(question, context),
Self::Roberta(h) => h.answer(question, context),
Self::DistilBert(h) => h.answer(question, context),
Self::XlmRoberta(h) => h.answer(question, context),
Self::Albert(h) => h.answer(question, context),
Self::DebertaV2(h) => h.answer(question, context),
}
}
#[cfg(feature = "tokenizer")]
pub fn answer_batch(
&self,
pairs: &[(&str, &str)],
) -> Result<Vec<crate::task_heads::Answer>> {
match self {
Self::Bert(h) => h.answer_batch(pairs),
Self::Roberta(h) => h.answer_batch(pairs),
Self::DistilBert(h) => h.answer_batch(pairs),
Self::XlmRoberta(h) => h.answer_batch(pairs),
Self::Albert(h) => h.answer_batch(pairs),
Self::DebertaV2(h) => h.answer_batch(pairs),
}
}
}
#[non_exhaustive]
pub enum AutoModelForMaskedLM {
Bert(BertForMaskedLM),
Roberta(RobertaForMaskedLM),
DistilBert(DistilBertForMaskedLM),
XlmRoberta(XlmRobertaForMaskedLM),
Albert(AlbertForMaskedLM),
DebertaV2(DebertaV2ForMaskedLM),
}
impl AutoModelForMaskedLM {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
Self::XlmRoberta(h) => h.graph(),
Self::Albert(h) => h.graph(),
Self::DebertaV2(h) => h.graph(),
}
}
pub fn into_graph(self) -> flodl::Graph {
match self {
Self::Bert(h) => h.into_graph(),
Self::Roberta(h) => h.into_graph(),
Self::DistilBert(h) => h.into_graph(),
Self::XlmRoberta(h) => h.into_graph(),
Self::Albert(h) => h.into_graph(),
Self::DebertaV2(h) => h.into_graph(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
Self::XlmRoberta(h) => Self::XlmRoberta(h.with_tokenizer(tok)),
Self::Albert(h) => Self::Albert(h.with_tokenizer(tok)),
Self::DebertaV2(h) => Self::DebertaV2(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn fill_mask(
&self,
text: &str,
top_k: usize,
) -> Result<Vec<Vec<(String, f32)>>> {
match self {
Self::Bert(h) => h.fill_mask(text, top_k),
Self::Roberta(h) => h.fill_mask(text, top_k),
Self::DistilBert(h) => h.fill_mask(text, top_k),
Self::XlmRoberta(h) => h.fill_mask(text, top_k),
Self::Albert(h) => h.fill_mask(text, top_k),
Self::DebertaV2(h) => h.fill_mask(text, top_k),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_config_dispatches_bert() {
let json = r#"{
"model_type": "bert",
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"pad_token_id": 0
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "bert");
match c {
AutoConfig::Bert(b) => {
assert_eq!(b.vocab_size, 30522);
assert_eq!(b.hidden_size, 768);
}
other => panic!("expected Bert, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_dispatches_roberta() {
let json = r#"{
"model_type": "roberta",
"vocab_size": 50265,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 514,
"type_vocab_size": 1,
"pad_token_id": 1
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "roberta");
match c {
AutoConfig::Roberta(r) => {
assert_eq!(r.vocab_size, 50265);
assert_eq!(r.pad_token_id, 1);
}
other => panic!("expected Roberta, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_dispatches_distilbert() {
let json = r#"{
"model_type": "distilbert",
"vocab_size": 30522,
"dim": 768,
"n_layers": 6,
"n_heads": 12,
"hidden_dim": 3072,
"max_position_embeddings": 512,
"pad_token_id": 0
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "distilbert");
match c {
AutoConfig::DistilBert(d) => {
assert_eq!(d.vocab_size, 30522);
assert_eq!(d.n_layers, 6);
}
other => panic!("expected DistilBert, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_dispatches_xlm_roberta() {
let json = r#"{
"model_type": "xlm-roberta",
"vocab_size": 250002,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 514,
"type_vocab_size": 1,
"pad_token_id": 1
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "xlm-roberta");
match c {
AutoConfig::XlmRoberta(x) => {
assert_eq!(x.vocab_size, 250_002);
assert_eq!(x.pad_token_id, 1);
}
other => panic!("expected XlmRoberta, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_rejects_unknown_model_type() {
let json = r#"{
"model_type": "modernbert",
"vocab_size": 50368,
"hidden_size": 768
}"#;
let err = AutoConfig::from_json_str(json).unwrap_err().to_string();
assert!(err.contains("modernbert"), "error names offending type: {err}");
assert!(err.contains("bert"), "error lists supported: {err}");
assert!(err.contains("roberta"), "error lists supported: {err}");
assert!(err.contains("distilbert"), "error lists supported: {err}");
assert!(err.contains("xlm-roberta"), "error lists supported: {err}");
assert!(err.contains("albert"), "error lists supported: {err}");
assert!(err.contains("deberta-v2"), "error lists supported: {err}");
}
#[test]
fn auto_config_dispatches_deberta_v2() {
let json = r#"{
"model_type": "deberta-v2",
"vocab_size": 128100,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"relative_attention": true,
"position_buckets": 256,
"norm_rel_ebd": "layer_norm",
"share_att_key": true,
"pos_att_type": "p2c|c2p",
"layer_norm_eps": 1e-7,
"max_relative_positions": -1,
"position_biased_input": false,
"type_vocab_size": 0
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "deberta-v2");
match c {
AutoConfig::DebertaV2(d) => {
assert_eq!(d.vocab_size, 128_100);
assert_eq!(d.hidden_size, 768);
assert_eq!(d.position_buckets, 256);
assert_eq!(d.max_relative_positions, 512);
}
other => panic!("expected DebertaV2, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_dispatches_albert() {
let json = r#"{
"model_type": "albert",
"vocab_size": 30000,
"embedding_size": 128,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "albert");
match c {
AutoConfig::Albert(a) => {
assert_eq!(a.vocab_size, 30000);
assert_eq!(a.embedding_size, 128);
assert_eq!(a.hidden_size, 768);
}
other => panic!("expected Albert, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_rejects_missing_model_type() {
let json = r#"{
"vocab_size": 30522,
"hidden_size": 768
}"#;
let err = AutoConfig::from_json_str(json).unwrap_err().to_string();
assert!(
err.contains("model_type"),
"error must name the missing field: {err}",
);
}
#[test]
fn auto_config_rejects_invalid_json() {
let err = AutoConfig::from_json_str("not json").unwrap_err().to_string();
assert!(err.contains("parse error"), "got: {err}");
}
}