#[derive(Clone)]
pub struct DequantizedFFNWeights {
pub up: Vec<f32>,
pub down: Vec<f32>,
pub up_bias: Option<Vec<f32>>,
pub down_bias: Option<Vec<f32>>,
}
#[cfg(feature = "gpu")]
pub struct DequantizedWeightCache {
layers: std::sync::RwLock<std::collections::HashMap<usize, DequantizedFFNWeights>>,
hidden_dim: usize,
intermediate_dim: usize,
num_layers: usize,
}
#[cfg(feature = "gpu")]
impl DequantizedWeightCache {
#[must_use]
pub fn new(hidden_dim: usize, intermediate_dim: usize, num_layers: usize) -> Self {
Self {
layers: std::sync::RwLock::new(std::collections::HashMap::with_capacity(num_layers)),
hidden_dim,
intermediate_dim,
num_layers,
}
}
pub fn warmup<F>(&self, dequant_fn: F)
where
F: Fn(usize) -> (Vec<f32>, Vec<f32>),
{
let mut cache = self.layers.write().expect("Cache lock poisoned");
for layer_idx in 0..self.num_layers {
cache.entry(layer_idx).or_insert_with(|| {
let (up, down) = dequant_fn(layer_idx);
DequantizedFFNWeights {
up,
down,
up_bias: None,
down_bias: None,
}
});
}
}
pub fn warmup_with_bias<F>(&self, dequant_fn: F)
where
F: Fn(usize) -> (Vec<f32>, Vec<f32>, Option<Vec<f32>>, Option<Vec<f32>>),
{
let mut cache = self.layers.write().expect("Cache lock poisoned");
for layer_idx in 0..self.num_layers {
cache.entry(layer_idx).or_insert_with(|| {
let (up, down, up_bias, down_bias) = dequant_fn(layer_idx);
DequantizedFFNWeights {
up,
down,
up_bias,
down_bias,
}
});
}
}
pub fn get(&self, layer_idx: usize) -> Option<DequantizedFFNWeights> {
let cache = self.layers.read().expect("Cache lock poisoned");
cache.get(&layer_idx).cloned()
}
pub fn is_cached(&self, layer_idx: usize) -> bool {
let cache = self.layers.read().expect("Cache lock poisoned");
cache.contains_key(&layer_idx)
}
pub fn cached_count(&self) -> usize {
let cache = self.layers.read().expect("Cache lock poisoned");
cache.len()
}
pub fn memory_bytes(&self) -> usize {
let per_layer = 2 * self.hidden_dim * self.intermediate_dim * 4;
self.cached_count() * per_layer
}
#[must_use]
pub fn dimensions(&self) -> (usize, usize, usize) {
(self.hidden_dim, self.intermediate_dim, self.num_layers)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dequantized_ffn_weights_basic() {
let weights = DequantizedFFNWeights {
up: vec![1.0, 2.0, 3.0, 4.0],
down: vec![5.0, 6.0, 7.0, 8.0],
up_bias: None,
down_bias: None,
};
assert_eq!(weights.up.len(), 4);
assert_eq!(weights.down.len(), 4);
assert!(weights.up_bias.is_none());
assert!(weights.down_bias.is_none());
}
#[test]
fn test_dequantized_ffn_weights_with_bias() {
let weights = DequantizedFFNWeights {
up: vec![1.0, 2.0],
down: vec![3.0, 4.0],
up_bias: Some(vec![0.1, 0.2]),
down_bias: Some(vec![0.3, 0.4]),
};
assert!(weights.up_bias.is_some());
assert!(weights.down_bias.is_some());
assert_eq!(weights.up_bias.as_ref().expect("as_ref").len(), 2);
}
#[test]
fn test_dequantized_ffn_weights_clone() {
let original = DequantizedFFNWeights {
up: vec![1.0, 2.0],
down: vec![3.0, 4.0],
up_bias: Some(vec![0.5]),
down_bias: None,
};
let cloned = original.clone();
assert_eq!(cloned.up, original.up);
assert_eq!(cloned.down, original.down);
assert_eq!(cloned.up_bias, original.up_bias);
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_new() {
let cache = DequantizedWeightCache::new(256, 1024, 4);
let (h, i, n) = cache.dimensions();
assert_eq!(h, 256);
assert_eq!(i, 1024);
assert_eq!(n, 4);
assert_eq!(cache.cached_count(), 0);
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_warmup() {
let cache = DequantizedWeightCache::new(64, 256, 2);
cache.warmup(|layer_idx| {
let up = vec![(layer_idx as f32) * 0.1; 64 * 256];
let down = vec![(layer_idx as f32) * 0.2; 256 * 64];
(up, down)
});
assert_eq!(cache.cached_count(), 2);
assert!(cache.is_cached(0));
assert!(cache.is_cached(1));
assert!(!cache.is_cached(2));
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_warmup_with_bias() {
let cache = DequantizedWeightCache::new(32, 128, 1);
cache.warmup_with_bias(|_| {
let up = vec![1.0; 32 * 128];
let down = vec![2.0; 128 * 32];
let up_bias = Some(vec![0.1; 128]);
let down_bias = Some(vec![0.2; 32]);
(up, down, up_bias, down_bias)
});
let weights = cache.get(0).expect("weights");
assert!(weights.up_bias.is_some());
assert!(weights.down_bias.is_some());
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_get() {
let cache = DequantizedWeightCache::new(16, 64, 2);
cache.warmup(|idx| {
let up = vec![idx as f32; 16 * 64];
let down = vec![(idx + 10) as f32; 64 * 16];
(up, down)
});
let w0 = cache.get(0).expect("w0");
assert!((w0.up[0] - 0.0).abs() < f32::EPSILON);
assert!((w0.down[0] - 10.0).abs() < f32::EPSILON);
let w1 = cache.get(1).expect("w1");
assert!((w1.up[0] - 1.0).abs() < f32::EPSILON);
assert!(cache.get(99).is_none());
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_is_cached() {
let cache = DequantizedWeightCache::new(8, 32, 3);
assert!(!cache.is_cached(0));
assert!(!cache.is_cached(1));
cache.warmup(|idx| (vec![idx as f32; 256], vec![idx as f32; 256]));
assert!(cache.is_cached(0));
assert!(cache.is_cached(1));
assert!(cache.is_cached(2));
assert!(!cache.is_cached(3));
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_memory_bytes() {
let cache = DequantizedWeightCache::new(64, 256, 2);
assert_eq!(cache.memory_bytes(), 0);
cache.warmup(|_| (vec![0.0; 64 * 256], vec![0.0; 256 * 64]));
let expected = 2 * 2 * 64 * 256 * 4;
assert_eq!(cache.memory_bytes(), expected);
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_dimensions() {
let cache = DequantizedWeightCache::new(512, 2048, 12);
let (h, i, n) = cache.dimensions();
assert_eq!(h, 512);
assert_eq!(i, 2048);
assert_eq!(n, 12);
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_idempotent_warmup() {
let cache = DequantizedWeightCache::new(8, 16, 1);
cache.warmup(|_| (vec![1.0; 128], vec![2.0; 128]));
let initial_up = cache.get(0).expect("initial_up").up[0];
assert!((initial_up - 1.0).abs() < f32::EPSILON);
cache.warmup(|_| (vec![999.0; 128], vec![888.0; 128]));
let after_up = cache.get(0).expect("after_up").up[0];
assert!((after_up - 1.0).abs() < f32::EPSILON);
}
#[test]
#[cfg(feature = "gpu")]
fn test_weight_cache_empty() {
let cache = DequantizedWeightCache::new(4, 8, 0);
assert_eq!(cache.cached_count(), 0);
assert!(cache.get(0).is_none());
assert!(!cache.is_cached(0));
}
}