kizzasi_core/
sequences.rs

1//! Variable-length Sequence Handling
2//!
3//! Provides utilities for efficiently handling sequences of variable length:
4//! - Padding and masking
5//! - Packed sequence representation
6//! - Efficient batch processing
7//! - Length-aware operations
8
9use crate::error::{CoreError, CoreResult};
10use scirs2_core::ndarray::{Array1, Array2, Array3};
11
12/// Padding strategy for variable-length sequences
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum PaddingStrategy {
15    /// Pad sequences to the right (default)
16    Right,
17    /// Pad sequences to the left
18    Left,
19    /// No padding (all sequences must have same length)
20    None,
21}
22
23/// Sequence mask for variable-length batches
24///
25/// Tracks which positions in a batch are valid vs. padding
26#[derive(Debug, Clone)]
27pub struct SequenceMask {
28    /// Boolean mask: true = valid position, false = padding
29    /// Shape: (batch_size, seq_len)
30    mask: Array2<bool>,
31    /// Actual lengths of each sequence in the batch
32    lengths: Array1<usize>,
33    /// Maximum sequence length in the batch
34    max_len: usize,
35}
36
37impl SequenceMask {
38    /// Create a new sequence mask from lengths
39    pub fn from_lengths(lengths: &[usize]) -> CoreResult<Self> {
40        if lengths.is_empty() {
41            return Err(CoreError::InvalidConfig(
42                "Cannot create mask from empty lengths".to_string(),
43            ));
44        }
45
46        let batch_size = lengths.len();
47        let max_len = *lengths.iter().max().unwrap();
48
49        if max_len == 0 {
50            return Err(CoreError::InvalidConfig(
51                "Max length must be greater than 0".to_string(),
52            ));
53        }
54
55        // Create mask array
56        let mut mask = Array2::from_elem((batch_size, max_len), false);
57
58        for (i, &length) in lengths.iter().enumerate() {
59            if length > max_len {
60                return Err(CoreError::InvalidConfig(format!(
61                    "Length {} exceeds max_len {}",
62                    length, max_len
63                )));
64            }
65            for j in 0..length {
66                mask[[i, j]] = true;
67            }
68        }
69
70        let lengths_array = Array1::from_vec(lengths.to_vec());
71
72        Ok(Self {
73            mask,
74            lengths: lengths_array,
75            max_len,
76        })
77    }
78
79    /// Get the boolean mask
80    pub fn mask(&self) -> &Array2<bool> {
81        &self.mask
82    }
83
84    /// Get sequence lengths
85    pub fn lengths(&self) -> &Array1<usize> {
86        &self.lengths
87    }
88
89    /// Get maximum length
90    pub fn max_len(&self) -> usize {
91        self.max_len
92    }
93
94    /// Get batch size
95    pub fn batch_size(&self) -> usize {
96        self.lengths.len()
97    }
98
99    /// Check if a position is valid (not padding)
100    pub fn is_valid(&self, batch_idx: usize, seq_idx: usize) -> bool {
101        if batch_idx >= self.batch_size() || seq_idx >= self.max_len {
102            return false;
103        }
104        self.mask[[batch_idx, seq_idx]]
105    }
106
107    /// Count total number of valid (non-padding) positions
108    pub fn count_valid(&self) -> usize {
109        self.mask.iter().filter(|&&x| x).count()
110    }
111}
112
113/// Packed sequence representation for efficient processing
114///
115/// Stores only the valid (non-padded) elements in a contiguous array
116#[derive(Debug, Clone)]
117pub struct PackedSequence {
118    /// Packed data (only valid elements)
119    /// Shape: (total_valid_elements, feature_dim)
120    data: Array2<f32>,
121    /// Batch indices for each element
122    batch_indices: Array1<usize>,
123    /// Sorted lengths (for efficient unpacking)
124    sorted_lengths: Array1<usize>,
125    /// Original batch size
126    batch_size: usize,
127    /// Feature dimension
128    feature_dim: usize,
129}
130
131impl PackedSequence {
132    /// Pack a batch of variable-length sequences
133    ///
134    /// Input shape: (batch_size, max_seq_len, feature_dim)
135    /// Mask shape: (batch_size, max_seq_len)
136    pub fn pack(sequences: &Array3<f32>, mask: &SequenceMask) -> CoreResult<Self> {
137        let (batch_size, max_seq_len, feature_dim) = sequences.dim();
138
139        if batch_size != mask.batch_size() {
140            return Err(CoreError::DimensionMismatch {
141                expected: mask.batch_size(),
142                got: batch_size,
143            });
144        }
145
146        if max_seq_len != mask.max_len() {
147            return Err(CoreError::DimensionMismatch {
148                expected: mask.max_len(),
149                got: max_seq_len,
150            });
151        }
152
153        let total_valid = mask.count_valid();
154
155        // Allocate packed arrays
156        let mut data = Array2::zeros((total_valid, feature_dim));
157        let mut batch_indices = Array1::zeros(total_valid);
158
159        // Pack data
160        let mut idx = 0;
161        for b in 0..batch_size {
162            let length = mask.lengths()[b];
163            for t in 0..length {
164                // Copy features
165                for f in 0..feature_dim {
166                    data[[idx, f]] = sequences[[b, t, f]];
167                }
168                batch_indices[idx] = b;
169                idx += 1;
170            }
171        }
172
173        Ok(Self {
174            data,
175            batch_indices,
176            sorted_lengths: mask.lengths().clone(),
177            batch_size,
178            feature_dim,
179        })
180    }
181
182    /// Unpack back to padded batch format
183    ///
184    /// Output shape: (batch_size, max_seq_len, feature_dim)
185    pub fn unpack(&self, padding_value: f32) -> CoreResult<Array3<f32>> {
186        let max_len = *self.sorted_lengths.iter().max().unwrap();
187        let mut output =
188            Array3::from_elem((self.batch_size, max_len, self.feature_dim), padding_value);
189
190        let mut idx = 0;
191        for b in 0..self.batch_size {
192            let length = self.sorted_lengths[b];
193            for t in 0..length {
194                for f in 0..self.feature_dim {
195                    output[[b, t, f]] = self.data[[idx, f]];
196                }
197                idx += 1;
198            }
199        }
200
201        Ok(output)
202    }
203
204    /// Get packed data
205    pub fn data(&self) -> &Array2<f32> {
206        &self.data
207    }
208
209    /// Get batch indices
210    pub fn batch_indices(&self) -> &Array1<usize> {
211        &self.batch_indices
212    }
213
214    /// Get total number of valid elements
215    pub fn num_elements(&self) -> usize {
216        self.data.nrows()
217    }
218}
219
220/// Pad sequences to the same length
221///
222/// Input: Vec of sequences with shape (seq_len, feature_dim)
223/// Output: Padded array with shape (batch_size, max_seq_len, feature_dim) and mask
224pub fn pad_sequences(
225    sequences: &[Array2<f32>],
226    padding_value: f32,
227    strategy: PaddingStrategy,
228) -> CoreResult<(Array3<f32>, SequenceMask)> {
229    if sequences.is_empty() {
230        return Err(CoreError::InvalidConfig(
231            "Cannot pad empty sequence list".to_string(),
232        ));
233    }
234
235    let batch_size = sequences.len();
236    let feature_dim = sequences[0].ncols();
237
238    // Collect lengths and find max
239    let lengths: Vec<usize> = sequences.iter().map(|s| s.nrows()).collect();
240    let max_len = *lengths.iter().max().unwrap();
241
242    // Check feature dimensions match
243    for (i, seq) in sequences.iter().enumerate() {
244        if seq.ncols() != feature_dim {
245            return Err(CoreError::InvalidConfig(format!(
246                "Feature dimension mismatch at index {}: expected {}, got {}",
247                i,
248                feature_dim,
249                seq.ncols()
250            )));
251        }
252    }
253
254    // Create padded array
255    let mut padded = Array3::from_elem((batch_size, max_len, feature_dim), padding_value);
256
257    // Fill in sequences based on padding strategy
258    for (b, seq) in sequences.iter().enumerate() {
259        let seq_len = seq.nrows();
260
261        match strategy {
262            PaddingStrategy::Right => {
263                // Pad on the right (default)
264                for t in 0..seq_len {
265                    for f in 0..feature_dim {
266                        padded[[b, t, f]] = seq[[t, f]];
267                    }
268                }
269            }
270            PaddingStrategy::Left => {
271                // Pad on the left
272                let offset = max_len - seq_len;
273                for t in 0..seq_len {
274                    for f in 0..feature_dim {
275                        padded[[b, offset + t, f]] = seq[[t, f]];
276                    }
277                }
278            }
279            PaddingStrategy::None => {
280                if seq_len != max_len {
281                    return Err(CoreError::InvalidConfig(format!(
282                        "Sequence {} has length {} but max_len is {}. Use padding strategy.",
283                        b, seq_len, max_len
284                    )));
285                }
286                for t in 0..seq_len {
287                    for f in 0..feature_dim {
288                        padded[[b, t, f]] = seq[[t, f]];
289                    }
290                }
291            }
292        }
293    }
294
295    // Create mask
296    let mask = SequenceMask::from_lengths(&lengths)?;
297
298    Ok((padded, mask))
299}
300
301/// Apply mask to a tensor by zeroing out padding positions
302pub fn apply_mask(tensor: &mut Array3<f32>, mask: &SequenceMask, mask_value: f32) {
303    let (batch_size, seq_len, feature_dim) = tensor.dim();
304
305    for b in 0..batch_size {
306        for t in 0..seq_len {
307            if !mask.is_valid(b, t) {
308                for f in 0..feature_dim {
309                    tensor[[b, t, f]] = mask_value;
310                }
311            }
312        }
313    }
314}
315
316/// Compute sequence-aware mean (ignoring padding)
317pub fn masked_mean(tensor: &Array3<f32>, mask: &SequenceMask) -> CoreResult<Array2<f32>> {
318    let (batch_size, seq_len, feature_dim) = tensor.dim();
319
320    if batch_size != mask.batch_size() {
321        return Err(CoreError::DimensionMismatch {
322            expected: mask.batch_size(),
323            got: batch_size,
324        });
325    }
326
327    let mut result = Array2::zeros((batch_size, feature_dim));
328
329    for b in 0..batch_size {
330        let length = mask.lengths()[b] as f32;
331        if length == 0.0 {
332            continue;
333        }
334
335        for t in 0..seq_len {
336            if mask.is_valid(b, t) {
337                for f in 0..feature_dim {
338                    result[[b, f]] += tensor[[b, t, f]] / length;
339                }
340            }
341        }
342    }
343
344    Ok(result)
345}
346
347/// Compute sequence-aware sum (ignoring padding)
348pub fn masked_sum(tensor: &Array3<f32>, mask: &SequenceMask) -> CoreResult<Array2<f32>> {
349    let (batch_size, seq_len, feature_dim) = tensor.dim();
350
351    if batch_size != mask.batch_size() {
352        return Err(CoreError::DimensionMismatch {
353            expected: mask.batch_size(),
354            got: batch_size,
355        });
356    }
357
358    let mut result = Array2::zeros((batch_size, feature_dim));
359
360    for b in 0..batch_size {
361        for t in 0..seq_len {
362            if mask.is_valid(b, t) {
363                for f in 0..feature_dim {
364                    result[[b, f]] += tensor[[b, t, f]];
365                }
366            }
367        }
368    }
369
370    Ok(result)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_sequence_mask() {
379        let lengths = vec![3, 5, 2];
380        let mask = SequenceMask::from_lengths(&lengths).unwrap();
381
382        assert_eq!(mask.batch_size(), 3);
383        assert_eq!(mask.max_len(), 5);
384        assert_eq!(mask.count_valid(), 10); // 3 + 5 + 2
385
386        // Check specific positions
387        assert!(mask.is_valid(0, 0));
388        assert!(mask.is_valid(0, 2));
389        assert!(!mask.is_valid(0, 3));
390
391        assert!(mask.is_valid(1, 4));
392        assert!(!mask.is_valid(1, 5));
393
394        assert!(mask.is_valid(2, 1));
395        assert!(!mask.is_valid(2, 2));
396    }
397
398    #[test]
399    fn test_pad_sequences() {
400        let seq1 = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
401        let seq2 = Array2::from_shape_vec((4, 3), vec![2.0; 12]).unwrap();
402        let seq3 = Array2::from_shape_vec((3, 3), vec![3.0; 9]).unwrap();
403
404        let sequences = vec![seq1, seq2, seq3];
405        let (padded, mask) = pad_sequences(&sequences, 0.0, PaddingStrategy::Right).unwrap();
406
407        assert_eq!(padded.dim(), (3, 4, 3)); // batch_size=3, max_len=4, feature_dim=3
408        assert_eq!(mask.max_len(), 4);
409        assert_eq!(mask.lengths()[0], 2);
410        assert_eq!(mask.lengths()[1], 4);
411        assert_eq!(mask.lengths()[2], 3);
412
413        // Check values
414        assert_eq!(padded[[0, 0, 0]], 1.0);
415        assert_eq!(padded[[0, 2, 0]], 0.0); // padding
416
417        assert_eq!(padded[[1, 3, 0]], 2.0);
418        assert_eq!(padded[[2, 2, 0]], 3.0);
419    }
420
421    #[test]
422    fn test_packed_sequence() {
423        let lengths = vec![2, 3, 1];
424        let mask = SequenceMask::from_lengths(&lengths).unwrap();
425
426        let mut sequences = Array3::zeros((3, 3, 2)); // batch=3, max_len=3, features=2
427                                                      // Fill with test data
428        for b in 0..3 {
429            for t in 0..lengths[b] {
430                for f in 0..2 {
431                    sequences[[b, t, f]] = (b * 10 + t) as f32;
432                }
433            }
434        }
435
436        let packed = PackedSequence::pack(&sequences, &mask).unwrap();
437        assert_eq!(packed.num_elements(), 6); // 2 + 3 + 1
438
439        let unpacked = packed.unpack(0.0).unwrap();
440        assert_eq!(unpacked.dim(), (3, 3, 2));
441
442        // Check that valid positions match
443        for b in 0..3 {
444            for t in 0..lengths[b] {
445                for f in 0..2 {
446                    assert_eq!(sequences[[b, t, f]], unpacked[[b, t, f]]);
447                }
448            }
449        }
450    }
451
452    #[test]
453    fn test_masked_mean() {
454        let lengths = vec![2, 3];
455        let mask = SequenceMask::from_lengths(&lengths).unwrap();
456
457        let mut sequences = Array3::zeros((2, 3, 2));
458        // Batch 0: [[1, 1], [2, 2], [0, 0]] with length 2 -> mean = [1.5, 1.5]
459        sequences[[0, 0, 0]] = 1.0;
460        sequences[[0, 0, 1]] = 1.0;
461        sequences[[0, 1, 0]] = 2.0;
462        sequences[[0, 1, 1]] = 2.0;
463
464        // Batch 1: [[3, 3], [4, 4], [5, 5]] with length 3 -> mean = [4, 4]
465        sequences[[1, 0, 0]] = 3.0;
466        sequences[[1, 0, 1]] = 3.0;
467        sequences[[1, 1, 0]] = 4.0;
468        sequences[[1, 1, 1]] = 4.0;
469        sequences[[1, 2, 0]] = 5.0;
470        sequences[[1, 2, 1]] = 5.0;
471
472        let mean = masked_mean(&sequences, &mask).unwrap();
473
474        assert!((mean[[0, 0]] - 1.5).abs() < 1e-6);
475        assert!((mean[[0, 1]] - 1.5).abs() < 1e-6);
476        assert!((mean[[1, 0]] - 4.0).abs() < 1e-6);
477        assert!((mean[[1, 1]] - 4.0).abs() < 1e-6);
478    }
479
480    #[test]
481    fn test_apply_mask() {
482        let lengths = vec![2, 1];
483        let mask = SequenceMask::from_lengths(&lengths).unwrap();
484
485        let mut sequences = Array3::from_elem((2, 3, 2), 1.0);
486        apply_mask(&mut sequences, &mask, 0.0);
487
488        // Check that padding positions are zeroed
489        assert_eq!(sequences[[0, 0, 0]], 1.0);
490        assert_eq!(sequences[[0, 1, 0]], 1.0);
491        assert_eq!(sequences[[0, 2, 0]], 0.0); // padding
492
493        assert_eq!(sequences[[1, 0, 0]], 1.0);
494        assert_eq!(sequences[[1, 1, 0]], 0.0); // padding
495        assert_eq!(sequences[[1, 2, 0]], 0.0); // padding
496    }
497}