1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum BatchError {
3 ShapeMismatch,
4 Empty,
5}
6
7pub fn pad_sequences_u32(
8 sequences: &[&[u32]],
9 pad_id: u32,
10 out: &mut [u32],
11 max_len: usize,
12) -> Result<(), BatchError> {
13 if sequences.is_empty() || max_len == 0 {
14 return Err(BatchError::Empty);
15 }
16 let needed = sequences.len().checked_mul(max_len).ok_or(BatchError::ShapeMismatch)?;
17 if out.len() < needed {
18 return Err(BatchError::ShapeMismatch);
19 }
20
21 for b in 0..sequences.len() {
22 let row = &mut out[b * max_len..(b + 1) * max_len];
23 for t in 0..max_len {
24 row[t] = if t < sequences[b].len() { sequences[b][t] } else { pad_id };
25 }
26 }
27 Ok(())
28}
29
30pub fn make_padding_mask(
31 tokens: &[u32],
32 pad_id: u32,
33 out_mask: &mut [u8],
34) -> Result<(), BatchError> {
35 if tokens.is_empty() {
36 return Err(BatchError::Empty);
37 }
38 if out_mask.len() < tokens.len() {
39 return Err(BatchError::ShapeMismatch);
40 }
41
42 for i in 0..tokens.len() {
43 out_mask[i] = if tokens[i] == pad_id { 0 } else { 1 };
44 }
45 Ok(())
46}