use super::{DataCollator, DataCollatorConfig};
use crate::auto::data_collators::language_modeling::{
LanguageModelingCollatorConfig, LanguageModelingDataCollator,
};
use crate::auto::types::{CollatedBatch, DataExample, PaddingStrategy};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QuestionAnsweringDataCollator {
config: QuestionAnsweringCollatorConfig,
}
impl QuestionAnsweringDataCollator {
pub fn new(config: QuestionAnsweringCollatorConfig) -> Self {
Self { config }
}
fn process_qa_labels(&self, examples: &[DataExample], sequence_length: usize) -> Vec<Vec<i64>> {
let mut processed_labels = Vec::with_capacity(examples.len());
for example in examples {
if let Some(ref example_labels) = example.labels {
if example_labels.len() >= 2 {
let start_pos = example_labels[0];
let end_pos = example_labels[1];
if start_pos >= 0
&& end_pos >= 0
&& (start_pos as usize) < sequence_length
&& (end_pos as usize) < sequence_length
&& start_pos <= end_pos
{
processed_labels.push(vec![start_pos, end_pos]);
} else {
processed_labels.push(vec![-100, -100]);
}
} else {
processed_labels.push(vec![-100, -100]);
}
} else {
processed_labels.push(vec![-100, -100]);
}
}
processed_labels
}
fn validate_answer_spans(
&self,
examples: &[DataExample],
sequence_length: usize,
) -> Result<()> {
for (i, example) in examples.iter().enumerate() {
if let Some(ref labels) = example.labels {
if labels.len() >= 2 {
let start_pos = labels[0];
let end_pos = labels[1];
if start_pos != -100 && end_pos != -100 {
if start_pos < 0 || end_pos < 0 {
return Err(TrustformersError::invalid_input(
format!(
"Negative answer positions in example {}: start={}, end={}",
i, start_pos, end_pos
),
Some("answer_positions"),
Some("non-negative start and end positions"),
Some(format!("start={}, end={}", start_pos, end_pos)),
));
}
if (start_pos as usize) >= sequence_length
|| (end_pos as usize) >= sequence_length
{
return Err(TrustformersError::invalid_input( format!("Answer positions exceed sequence length in example {}: start={}, end={}, seq_len={}", i, start_pos, end_pos, sequence_length),
Some("answer_positions"),
Some(format!("positions within sequence length {}", sequence_length)),
Some(format!("start={}, end={}", start_pos, end_pos))
));
}
if start_pos > end_pos {
return Err(TrustformersError::invalid_input(
format!(
"Invalid answer span in example {}: start={} > end={}",
i, start_pos, end_pos
),
Some("answer_span"),
Some("start position <= end position"),
Some(format!("start={}, end={}", start_pos, end_pos)),
));
}
let answer_length = (end_pos - start_pos + 1) as usize;
if answer_length > self.config.max_answer_length {
return Err(TrustformersError::invalid_input(
format!(
"Answer span too long in example {}: length={}, max={}",
i, answer_length, self.config.max_answer_length
),
Some("answer_length"),
Some(format!("length <= {}", self.config.max_answer_length)),
Some(answer_length.to_string()),
));
}
}
}
}
}
Ok(())
}
fn create_sequence_metadata(
&self,
examples: &[DataExample],
) -> HashMap<String, serde_json::Value> {
let mut metadata = HashMap::new();
let mut question_lengths = Vec::new();
let mut context_starts = Vec::new();
for example in examples {
if let Some(ref token_types) = example.token_type_ids {
let context_start =
token_types.iter().position(|&t| t == 1).unwrap_or(token_types.len());
context_starts.push(context_start);
question_lengths.push(context_start);
} else {
context_starts.push(0);
question_lengths.push(0);
}
}
metadata.insert(
"question_lengths".to_string(),
serde_json::to_value(question_lengths).unwrap_or(serde_json::Value::Null),
);
metadata.insert(
"context_starts".to_string(),
serde_json::to_value(context_starts).unwrap_or(serde_json::Value::Null),
);
metadata
}
}
impl DataCollator for QuestionAnsweringDataCollator {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch> {
if examples.is_empty() {
return Err(TrustformersError::invalid_input(
"Cannot collate empty batch for question answering".to_string(),
Some("examples".to_string()),
Some("non-empty batch".to_string()),
Some("empty batch".to_string()),
));
}
let sequence_collator = LanguageModelingDataCollator::new(LanguageModelingCollatorConfig {
max_length: self.config.max_length,
padding: self.config.padding,
truncation: self.config.truncation,
pad_token_id: self.config.pad_token_id,
mask_token_id: 0, mlm_probability: 0.0,
});
let mut batch = sequence_collator.collate(examples)?;
self.validate_answer_spans(examples, batch.sequence_length)?;
let processed_labels = self.process_qa_labels(examples, batch.sequence_length);
batch.labels = Some(processed_labels);
batch.metadata.insert(
"task_type".to_string(),
serde_json::Value::String("question_answering".to_string()),
);
batch.metadata.insert(
"doc_stride".to_string(),
serde_json::Value::Number(self.config.doc_stride.into()),
);
batch.metadata.insert(
"max_answer_length".to_string(),
serde_json::Value::Number(self.config.max_answer_length.into()),
);
let sequence_metadata = self.create_sequence_metadata(examples);
for (key, value) in sequence_metadata {
batch.metadata.insert(key, value);
}
Ok(batch)
}
fn config(&self) -> &dyn DataCollatorConfig {
&self.config
}
fn preprocess_examples(&self, examples: &[DataExample]) -> Result<Vec<DataExample>> {
Ok(examples.to_vec())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionAnsweringCollatorConfig {
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: bool,
pub pad_token_id: u32,
pub doc_stride: usize,
pub max_answer_length: usize,
}
impl QuestionAnsweringCollatorConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
max_length: config
.get("max_position_embeddings")
.or_else(|| config.get("max_length"))
.and_then(|v| v.as_u64())
.map(|v| v as usize),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: config.get("pad_token_id").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
doc_stride: config.get("doc_stride").and_then(|v| v.as_u64()).unwrap_or(128) as usize,
max_answer_length: config
.get("max_answer_length")
.and_then(|v| v.as_u64())
.unwrap_or(30) as usize,
})
}
pub fn for_squad(config: &serde_json::Value) -> Result<Self> {
let mut squad_config = Self::from_config(config)?;
squad_config.max_length = Some(384); squad_config.doc_stride = 128;
squad_config.max_answer_length = 30;
Ok(squad_config)
}
pub fn for_long_form_qa(
config: &serde_json::Value,
context_length: Option<usize>,
answer_length: Option<usize>,
) -> Result<Self> {
let mut long_config = Self::from_config(config)?;
if let Some(max_len) = context_length {
long_config.max_length = Some(max_len);
long_config.doc_stride = max_len / 4;
}
if let Some(max_ans) = answer_length {
long_config.max_answer_length = max_ans;
}
Ok(long_config)
}
pub fn for_conversational_qa(config: &serde_json::Value) -> Result<Self> {
let mut conv_config = Self::from_config(config)?;
conv_config.max_length = Some(512);
conv_config.doc_stride = 256; conv_config.max_answer_length = 50; Ok(conv_config)
}
}
impl DataCollatorConfig for QuestionAnsweringCollatorConfig {
fn max_length(&self) -> Option<usize> {
self.max_length
}
fn padding(&self) -> PaddingStrategy {
self.padding
}
fn truncation(&self) -> bool {
self.truncation
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_qa_collator_creation() {
let config = QuestionAnsweringCollatorConfig {
max_length: Some(384),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
doc_stride: 128,
max_answer_length: 30,
};
let collator = QuestionAnsweringDataCollator::new(config);
assert_eq!(collator.config().max_length(), Some(384));
assert_eq!(collator.config.doc_stride, 128);
assert_eq!(collator.config.max_answer_length, 30);
}
#[test]
fn test_qa_config_from_json() {
let config_json = serde_json::json!({
"max_position_embeddings": 512,
"pad_token_id": 1,
"doc_stride": 64,
"max_answer_length": 20,
"vocab_size": 30522
});
let config = QuestionAnsweringCollatorConfig::from_config(&config_json)
.expect("operation failed in test");
assert_eq!(config.max_length, Some(512));
assert_eq!(config.pad_token_id, 1);
assert_eq!(config.doc_stride, 64);
assert_eq!(config.max_answer_length, 20);
}
#[test]
fn test_collate_qa_examples() {
let config = QuestionAnsweringCollatorConfig {
max_length: Some(20),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
doc_stride: 128,
max_answer_length: 10,
};
let collator = QuestionAnsweringDataCollator::new(config);
let examples = vec![
DataExample {
input_ids: vec![101, 2054, 2003, 102, 1996, 3438, 2003, 2769, 102],
attention_mask: Some(vec![1, 1, 1, 1, 1, 1, 1, 1, 1]),
token_type_ids: Some(vec![0, 0, 0, 0, 1, 1, 1, 1, 1]),
labels: Some(vec![7, 7]), metadata: HashMap::new(),
},
DataExample {
input_ids: vec![101, 2073, 102, 3376, 102],
attention_mask: Some(vec![1, 1, 1, 1, 1]),
token_type_ids: Some(vec![0, 0, 0, 1, 1]),
labels: Some(vec![3, 3]), metadata: HashMap::new(),
},
];
let batch = collator.collate(&examples).expect("operation failed in test");
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.input_ids.len(), 2);
assert!(batch.labels.is_some());
let labels = batch.labels.as_ref().expect("operation failed in test");
assert_eq!(labels.len(), 2);
assert_eq!(labels[0], vec![7, 7]);
assert_eq!(labels[1], vec![3, 3]);
}
#[test]
fn test_unanswerable_questions() {
let config = QuestionAnsweringCollatorConfig {
max_length: Some(20),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
doc_stride: 128,
max_answer_length: 10,
};
let collator = QuestionAnsweringDataCollator::new(config);
let examples = vec![DataExample {
input_ids: vec![101, 2054, 102, 1045, 2123, 1005, 1056, 2113, 102],
attention_mask: Some(vec![1, 1, 1, 1, 1, 1, 1, 1, 1]),
token_type_ids: Some(vec![0, 0, 0, 1, 1, 1, 1, 1, 1]),
labels: Some(vec![-100, -100]), metadata: HashMap::new(),
}];
let batch = collator.collate(&examples).expect("operation failed in test");
let labels = batch.labels.as_ref().expect("operation failed in test");
assert_eq!(labels[0], vec![-100, -100]);
}
#[test]
fn test_answer_span_validation() {
let config = QuestionAnsweringCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
doc_stride: 128,
max_answer_length: 3,
};
let collator = QuestionAnsweringDataCollator::new(config);
let valid_examples = vec![DataExample {
input_ids: vec![101, 102, 103, 104],
attention_mask: Some(vec![1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![2, 3]), metadata: HashMap::new(),
}];
assert!(collator.validate_answer_spans(&valid_examples, 4).is_ok());
let invalid_examples = vec![DataExample {
input_ids: vec![101, 102, 103, 104],
attention_mask: Some(vec![1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![3, 2]), metadata: HashMap::new(),
}];
assert!(collator.validate_answer_spans(&invalid_examples, 4).is_err());
let long_span_examples = vec![DataExample {
input_ids: vec![101, 102, 103, 104, 105, 106],
attention_mask: Some(vec![1, 1, 1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![1, 5]), metadata: HashMap::new(),
}];
assert!(collator.validate_answer_spans(&long_span_examples, 6).is_err());
}
#[test]
fn test_squad_config() {
let model_config = serde_json::json!({
"max_position_embeddings": 512,
"pad_token_id": 0
});
let config = QuestionAnsweringCollatorConfig::for_squad(&model_config)
.expect("operation failed in test");
assert_eq!(config.max_length, Some(384));
assert_eq!(config.doc_stride, 128);
assert_eq!(config.max_answer_length, 30);
}
#[test]
fn test_sequence_metadata() {
let config = QuestionAnsweringCollatorConfig {
max_length: Some(20),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
doc_stride: 128,
max_answer_length: 10,
};
let collator = QuestionAnsweringDataCollator::new(config);
let examples = vec![DataExample {
input_ids: vec![101, 2054, 102, 1996, 3438, 102],
attention_mask: Some(vec![1, 1, 1, 1, 1, 1]),
token_type_ids: Some(vec![0, 0, 0, 1, 1, 1]),
labels: Some(vec![4, 4]),
metadata: HashMap::new(),
}];
let metadata = collator.create_sequence_metadata(&examples);
assert!(metadata.contains_key("question_lengths"));
assert!(metadata.contains_key("context_starts"));
}
}