1use crate::error::{CoreError, CoreResult};
10use scirs2_core::ndarray::{Array1, Array2, Array3};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum PaddingStrategy {
15 Right,
17 Left,
19 None,
21}
22
23#[derive(Debug, Clone)]
27pub struct SequenceMask {
28 mask: Array2<bool>,
31 lengths: Array1<usize>,
33 max_len: usize,
35}
36
37impl SequenceMask {
38 pub fn from_lengths(lengths: &[usize]) -> CoreResult<Self> {
40 if lengths.is_empty() {
41 return Err(CoreError::InvalidConfig(
42 "Cannot create mask from empty lengths".to_string(),
43 ));
44 }
45
46 let batch_size = lengths.len();
47 let max_len = *lengths.iter().max().unwrap();
48
49 if max_len == 0 {
50 return Err(CoreError::InvalidConfig(
51 "Max length must be greater than 0".to_string(),
52 ));
53 }
54
55 let mut mask = Array2::from_elem((batch_size, max_len), false);
57
58 for (i, &length) in lengths.iter().enumerate() {
59 if length > max_len {
60 return Err(CoreError::InvalidConfig(format!(
61 "Length {} exceeds max_len {}",
62 length, max_len
63 )));
64 }
65 for j in 0..length {
66 mask[[i, j]] = true;
67 }
68 }
69
70 let lengths_array = Array1::from_vec(lengths.to_vec());
71
72 Ok(Self {
73 mask,
74 lengths: lengths_array,
75 max_len,
76 })
77 }
78
79 pub fn mask(&self) -> &Array2<bool> {
81 &self.mask
82 }
83
84 pub fn lengths(&self) -> &Array1<usize> {
86 &self.lengths
87 }
88
89 pub fn max_len(&self) -> usize {
91 self.max_len
92 }
93
94 pub fn batch_size(&self) -> usize {
96 self.lengths.len()
97 }
98
99 pub fn is_valid(&self, batch_idx: usize, seq_idx: usize) -> bool {
101 if batch_idx >= self.batch_size() || seq_idx >= self.max_len {
102 return false;
103 }
104 self.mask[[batch_idx, seq_idx]]
105 }
106
107 pub fn count_valid(&self) -> usize {
109 self.mask.iter().filter(|&&x| x).count()
110 }
111}
112
113#[derive(Debug, Clone)]
117pub struct PackedSequence {
118 data: Array2<f32>,
121 batch_indices: Array1<usize>,
123 sorted_lengths: Array1<usize>,
125 batch_size: usize,
127 feature_dim: usize,
129}
130
131impl PackedSequence {
132 pub fn pack(sequences: &Array3<f32>, mask: &SequenceMask) -> CoreResult<Self> {
137 let (batch_size, max_seq_len, feature_dim) = sequences.dim();
138
139 if batch_size != mask.batch_size() {
140 return Err(CoreError::DimensionMismatch {
141 expected: mask.batch_size(),
142 got: batch_size,
143 });
144 }
145
146 if max_seq_len != mask.max_len() {
147 return Err(CoreError::DimensionMismatch {
148 expected: mask.max_len(),
149 got: max_seq_len,
150 });
151 }
152
153 let total_valid = mask.count_valid();
154
155 let mut data = Array2::zeros((total_valid, feature_dim));
157 let mut batch_indices = Array1::zeros(total_valid);
158
159 let mut idx = 0;
161 for b in 0..batch_size {
162 let length = mask.lengths()[b];
163 for t in 0..length {
164 for f in 0..feature_dim {
166 data[[idx, f]] = sequences[[b, t, f]];
167 }
168 batch_indices[idx] = b;
169 idx += 1;
170 }
171 }
172
173 Ok(Self {
174 data,
175 batch_indices,
176 sorted_lengths: mask.lengths().clone(),
177 batch_size,
178 feature_dim,
179 })
180 }
181
182 pub fn unpack(&self, padding_value: f32) -> CoreResult<Array3<f32>> {
186 let max_len = *self.sorted_lengths.iter().max().unwrap();
187 let mut output =
188 Array3::from_elem((self.batch_size, max_len, self.feature_dim), padding_value);
189
190 let mut idx = 0;
191 for b in 0..self.batch_size {
192 let length = self.sorted_lengths[b];
193 for t in 0..length {
194 for f in 0..self.feature_dim {
195 output[[b, t, f]] = self.data[[idx, f]];
196 }
197 idx += 1;
198 }
199 }
200
201 Ok(output)
202 }
203
204 pub fn data(&self) -> &Array2<f32> {
206 &self.data
207 }
208
209 pub fn batch_indices(&self) -> &Array1<usize> {
211 &self.batch_indices
212 }
213
214 pub fn num_elements(&self) -> usize {
216 self.data.nrows()
217 }
218}
219
220pub fn pad_sequences(
225 sequences: &[Array2<f32>],
226 padding_value: f32,
227 strategy: PaddingStrategy,
228) -> CoreResult<(Array3<f32>, SequenceMask)> {
229 if sequences.is_empty() {
230 return Err(CoreError::InvalidConfig(
231 "Cannot pad empty sequence list".to_string(),
232 ));
233 }
234
235 let batch_size = sequences.len();
236 let feature_dim = sequences[0].ncols();
237
238 let lengths: Vec<usize> = sequences.iter().map(|s| s.nrows()).collect();
240 let max_len = *lengths.iter().max().unwrap();
241
242 for (i, seq) in sequences.iter().enumerate() {
244 if seq.ncols() != feature_dim {
245 return Err(CoreError::InvalidConfig(format!(
246 "Feature dimension mismatch at index {}: expected {}, got {}",
247 i,
248 feature_dim,
249 seq.ncols()
250 )));
251 }
252 }
253
254 let mut padded = Array3::from_elem((batch_size, max_len, feature_dim), padding_value);
256
257 for (b, seq) in sequences.iter().enumerate() {
259 let seq_len = seq.nrows();
260
261 match strategy {
262 PaddingStrategy::Right => {
263 for t in 0..seq_len {
265 for f in 0..feature_dim {
266 padded[[b, t, f]] = seq[[t, f]];
267 }
268 }
269 }
270 PaddingStrategy::Left => {
271 let offset = max_len - seq_len;
273 for t in 0..seq_len {
274 for f in 0..feature_dim {
275 padded[[b, offset + t, f]] = seq[[t, f]];
276 }
277 }
278 }
279 PaddingStrategy::None => {
280 if seq_len != max_len {
281 return Err(CoreError::InvalidConfig(format!(
282 "Sequence {} has length {} but max_len is {}. Use padding strategy.",
283 b, seq_len, max_len
284 )));
285 }
286 for t in 0..seq_len {
287 for f in 0..feature_dim {
288 padded[[b, t, f]] = seq[[t, f]];
289 }
290 }
291 }
292 }
293 }
294
295 let mask = SequenceMask::from_lengths(&lengths)?;
297
298 Ok((padded, mask))
299}
300
301pub fn apply_mask(tensor: &mut Array3<f32>, mask: &SequenceMask, mask_value: f32) {
303 let (batch_size, seq_len, feature_dim) = tensor.dim();
304
305 for b in 0..batch_size {
306 for t in 0..seq_len {
307 if !mask.is_valid(b, t) {
308 for f in 0..feature_dim {
309 tensor[[b, t, f]] = mask_value;
310 }
311 }
312 }
313 }
314}
315
316pub fn masked_mean(tensor: &Array3<f32>, mask: &SequenceMask) -> CoreResult<Array2<f32>> {
318 let (batch_size, seq_len, feature_dim) = tensor.dim();
319
320 if batch_size != mask.batch_size() {
321 return Err(CoreError::DimensionMismatch {
322 expected: mask.batch_size(),
323 got: batch_size,
324 });
325 }
326
327 let mut result = Array2::zeros((batch_size, feature_dim));
328
329 for b in 0..batch_size {
330 let length = mask.lengths()[b] as f32;
331 if length == 0.0 {
332 continue;
333 }
334
335 for t in 0..seq_len {
336 if mask.is_valid(b, t) {
337 for f in 0..feature_dim {
338 result[[b, f]] += tensor[[b, t, f]] / length;
339 }
340 }
341 }
342 }
343
344 Ok(result)
345}
346
347pub fn masked_sum(tensor: &Array3<f32>, mask: &SequenceMask) -> CoreResult<Array2<f32>> {
349 let (batch_size, seq_len, feature_dim) = tensor.dim();
350
351 if batch_size != mask.batch_size() {
352 return Err(CoreError::DimensionMismatch {
353 expected: mask.batch_size(),
354 got: batch_size,
355 });
356 }
357
358 let mut result = Array2::zeros((batch_size, feature_dim));
359
360 for b in 0..batch_size {
361 for t in 0..seq_len {
362 if mask.is_valid(b, t) {
363 for f in 0..feature_dim {
364 result[[b, f]] += tensor[[b, t, f]];
365 }
366 }
367 }
368 }
369
370 Ok(result)
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_sequence_mask() {
379 let lengths = vec![3, 5, 2];
380 let mask = SequenceMask::from_lengths(&lengths).unwrap();
381
382 assert_eq!(mask.batch_size(), 3);
383 assert_eq!(mask.max_len(), 5);
384 assert_eq!(mask.count_valid(), 10); assert!(mask.is_valid(0, 0));
388 assert!(mask.is_valid(0, 2));
389 assert!(!mask.is_valid(0, 3));
390
391 assert!(mask.is_valid(1, 4));
392 assert!(!mask.is_valid(1, 5));
393
394 assert!(mask.is_valid(2, 1));
395 assert!(!mask.is_valid(2, 2));
396 }
397
398 #[test]
399 fn test_pad_sequences() {
400 let seq1 = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
401 let seq2 = Array2::from_shape_vec((4, 3), vec![2.0; 12]).unwrap();
402 let seq3 = Array2::from_shape_vec((3, 3), vec![3.0; 9]).unwrap();
403
404 let sequences = vec![seq1, seq2, seq3];
405 let (padded, mask) = pad_sequences(&sequences, 0.0, PaddingStrategy::Right).unwrap();
406
407 assert_eq!(padded.dim(), (3, 4, 3)); assert_eq!(mask.max_len(), 4);
409 assert_eq!(mask.lengths()[0], 2);
410 assert_eq!(mask.lengths()[1], 4);
411 assert_eq!(mask.lengths()[2], 3);
412
413 assert_eq!(padded[[0, 0, 0]], 1.0);
415 assert_eq!(padded[[0, 2, 0]], 0.0); assert_eq!(padded[[1, 3, 0]], 2.0);
418 assert_eq!(padded[[2, 2, 0]], 3.0);
419 }
420
421 #[test]
422 fn test_packed_sequence() {
423 let lengths = vec![2, 3, 1];
424 let mask = SequenceMask::from_lengths(&lengths).unwrap();
425
426 let mut sequences = Array3::zeros((3, 3, 2)); for b in 0..3 {
429 for t in 0..lengths[b] {
430 for f in 0..2 {
431 sequences[[b, t, f]] = (b * 10 + t) as f32;
432 }
433 }
434 }
435
436 let packed = PackedSequence::pack(&sequences, &mask).unwrap();
437 assert_eq!(packed.num_elements(), 6); let unpacked = packed.unpack(0.0).unwrap();
440 assert_eq!(unpacked.dim(), (3, 3, 2));
441
442 for b in 0..3 {
444 for t in 0..lengths[b] {
445 for f in 0..2 {
446 assert_eq!(sequences[[b, t, f]], unpacked[[b, t, f]]);
447 }
448 }
449 }
450 }
451
452 #[test]
453 fn test_masked_mean() {
454 let lengths = vec![2, 3];
455 let mask = SequenceMask::from_lengths(&lengths).unwrap();
456
457 let mut sequences = Array3::zeros((2, 3, 2));
458 sequences[[0, 0, 0]] = 1.0;
460 sequences[[0, 0, 1]] = 1.0;
461 sequences[[0, 1, 0]] = 2.0;
462 sequences[[0, 1, 1]] = 2.0;
463
464 sequences[[1, 0, 0]] = 3.0;
466 sequences[[1, 0, 1]] = 3.0;
467 sequences[[1, 1, 0]] = 4.0;
468 sequences[[1, 1, 1]] = 4.0;
469 sequences[[1, 2, 0]] = 5.0;
470 sequences[[1, 2, 1]] = 5.0;
471
472 let mean = masked_mean(&sequences, &mask).unwrap();
473
474 assert!((mean[[0, 0]] - 1.5).abs() < 1e-6);
475 assert!((mean[[0, 1]] - 1.5).abs() < 1e-6);
476 assert!((mean[[1, 0]] - 4.0).abs() < 1e-6);
477 assert!((mean[[1, 1]] - 4.0).abs() < 1e-6);
478 }
479
480 #[test]
481 fn test_apply_mask() {
482 let lengths = vec![2, 1];
483 let mask = SequenceMask::from_lengths(&lengths).unwrap();
484
485 let mut sequences = Array3::from_elem((2, 3, 2), 1.0);
486 apply_mask(&mut sequences, &mask, 0.0);
487
488 assert_eq!(sequences[[0, 0, 0]], 1.0);
490 assert_eq!(sequences[[0, 1, 0]], 1.0);
491 assert_eq!(sequences[[0, 2, 0]], 0.0); assert_eq!(sequences[[1, 0, 0]], 1.0);
494 assert_eq!(sequences[[1, 1, 0]], 0.0); assert_eq!(sequences[[1, 2, 0]], 0.0); }
497}