1 2 3 4 5 6 7
pub fn apply_causal_mask_row(scores_row: &mut [f32], query_idx: usize) { for (k_idx, score) in scores_row.iter_mut().enumerate() { if k_idx > query_idx { *score = -1.0e9; } } }