use smallvec::SmallVec;
use crate::error::PackError;
use crate::sequence::Sequence;
#[derive(Clone, Debug)]
pub struct Bin {
pub id: usize,
pub capacity: usize,
pub used: usize,
pub items: SmallVec<[usize; 8]>,
}
impl Bin {
pub fn new(id: usize, capacity: usize) -> Self {
Self {
id,
capacity,
used: 0,
items: SmallVec::new(),
}
}
pub fn remaining(&self) -> usize {
self.capacity - self.used
}
}
#[derive(Clone, Debug)]
pub struct Pack {
pub sequences: Vec<Sequence>,
pub capacity: usize,
pub(crate) used: usize,
}
impl Pack {
pub fn new(capacity: usize) -> Self {
Self {
sequences: Vec::new(),
capacity,
used: 0,
}
}
pub fn add(&mut self, seq: Sequence) -> Result<(), PackError> {
if seq.length > self.remaining_capacity() {
return Err(PackError::PackFull { length: seq.length });
}
self.used += seq.length;
self.sequences.push(seq);
Ok(())
}
pub fn remaining_capacity(&self) -> usize {
self.capacity.saturating_sub(self.used)
}
pub fn used_capacity(&self) -> usize {
self.used
}
pub fn utilisation(&self) -> f64 {
self.used as f64 / self.capacity as f64
}
pub fn len(&self) -> usize {
self.sequences.len()
}
pub fn is_empty(&self) -> bool {
self.sequences.is_empty()
}
pub fn padding_tokens(&self) -> usize {
self.capacity.saturating_sub(self.used)
}
pub fn cu_seqlens(&self) -> Vec<usize> {
let mut cu = Vec::with_capacity(self.sequences.len() + 1);
cu.push(0);
let mut offset = 0;
for seq in &self.sequences {
offset += seq.length;
cu.push(offset);
}
cu
}
pub fn max_seqlen_in_pack(&self) -> usize {
self.sequences.iter().map(|s| s.length).max().unwrap_or(0)
}
pub fn position_ids(&self) -> Vec<usize> {
let mut ids = Vec::with_capacity(self.used);
for seq in &self.sequences {
ids.extend(0..seq.length);
}
ids
}
pub fn segment_ids(&self) -> Vec<usize> {
let mut ids = Vec::with_capacity(self.used);
for (seq_id, seq) in self.sequences.iter().enumerate() {
ids.extend(std::iter::repeat_n(seq_id, seq.length));
}
ids
}
pub fn attention_mask(&self) -> Vec<bool> {
let n = self.used;
let mut mask = vec![false; n * n];
let mut offset = 0;
for seq in &self.sequences {
for i in 0..seq.length {
for j in 0..=i {
mask[(offset + i) * n + (offset + j)] = true;
}
}
offset += seq.length;
}
mask
}
}
pub fn bins_to_packs(bins: Vec<Bin>, sequences: &[Sequence]) -> Vec<Pack> {
bins.into_iter()
.map(|bin| {
let mut pack_seqs = Vec::with_capacity(bin.items.len());
let mut used = 0;
for &item_id in &bin.items {
let src = &sequences[item_id];
let seq = if src.tokens.is_some() {
src.clone()
} else {
Sequence::new(src.id, src.length)
};
used += seq.length;
pack_seqs.push(seq);
}
Pack {
sequences: pack_seqs,
capacity: bin.capacity,
used,
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sequence::Sequence;
#[test]
fn test_bin_new() {
let bin = Bin::new(0, 100);
assert_eq!(bin.id, 0);
assert_eq!(bin.capacity, 100);
assert_eq!(bin.used, 0);
assert!(bin.items.is_empty());
}
#[test]
fn test_bin_remaining() {
let mut bin = Bin::new(0, 100);
bin.used = 60;
bin.items.push(0);
assert_eq!(bin.remaining(), 40);
}
#[test]
fn test_pack_new() {
let pack = Pack::new(100);
assert_eq!(pack.capacity, 100);
assert_eq!(pack.used_capacity(), 0);
assert_eq!(pack.len(), 0);
assert!(pack.is_empty());
}
#[test]
fn test_pack_add_single() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 60)).unwrap();
assert_eq!(pack.used_capacity(), 60);
assert_eq!(pack.remaining_capacity(), 40);
assert_eq!(pack.len(), 1);
assert!(!pack.is_empty());
}
#[test]
fn test_pack_add_multiple() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 60)).unwrap();
pack.add(Sequence::new(1, 30)).unwrap();
assert_eq!(pack.used_capacity(), 90);
assert_eq!(pack.remaining_capacity(), 10);
assert_eq!(pack.len(), 2);
}
#[test]
fn test_pack_add_exact_fill() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 100)).unwrap();
assert_eq!(pack.used_capacity(), 100);
assert_eq!(pack.remaining_capacity(), 0);
}
#[test]
fn test_pack_full_error() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 60)).unwrap();
let err = pack.add(Sequence::new(1, 50)).unwrap_err();
assert!(matches!(err, PackError::PackFull { length: 50 }));
}
#[test]
fn test_utilisation() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 75)).unwrap();
assert!((pack.utilisation() - 0.75).abs() < f64::EPSILON);
}
#[test]
fn test_padding_tokens() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 60)).unwrap();
assert_eq!(pack.padding_tokens(), 40);
}
#[test]
fn test_cu_seqlens() {
let mut pack = Pack::new(1024);
pack.add(Sequence::new(0, 512)).unwrap();
pack.add(Sequence::new(1, 256)).unwrap();
pack.add(Sequence::new(2, 128)).unwrap();
assert_eq!(pack.cu_seqlens(), vec![0, 512, 768, 896]);
}
#[test]
fn test_cu_seqlens_single_sequence() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 50)).unwrap();
assert_eq!(pack.cu_seqlens(), vec![0, 50]);
}
#[test]
fn test_cu_seqlens_empty_pack() {
let pack = Pack::new(100);
assert_eq!(pack.cu_seqlens(), vec![0]);
}
#[test]
fn test_max_seqlen_in_pack() {
let mut pack = Pack::new(1024);
pack.add(Sequence::new(0, 512)).unwrap();
pack.add(Sequence::new(1, 256)).unwrap();
pack.add(Sequence::new(2, 128)).unwrap();
assert_eq!(pack.max_seqlen_in_pack(), 512);
}
#[test]
fn test_max_seqlen_empty_pack() {
let pack = Pack::new(100);
assert_eq!(pack.max_seqlen_in_pack(), 0);
}
#[test]
fn test_position_ids() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 3)).unwrap();
pack.add(Sequence::new(1, 2)).unwrap();
assert_eq!(pack.position_ids(), vec![0, 1, 2, 0, 1]);
}
#[test]
fn test_position_ids_single_sequence() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 4)).unwrap();
assert_eq!(pack.position_ids(), vec![0, 1, 2, 3]);
}
#[test]
fn test_segment_ids() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 3)).unwrap();
pack.add(Sequence::new(1, 2)).unwrap();
assert_eq!(pack.segment_ids(), vec![0, 0, 0, 1, 1]);
}
#[test]
fn test_segment_ids_three_sequences() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 2)).unwrap();
pack.add(Sequence::new(1, 1)).unwrap();
pack.add(Sequence::new(2, 3)).unwrap();
assert_eq!(pack.segment_ids(), vec![0, 0, 1, 2, 2, 2]);
}
#[test]
fn test_attention_mask_block_diagonal() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 2)).unwrap();
pack.add(Sequence::new(1, 2)).unwrap();
let mask = pack.attention_mask();
#[rustfmt::skip]
let expected = vec![
true, false, false, false,
true, true, false, false,
false, false, true, false,
false, false, true, true,
];
assert_eq!(mask, expected);
}
#[test]
fn test_attention_mask_single_sequence() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 3)).unwrap();
let mask = pack.attention_mask();
#[rustfmt::skip]
let expected = vec![
true, false, false,
true, true, false,
true, true, true,
];
assert_eq!(mask, expected);
}
#[test]
fn test_bins_to_packs() {
let sequences = vec![
Sequence::new(0, 60),
Sequence::new(1, 40),
Sequence::new(2, 50),
];
let bins = vec![
Bin {
id: 0,
capacity: 100,
used: 100,
items: vec![0, 1].into(),
},
Bin {
id: 1,
capacity: 100,
used: 50,
items: vec![2].into(),
},
];
let packs = bins_to_packs(bins, &sequences);
assert_eq!(packs.len(), 2);
assert_eq!(packs[0].len(), 2);
assert_eq!(packs[0].used_capacity(), 100);
assert_eq!(packs[1].len(), 1);
assert_eq!(packs[1].used_capacity(), 50);
}
#[test]
fn test_bins_to_packs_preserves_sequence_data() {
let sequences = vec![
Sequence::with_tokens(0, vec![10, 20, 30]),
Sequence::new(1, 2),
];
let bins = vec![Bin {
id: 0,
capacity: 10,
used: 5,
items: vec![0, 1].into(),
}];
let packs = bins_to_packs(bins, &sequences);
assert_eq!(
packs[0].sequences[0].tokens.as_ref().unwrap(),
&vec![10, 20, 30]
);
assert!(packs[0].sequences[1].tokens.is_none());
}
#[test]
fn test_bins_to_packs_empty() {
let sequences: Vec<Sequence> = vec![];
let bins: Vec<Bin> = vec![];
let packs = bins_to_packs(bins, &sequences);
assert!(packs.is_empty());
}
#[test]
fn test_metadata_lengths_consistent() {
let mut pack = Pack::new(100);
pack.add(Sequence::new(0, 3)).unwrap();
pack.add(Sequence::new(1, 2)).unwrap();
pack.add(Sequence::new(2, 4)).unwrap();
let total = pack.used_capacity();
let cu = pack.cu_seqlens();
assert_eq!(*cu.last().unwrap(), total);
assert_eq!(pack.position_ids().len(), total);
assert_eq!(pack.segment_ids().len(), total);
assert_eq!(pack.attention_mask().len(), total * total);
}
}