use super::{DataCollator, DataCollatorConfig};
use crate::auto::data_collators::language_modeling::{
LanguageModelingCollatorConfig, LanguageModelingDataCollator,
};
use crate::auto::types::{CollatedBatch, DataExample, PaddingStrategy};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DefaultDataCollator {
config: DefaultCollatorConfig,
}
impl DefaultDataCollator {
pub fn new(config: DefaultCollatorConfig) -> Self {
Self { config }
}
fn process_generic_labels(&self, examples: &[DataExample]) -> Option<Vec<Vec<i64>>> {
let has_labels = examples.iter().any(|ex| ex.labels.is_some());
if !has_labels {
return None;
}
let mut processed_labels = Vec::with_capacity(examples.len());
for example in examples {
if let Some(ref example_labels) = example.labels {
processed_labels.push(example_labels.clone());
} else {
processed_labels.push(vec![-100]); }
}
Some(processed_labels)
}
fn create_basic_metadata(
&self,
examples: &[DataExample],
batch_size: usize,
sequence_length: usize,
) -> HashMap<String, serde_json::Value> {
let mut metadata = HashMap::new();
metadata.insert(
"collator_type".to_string(),
serde_json::Value::String("default".to_string()),
);
let has_token_types = examples.iter().any(|ex| ex.token_type_ids.is_some());
metadata.insert(
"has_token_type_ids".to_string(),
serde_json::Value::Bool(has_token_types),
);
let has_labels = examples.iter().any(|ex| ex.labels.is_some());
metadata.insert(
"has_labels".to_string(),
serde_json::Value::Bool(has_labels),
);
let input_lengths: Vec<usize> = examples.iter().map(|ex| ex.input_ids.len()).collect();
if !input_lengths.is_empty() {
let min_length =
*input_lengths.iter().min().expect("input_lengths checked as non-empty");
let max_length =
*input_lengths.iter().max().expect("input_lengths checked as non-empty");
let avg_length =
input_lengths.iter().sum::<usize>() as f64 / input_lengths.len() as f64;
metadata.insert(
"original_sequence_stats".to_string(),
serde_json::json!({
"min_length": min_length,
"max_length": max_length,
"avg_length": avg_length,
"total_sequences": input_lengths.len()
}),
);
}
metadata.insert(
"padding_strategy".to_string(),
serde_json::Value::String(format!("{:?}", self.config.padding)),
);
metadata.insert(
"truncation_enabled".to_string(),
serde_json::Value::Bool(self.config.truncation),
);
metadata
}
fn validate_examples(&self, examples: &[DataExample]) -> Result<()> {
if examples.is_empty() {
return Err(TrustformersError::invalid_input(
"Cannot collate empty batch".to_string(),
Some("examples".to_string()),
Some("non-empty batch".to_string()),
Some("empty batch".to_string()),
));
}
for (i, example) in examples.iter().enumerate() {
if example.input_ids.is_empty() {
return Err(TrustformersError::invalid_input(
format!("Empty input_ids in example {}", i),
Some("input_ids"),
Some("non-empty input_ids"),
Some("empty input_ids"),
));
}
if let Some(ref attention_mask) = example.attention_mask {
if attention_mask.len() != example.input_ids.len() {
return Err(TrustformersError::invalid_input( format!("Attention mask length {} doesn't match input_ids length {} in example {}", attention_mask.len(), example.input_ids.len(), i),
Some("attention_mask"),
Some(format!("length {}", example.input_ids.len())),
Some(format!("length {}", attention_mask.len()))
));
}
}
if let Some(ref token_type_ids) = example.token_type_ids {
if token_type_ids.len() != example.input_ids.len() {
return Err(TrustformersError::invalid_input( format!("Token type IDs length {} doesn't match input_ids length {} in example {}", token_type_ids.len(), example.input_ids.len(), i),
Some("token_type_ids"),
Some(format!("length {}", example.input_ids.len())),
Some(format!("length {}", token_type_ids.len()))
));
}
}
}
Ok(())
}
}
impl DataCollator for DefaultDataCollator {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch> {
self.validate_examples(examples)?;
let batch_size = examples.len();
let base_collator = LanguageModelingDataCollator::new(LanguageModelingCollatorConfig {
max_length: self.config.max_length,
padding: self.config.padding,
truncation: self.config.truncation,
pad_token_id: self.config.pad_token_id,
mask_token_id: 0, mlm_probability: 0.0,
});
let mut batch = base_collator.collate(examples)?;
let processed_labels = self.process_generic_labels(examples);
batch.labels = processed_labels;
let basic_metadata =
self.create_basic_metadata(examples, batch_size, batch.sequence_length);
for (key, value) in basic_metadata {
batch.metadata.insert(key, value);
}
let mut example_metadata = HashMap::new();
for (i, example) in examples.iter().enumerate() {
if !example.metadata.is_empty() {
example_metadata.insert(
format!("example_{}_metadata", i),
serde_json::to_value(&example.metadata).unwrap_or(serde_json::Value::Null),
);
}
}
if !example_metadata.is_empty() {
batch.metadata.insert(
"original_metadata".to_string(),
serde_json::to_value(example_metadata).unwrap_or(serde_json::Value::Null),
);
}
Ok(batch)
}
fn config(&self) -> &dyn DataCollatorConfig {
&self.config
}
fn preprocess_examples(&self, examples: &[DataExample]) -> Result<Vec<DataExample>> {
self.validate_examples(examples)?;
Ok(examples.to_vec())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DefaultCollatorConfig {
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: bool,
pub pad_token_id: u32,
}
impl DefaultCollatorConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
max_length: config
.get("max_position_embeddings")
.or_else(|| config.get("max_length"))
.or_else(|| config.get("model_max_length"))
.and_then(|v| v.as_u64())
.map(|v| v as usize),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: config.get("pad_token_id").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
})
}
pub fn minimal(pad_token_id: u32) -> Self {
Self {
max_length: None,
padding: PaddingStrategy::Longest,
truncation: false,
pad_token_id,
}
}
pub fn for_inference(config: &serde_json::Value) -> Result<Self> {
let mut inference_config = Self::from_config(config)?;
inference_config.truncation = false;
Ok(inference_config)
}
pub fn for_development(config: &serde_json::Value) -> Result<Self> {
let mut dev_config = Self::from_config(config)?;
if let Some(max_len) = dev_config.max_length {
dev_config.max_length = Some(max_len.min(128));
} else {
dev_config.max_length = Some(128);
}
Ok(dev_config)
}
pub fn with_max_length(config: &serde_json::Value, max_length: usize) -> Result<Self> {
let mut length_config = Self::from_config(config)?;
length_config.max_length = Some(max_length);
length_config.truncation = true;
Ok(length_config)
}
}
impl DataCollatorConfig for DefaultCollatorConfig {
fn max_length(&self) -> Option<usize> {
self.max_length
}
fn padding(&self) -> PaddingStrategy {
self.padding
}
fn truncation(&self) -> bool {
self.truncation
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_default_collator_creation() {
let config = DefaultCollatorConfig {
max_length: Some(128),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = DefaultDataCollator::new(config);
assert_eq!(collator.config().max_length(), Some(128));
assert_eq!(collator.config().padding(), PaddingStrategy::Longest);
assert!(collator.config().truncation());
}
#[test]
fn test_default_config_from_json() {
let config_json = serde_json::json!({
"max_position_embeddings": 512,
"pad_token_id": 1,
"vocab_size": 30522
});
let config =
DefaultCollatorConfig::from_config(&config_json).expect("operation failed in test");
assert_eq!(config.max_length, Some(512));
assert_eq!(config.pad_token_id, 1);
assert!(config.truncation);
assert_eq!(config.padding, PaddingStrategy::Longest);
}
#[test]
fn test_collate_basic_examples() {
let config = DefaultCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = DefaultDataCollator::new(config);
let examples = vec![
DataExample {
input_ids: vec![1, 2, 3, 4],
attention_mask: Some(vec![1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![1, 0, 1]),
metadata: HashMap::new(),
},
DataExample {
input_ids: vec![5, 6],
attention_mask: Some(vec![1, 1]),
token_type_ids: None,
labels: Some(vec![0]),
metadata: HashMap::new(),
},
];
let batch = collator.collate(&examples).expect("operation failed in test");
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.input_ids.len(), 2);
assert!(batch.labels.is_some());
let labels = batch.labels.as_ref().expect("operation failed in test");
assert_eq!(labels.len(), 2);
assert_eq!(labels[0], vec![1, 0, 1]);
assert_eq!(labels[1], vec![0]);
}
#[test]
fn test_examples_without_labels() {
let config = DefaultCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = DefaultDataCollator::new(config);
let examples = vec![DataExample {
input_ids: vec![1, 2, 3],
attention_mask: Some(vec![1, 1, 1]),
token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}];
let batch = collator.collate(&examples).expect("operation failed in test");
assert_eq!(batch.batch_size, 1);
assert!(batch.labels.is_none());
}
#[test]
fn test_metadata_creation() {
let config = DefaultCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = DefaultDataCollator::new(config);
let examples = vec![DataExample {
input_ids: vec![1, 2, 3, 4, 5],
attention_mask: Some(vec![1, 1, 1, 1, 1]),
token_type_ids: Some(vec![0, 0, 1, 1, 1]),
labels: Some(vec![1]),
metadata: {
let mut meta = HashMap::new();
meta.insert(
"source".to_string(),
serde_json::Value::String("test".to_string()),
);
meta
},
}];
let metadata = collator.create_basic_metadata(&examples, 1, 5);
assert_eq!(
metadata.get("collator_type").expect("expected value not found"),
"default"
);
assert_eq!(
metadata.get("has_token_type_ids").expect("expected value not found"),
&serde_json::Value::Bool(true)
);
assert_eq!(
metadata.get("has_labels").expect("expected value not found"),
&serde_json::Value::Bool(true)
);
assert!(metadata.contains_key("original_sequence_stats"));
}
#[test]
fn test_example_validation() {
let config = DefaultCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = DefaultDataCollator::new(config);
let valid_examples = vec![DataExample {
input_ids: vec![1, 2, 3],
attention_mask: Some(vec![1, 1, 1]),
token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}];
assert!(collator.validate_examples(&valid_examples).is_ok());
let empty_examples = vec![DataExample {
input_ids: vec![],
attention_mask: Some(vec![]),
token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}];
assert!(collator.validate_examples(&empty_examples).is_err());
let mismatched_examples = vec![DataExample {
input_ids: vec![1, 2, 3],
attention_mask: Some(vec![1, 1]), token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}];
assert!(collator.validate_examples(&mismatched_examples).is_err());
}
#[test]
fn test_minimal_config() {
let config = DefaultCollatorConfig::minimal(42);
assert_eq!(config.pad_token_id, 42);
assert_eq!(config.max_length, None);
assert!(!config.truncation);
assert_eq!(config.padding, PaddingStrategy::Longest);
}
#[test]
fn test_inference_config() {
let model_config = serde_json::json!({
"max_position_embeddings": 512,
"pad_token_id": 0
});
let config =
DefaultCollatorConfig::for_inference(&model_config).expect("operation failed in test");
assert_eq!(config.max_length, Some(512));
assert!(!config.truncation); }
#[test]
fn test_development_config() {
let model_config = serde_json::json!({
"max_position_embeddings": 1024,
"pad_token_id": 0
});
let config = DefaultCollatorConfig::for_development(&model_config)
.expect("operation failed in test");
assert_eq!(config.max_length, Some(128)); }
#[test]
fn test_custom_length_config() {
let model_config = serde_json::json!({
"max_position_embeddings": 1024,
"pad_token_id": 0
});
let config = DefaultCollatorConfig::with_max_length(&model_config, 256)
.expect("operation failed in test");
assert_eq!(config.max_length, Some(256));
assert!(config.truncation);
}
#[test]
fn test_preserve_original_metadata() {
let config = DefaultCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = DefaultDataCollator::new(config);
let mut example_metadata = HashMap::new();
example_metadata.insert(
"custom_field".to_string(),
serde_json::Value::String("test_value".to_string()),
);
let examples = vec![DataExample {
input_ids: vec![1, 2, 3],
attention_mask: Some(vec![1, 1, 1]),
token_type_ids: None,
labels: None,
metadata: example_metadata,
}];
let batch = collator.collate(&examples).expect("operation failed in test");
assert!(batch.metadata.contains_key("original_metadata"));
}
}