use flodl::{Result, TensorError};
use crate::models::bert::{
BertConfig, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification,
};
use crate::models::distilbert::{
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertForSequenceClassification,
DistilBertForTokenClassification,
};
use crate::models::roberta::{
RobertaConfig, RobertaForQuestionAnswering, RobertaForSequenceClassification,
RobertaForTokenClassification,
};
#[derive(Debug, Clone)]
pub enum AutoConfig {
Bert(BertConfig),
Roberta(RobertaConfig),
DistilBert(DistilBertConfig),
}
impl AutoConfig {
pub fn from_json_str(s: &str) -> Result<Self> {
use crate::config_json::required_string;
let v: serde_json::Value = serde_json::from_str(s)
.map_err(|e| TensorError::new(&format!("config.json parse error: {e}")))?;
let model_type = required_string(&v, "model_type")?;
match model_type {
"bert" => Ok(AutoConfig::Bert(BertConfig::from_json_str(s)?)),
"roberta" => Ok(AutoConfig::Roberta(RobertaConfig::from_json_str(s)?)),
"distilbert" => Ok(AutoConfig::DistilBert(DistilBertConfig::from_json_str(s)?)),
other => Err(TensorError::new(&format!(
"AutoConfig: unsupported model_type {other:?}. \
Supported families: \"bert\", \"roberta\", \"distilbert\". \
ModernBERT and other architectures are planned for a future release.",
))),
}
}
pub fn model_type(&self) -> &'static str {
match self {
AutoConfig::Bert(_) => "bert",
AutoConfig::Roberta(_) => "roberta",
AutoConfig::DistilBert(_) => "distilbert",
}
}
}
pub struct AutoModel;
pub enum AutoModelForSequenceClassification {
Bert(BertForSequenceClassification),
Roberta(RobertaForSequenceClassification),
DistilBert(DistilBertForSequenceClassification),
}
pub enum AutoModelForTokenClassification {
Bert(BertForTokenClassification),
Roberta(RobertaForTokenClassification),
DistilBert(DistilBertForTokenClassification),
}
pub enum AutoModelForQuestionAnswering {
Bert(BertForQuestionAnswering),
Roberta(RobertaForQuestionAnswering),
DistilBert(DistilBertForQuestionAnswering),
}
impl AutoModelForSequenceClassification {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
}
}
pub fn labels(&self) -> &[String] {
match self {
Self::Bert(h) => h.labels(),
Self::Roberta(h) => h.labels(),
Self::DistilBert(h) => h.labels(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<(String, f32)>>> {
match self {
Self::Bert(h) => h.predict(texts),
Self::Roberta(h) => h.predict(texts),
Self::DistilBert(h) => h.predict(texts),
}
}
}
impl AutoModelForTokenClassification {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
}
}
pub fn labels(&self) -> &[String] {
match self {
Self::Bert(h) => h.labels(),
Self::Roberta(h) => h.labels(),
Self::DistilBert(h) => h.labels(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn predict(
&self,
texts: &[&str],
) -> Result<Vec<Vec<crate::task_heads::TokenPrediction>>> {
match self {
Self::Bert(h) => h.predict(texts),
Self::Roberta(h) => h.predict(texts),
Self::DistilBert(h) => h.predict(texts),
}
}
}
impl AutoModelForQuestionAnswering {
pub fn graph(&self) -> &flodl::Graph {
match self {
Self::Bert(h) => h.graph(),
Self::Roberta(h) => h.graph(),
Self::DistilBert(h) => h.graph(),
}
}
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(self, tok: crate::tokenizer::HfTokenizer) -> Self {
match self {
Self::Bert(h) => Self::Bert(h.with_tokenizer(tok)),
Self::Roberta(h) => Self::Roberta(h.with_tokenizer(tok)),
Self::DistilBert(h) => Self::DistilBert(h.with_tokenizer(tok)),
}
}
#[cfg(feature = "tokenizer")]
pub fn answer(
&self,
question: &str,
context: &str,
) -> Result<crate::task_heads::Answer> {
match self {
Self::Bert(h) => h.answer(question, context),
Self::Roberta(h) => h.answer(question, context),
Self::DistilBert(h) => h.answer(question, context),
}
}
#[cfg(feature = "tokenizer")]
pub fn answer_batch(
&self,
pairs: &[(&str, &str)],
) -> Result<Vec<crate::task_heads::Answer>> {
match self {
Self::Bert(h) => h.answer_batch(pairs),
Self::Roberta(h) => h.answer_batch(pairs),
Self::DistilBert(h) => h.answer_batch(pairs),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_config_dispatches_bert() {
let json = r#"{
"model_type": "bert",
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"pad_token_id": 0
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "bert");
match c {
AutoConfig::Bert(b) => {
assert_eq!(b.vocab_size, 30522);
assert_eq!(b.hidden_size, 768);
}
other => panic!("expected Bert, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_dispatches_roberta() {
let json = r#"{
"model_type": "roberta",
"vocab_size": 50265,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 514,
"type_vocab_size": 1,
"pad_token_id": 1
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "roberta");
match c {
AutoConfig::Roberta(r) => {
assert_eq!(r.vocab_size, 50265);
assert_eq!(r.pad_token_id, 1);
}
other => panic!("expected Roberta, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_dispatches_distilbert() {
let json = r#"{
"model_type": "distilbert",
"vocab_size": 30522,
"dim": 768,
"n_layers": 6,
"n_heads": 12,
"hidden_dim": 3072,
"max_position_embeddings": 512,
"pad_token_id": 0
}"#;
let c = AutoConfig::from_json_str(json).unwrap();
assert_eq!(c.model_type(), "distilbert");
match c {
AutoConfig::DistilBert(d) => {
assert_eq!(d.vocab_size, 30522);
assert_eq!(d.n_layers, 6);
}
other => panic!("expected DistilBert, got {:?}", other.model_type()),
}
}
#[test]
fn auto_config_rejects_unknown_model_type() {
let json = r#"{
"model_type": "modernbert",
"vocab_size": 50368,
"hidden_size": 768
}"#;
let err = AutoConfig::from_json_str(json).unwrap_err().to_string();
assert!(err.contains("modernbert"), "error names offending type: {err}");
assert!(err.contains("bert"), "error lists supported: {err}");
assert!(err.contains("roberta"), "error lists supported: {err}");
assert!(err.contains("distilbert"), "error lists supported: {err}");
}
#[test]
fn auto_config_rejects_missing_model_type() {
let json = r#"{
"vocab_size": 30522,
"hidden_size": 768
}"#;
let err = AutoConfig::from_json_str(json).unwrap_err().to_string();
assert!(
err.contains("model_type"),
"error must name the missing field: {err}",
);
}
#[test]
fn auto_config_rejects_invalid_json() {
let err = AutoConfig::from_json_str("not json").unwrap_err().to_string();
assert!(err.contains("parse error"), "got: {err}");
}
}