use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
use candle_core::{DType, Device, DeviceLocation, Tensor};
use crate::error::{MIError, Result};
type CausalMaskCache = LazyLock<Mutex<HashMap<(usize, DeviceLocation, DType), Tensor>>>;
static CAUSAL_MASK_CACHE: CausalMaskCache = LazyLock::new(|| Mutex::new(HashMap::new()));
type CacheGuard = std::sync::MutexGuard<'static, HashMap<(usize, DeviceLocation, DType), Tensor>>;
fn lock_cache() -> Result<CacheGuard> {
CAUSAL_MASK_CACHE
.lock()
.map_err(|e| MIError::Hook(format!("mask cache lock poisoned: {e}")))
}
pub fn create_causal_mask(seq_len: usize, device: &Device, dtype: DType) -> Result<Tensor> {
let cache_key = (seq_len, device.location(), dtype);
{
let cache = lock_cache()?;
if let Some(cached) = cache.get(&cache_key) {
return Ok(cached.clone()); }
}
let mask: Vec<f32> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j <= i { 0.0 } else { f32::NEG_INFINITY }))
.collect();
let mask_tensor = Tensor::from_vec(mask, (1, 1, seq_len, seq_len), device)?.to_dtype(dtype)?;
{
let mut cache = lock_cache()?;
cache.insert(cache_key, mask_tensor.clone());
}
Ok(mask_tensor)
}
pub fn create_generation_mask(
new_seq_len: usize,
total_seq_len: usize,
start_pos: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
if new_seq_len == 1 {
let mask = Tensor::zeros((1, 1, 1, total_seq_len), dtype, device)?;
return Ok(mask);
}
let mask: Vec<f32> = (0..new_seq_len)
.flat_map(|i| {
let visible_up_to = start_pos + i;
(0..total_seq_len).map(move |j| {
if j <= visible_up_to {
0.0
} else {
f32::NEG_INFINITY
}
})
})
.collect();
let mask_tensor =
Tensor::from_vec(mask, (1, 1, new_seq_len, total_seq_len), device)?.to_dtype(dtype)?;
Ok(mask_tensor)
}
pub fn clear_mask_caches() -> Result<()> {
lock_cache()?.clear();
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn causal_mask_shape() {
let device = Device::Cpu;
let dtype = DType::F32;
let mask = create_causal_mask(4, &device, dtype).unwrap();
assert_eq!(mask.dims(), &[1, 1, 4, 4]);
}
#[test]
fn causal_mask_values() {
clear_mask_caches().unwrap();
let device = Device::Cpu;
let dtype = DType::F32;
let mask = create_causal_mask(3, &device, dtype).unwrap();
let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(data[0], 0.0);
assert!(data[1].is_infinite() && data[1] < 0.0);
assert!(data[2].is_infinite() && data[2] < 0.0);
assert_eq!(data[3], 0.0);
assert_eq!(data[4], 0.0);
assert!(data[5].is_infinite() && data[5] < 0.0);
assert_eq!(data[6], 0.0);
assert_eq!(data[7], 0.0);
assert_eq!(data[8], 0.0);
}
#[test]
fn generation_mask_single_token() {
let device = Device::Cpu;
let dtype = DType::F32;
let mask = create_generation_mask(1, 5, 4, &device, dtype).unwrap();
assert_eq!(mask.dims(), &[1, 1, 1, 5]);
let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
assert!(data.iter().all(|&v| v == 0.0));
}
#[test]
fn generation_mask_multi_token() {
let device = Device::Cpu;
let dtype = DType::F32;
let mask = create_generation_mask(2, 5, 3, &device, dtype).unwrap();
assert_eq!(mask.dims(), &[1, 1, 2, 5]);
let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(data[0], 0.0);
assert_eq!(data[1], 0.0);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 0.0);
assert!(data[4].is_infinite() && data[4] < 0.0);
assert_eq!(data[5], 0.0);
assert_eq!(data[6], 0.0);
assert_eq!(data[7], 0.0);
assert_eq!(data[8], 0.0);
assert_eq!(data[9], 0.0);
}
}