Skip to main content

entrenar/train/transformer_trainer/
batch.rs

1//! Language model batch utilities
2
3/// A batch of tokenized sequences for language model training.
4///
5/// # Memory Layout (ALB-100)
6///
7/// For causal LM, `target[i] = input[i+1]` — storing both wastes 50%.
8/// `LMBatch` deduplicates by storing a single `tokens` buffer:
9///
10/// - **Shared layout** (`stride = seq_len + 1`): Used when all sequences are
11///   the same length (the production pre-tokenized path). Input and target are
12///   derived as overlapping slices with a 1-token offset.
13///
14/// - **Split layout** (`stride = 0`): Used when sequences have different
15///   lengths (padding breaks the shift invariant). Stores `[input_ids...,
16///   target_ids...]` concatenated, matching the legacy layout.
17#[derive(Debug, Clone)]
18pub struct LMBatch {
19    /// Token storage. Layout depends on `stride`:
20    /// - stride > 0 (shared): batch_size * stride tokens, input/target overlap
21    /// - stride == 0 (split): batch_size * seq_len * 2 (input then target)
22    tokens: Vec<u32>,
23    /// Batch size
24    pub batch_size: usize,
25    /// Sequence length (tokens per input/target per batch item)
26    pub seq_len: usize,
27    /// Stride between batch items in shared layout, or 0 for split layout.
28    stride: usize,
29}
30
31impl LMBatch {
32    /// Create a new LM batch from token sequences.
33    ///
34    /// For causal LM, targets are inputs shifted by 1:
35    /// ```text
36    /// input:  [BOS, A, B, C, D]
37    /// target: [A, B, C, D, EOS]
38    /// ```
39    ///
40    /// When all sequences have the same length, uses shared layout (ALB-100):
41    /// one buffer of `batch_size * (seq_len + 1)` tokens, saving ~50% memory.
42    ///
43    /// When sequences differ in length, falls back to split layout (padding
44    /// breaks the overlap invariant at sequence boundaries).
45    pub fn from_sequences(sequences: &[Vec<u32>], pad_id: u32, eos_id: u32) -> Self {
46        if sequences.is_empty() {
47            return Self { tokens: Vec::new(), batch_size: 0, seq_len: 0, stride: 0 };
48        }
49
50        let batch_size = sequences.len();
51        let max_len = sequences.iter().map(Vec::len).max().unwrap_or(0);
52        let seq_len = max_len.saturating_sub(1).max(1);
53
54        // Check if all sequences have the same length — enables shared layout
55        let uniform = sequences.iter().all(|s| s.len() == max_len);
56
57        if uniform {
58            // Shared layout: store raw tokens, input/target derived by offset
59            let stride = seq_len + 1; // = max_len
60            let mut tokens = Vec::with_capacity(batch_size * stride);
61
62            for seq in sequences {
63                // Copy the raw sequence (all max_len tokens)
64                tokens.extend_from_slice(seq);
65            }
66
67            Self { tokens, batch_size, seq_len, stride }
68        } else {
69            // Split layout: separate input_ids then target_ids (legacy)
70            let mut tokens = Vec::with_capacity(batch_size * seq_len * 2);
71
72            // First half: input_ids
73            for seq in sequences {
74                for i in 0..seq_len {
75                    if i < seq.len() - 1 {
76                        tokens.push(seq[i]);
77                    } else {
78                        tokens.push(pad_id);
79                    }
80                }
81            }
82
83            // Second half: target_ids
84            for seq in sequences {
85                for i in 0..seq_len {
86                    match (i + 1).cmp(&seq.len()) {
87                        std::cmp::Ordering::Less => tokens.push(seq[i + 1]),
88                        std::cmp::Ordering::Equal => tokens.push(eos_id),
89                        std::cmp::Ordering::Greater => tokens.push(pad_id),
90                    }
91                }
92            }
93
94            Self { tokens, batch_size, seq_len, stride: 0 }
95        }
96    }
97
98    /// Create a batch from a single sequence pair (for testing).
99    ///
100    /// Uses split layout since caller provides separate input/target vecs.
101    pub fn single(input_ids: Vec<u32>, target_ids: Vec<u32>) -> Self {
102        let seq_len = input_ids.len();
103        let mut tokens = Vec::with_capacity(seq_len * 2);
104        tokens.extend_from_slice(&input_ids);
105        tokens.extend_from_slice(&target_ids);
106        Self { tokens, batch_size: 1, seq_len, stride: 0 }
107    }
108
109    /// Get input IDs for a specific batch item.
110    pub fn get_input(&self, batch_idx: usize) -> Option<&[u32]> {
111        if batch_idx >= self.batch_size {
112            return None;
113        }
114        if self.stride > 0 {
115            // Shared layout: input is tokens[b*stride .. b*stride + seq_len]
116            let start = batch_idx * self.stride;
117            Some(&self.tokens[start..start + self.seq_len])
118        } else {
119            // Split layout: first half is input_ids
120            let start = batch_idx * self.seq_len;
121            Some(&self.tokens[start..start + self.seq_len])
122        }
123    }
124
125    /// Get target IDs for a specific batch item.
126    pub fn get_target(&self, batch_idx: usize) -> Option<&[u32]> {
127        if batch_idx >= self.batch_size {
128            return None;
129        }
130        if self.stride > 0 {
131            // Shared layout: target is tokens[b*stride + 1 .. b*stride + 1 + seq_len]
132            let start = batch_idx * self.stride + 1;
133            Some(&self.tokens[start..start + self.seq_len])
134        } else {
135            // Split layout: second half is target_ids
136            let offset = self.batch_size * self.seq_len;
137            let start = offset + batch_idx * self.seq_len;
138            Some(&self.tokens[start..start + self.seq_len])
139        }
140    }
141
142    /// Total number of tokens in batch (input side).
143    pub fn num_tokens(&self) -> usize {
144        self.batch_size * self.seq_len
145    }
146
147    /// Returns true if this batch uses shared (deduplicated) token storage.
148    #[cfg(test)]
149    pub fn is_shared_layout(&self) -> bool {
150        self.stride > 0
151    }
152
153    /// Returns true if the token buffer is non-empty.
154    pub fn has_tokens(&self) -> bool {
155        !self.tokens.is_empty()
156    }
157}