use crate::tokenizer::TransformerTokenizer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PaddingStrategy {
NoPadding,
MaxLength,
LongestInBatch,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationStrategy {
NoTruncation,
Right,
Left,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PaddingSide {
Right,
Left,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: TruncationStrategy,
pub pad_token_id: u32,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_length: None,
padding: PaddingStrategy::LongestInBatch,
truncation: TruncationStrategy::Right,
pad_token_id: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct BatchConfigExt {
pub base: BatchConfig,
pub padding_side: PaddingSide,
}
impl Default for BatchConfigExt {
fn default() -> Self {
Self {
base: BatchConfig::default(),
padding_side: PaddingSide::Right,
}
}
}
#[derive(Debug, Clone)]
pub struct BatchEncoding {
pub input_ids: Vec<Vec<u32>>,
pub attention_mask: Vec<Vec<u32>>,
pub lengths: Vec<usize>,
}
impl BatchEncoding {
pub fn batch_size(&self) -> usize {
self.input_ids.len()
}
pub fn seq_length(&self) -> usize {
self.input_ids.first().map_or(0, |v| v.len())
}
pub fn total_real_tokens(&self) -> usize {
self.attention_mask
.iter()
.flat_map(|mask| mask.iter())
.filter(|&&v| v == 1)
.count()
}
}
fn truncate(ids: &[u32], strategy: TruncationStrategy, max_length: usize) -> Vec<u32> {
if ids.len() <= max_length {
return ids.to_vec();
}
match strategy {
TruncationStrategy::NoTruncation => ids.to_vec(),
TruncationStrategy::Right => ids[..max_length].to_vec(),
TruncationStrategy::Left => ids[ids.len() - max_length..].to_vec(),
}
}
fn pad_right(ids: &[u32], target_length: usize, pad_id: u32) -> (Vec<u32>, Vec<u32>) {
let real_len = ids.len();
if real_len >= target_length {
let truncated = &ids[..target_length];
let mask = vec![1u32; target_length];
return (truncated.to_vec(), mask);
}
let mut padded = ids.to_vec();
let mut mask = vec![1u32; real_len];
let pad_count = target_length - real_len;
padded.extend(std::iter::repeat_n(pad_id, pad_count));
mask.extend(std::iter::repeat_n(0u32, pad_count));
(padded, mask)
}
fn pad_left(ids: &[u32], target_length: usize, pad_id: u32) -> (Vec<u32>, Vec<u32>) {
let real_len = ids.len();
if real_len >= target_length {
let start = real_len - target_length;
let truncated = &ids[start..];
let mask = vec![1u32; target_length];
return (truncated.to_vec(), mask);
}
let pad_count = target_length - real_len;
let mut padded: Vec<u32> = std::iter::repeat_n(pad_id, pad_count).collect();
let mut mask: Vec<u32> = std::iter::repeat_n(0u32, pad_count).collect();
padded.extend_from_slice(ids);
mask.extend(std::iter::repeat_n(1u32, real_len));
(padded, mask)
}
pub fn batch_encode<T: TransformerTokenizer>(
texts: &[&str],
tokenizer: &T,
config: &BatchConfig,
) -> BatchEncoding {
if texts.is_empty() {
return BatchEncoding {
input_ids: Vec::new(),
attention_mask: Vec::new(),
lengths: Vec::new(),
};
}
let mut encoded: Vec<Vec<u32>> = texts.iter().map(|t| tokenizer.encode(t)).collect();
let original_lengths: Vec<usize> = encoded.iter().map(|v| v.len()).collect();
if let Some(max_len) = config.max_length {
if config.truncation != TruncationStrategy::NoTruncation {
for seq in &mut encoded {
*seq = truncate(seq, config.truncation, max_len);
}
}
}
let target_length = match config.padding {
PaddingStrategy::NoPadding => {
let attention_mask: Vec<Vec<u32>> =
encoded.iter().map(|seq| vec![1u32; seq.len()]).collect();
return BatchEncoding {
input_ids: encoded,
attention_mask,
lengths: original_lengths,
};
}
PaddingStrategy::MaxLength => config
.max_length
.unwrap_or_else(|| encoded.iter().map(|s| s.len()).max().unwrap_or(0)),
PaddingStrategy::LongestInBatch => {
let longest = encoded.iter().map(|s| s.len()).max().unwrap_or(0);
match config.max_length {
Some(ml) => longest.min(ml),
None => longest,
}
}
};
let mut input_ids = Vec::with_capacity(encoded.len());
let mut attention_mask = Vec::with_capacity(encoded.len());
for seq in &encoded {
let (padded, mask) = pad_right(seq, target_length, config.pad_token_id);
input_ids.push(padded);
attention_mask.push(mask);
}
BatchEncoding {
input_ids,
attention_mask,
lengths: original_lengths,
}
}
pub fn batch_encode_ext<T: TransformerTokenizer>(
texts: &[&str],
tokenizer: &T,
config: &BatchConfigExt,
) -> BatchEncoding {
if texts.is_empty() {
return BatchEncoding {
input_ids: Vec::new(),
attention_mask: Vec::new(),
lengths: Vec::new(),
};
}
let mut encoded: Vec<Vec<u32>> = texts.iter().map(|t| tokenizer.encode(t)).collect();
let original_lengths: Vec<usize> = encoded.iter().map(|v| v.len()).collect();
if let Some(max_len) = config.base.max_length {
if config.base.truncation != TruncationStrategy::NoTruncation {
for seq in &mut encoded {
*seq = truncate(seq, config.base.truncation, max_len);
}
}
}
let target_length = match config.base.padding {
PaddingStrategy::NoPadding => {
let attention_mask: Vec<Vec<u32>> =
encoded.iter().map(|seq| vec![1u32; seq.len()]).collect();
return BatchEncoding {
input_ids: encoded,
attention_mask,
lengths: original_lengths,
};
}
PaddingStrategy::MaxLength => config
.base
.max_length
.unwrap_or_else(|| encoded.iter().map(|s| s.len()).max().unwrap_or(0)),
PaddingStrategy::LongestInBatch => {
let longest = encoded.iter().map(|s| s.len()).max().unwrap_or(0);
match config.base.max_length {
Some(ml) => longest.min(ml),
None => longest,
}
}
};
let pad_fn = match config.padding_side {
PaddingSide::Right => pad_right,
PaddingSide::Left => pad_left,
};
let mut input_ids = Vec::with_capacity(encoded.len());
let mut attention_mask = Vec::with_capacity(encoded.len());
for seq in &encoded {
let (padded, mask) = pad_fn(seq, target_length, config.base.pad_token_id);
input_ids.push(padded);
attention_mask.push(mask);
}
BatchEncoding {
input_ids,
attention_mask,
lengths: original_lengths,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::BPETokenizer;
fn train_tokenizer() -> BPETokenizer {
let corpus = &[
"the cat sat on the mat",
"the dog sat on the log",
"cats and dogs",
"the quick brown fox",
];
BPETokenizer::train(corpus, 100).expect("training should succeed")
}
#[test]
fn test_batch_encode_basic() {
let tok = train_tokenizer();
let texts = &["the cat", "the dog sat"];
let config = BatchConfig {
padding: PaddingStrategy::LongestInBatch,
..Default::default()
};
let batch = batch_encode(texts, &tok, &config);
assert_eq!(batch.batch_size(), 2);
assert_eq!(batch.input_ids[0].len(), batch.input_ids[1].len());
assert_eq!(batch.attention_mask[0].len(), batch.attention_mask[1].len());
}
#[test]
fn test_padding_adds_correct_tokens() {
let tok = train_tokenizer();
let texts = &["the", "the cat sat on the mat"];
let config = BatchConfig {
padding: PaddingStrategy::LongestInBatch,
pad_token_id: 0,
..Default::default()
};
let batch = batch_encode(texts, &tok, &config);
let shorter_len = batch.lengths[0];
let padded_len = batch.input_ids[0].len();
if shorter_len < padded_len {
for i in shorter_len..padded_len {
assert_eq!(
batch.input_ids[0][i], 0,
"padding token should be 0 at position {i}"
);
}
}
}
#[test]
fn test_attention_mask_correct() {
let tok = train_tokenizer();
let texts = &["the", "the cat sat"];
let config = BatchConfig {
padding: PaddingStrategy::LongestInBatch,
pad_token_id: 0,
..Default::default()
};
let batch = batch_encode(texts, &tok, &config);
let shorter_len = batch.lengths[0];
for i in 0..shorter_len.min(batch.attention_mask[0].len()) {
assert_eq!(
batch.attention_mask[0][i], 1,
"real token at {i} should have mask 1"
);
}
for i in shorter_len..batch.attention_mask[0].len() {
assert_eq!(
batch.attention_mask[0][i], 0,
"padding at {i} should have mask 0"
);
}
}
#[test]
fn test_truncation_right() {
let tok = train_tokenizer();
let texts = &["the cat sat on the mat"];
let config = BatchConfig {
max_length: Some(3),
padding: PaddingStrategy::NoPadding,
truncation: TruncationStrategy::Right,
pad_token_id: 0,
};
let batch = batch_encode(texts, &tok, &config);
assert!(
batch.input_ids[0].len() <= 3,
"truncated length should be <= 3, got {}",
batch.input_ids[0].len()
);
}
#[test]
fn test_truncation_left() {
let tok = train_tokenizer();
let texts = &["the cat sat on the mat"];
let config = BatchConfig {
max_length: Some(3),
padding: PaddingStrategy::NoPadding,
truncation: TruncationStrategy::Left,
pad_token_id: 0,
};
let batch = batch_encode(texts, &tok, &config);
assert!(
batch.input_ids[0].len() <= 3,
"truncated length should be <= 3"
);
}
#[test]
fn test_no_padding_varying_lengths() {
let tok = train_tokenizer();
let texts = &["the", "the cat sat"];
let config = BatchConfig {
padding: PaddingStrategy::NoPadding,
truncation: TruncationStrategy::NoTruncation,
..Default::default()
};
let batch = batch_encode(texts, &tok, &config);
assert_eq!(batch.input_ids[0].len(), batch.lengths[0]);
assert_eq!(batch.input_ids[1].len(), batch.lengths[1]);
}
#[test]
fn test_max_length_padding() {
let tok = train_tokenizer();
let texts = &["the"];
let config = BatchConfig {
max_length: Some(10),
padding: PaddingStrategy::MaxLength,
truncation: TruncationStrategy::Right,
pad_token_id: 0,
};
let batch = batch_encode(texts, &tok, &config);
assert_eq!(
batch.input_ids[0].len(),
10,
"should be padded to max_length"
);
}
#[test]
fn test_empty_input() {
let tok = train_tokenizer();
let texts: &[&str] = &[];
let config = BatchConfig::default();
let batch = batch_encode(texts, &tok, &config);
assert_eq!(batch.batch_size(), 0);
}
#[test]
fn test_empty_string_in_batch() {
let tok = train_tokenizer();
let texts = &["", "the cat"];
let config = BatchConfig {
padding: PaddingStrategy::LongestInBatch,
pad_token_id: 0,
..Default::default()
};
let batch = batch_encode(texts, &tok, &config);
assert_eq!(batch.batch_size(), 2);
assert_eq!(batch.lengths[0], 0);
}
#[test]
fn test_left_padding() {
let tok = train_tokenizer();
let texts = &["the", "the cat sat"];
let config = BatchConfigExt {
base: BatchConfig {
padding: PaddingStrategy::LongestInBatch,
pad_token_id: 0,
..Default::default()
},
padding_side: PaddingSide::Left,
};
let batch = batch_encode_ext(texts, &tok, &config);
let shorter_len = batch.lengths[0];
let total_len = batch.input_ids[0].len();
if shorter_len < total_len {
let pad_count = total_len - shorter_len;
for i in 0..pad_count {
assert_eq!(
batch.attention_mask[0][i], 0,
"left padding mask at {i} should be 0"
);
assert_eq!(
batch.input_ids[0][i], 0,
"left padding token at {i} should be pad_id"
);
}
for i in pad_count..total_len {
assert_eq!(
batch.attention_mask[0][i], 1,
"real token mask at {i} should be 1"
);
}
}
}
#[test]
fn test_total_real_tokens() {
let tok = train_tokenizer();
let texts = &["the cat", "the"];
let config = BatchConfig {
padding: PaddingStrategy::LongestInBatch,
pad_token_id: 0,
..Default::default()
};
let batch = batch_encode(texts, &tok, &config);
let total = batch.total_real_tokens();
let expected: usize = batch.lengths.iter().sum();
assert_eq!(
total, expected,
"total real tokens should equal sum of lengths"
);
}
#[test]
fn test_truncation_with_padding() {
let tok = train_tokenizer();
let texts = &["the cat sat on the mat", "the"];
let config = BatchConfig {
max_length: Some(4),
padding: PaddingStrategy::MaxLength,
truncation: TruncationStrategy::Right,
pad_token_id: 0,
};
let batch = batch_encode(texts, &tok, &config);
for seq in &batch.input_ids {
assert_eq!(seq.len(), 4, "all sequences should be padded to max_length");
}
}
}