use super::{DataCollator, DataCollatorConfig};
use crate::auto::data_collators::language_modeling::{
LanguageModelingCollatorConfig, LanguageModelingDataCollator,
};
use crate::auto::types::{CollatedBatch, DataExample, PaddingStrategy};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct Seq2SeqDataCollator {
config: Seq2SeqCollatorConfig,
}
impl Seq2SeqDataCollator {
pub fn new(config: Seq2SeqCollatorConfig) -> Self {
Self { config }
}
fn prepare_decoder_inputs(&self, target_labels: &[i64], bos_token_id: Option<u32>) -> Vec<u32> {
let mut decoder_inputs = Vec::with_capacity(target_labels.len());
if let Some(bos_id) = bos_token_id {
decoder_inputs.push(bos_id);
}
for &label in target_labels.iter().take(target_labels.len().saturating_sub(1)) {
if label != -100 {
decoder_inputs.push(label as u32);
}
}
decoder_inputs
}
fn process_target_sequences(
&self,
examples: &[DataExample],
max_target_len: usize,
) -> Result<(Vec<Vec<i64>>, Vec<Vec<u32>>)> {
let mut processed_labels = Vec::with_capacity(examples.len());
let mut decoder_attention_masks = Vec::with_capacity(examples.len());
for example in examples {
if let Some(ref labels) = example.labels {
let mut sequence_labels = labels.clone();
if self.config.truncation && sequence_labels.len() > max_target_len {
sequence_labels.truncate(max_target_len);
}
let mut attention_mask = vec![1u32; sequence_labels.len()];
while sequence_labels.len() < max_target_len {
sequence_labels.push(-100); attention_mask.push(0);
}
processed_labels.push(sequence_labels);
decoder_attention_masks.push(attention_mask);
} else {
processed_labels.push(vec![-100i64; max_target_len]);
decoder_attention_masks.push(vec![0u32; max_target_len]);
}
}
Ok((processed_labels, decoder_attention_masks))
}
}
impl DataCollator for Seq2SeqDataCollator {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch> {
if examples.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Cannot collate empty batch for sequence-to-sequence".to_string(),
));
}
let batch_size = examples.len();
let max_encoder_len = match self.config.padding {
PaddingStrategy::Longest => examples
.iter()
.map(|ex| ex.input_ids.len())
.max()
.unwrap_or(0)
.min(self.config.max_length.unwrap_or(usize::MAX)),
PaddingStrategy::MaxLength => self.config.max_length.unwrap_or(512),
PaddingStrategy::DoNotPad => {
examples.iter().map(|ex| ex.input_ids.len()).max().unwrap_or(0)
},
PaddingStrategy::None => examples[0].input_ids.len(),
};
let max_decoder_len = match self.config.padding {
PaddingStrategy::Longest => examples
.iter()
.filter_map(|ex| ex.labels.as_ref())
.map(|labels| labels.len())
.max()
.unwrap_or(0)
.min(self.config.max_target_length.unwrap_or(usize::MAX)),
PaddingStrategy::MaxLength => self.config.max_target_length.unwrap_or(max_encoder_len),
PaddingStrategy::DoNotPad => examples
.iter()
.filter_map(|ex| ex.labels.as_ref())
.map(|labels| labels.len())
.max()
.unwrap_or(0),
PaddingStrategy::None => examples
.first()
.and_then(|ex| ex.labels.as_ref())
.map(|labels| labels.len())
.unwrap_or(0),
};
let encoder_collator = LanguageModelingDataCollator::new(LanguageModelingCollatorConfig {
max_length: Some(max_encoder_len),
padding: self.config.padding,
truncation: self.config.truncation,
pad_token_id: self.config.pad_token_id,
mask_token_id: 0, mlm_probability: 0.0,
});
let mut batch = encoder_collator.collate(examples)?;
let (processed_labels, decoder_attention_masks) =
self.process_target_sequences(examples, max_decoder_len)?;
batch.labels = Some(processed_labels);
batch.metadata.insert(
"decoder_attention_mask".to_string(),
serde_json::to_value(decoder_attention_masks)
.map_err(|e| TrustformersError::runtime_error(e.to_string()))?,
);
batch.metadata.insert(
"target_sequence_length".to_string(),
serde_json::Value::Number(max_decoder_len.into()),
);
Ok(batch)
}
fn config(&self) -> &dyn DataCollatorConfig {
&self.config
}
fn preprocess_examples(&self, examples: &[DataExample]) -> Result<Vec<DataExample>> {
Ok(examples.to_vec())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Seq2SeqCollatorConfig {
pub max_length: Option<usize>,
pub max_target_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: bool,
pub pad_token_id: u32,
}
impl Seq2SeqCollatorConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
let max_length = config
.get("max_position_embeddings")
.or_else(|| config.get("max_length"))
.and_then(|v| v.as_u64())
.map(|v| v as usize);
let max_target_length = config
.get("max_target_length")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.or(max_length);
Ok(Self {
max_length,
max_target_length,
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: config.get("pad_token_id").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
})
}
pub fn for_translation(config: &serde_json::Value) -> Result<Self> {
let mut translation_config = Self::from_config(config)?;
if translation_config.max_target_length.is_none() {
translation_config.max_target_length = translation_config.max_length;
}
Ok(translation_config)
}
pub fn for_summarization(
config: &serde_json::Value,
summary_ratio: Option<f32>,
) -> Result<Self> {
let mut summarization_config = Self::from_config(config)?;
let ratio = summary_ratio.unwrap_or(0.25);
if let Some(max_len) = summarization_config.max_length {
summarization_config.max_target_length = Some((max_len as f32 * ratio) as usize);
}
Ok(summarization_config)
}
}
impl DataCollatorConfig for Seq2SeqCollatorConfig {
fn max_length(&self) -> Option<usize> {
self.max_length
}
fn padding(&self) -> PaddingStrategy {
self.padding
}
fn truncation(&self) -> bool {
self.truncation
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_seq2seq_collator_creation() {
let config = Seq2SeqCollatorConfig {
max_length: Some(512),
max_target_length: Some(128),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = Seq2SeqDataCollator::new(config);
assert_eq!(collator.config().max_length(), Some(512));
assert_eq!(collator.config.max_target_length, Some(128));
}
#[test]
fn test_seq2seq_config_from_json() {
let config_json = serde_json::json!({
"max_position_embeddings": 512,
"max_target_length": 64,
"pad_token_id": 1,
"vocab_size": 32000
});
let config =
Seq2SeqCollatorConfig::from_config(&config_json).expect("operation failed in test");
assert_eq!(config.max_length, Some(512));
assert_eq!(config.max_target_length, Some(64));
assert_eq!(config.pad_token_id, 1);
}
#[test]
fn test_collate_seq2seq_examples() {
let config = Seq2SeqCollatorConfig {
max_length: Some(10),
max_target_length: Some(8),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
};
let collator = Seq2SeqDataCollator::new(config);
let examples = vec![
DataExample {
input_ids: vec![1, 2, 3, 4],
attention_mask: Some(vec![1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![5, 6, 7]),
metadata: HashMap::new(),
},
DataExample {
input_ids: vec![1, 2],
attention_mask: Some(vec![1, 1]),
token_type_ids: None,
labels: Some(vec![8, 9, 10, 11]),
metadata: HashMap::new(),
},
];
let batch = collator.collate(&examples).expect("operation failed in test");
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.input_ids.len(), 2);
assert!(batch.labels.is_some());
let labels = batch.labels.as_ref().expect("operation failed in test");
assert_eq!(labels.len(), 2);
}
#[test]
fn test_summarization_config() {
let model_config = serde_json::json!({
"max_position_embeddings": 1024,
"pad_token_id": 0
});
let config = Seq2SeqCollatorConfig::for_summarization(&model_config, Some(0.2))
.expect("operation failed in test");
assert_eq!(config.max_length, Some(1024));
assert_eq!(config.max_target_length, Some(204)); }
}