use super::{DataCollator, DataCollatorConfig};
use crate::auto::types::{CollatedBatch, DataExample, PaddingStrategy};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LanguageModelingDataCollator {
config: LanguageModelingCollatorConfig,
}
impl LanguageModelingDataCollator {
pub fn new(config: LanguageModelingCollatorConfig) -> Self {
Self { config }
}
fn apply_mlm_masking(&self, input_ids: &[u32]) -> (Vec<u32>, Vec<i64>) {
let mut masked_ids = input_ids.to_vec();
let mut labels = vec![-100i64; input_ids.len()];
if self.config.mlm_probability > 0.0 {
for (i, &token_id) in input_ids.iter().enumerate() {
if token_id < 100 {
continue;
}
if rand::random_f32() < self.config.mlm_probability {
labels[i] = token_id as i64;
let random_value = rand::random_f32();
if random_value < 0.8 {
masked_ids[i] = self.config.mask_token_id;
} else if random_value < 0.9 {
masked_ids[i] = rand::random::<u32>() % 30000 + 1000; }
}
}
}
(masked_ids, labels)
}
}
impl DataCollator for LanguageModelingDataCollator {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch> {
if examples.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Cannot collate empty batch for language modeling".to_string(),
));
}
let batch_size = examples.len();
let max_len = match self.config.padding {
PaddingStrategy::Longest => examples
.iter()
.map(|ex| ex.input_ids.len())
.max()
.unwrap_or(0)
.min(self.config.max_length.unwrap_or(usize::MAX)),
PaddingStrategy::MaxLength => self.config.max_length.unwrap_or(512),
PaddingStrategy::DoNotPad => {
examples.iter().map(|ex| ex.input_ids.len()).max().unwrap_or(0)
},
PaddingStrategy::None => examples[0].input_ids.len(),
};
let mut input_ids = Vec::with_capacity(batch_size);
let mut attention_mask = Vec::with_capacity(batch_size);
let mut token_type_ids = Vec::new();
let mut labels = Vec::new();
let has_token_type_ids = examples.iter().any(|ex| ex.token_type_ids.is_some());
if has_token_type_ids {
token_type_ids = vec![vec![0u32; max_len]; batch_size];
}
for (batch_idx, example) in examples.iter().enumerate() {
let mut sequence_input_ids = example.input_ids.clone();
let mut sequence_attention_mask = example
.attention_mask
.clone()
.unwrap_or_else(|| vec![1u32; example.input_ids.len()]);
if self.config.truncation && sequence_input_ids.len() > max_len {
sequence_input_ids.truncate(max_len);
sequence_attention_mask.truncate(max_len);
}
let (masked_input_ids, mlm_labels) = if self.config.mlm_probability > 0.0 {
self.apply_mlm_masking(&sequence_input_ids)
} else {
(
sequence_input_ids.clone(),
example
.labels
.clone()
.unwrap_or_else(|| vec![-100i64; sequence_input_ids.len()]),
)
};
let mut padded_input_ids = masked_input_ids;
let mut padded_attention_mask = sequence_attention_mask;
let mut padded_labels = mlm_labels;
while padded_input_ids.len() < max_len {
padded_input_ids.push(self.config.pad_token_id);
padded_attention_mask.push(0);
padded_labels.push(-100); }
input_ids.push(padded_input_ids);
attention_mask.push(padded_attention_mask);
labels.push(padded_labels);
if let Some(token_types) = &example.token_type_ids {
let mut padded_token_types = token_types.clone();
if self.config.truncation && padded_token_types.len() > max_len {
padded_token_types.truncate(max_len);
}
while padded_token_types.len() < max_len {
padded_token_types.push(0);
}
token_type_ids[batch_idx] = padded_token_types;
}
}
Ok(CollatedBatch {
input_ids,
attention_mask,
token_type_ids: if has_token_type_ids { Some(token_type_ids) } else { None },
labels: Some(labels),
batch_size,
sequence_length: max_len,
metadata: HashMap::new(),
})
}
fn config(&self) -> &dyn DataCollatorConfig {
&self.config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageModelingCollatorConfig {
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: bool,
pub pad_token_id: u32,
pub mask_token_id: u32,
pub mlm_probability: f32,
}
impl LanguageModelingCollatorConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
max_length: config
.get("max_position_embeddings")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: config.get("pad_token_id").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
mask_token_id: config.get("mask_token_id").and_then(|v| v.as_u64()).unwrap_or(103)
as u32,
mlm_probability: 0.15, })
}
pub fn for_inference(config: &serde_json::Value) -> Result<Self> {
let mut mlm_config = Self::from_config(config)?;
mlm_config.mlm_probability = 0.0;
Ok(mlm_config)
}
}
impl DataCollatorConfig for LanguageModelingCollatorConfig {
fn max_length(&self) -> Option<usize> {
self.max_length
}
fn padding(&self) -> PaddingStrategy {
self.padding
}
fn truncation(&self) -> bool {
self.truncation
}
}
#[derive(Debug, Clone)]
pub struct CausalLanguageModelingDataCollator {
config: CausalLanguageModelingCollatorConfig,
}
impl CausalLanguageModelingDataCollator {
pub fn new(config: CausalLanguageModelingCollatorConfig) -> Self {
Self { config }
}
fn create_causal_labels(&self, input_ids: &[u32], attention_mask: &[u32]) -> Vec<i64> {
let mut labels = Vec::with_capacity(input_ids.len());
for i in 0..input_ids.len() {
if i < input_ids.len() - 1 && attention_mask[i + 1] == 1 {
labels.push(input_ids[i + 1] as i64);
} else {
labels.push(-100);
}
}
labels
}
}
impl DataCollator for CausalLanguageModelingDataCollator {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch> {
if examples.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Cannot collate empty batch for causal language modeling".to_string(),
));
}
let batch_size = examples.len();
let max_len = match self.config.padding {
PaddingStrategy::Longest => examples
.iter()
.map(|ex| ex.input_ids.len())
.max()
.unwrap_or(0)
.min(self.config.max_length.unwrap_or(usize::MAX)),
PaddingStrategy::MaxLength => self.config.max_length.unwrap_or(1024),
PaddingStrategy::DoNotPad => {
examples.iter().map(|ex| ex.input_ids.len()).max().unwrap_or(0)
},
PaddingStrategy::None => examples[0].input_ids.len(),
};
let mut input_ids = Vec::with_capacity(batch_size);
let mut attention_mask = Vec::with_capacity(batch_size);
let mut labels = Vec::with_capacity(batch_size);
for example in examples {
let mut sequence_input_ids = example.input_ids.clone();
let mut sequence_attention_mask = example
.attention_mask
.clone()
.unwrap_or_else(|| vec![1u32; example.input_ids.len()]);
if self.config.truncation && sequence_input_ids.len() > max_len {
sequence_input_ids.truncate(max_len);
sequence_attention_mask.truncate(max_len);
}
while sequence_input_ids.len() < max_len {
sequence_input_ids.push(self.config.pad_token_id);
sequence_attention_mask.push(0);
}
let sequence_labels =
self.create_causal_labels(&sequence_input_ids, &sequence_attention_mask);
input_ids.push(sequence_input_ids);
attention_mask.push(sequence_attention_mask);
labels.push(sequence_labels);
}
Ok(CollatedBatch {
input_ids,
attention_mask,
token_type_ids: None, labels: Some(labels),
batch_size,
sequence_length: max_len,
metadata: HashMap::new(),
})
}
fn config(&self) -> &dyn DataCollatorConfig {
&self.config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CausalLanguageModelingCollatorConfig {
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: bool,
pub pad_token_id: u32,
}
impl CausalLanguageModelingCollatorConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
max_length: config
.get("max_position_embeddings")
.or_else(|| config.get("n_positions"))
.and_then(|v| v.as_u64())
.map(|v| v as usize),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: config
.get("pad_token_id")
.or_else(|| config.get("eos_token_id")) .and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
})
}
pub fn for_generation(config: &serde_json::Value) -> Result<Self> {
let mut gen_config = Self::from_config(config)?;
gen_config.truncation = false;
Ok(gen_config)
}
}
impl DataCollatorConfig for CausalLanguageModelingCollatorConfig {
fn max_length(&self) -> Option<usize> {
self.max_length
}
fn padding(&self) -> PaddingStrategy {
self.padding
}
fn truncation(&self) -> bool {
self.truncation
}
}
mod rand {
static mut SEED: u64 = 1;
pub fn random<T>() -> T
where
T: From<u32>,
{
unsafe {
SEED = SEED.wrapping_mul(1103515245).wrapping_add(12345);
T::from((SEED >> 16) as u32)
}
}
pub fn random_f32() -> f32 {
unsafe {
SEED = SEED.wrapping_mul(1103515245).wrapping_add(12345);
((SEED >> 16) as u32) as f32 / (u32::MAX as f32)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::types::{DataExample, PaddingStrategy};
use std::collections::HashMap;
const CLS_TOKEN: u32 = 101;
const SEP_TOKEN: u32 = 102;
const PAD_TOKEN: u32 = 0;
const MASK_TOKEN: u32 = 103;
const GPT_PAD_TOKEN: u32 = 50256;
fn make_mlm_config(mlm_probability: f32) -> LanguageModelingCollatorConfig {
LanguageModelingCollatorConfig {
max_length: Some(512),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: PAD_TOKEN,
mask_token_id: MASK_TOKEN,
mlm_probability,
}
}
fn make_clm_config() -> CausalLanguageModelingCollatorConfig {
CausalLanguageModelingCollatorConfig {
max_length: Some(1024),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: GPT_PAD_TOKEN,
}
}
fn make_bert_example(tokens: Vec<u32>) -> DataExample {
let len = tokens.len();
DataExample {
input_ids: tokens,
attention_mask: Some(vec![1u32; len]),
token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}
}
fn make_gpt_example(tokens: Vec<u32>) -> DataExample {
let len = tokens.len();
DataExample {
input_ids: tokens,
attention_mask: Some(vec![1u32; len]),
token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}
}
#[test]
fn test_mlm_collator_empty_batch_returns_error() {
let config = make_mlm_config(0.0);
let collator = LanguageModelingDataCollator::new(config);
let result = collator.collate(&[]);
assert!(result.is_err(), "collating empty batch should return Err");
}
#[test]
fn test_mlm_collator_batch_size_correct() {
let config = make_mlm_config(0.0);
let collator = LanguageModelingDataCollator::new(config);
let examples = vec![
make_bert_example(vec![CLS_TOKEN, 1000, 2000, SEP_TOKEN]),
make_bert_example(vec![CLS_TOKEN, 3000, SEP_TOKEN]),
];
let batch = collator.collate(&examples).expect("collate should succeed");
assert_eq!(
batch.batch_size, 2,
"batch_size should equal number of examples"
);
assert_eq!(
batch.input_ids.len(),
2,
"input_ids should have batch_size rows"
);
}
#[test]
fn test_mlm_collator_padding_to_longest() {
let config = make_mlm_config(0.0);
let collator = LanguageModelingDataCollator::new(config);
let short = make_bert_example(vec![CLS_TOKEN, 1000, SEP_TOKEN]); let long = make_bert_example(vec![CLS_TOKEN, 1000, 2000, 3000, SEP_TOKEN]); let batch = collator.collate(&[short, long]).expect("collate should succeed");
assert_eq!(
batch.sequence_length, 5,
"sequence_length should be the longest in batch"
);
assert_eq!(
batch.input_ids[0].len(),
5,
"short sequence should be padded to 5"
);
assert_eq!(
batch.input_ids[1].len(),
5,
"long sequence should remain at 5"
);
}
#[test]
fn test_mlm_collator_padding_adds_pad_token() {
let config = make_mlm_config(0.0);
let collator = LanguageModelingDataCollator::new(config);
let short = make_bert_example(vec![CLS_TOKEN, 1000, SEP_TOKEN]); let long = make_bert_example(vec![CLS_TOKEN, 1000, 2000, 3000, SEP_TOKEN]); let batch = collator.collate(&[short, long]).expect("collate should succeed");
assert_eq!(
batch.input_ids[0][3], PAD_TOKEN,
"position 3 of short sequence should be PAD_TOKEN"
);
assert_eq!(
batch.input_ids[0][4], PAD_TOKEN,
"position 4 of short sequence should be PAD_TOKEN"
);
}
#[test]
fn test_mlm_collator_attention_mask_zeros_for_padding() {
let config = make_mlm_config(0.0);
let collator = LanguageModelingDataCollator::new(config);
let short = make_bert_example(vec![CLS_TOKEN, 1000, SEP_TOKEN]); let long = make_bert_example(vec![CLS_TOKEN, 1000, 2000, 3000, SEP_TOKEN]); let batch = collator.collate(&[short, long]).expect("collate should succeed");
assert_eq!(
batch.attention_mask[0][3], 0,
"padded position attention_mask should be 0"
);
assert_eq!(
batch.attention_mask[0][4], 0,
"padded position attention_mask should be 0"
);
assert_eq!(
batch.attention_mask[0][0], 1,
"CLS position should have attention_mask 1"
);
}
#[test]
fn test_mlm_collator_labels_present_when_no_masking() {
let config = make_mlm_config(0.0); let collator = LanguageModelingDataCollator::new(config);
let examples = vec![make_bert_example(vec![CLS_TOKEN, 1000, 2000, SEP_TOKEN])];
let batch = collator.collate(&examples).expect("collate should succeed");
assert!(
batch.labels.is_some(),
"labels should be present in the collated batch"
);
}
#[test]
fn test_mlm_collator_truncation() {
let config = LanguageModelingCollatorConfig {
max_length: Some(4),
padding: PaddingStrategy::MaxLength,
truncation: true,
pad_token_id: PAD_TOKEN,
mask_token_id: MASK_TOKEN,
mlm_probability: 0.0,
};
let collator = LanguageModelingDataCollator::new(config);
let long = make_bert_example(vec![CLS_TOKEN, 1000, 2000, 3000, 4000, SEP_TOKEN]); let batch = collator.collate(&[long]).expect("collate should succeed");
assert_eq!(
batch.sequence_length, 4,
"sequence should be truncated to max_length=4"
);
assert_eq!(
batch.input_ids[0].len(),
4,
"truncated input_ids should have length 4"
);
}
#[test]
fn test_mlm_collator_output_shape_consistency() {
let config = make_mlm_config(0.0);
let collator = LanguageModelingDataCollator::new(config);
let examples = vec![
make_bert_example(vec![CLS_TOKEN, 1000, 2000, SEP_TOKEN]),
make_bert_example(vec![CLS_TOKEN, 3000, 4000, 5000, 6000, SEP_TOKEN]),
];
let batch = collator.collate(&examples).expect("collate should succeed");
let seq_len = batch.sequence_length;
for (i, ids) in batch.input_ids.iter().enumerate() {
assert_eq!(
ids.len(),
seq_len,
"row {} of input_ids should have length {}",
i,
seq_len
);
}
for (i, mask) in batch.attention_mask.iter().enumerate() {
assert_eq!(
mask.len(),
seq_len,
"row {} of attention_mask should have length {}",
i,
seq_len
);
}
}
#[test]
fn test_mlm_collator_special_tokens_never_masked() {
let config = make_mlm_config(0.99); let collator = LanguageModelingDataCollator::new(config);
for _ in 0..5 {
let examples = vec![make_bert_example(vec![CLS_TOKEN, 1000, 2000, SEP_TOKEN])];
let batch = collator.collate(&examples).expect("collate should succeed");
let labels = batch.labels.as_ref().expect("labels should be present");
let _ = labels[0][0]; }
}
#[test]
fn test_clm_collator_empty_batch_returns_error() {
let config = make_clm_config();
let collator = CausalLanguageModelingDataCollator::new(config);
let result = collator.collate(&[]);
assert!(result.is_err(), "collating empty batch should return Err");
}
#[test]
fn test_clm_collator_labels_shifted_left() {
let config = make_clm_config();
let collator = CausalLanguageModelingDataCollator::new(config);
let tokens = vec![100u32, 200, 300, 400, 500];
let example = make_gpt_example(tokens.clone());
let batch = collator.collate(&[example]).expect("collate should succeed");
let labels = batch.labels.as_ref().expect("CLM should produce labels");
assert_eq!(
labels[0][0], 200,
"first label should be second input token (shifted left)"
);
assert_eq!(
labels[0][1], 300,
"second label should be third input token"
);
assert_eq!(
labels[0][4], -100,
"last label should be -100 (no next token)"
);
}
#[test]
fn test_clm_collator_batch_output_shape() {
let config = make_clm_config();
let collator = CausalLanguageModelingDataCollator::new(config);
let examples = vec![
make_gpt_example(vec![100u32, 200, 300]),
make_gpt_example(vec![400u32, 500, 600, 700]),
];
let batch = collator.collate(&examples).expect("collate should succeed");
assert_eq!(batch.batch_size, 2);
assert_eq!(
batch.sequence_length, 4,
"sequence_length should be the longest"
);
assert_eq!(batch.input_ids.len(), 2);
}
#[test]
fn test_clm_collator_padding_adds_pad_token() {
let config = make_clm_config();
let collator = CausalLanguageModelingDataCollator::new(config);
let short = make_gpt_example(vec![100u32, 200, 300]);
let long = make_gpt_example(vec![400u32, 500, 600, 700, 800]);
let batch = collator.collate(&[short, long]).expect("collate should succeed");
assert_eq!(
batch.input_ids[0][3], GPT_PAD_TOKEN,
"padded positions should use GPT_PAD_TOKEN"
);
assert_eq!(batch.input_ids[0][4], GPT_PAD_TOKEN);
}
#[test]
fn test_clm_collator_no_token_type_ids() {
let config = make_clm_config();
let collator = CausalLanguageModelingDataCollator::new(config);
let examples = vec![make_gpt_example(vec![100u32, 200, 300])];
let batch = collator.collate(&examples).expect("collate should succeed");
assert!(
batch.token_type_ids.is_none(),
"CLM collator should not produce token_type_ids"
);
}
#[test]
fn test_mlm_config_from_config_json() {
let model_config = serde_json::json!({
"max_position_embeddings": 512,
"pad_token_id": 0,
"mask_token_id": 103
});
let config = LanguageModelingCollatorConfig::from_config(&model_config)
.expect("from_config should succeed");
assert_eq!(config.max_length, Some(512));
assert_eq!(config.pad_token_id, 0);
assert!(
(config.mlm_probability - 0.15).abs() < 1e-3,
"default mlm_probability should be 0.15, got {}",
config.mlm_probability
);
}
#[test]
fn test_mlm_config_for_inference_disables_masking() {
let model_config = serde_json::json!({"max_position_embeddings": 512, "pad_token_id": 0});
let config = LanguageModelingCollatorConfig::for_inference(&model_config)
.expect("for_inference should succeed");
assert_eq!(
config.mlm_probability, 0.0,
"inference config should have mlm_probability == 0.0"
);
}
#[test]
fn test_clm_config_from_config_json() {
let model_config = serde_json::json!({
"max_position_embeddings": 1024,
"pad_token_id": 50256
});
let config = CausalLanguageModelingCollatorConfig::from_config(&model_config)
.expect("from_config should succeed");
assert_eq!(config.max_length, Some(1024));
assert_eq!(config.pad_token_id, 50256);
assert!(config.truncation, "default truncation should be enabled");
}
#[test]
fn test_clm_config_for_generation_disables_truncation() {
let model_config = serde_json::json!({"max_position_embeddings": 1024, "pad_token_id": 0});
let config = CausalLanguageModelingCollatorConfig::for_generation(&model_config)
.expect("for_generation should succeed");
assert!(
!config.truncation,
"generation config should have truncation == false"
);
}
}