pub fn apply_causal_mask_row_f32(scores_row: &mut [f32], query_idx: usize) {
for (k_idx, score) in scores_row.iter_mut().enumerate() {
if k_idx > query_idx {
*score = -1.0e9;
}
}
}
pub fn apply_causal_mask_row_f64(scores_row: &mut [f64], query_idx: usize) {
for (k_idx, score) in scores_row.iter_mut().enumerate() {
if k_idx > query_idx {
*score = -1.0e9f64;
}
}
}