pub fn causal_mask(seq: usize) -> Vec<f32> {
let mut m = vec![0.0f32; seq * seq];
for qi in 0..seq {
for ki in 0..=qi {
m[qi * seq + ki] = 1.0;
}
}
m
}
pub fn sliding_window_mask(seq: usize, window: usize) -> Vec<f32> {
let mut m = vec![0.0f32; seq * seq];
for qi in 0..seq {
let lo = qi.saturating_sub(window);
for ki in lo..=qi {
m[qi * seq + ki] = 1.0;
}
}
m
}
pub fn padding_mask(lengths: &[usize], max_seq: usize) -> Vec<f32> {
let mut m = vec![0.0f32; lengths.len() * max_seq];
for (i, &len) in lengths.iter().enumerate() {
let n = len.min(max_seq);
for j in 0..n {
m[i * max_seq + j] = 1.0;
}
}
m
}
pub fn collate_padded_f32(rows: &[Vec<f32>], pad_value: f32) -> (Vec<f32>, Vec<usize>) {
let lengths: Vec<usize> = rows.iter().map(|r| r.len()).collect();
let max_seq = lengths.iter().copied().max().unwrap_or(0);
let mut out = vec![pad_value; rows.len() * max_seq];
for (i, r) in rows.iter().enumerate() {
out[i * max_seq..i * max_seq + r.len()].copy_from_slice(r);
}
(out, lengths)
}
pub fn collate_padded_i64(rows: &[Vec<i64>], pad_value: i64) -> (Vec<i64>, Vec<usize>) {
let lengths: Vec<usize> = rows.iter().map(|r| r.len()).collect();
let max_seq = lengths.iter().copied().max().unwrap_or(0);
let mut out = vec![pad_value; rows.len() * max_seq];
for (i, r) in rows.iter().enumerate() {
out[i * max_seq..i * max_seq + r.len()].copy_from_slice(r);
}
(out, lengths)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn causal_is_lower_triangular() {
let m = causal_mask(4);
assert_eq!(
m,
vec![
1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0,
]
);
}
#[test]
fn sliding_window_band() {
let m = sliding_window_mask(5, 1);
assert_eq!(
m,
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0,
]
);
}
#[test]
fn padding_zeros_after_length() {
let m = padding_mask(&[3, 1, 2], 4);
assert_eq!(
m,
vec![1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,]
);
}
#[test]
fn collate_pads_to_longest() {
let rows = vec![vec![1, 2, 3], vec![4], vec![5, 6]];
let (flat, lens) = collate_padded_i64(&rows, 0);
assert_eq!(lens, vec![3, 1, 2]);
assert_eq!(flat, vec![1, 2, 3, 4, 0, 0, 5, 6, 0]);
}
}