Skip to main content

rnn/batching/
batching.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum BatchError {
3    ShapeMismatch,
4    Empty,
5}
6
7pub fn pad_sequences_u32(
8    sequences: &[&[u32]],
9    pad_id: u32,
10    out: &mut [u32],
11    max_len: usize,
12) -> Result<(), BatchError> {
13    if sequences.is_empty() || max_len == 0 {
14        return Err(BatchError::Empty);
15    }
16    let needed = sequences.len().checked_mul(max_len).ok_or(BatchError::ShapeMismatch)?;
17    if out.len() < needed {
18        return Err(BatchError::ShapeMismatch);
19    }
20
21    for b in 0..sequences.len() {
22        let row = &mut out[b * max_len..(b + 1) * max_len];
23        for t in 0..max_len {
24            row[t] = if t < sequences[b].len() { sequences[b][t] } else { pad_id };
25        }
26    }
27    Ok(())
28}
29
30pub fn make_padding_mask(
31    tokens: &[u32],
32    pad_id: u32,
33    out_mask: &mut [u8],
34) -> Result<(), BatchError> {
35    if tokens.is_empty() {
36        return Err(BatchError::Empty);
37    }
38    if out_mask.len() < tokens.len() {
39        return Err(BatchError::ShapeMismatch);
40    }
41
42    for i in 0..tokens.len() {
43        out_mask[i] = if tokens[i] == pad_id { 0 } else { 1 };
44    }
45    Ok(())
46}