use crate::auto::types::{CollatedBatch, DataExample, PaddingStrategy};
use crate::error::Result;
pub mod classification;
pub mod default;
pub mod language_modeling;
pub mod question_answering;
pub mod seq2seq;
#[derive(Debug, Clone)]
pub struct AutoDataCollator;
impl AutoDataCollator {
pub fn from_pretrained(model_name_or_path: &str) -> Result<Box<dyn DataCollator>> {
let config = crate::hub::load_config_from_hub(model_name_or_path, None)?;
Self::from_config(&config)
}
pub fn from_config(config: &serde_json::Value) -> Result<Box<dyn DataCollator>> {
let model_type = config.get("model_type").and_then(|v| v.as_str()).unwrap_or("default");
match model_type {
"bert" | "roberta" | "electra" => Ok(Box::new(LanguageModelingDataCollator::new(
LanguageModelingCollatorConfig::from_config(config)?,
))),
"gpt2" | "gpt_neo" | "gpt_j" => Ok(Box::new(CausalLanguageModelingDataCollator::new(
CausalLanguageModelingCollatorConfig::from_config(config)?,
))),
"t5" | "bart" | "pegasus" => Ok(Box::new(seq2seq::Seq2SeqDataCollator::new(
seq2seq::Seq2SeqCollatorConfig::from_config(config)?,
))),
_ => Ok(Box::new(default::DefaultDataCollator::new(
default::DefaultCollatorConfig::from_config(config)?,
))),
}
}
pub fn for_task(task: &str, config: &serde_json::Value) -> Result<Box<dyn DataCollator>> {
match task {
"masked-lm" | "fill-mask" => Ok(Box::new(LanguageModelingDataCollator::new(
LanguageModelingCollatorConfig::from_config(config)?,
))),
"causal-lm" | "text-generation" => {
Ok(Box::new(CausalLanguageModelingDataCollator::new(
CausalLanguageModelingCollatorConfig::from_config(config)?,
)))
},
"text2text-generation" | "translation" | "summarization" => {
Ok(Box::new(seq2seq::Seq2SeqDataCollator::new(
seq2seq::Seq2SeqCollatorConfig::from_config(config)?,
)))
},
"text-classification" | "sentiment-analysis" => {
Ok(Box::new(classification::ClassificationDataCollator::new(
classification::ClassificationCollatorConfig::from_config(config)?,
)))
},
"question-answering" => Ok(Box::new(
question_answering::QuestionAnsweringDataCollator::new(
question_answering::QuestionAnsweringCollatorConfig::from_config(config)?,
),
)),
_ => Ok(Box::new(default::DefaultDataCollator::new(
default::DefaultCollatorConfig::from_config(config)?,
))),
}
}
}
pub trait DataCollator: Send + Sync {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch>;
fn config(&self) -> &dyn DataCollatorConfig;
fn preprocess_examples(&self, examples: &[DataExample]) -> Result<Vec<DataExample>> {
Ok(examples.to_vec())
}
}
pub trait DataCollatorConfig: Send + Sync {
fn max_length(&self) -> Option<usize>;
fn padding(&self) -> PaddingStrategy;
fn truncation(&self) -> bool;
}
pub use classification::{ClassificationCollatorConfig, ClassificationDataCollator};
pub use default::{DefaultCollatorConfig, DefaultDataCollator};
pub use language_modeling::{
CausalLanguageModelingCollatorConfig, CausalLanguageModelingDataCollator,
LanguageModelingCollatorConfig, LanguageModelingDataCollator,
};
pub use question_answering::{QuestionAnsweringCollatorConfig, QuestionAnsweringDataCollator};
pub use seq2seq::{Seq2SeqCollatorConfig, Seq2SeqDataCollator};
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::types::{DataExample, PaddingStrategy};
#[test]
fn test_padding_strategy_none_no_padding() {
let strategy = PaddingStrategy::None;
assert!(!strategy.should_pad(), "None should not require padding");
}
#[test]
fn test_padding_strategy_longest_pads() {
let strategy = PaddingStrategy::Longest;
assert!(strategy.should_pad(), "Longest should require padding");
}
#[test]
fn test_padding_strategy_max_length_pads() {
let strategy = PaddingStrategy::MaxLength;
assert!(strategy.should_pad(), "MaxLength should require padding");
}
#[test]
fn test_padding_strategy_do_not_pad_no_padding() {
let strategy = PaddingStrategy::DoNotPad;
assert!(
!strategy.should_pad(),
"DoNotPad should not require padding"
);
}
#[test]
fn test_padding_strategy_longest_is_dynamic() {
let strategy = PaddingStrategy::Longest;
assert!(strategy.is_dynamic(), "Longest should be dynamic");
}
#[test]
fn test_padding_strategy_max_length_not_dynamic() {
let strategy = PaddingStrategy::MaxLength;
assert!(!strategy.is_dynamic(), "MaxLength should not be dynamic");
}
#[test]
fn test_padding_strategy_none_not_dynamic() {
let strategy = PaddingStrategy::None;
assert!(!strategy.is_dynamic(), "None should not be dynamic");
}
#[test]
fn test_data_example_new() {
let example = DataExample::new(vec![101, 2023, 102]);
assert_eq!(example.input_ids, vec![101, 2023, 102]);
assert!(
example.attention_mask.is_none(),
"attention_mask should be None by default"
);
assert!(example.labels.is_none(), "labels should be None by default");
}
#[test]
fn test_data_example_sequence_length() {
let example = DataExample::new(vec![101, 2023, 3000, 102]);
assert_eq!(example.sequence_length(), 4);
}
#[test]
fn test_data_example_with_attention_mask() {
let example = DataExample::new(vec![101, 102]).with_attention_mask(vec![1, 1]);
assert!(example.attention_mask.is_some());
if let Some(mask) = &example.attention_mask {
assert_eq!(mask, &vec![1, 1]);
}
}
#[test]
fn test_data_example_with_labels() {
let example = DataExample::new(vec![101, 102]).with_labels(vec![0]);
assert!(example.has_labels(), "Example should have labels");
}
#[test]
fn test_data_example_without_labels() {
let example = DataExample::new(vec![101, 102]);
assert!(
!example.has_labels(),
"Example without labels should return false"
);
}
#[test]
fn test_data_example_with_token_type_ids() {
let example = DataExample::new(vec![101, 200, 102]).with_token_type_ids(vec![0, 0, 0]);
assert!(example.token_type_ids.is_some());
}
#[test]
fn test_language_modeling_collator_config_creation() {
let config = LanguageModelingCollatorConfig {
max_length: Some(512),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
mask_token_id: 103,
mlm_probability: 0.15,
};
assert_eq!(config.max_length, Some(512));
assert_eq!(config.pad_token_id, 0);
assert_eq!(config.mask_token_id, 103);
let diff = (config.mlm_probability - 0.15).abs();
assert!(diff < 1e-6, "mlm_probability should be 0.15");
}
#[test]
fn test_language_modeling_collator_creation() {
let config = LanguageModelingCollatorConfig {
max_length: Some(128),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
mask_token_id: 103,
mlm_probability: 0.0, };
let collator = LanguageModelingDataCollator::new(config);
assert_eq!(collator.config().max_length(), Some(128));
assert!(collator.config().truncation());
}
#[test]
fn test_language_modeling_collate_single_example() {
let config = LanguageModelingCollatorConfig {
max_length: Some(16),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
mask_token_id: 103,
mlm_probability: 0.0,
};
let collator = LanguageModelingDataCollator::new(config);
let examples = vec![DataExample::new(vec![101_u32, 2023, 2003, 102])];
let result = collator.collate(&examples);
assert!(result.is_ok(), "Collation should succeed");
if let Ok(batch) = result {
assert_eq!(batch.batch_size, 1);
assert_eq!(batch.input_ids[0].len(), batch.sequence_length);
}
}
#[test]
fn test_language_modeling_collate_pads_shorter_sequence() {
let config = LanguageModelingCollatorConfig {
max_length: None,
padding: PaddingStrategy::Longest,
truncation: false,
pad_token_id: 0,
mask_token_id: 103,
mlm_probability: 0.0,
};
let collator = LanguageModelingDataCollator::new(config);
let examples = vec![
DataExample::new(vec![101_u32, 200, 300, 102]),
DataExample::new(vec![101_u32, 400, 102]),
];
let result = collator.collate(&examples);
if let Ok(batch) = result {
assert_eq!(batch.batch_size, 2);
let len0 = batch.input_ids[0].len();
let len1 = batch.input_ids[1].len();
assert_eq!(len0, len1, "All sequences should be padded to same length");
}
}
#[test]
fn test_causal_lm_collator_config_creation() {
let config = CausalLanguageModelingCollatorConfig {
max_length: Some(1024),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 50256,
};
assert_eq!(config.max_length, Some(1024));
assert_eq!(config.pad_token_id, 50256);
assert!(config.truncation);
}
#[test]
fn test_causal_lm_collator_collate_single() {
let config = CausalLanguageModelingCollatorConfig {
max_length: Some(32),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = CausalLanguageModelingDataCollator::new(config);
let examples = vec![DataExample::new(vec![1_u32, 2, 3, 4, 5])];
let result = collator.collate(&examples);
assert!(result.is_ok(), "Causal LM collation should succeed");
if let Ok(batch) = result {
assert_eq!(batch.batch_size, 1);
}
}
#[test]
fn test_auto_collator_from_config_bert() {
let config = serde_json::json!({
"model_type": "bert",
"pad_token_id": 0,
"mask_token_id": 103,
"max_position_embeddings": 512
});
let result = AutoDataCollator::from_config(&config);
assert!(
result.is_ok(),
"AutoDataCollator::from_config for bert should succeed"
);
}
#[test]
fn test_auto_collator_from_config_gpt2() {
let config = serde_json::json!({
"model_type": "gpt2",
"pad_token_id": 50256,
"n_positions": 1024
});
let result = AutoDataCollator::from_config(&config);
assert!(
result.is_ok(),
"AutoDataCollator::from_config for gpt2 should succeed"
);
}
#[test]
fn test_auto_collator_from_config_t5() {
let config = serde_json::json!({
"model_type": "t5",
"pad_token_id": 0
});
let result = AutoDataCollator::from_config(&config);
assert!(
result.is_ok(),
"AutoDataCollator::from_config for t5 should succeed"
);
}
#[test]
fn test_auto_collator_from_config_unknown_uses_default() {
let config = serde_json::json!({
"model_type": "custom-model-xyz"
});
let result = AutoDataCollator::from_config(&config);
assert!(
result.is_ok(),
"Unknown model type should fall back to DefaultDataCollator"
);
}
#[test]
fn test_auto_collator_for_task_masked_lm() {
let config = serde_json::json!({"pad_token_id": 0, "mask_token_id": 103});
let result = AutoDataCollator::for_task("masked-lm", &config);
assert!(result.is_ok(), "for_task masked-lm should succeed");
}
#[test]
fn test_auto_collator_for_task_causal_lm() {
let config = serde_json::json!({"pad_token_id": 0});
let result = AutoDataCollator::for_task("causal-lm", &config);
assert!(result.is_ok(), "for_task causal-lm should succeed");
}
#[test]
fn test_auto_collator_for_task_text_generation() {
let config = serde_json::json!({"pad_token_id": 50256});
let result = AutoDataCollator::for_task("text-generation", &config);
assert!(result.is_ok(), "for_task text-generation should succeed");
}
#[test]
fn test_auto_collator_for_task_classification() {
let config = serde_json::json!({"pad_token_id": 0});
let result = AutoDataCollator::for_task("text-classification", &config);
assert!(
result.is_ok(),
"for_task text-classification should succeed"
);
}
#[test]
fn test_auto_collator_for_task_question_answering() {
let config = serde_json::json!({"pad_token_id": 0});
let result = AutoDataCollator::for_task("question-answering", &config);
assert!(result.is_ok(), "for_task question-answering should succeed");
}
#[test]
fn test_auto_collator_for_task_summarization() {
let config = serde_json::json!({"pad_token_id": 0});
let result = AutoDataCollator::for_task("summarization", &config);
assert!(result.is_ok(), "for_task summarization should succeed");
}
#[test]
fn test_auto_collator_for_task_translation() {
let config = serde_json::json!({"pad_token_id": 0});
let result = AutoDataCollator::for_task("translation", &config);
assert!(result.is_ok(), "for_task translation should succeed");
}
#[test]
fn test_auto_collator_for_task_unknown_default() {
let config = serde_json::json!({"pad_token_id": 0});
let result = AutoDataCollator::for_task("very-unusual-task-xyz", &config);
assert!(
result.is_ok(),
"Unknown task should fall back to DefaultDataCollator"
);
}
}