use super::{DataCollator, DataCollatorConfig};
use crate::auto::data_collators::language_modeling::{
LanguageModelingCollatorConfig, LanguageModelingDataCollator,
};
use crate::auto::types::{CollatedBatch, DataExample, PaddingStrategy};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct ClassificationDataCollator {
config: ClassificationCollatorConfig,
}
impl ClassificationDataCollator {
pub fn new(config: ClassificationCollatorConfig) -> Self {
Self { config }
}
fn process_classification_labels(&self, examples: &[DataExample]) -> Vec<Vec<i64>> {
let mut processed_labels = Vec::with_capacity(examples.len());
for example in examples {
if let Some(ref example_labels) = example.labels {
if example_labels.is_empty() {
processed_labels.push(vec![-100]);
} else if example_labels.len() == 1 {
processed_labels.push(vec![example_labels[0]]);
} else {
processed_labels.push(vec![example_labels[0]]);
}
} else {
processed_labels.push(vec![-100]);
}
}
processed_labels
}
fn create_multilabel_encoding(&self, examples: &[DataExample]) -> Vec<Vec<f32>> {
let mut multilabel_vectors = Vec::with_capacity(examples.len());
for example in examples {
let mut label_vector = vec![0.0f32; self.config.num_labels];
if let Some(ref example_labels) = example.labels {
for &label in example_labels {
if label >= 0 && (label as usize) < self.config.num_labels {
label_vector[label as usize] = 1.0;
}
}
}
multilabel_vectors.push(label_vector);
}
multilabel_vectors
}
fn validate_labels(&self, examples: &[DataExample]) -> Result<()> {
for (i, example) in examples.iter().enumerate() {
if let Some(ref labels) = example.labels {
for &label in labels {
if label >= 0 && (label as usize) >= self.config.num_labels {
return Err(TrustformersError::invalid_input_simple(format!(
"Label {} in example {} exceeds num_labels {}",
label, i, self.config.num_labels
)));
}
}
}
}
Ok(())
}
}
impl DataCollator for ClassificationDataCollator {
fn collate(&self, examples: &[DataExample]) -> Result<CollatedBatch> {
if examples.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Cannot collate empty batch for classification".to_string(),
));
}
self.validate_labels(examples)?;
let sequence_collator = LanguageModelingDataCollator::new(LanguageModelingCollatorConfig {
max_length: self.config.max_length,
padding: self.config.padding,
truncation: self.config.truncation,
pad_token_id: self.config.pad_token_id,
mask_token_id: 0, mlm_probability: 0.0,
});
let mut batch = sequence_collator.collate(examples)?;
let processed_labels = self.process_classification_labels(examples);
batch.labels = Some(processed_labels);
batch.metadata.insert(
"num_labels".to_string(),
serde_json::Value::Number(self.config.num_labels.into()),
);
let is_multilabel = examples
.iter()
.any(|ex| ex.labels.as_ref().is_some_and(|labels| labels.len() > 1));
if is_multilabel {
let multilabel_encoding = self.create_multilabel_encoding(examples);
batch.metadata.insert(
"multilabel_targets".to_string(),
serde_json::to_value(multilabel_encoding)
.map_err(|e| TrustformersError::runtime_error(e.to_string()))?,
);
batch.metadata.insert(
"task_type".to_string(),
serde_json::Value::String("multilabel".to_string()),
);
} else {
batch.metadata.insert(
"task_type".to_string(),
serde_json::Value::String("single_label".to_string()),
);
}
Ok(batch)
}
fn config(&self) -> &dyn DataCollatorConfig {
&self.config
}
fn preprocess_examples(&self, examples: &[DataExample]) -> Result<Vec<DataExample>> {
Ok(examples.to_vec())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationCollatorConfig {
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: bool,
pub pad_token_id: u32,
pub num_labels: usize,
}
impl ClassificationCollatorConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
max_length: config
.get("max_position_embeddings")
.or_else(|| config.get("max_length"))
.and_then(|v| v.as_u64())
.map(|v| v as usize),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: config.get("pad_token_id").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
num_labels: config.get("num_labels").and_then(|v| v.as_u64()).unwrap_or(2) as usize,
})
}
pub fn for_binary_classification(config: &serde_json::Value) -> Result<Self> {
let mut binary_config = Self::from_config(config)?;
binary_config.num_labels = 2;
Ok(binary_config)
}
pub fn for_multilabel_classification(
config: &serde_json::Value,
num_labels: usize,
) -> Result<Self> {
let mut multilabel_config = Self::from_config(config)?;
multilabel_config.num_labels = num_labels;
Ok(multilabel_config)
}
pub fn for_sentiment_analysis(
config: &serde_json::Value,
include_neutral: bool,
) -> Result<Self> {
let mut sentiment_config = Self::from_config(config)?;
sentiment_config.num_labels = if include_neutral { 3 } else { 2 };
Ok(sentiment_config)
}
}
impl DataCollatorConfig for ClassificationCollatorConfig {
fn max_length(&self) -> Option<usize> {
self.max_length
}
fn padding(&self) -> PaddingStrategy {
self.padding
}
fn truncation(&self) -> bool {
self.truncation
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_classification_collator_creation() {
let config = ClassificationCollatorConfig {
max_length: Some(128),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
num_labels: 3,
};
let collator = ClassificationDataCollator::new(config);
assert_eq!(collator.config().max_length(), Some(128));
assert_eq!(collator.config.num_labels, 3);
}
#[test]
fn test_classification_config_from_json() {
let config_json = serde_json::json!({
"max_position_embeddings": 512,
"pad_token_id": 1,
"num_labels": 5,
"vocab_size": 30522
});
let config = ClassificationCollatorConfig::from_config(&config_json)
.expect("operation failed in test");
assert_eq!(config.max_length, Some(512));
assert_eq!(config.pad_token_id, 1);
assert_eq!(config.num_labels, 5);
}
#[test]
fn test_collate_classification_examples() {
let config = ClassificationCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
num_labels: 2,
};
let collator = ClassificationDataCollator::new(config);
let examples = vec![
DataExample {
input_ids: vec![101, 2023, 2003, 102], attention_mask: Some(vec![1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![1]), metadata: HashMap::new(),
},
DataExample {
input_ids: vec![101, 2025, 102], attention_mask: Some(vec![1, 1, 1]),
token_type_ids: None,
labels: Some(vec![0]), metadata: HashMap::new(),
},
];
let batch = collator.collate(&examples).expect("operation failed in test");
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.input_ids.len(), 2);
assert!(batch.labels.is_some());
let labels = batch.labels.as_ref().expect("operation failed in test");
assert_eq!(labels.len(), 2);
assert_eq!(labels[0], vec![1]);
assert_eq!(labels[1], vec![0]);
}
#[test]
fn test_multilabel_encoding() {
let config = ClassificationCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
num_labels: 5,
};
let collator = ClassificationDataCollator::new(config);
let examples = vec![DataExample {
input_ids: vec![101, 1037, 3231, 102],
attention_mask: Some(vec![1, 1, 1, 1]),
token_type_ids: None,
labels: Some(vec![0, 2, 4]), metadata: HashMap::new(),
}];
let multilabel_encoding = collator.create_multilabel_encoding(&examples);
assert_eq!(multilabel_encoding.len(), 1);
assert_eq!(multilabel_encoding[0].len(), 5);
assert_eq!(multilabel_encoding[0][0], 1.0);
assert_eq!(multilabel_encoding[0][1], 0.0);
assert_eq!(multilabel_encoding[0][2], 1.0);
assert_eq!(multilabel_encoding[0][3], 0.0);
assert_eq!(multilabel_encoding[0][4], 1.0);
}
#[test]
fn test_sentiment_analysis_config() {
let model_config = serde_json::json!({
"max_position_embeddings": 128,
"pad_token_id": 0
});
let config = ClassificationCollatorConfig::for_sentiment_analysis(&model_config, true)
.expect("operation failed in test");
assert_eq!(config.num_labels, 3);
let config = ClassificationCollatorConfig::for_sentiment_analysis(&model_config, false)
.expect("operation failed in test");
assert_eq!(config.num_labels, 2); }
#[test]
fn test_label_validation() {
let config = ClassificationCollatorConfig {
max_length: Some(10),
padding: PaddingStrategy::Longest,
truncation: true,
pad_token_id: 0,
num_labels: 2, };
let collator = ClassificationDataCollator::new(config);
let valid_examples = vec![DataExample {
input_ids: vec![101, 102],
attention_mask: Some(vec![1, 1]),
token_type_ids: None,
labels: Some(vec![1]),
metadata: HashMap::new(),
}];
assert!(collator.validate_labels(&valid_examples).is_ok());
let invalid_examples = vec![DataExample {
input_ids: vec![101, 102],
attention_mask: Some(vec![1, 1]),
token_type_ids: None,
labels: Some(vec![2]), metadata: HashMap::new(),
}];
assert!(collator.validate_labels(&invalid_examples).is_err());
}
}