use flodl::nn::{Dropout, GeluApprox, Linear};
use flodl::{DType, Device, Graph, Result, TensorError, Variable};
use crate::models::roberta::{
roberta_backbone_flow, roberta_masked_lm_graph, RobertaClassificationHead, RobertaConfig,
RobertaModel,
};
#[derive(Debug, Clone)]
pub struct XlmRobertaConfig {
pub vocab_size: i64,
pub hidden_size: i64,
pub num_hidden_layers: i64,
pub num_attention_heads: i64,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub type_vocab_size: i64,
pub pad_token_id: i64,
pub layer_norm_eps: f64,
pub hidden_dropout_prob: f64,
pub attention_probs_dropout_prob: f64,
pub hidden_act: GeluApprox,
pub num_labels: Option<i64>,
pub id2label: Option<Vec<String>>,
pub architectures: Option<Vec<String>>,
}
impl XlmRobertaConfig {
pub fn xlm_roberta_base() -> Self {
XlmRobertaConfig {
vocab_size: 250_002,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
max_position_embeddings: 514,
type_vocab_size: 1,
pad_token_id: 1,
layer_norm_eps: 1e-5,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
hidden_act: GeluApprox::Exact,
num_labels: None,
id2label: None,
architectures: None,
}
}
pub fn from_json_str(s: &str) -> Result<Self> {
use crate::config_json::{
optional_f64, optional_hidden_act, optional_i64, parse_architectures, parse_id2label,
parse_num_labels, required_i64,
};
let v: serde_json::Value = serde_json::from_str(s)
.map_err(|e| TensorError::new(&format!("config.json parse error: {e}")))?;
let id2label = parse_id2label(&v)?;
let num_labels = parse_num_labels(&v, id2label.as_deref());
let architectures = parse_architectures(&v);
Ok(XlmRobertaConfig {
vocab_size: required_i64(&v, "vocab_size")?,
hidden_size: required_i64(&v, "hidden_size")?,
num_hidden_layers: required_i64(&v, "num_hidden_layers")?,
num_attention_heads: required_i64(&v, "num_attention_heads")?,
intermediate_size: required_i64(&v, "intermediate_size")?,
max_position_embeddings: required_i64(&v, "max_position_embeddings")?,
type_vocab_size: optional_i64(&v, "type_vocab_size", 1),
pad_token_id: optional_i64(&v, "pad_token_id", 1),
layer_norm_eps: optional_f64(&v, "layer_norm_eps", 1e-5),
hidden_dropout_prob: optional_f64(&v, "hidden_dropout_prob", 0.1),
attention_probs_dropout_prob: optional_f64(&v, "attention_probs_dropout_prob", 0.1),
hidden_act: optional_hidden_act(&v, "hidden_act", "gelu")?,
num_labels,
id2label,
architectures,
})
}
pub fn with_architectures(mut self, arch_class: &str) -> Self {
self.architectures = Some(vec![arch_class.to_string()]);
self
}
pub fn to_json_str(&self) -> String {
use crate::config_json::{emit_architectures, emit_hidden_act, emit_id2label};
let mut m = serde_json::Map::new();
m.insert("model_type".into(), "xlm-roberta".into());
m.insert(
"architectures".into(),
emit_architectures(self.architectures.as_deref(), "XLMRobertaModel"),
);
m.insert("vocab_size".into(), self.vocab_size.into());
m.insert("hidden_size".into(), self.hidden_size.into());
m.insert("num_hidden_layers".into(), self.num_hidden_layers.into());
m.insert("num_attention_heads".into(), self.num_attention_heads.into());
m.insert("intermediate_size".into(), self.intermediate_size.into());
m.insert(
"max_position_embeddings".into(),
self.max_position_embeddings.into(),
);
m.insert("type_vocab_size".into(), self.type_vocab_size.into());
m.insert("pad_token_id".into(), self.pad_token_id.into());
m.insert("layer_norm_eps".into(), self.layer_norm_eps.into());
m.insert("hidden_dropout_prob".into(), self.hidden_dropout_prob.into());
m.insert(
"attention_probs_dropout_prob".into(),
self.attention_probs_dropout_prob.into(),
);
m.insert("hidden_act".into(), emit_hidden_act(self.hidden_act).into());
emit_id2label(&mut m, self.id2label.as_deref());
if let Some(n) = self.num_labels {
m.insert("num_labels".into(), n.into());
}
serde_json::to_string_pretty(&serde_json::Value::Object(m))
.expect("serde_json::Map serialization is infallible")
}
}
impl From<&XlmRobertaConfig> for RobertaConfig {
fn from(c: &XlmRobertaConfig) -> Self {
RobertaConfig {
vocab_size: c.vocab_size,
hidden_size: c.hidden_size,
num_hidden_layers: c.num_hidden_layers,
num_attention_heads: c.num_attention_heads,
intermediate_size: c.intermediate_size,
max_position_embeddings: c.max_position_embeddings,
type_vocab_size: c.type_vocab_size,
pad_token_id: c.pad_token_id,
layer_norm_eps: c.layer_norm_eps,
hidden_dropout_prob: c.hidden_dropout_prob,
attention_probs_dropout_prob: c.attention_probs_dropout_prob,
hidden_act: c.hidden_act,
num_labels: c.num_labels,
id2label: c.id2label.clone(),
architectures: c.architectures.clone(),
}
}
}
pub struct XlmRobertaModel;
impl XlmRobertaModel {
pub fn build(config: &XlmRobertaConfig) -> Result<Graph> {
Self::on_device(config, Device::CPU)
}
pub fn on_device(config: &XlmRobertaConfig, device: Device) -> Result<Graph> {
let rc: RobertaConfig = config.into();
RobertaModel::on_device(&rc, device)
}
pub fn on_device_without_pooler(
config: &XlmRobertaConfig,
device: Device,
) -> Result<Graph> {
let rc: RobertaConfig = config.into();
RobertaModel::on_device_without_pooler(&rc, device)
}
}
use crate::task_heads::{
check_num_labels, ClassificationHead, EncoderInputs, MaskedLmHead, QaHead, TaggingHead,
};
pub use crate::task_heads::{Answer, TokenPrediction};
#[cfg(feature = "tokenizer")]
impl EncoderInputs for XlmRobertaConfig {
const FAMILY_NAME: &'static str = "XlmRoberta";
const MASK_TOKEN: &'static str = "<mask>";
fn encoder_inputs(enc: &crate::tokenizer::EncodedBatch) -> Result<Vec<Variable>> {
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(
crate::models::bert::build_extended_attention_mask(&mask_f32)?,
false,
);
Ok(vec![
enc.input_ids.clone(),
enc.token_type_ids.clone(),
mask,
])
}
}
pub type XlmRobertaForSequenceClassification = ClassificationHead<XlmRobertaConfig>;
impl ClassificationHead<XlmRobertaConfig> {
pub fn on_device(
config: &XlmRobertaConfig,
num_labels: i64,
device: Device,
) -> Result<Self> {
let num_labels = check_num_labels(num_labels)?;
let rc: RobertaConfig = config.into();
let graph = roberta_backbone_flow(&rc, device, false)?
.through(RobertaClassificationHead::on_device(&rc, num_labels, device)?)
.tag("classifier")
.build()?;
Ok(Self::from_graph(graph, config, num_labels, config.id2label.clone()))
}
pub(crate) fn num_labels_from_config(config: &XlmRobertaConfig) -> Result<i64> {
config.num_labels.ok_or_else(|| {
TensorError::new(
"XlmRobertaForSequenceClassification: config.json has no `num_labels` \
(nor `id2label`); cannot infer head size",
)
})
}
}
pub type XlmRobertaForTokenClassification = TaggingHead<XlmRobertaConfig>;
impl TaggingHead<XlmRobertaConfig> {
pub fn on_device(
config: &XlmRobertaConfig,
num_labels: i64,
device: Device,
) -> Result<Self> {
let num_labels = check_num_labels(num_labels)?;
let rc: RobertaConfig = config.into();
let graph = roberta_backbone_flow(&rc, device, false)?
.through(Dropout::new(config.hidden_dropout_prob))
.through(Linear::on_device(config.hidden_size, num_labels, device)?)
.tag("classifier")
.build()?;
Ok(Self::from_graph(graph, config, num_labels, config.id2label.clone()))
}
pub(crate) fn num_labels_from_config(config: &XlmRobertaConfig) -> Result<i64> {
config.num_labels.ok_or_else(|| {
TensorError::new(
"XlmRobertaForTokenClassification: config.json has no `num_labels` \
(nor `id2label`); cannot infer head size",
)
})
}
}
pub type XlmRobertaForQuestionAnswering = QaHead<XlmRobertaConfig>;
impl QaHead<XlmRobertaConfig> {
pub fn on_device(config: &XlmRobertaConfig, device: Device) -> Result<Self> {
let rc: RobertaConfig = config.into();
let graph = roberta_backbone_flow(&rc, device, false)?
.through(Linear::on_device(config.hidden_size, 2, device)?)
.tag("qa_outputs")
.build()?;
Ok(Self::from_graph(graph, config))
}
}
pub type XlmRobertaForMaskedLM = MaskedLmHead<XlmRobertaConfig>;
impl MaskedLmHead<XlmRobertaConfig> {
pub fn on_device(config: &XlmRobertaConfig, device: Device) -> Result<Self> {
let rc: RobertaConfig = config.into();
let graph = roberta_masked_lm_graph(&rc, device)?;
Ok(Self::from_graph(graph, config))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::safetensors_io::expected_from_graph;
#[test]
fn xlm_roberta_config_to_json_str_round_trip() {
let preset = XlmRobertaConfig::xlm_roberta_base();
let s = preset.to_json_str();
let recovered = XlmRobertaConfig::from_json_str(&s).unwrap();
assert_eq!(preset.to_json_str(), recovered.to_json_str());
let v: serde_json::Value = serde_json::from_str(&s).unwrap();
assert_eq!(
v.get("model_type").and_then(|x| x.as_str()),
Some("xlm-roberta"),
);
}
#[test]
fn xlm_roberta_config_conversion_preserves_shape_fields() {
let c = XlmRobertaConfig::xlm_roberta_base();
let rc: RobertaConfig = (&c).into();
assert_eq!(rc.vocab_size, c.vocab_size);
assert_eq!(rc.hidden_size, c.hidden_size);
assert_eq!(rc.num_hidden_layers, c.num_hidden_layers);
assert_eq!(rc.num_attention_heads, c.num_attention_heads);
assert_eq!(rc.intermediate_size, c.intermediate_size);
assert_eq!(rc.max_position_embeddings, c.max_position_embeddings);
assert_eq!(rc.type_vocab_size, c.type_vocab_size);
assert_eq!(rc.pad_token_id, c.pad_token_id);
assert!((rc.layer_norm_eps - c.layer_norm_eps).abs() < 1e-12);
assert!((rc.hidden_dropout_prob - c.hidden_dropout_prob).abs() < 1e-12);
assert!(
(rc.attention_probs_dropout_prob - c.attention_probs_dropout_prob).abs() < 1e-12,
);
}
#[test]
fn xlm_roberta_base_preset_matches_hf_defaults() {
let c = XlmRobertaConfig::xlm_roberta_base();
assert_eq!(c.vocab_size, 250_002);
assert_eq!(c.hidden_size, 768);
assert_eq!(c.num_hidden_layers, 12);
assert_eq!(c.max_position_embeddings, 514);
assert_eq!(c.pad_token_id, 1);
}
#[test]
fn xlm_roberta_config_from_json_parses_base() {
let json = r#"{
"model_type": "xlm-roberta",
"vocab_size": 250002,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 514,
"type_vocab_size": 1,
"pad_token_id": 1
}"#;
let c = XlmRobertaConfig::from_json_str(json).unwrap();
assert_eq!(c.vocab_size, 250_002);
assert_eq!(c.hidden_size, 768);
assert_eq!(c.pad_token_id, 1);
}
#[test]
fn xlm_roberta_backbone_emits_roberta_prefix() {
let config = XlmRobertaConfig::xlm_roberta_base();
let graph = XlmRobertaModel::on_device_without_pooler(&config, Device::CPU).unwrap();
let expected = expected_from_graph(&graph);
let keys: Vec<&str> = expected.iter().map(|p| p.key.as_str()).collect();
assert!(
keys.contains(&"roberta.embeddings.word_embeddings.weight"),
"expected roberta.embeddings.word_embeddings.weight, got {keys:?}",
);
assert!(
keys.iter().any(|k| k.starts_with("roberta.encoder.layer.0.attention.self.query.")),
"expected roberta.encoder.* layer keys, got {keys:?}",
);
assert!(
!keys.iter().any(|k| k.starts_with("xlm_roberta.")),
"no keys should use an xlm_roberta.* prefix (HF uses roberta.*): {keys:?}",
);
}
#[test]
fn xlm_roberta_masked_lm_keeps_tied_weight_dedup() {
let config = XlmRobertaConfig::xlm_roberta_base();
let head = XlmRobertaForMaskedLM::on_device(&config, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let keys: Vec<&str> = expected.iter().map(|p| p.key.as_str()).collect();
assert!(
keys.contains(&"roberta.embeddings.word_embeddings.weight"),
"tied weight must surface under roberta.embeddings tag: {keys:?}",
);
assert!(
!keys.contains(&"lm_head.decoder.weight"),
"lm_head.decoder.weight must be absent (tied, dedup kept embeddings entry)",
);
let named = head.graph().named_parameters();
let vocab_shaped = named
.iter()
.filter(|(_, p)| p.variable.shape() == vec![config.vocab_size, config.hidden_size])
.count();
assert_eq!(
vocab_shaped, 1,
"exactly one [V, H]-shaped Parameter expected under tying",
);
}
#[test]
fn xlm_roberta_seqcls_head_has_two_layer_keys() {
let config = XlmRobertaConfig::xlm_roberta_base();
let head = XlmRobertaForSequenceClassification::on_device(
&config, 3, Device::CPU,
)
.unwrap();
let expected = expected_from_graph(head.graph());
let keys: Vec<String> = expected.iter().map(|p| p.key.clone()).collect();
assert!(keys.contains(&"classifier.dense.weight".to_string()));
assert!(keys.contains(&"classifier.dense.bias".to_string()));
assert!(keys.contains(&"classifier.out_proj.weight".to_string()));
assert!(keys.contains(&"classifier.out_proj.bias".to_string()));
assert!(!keys.iter().any(|k| k == "classifier.weight"));
}
}