pub mod cache;
pub mod extractor;
pub mod reuse;
pub mod traits;
pub mod types;
pub use cache::{
CachedHiddenState, HiddenStateCache, HiddenStateCacheConfig, HiddenStateCacheStats,
};
pub use extractor::{LayerExtractor, MockHiddenStateProvider, StatePooling, StateSimilarity};
pub use reuse::{
AdaptiveReuseStrategy, HybridReuseStrategy, LengthAwareReuseStrategy, PrefixReuseStrategy,
SemanticReuseStrategy,
};
pub use traits::{
BoxedStateReuseStrategy, HiddenStateProvider, HiddenStateProviderExt, StateReuseStrategy,
};
pub use types::{
DType, Device, HiddenStateConfig, HiddenStateTensor, KVCache, LayerHiddenState,
ModelHiddenStates, ModelKVCache, TensorShape,
};
#[cfg(test)]
#[allow(clippy::cast_precision_loss, clippy::no_effect_underscore_binding)]
mod tests {
use super::*;
#[tokio::test]
async fn test_full_workflow() {
let provider = MockHiddenStateProvider::new("test-model", 6, 256);
let config = HiddenStateCacheConfig::new(10, 10 * 1024 * 1024)
.with_ttl(3600)
.with_store_kv_cache(true);
let mut cache = HiddenStateCache::new(config);
let text1 = "The quick brown fox jumps over the lazy dog.";
let states1 = provider.extract_hidden_states(text1).await.unwrap();
cache.put(text1.to_string(), states1.clone(), None);
assert!(cache.contains(text1));
let cached = cache.get(text1).unwrap();
assert_eq!(cached.states.model_id, "test-model");
let text2 = "The quick brown fox jumps over the lazy dog. And more text here.";
let prefix_match = cache.find_prefix_match(text2);
assert!(prefix_match.is_some());
let strategy = PrefixReuseStrategy::default();
if let Some((cached_text, cached_state)) = prefix_match {
let can_reuse = strategy.can_reuse(&cached_state.states, text2, cached_text);
assert!(can_reuse);
let reuse_point = strategy.reuse_point(&cached_state.states, text2, cached_text);
assert!(reuse_point > 0);
let quality = strategy.reuse_quality(&cached_state.states, text2, cached_text);
assert!(quality > 0.0);
}
let stats = cache.stats();
assert_eq!(stats.entry_count, 1);
}
#[tokio::test]
async fn test_kv_cache_workflow() {
let provider = MockHiddenStateProvider::new("test-model", 4, 128);
let (states1, kv1) = provider
.extract_with_kv_cache("Hello, world!", None)
.await
.unwrap();
assert_eq!(states1.model_id, "test-model");
assert_eq!(kv1.layers.len(), 4);
let (states2, kv2) = provider
.extract_with_kv_cache("How are you?", Some(&kv1))
.await
.unwrap();
assert_eq!(states2.model_id, "test-model");
assert!(kv2.total_size_bytes() >= kv1.total_size_bytes());
}
#[tokio::test]
async fn test_state_similarity_integration() {
let provider = MockHiddenStateProvider::new("test-model", 4, 128);
let states1 = provider.extract_hidden_states("test input").await.unwrap();
let states2 = provider.extract_hidden_states("test input").await.unwrap();
let avg_sim = StateSimilarity::average_similarity(&states1, &states2);
assert!((avg_sim - 1.0).abs() < 0.001);
let states3 = provider.extract_hidden_states("different").await.unwrap();
let avg_sim2 = StateSimilarity::average_similarity(&states1, &states3);
assert!(avg_sim2 < 1.0);
}
#[test]
fn test_hybrid_strategy() {
let strategy = HybridReuseStrategy::new();
let mut states = ModelHiddenStates::new("test", 4, 64);
states.sequence_length = 50;
for i in 0..4 {
let hidden = HiddenStateTensor::from_vec_1d(vec![0.5; 64 * 50]);
states.add_layer(LayerHiddenState::new(i, hidden));
}
states.set_pooled_output(HiddenStateTensor::from_vec_1d(vec![0.5; 64]));
let can_reuse = strategy.can_reuse(
&states,
"The quick brown fox jumps over the lazy dog.",
"The quick brown fox",
);
assert!(can_reuse);
}
#[test]
fn test_tensor_operations() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = TensorShape::new(vec![2, 3]);
let tensor = HiddenStateTensor::from_vec(data.clone(), shape.clone()).unwrap();
assert_eq!(tensor.numel(), 6);
assert_eq!(tensor.shape.ndim(), 2);
assert_eq!(tensor.size_bytes(), 24);
let sliced = tensor.slice(0, 0, 1).unwrap();
assert_eq!(sliced.shape.dims, vec![1, 3]);
assert_eq!(sliced.data, vec![1.0, 2.0, 3.0]);
let t1 = HiddenStateTensor::from_vec_1d(vec![1.0, 2.0]);
let t2 = HiddenStateTensor::from_vec_1d(vec![3.0, 4.0]);
let concat = HiddenStateTensor::concat(&[&t1, &t2], 0).unwrap();
assert_eq!(concat.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_layer_extractor() {
let mut states = ModelHiddenStates::new("test", 8, 64);
for i in 0..8 {
let hidden = HiddenStateTensor::from_vec_1d(vec![i as f32; 64]);
states.add_layer(LayerHiddenState::new(i, hidden));
}
let last_3 = LayerExtractor::extract_last_n(&states, 3);
assert_eq!(last_3.len(), 3);
assert_eq!(last_3[0].layer_idx, 5);
assert_eq!(last_3[2].layer_idx, 7);
let every_2 = LayerExtractor::extract_every_n(&states, 2);
assert_eq!(every_2.len(), 4);
assert_eq!(every_2[0].layer_idx, 0);
assert_eq!(every_2[1].layer_idx, 2);
let middle = LayerExtractor::extract_middle(&states, 4);
assert_eq!(middle.len(), 4);
assert_eq!(middle[0].layer_idx, 2);
}
#[test]
fn test_state_pooling() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = HiddenStateTensor::from_vec(data, TensorShape::new(vec![1, 2, 4])).unwrap();
let mean_pooled = StatePooling::mean_pool(&tensor).unwrap();
assert_eq!(mean_pooled.shape.dims, vec![4]);
assert!((mean_pooled.data[0] - 3.0).abs() < 0.001);
let max_pooled = StatePooling::max_pool(&tensor).unwrap();
assert_eq!(max_pooled.shape.dims, vec![4]);
assert!((max_pooled.data[0] - 5.0).abs() < 0.001);
let cls_pooled = StatePooling::cls_pool(&tensor).unwrap();
assert_eq!(cls_pooled.shape.dims, vec![4]);
assert!((cls_pooled.data[0] - 1.0).abs() < 0.001);
}
#[test]
fn test_cache_eviction() {
let config = HiddenStateCacheConfig::new(3, usize::MAX).without_ttl();
let mut cache = HiddenStateCache::new(config);
for i in 0..5 {
let states = ModelHiddenStates::new("test", 2, 32);
cache.put(format!("entry{i}"), states, None);
}
assert!(cache.len() <= 3);
assert!(cache.stats().entry_count <= 3);
}
#[test]
fn test_all_exports_accessible() {
let _config = HiddenStateConfig::default();
let _cache_config = HiddenStateCacheConfig::default();
let _shape = TensorShape::new(vec![1, 2, 3]);
let _dtype = DType::F32;
let _device = Device::Cpu;
let _tensor = HiddenStateTensor::default();
let _states = ModelHiddenStates::new("test", 4, 64);
let _kv = KVCache::new(0, 8, 64, 512);
let _model_kv = ModelKVCache::new("test", 4, 8, 64, 512);
let _cache = HiddenStateCache::with_defaults();
let _prefix_strategy = PrefixReuseStrategy::default();
let _semantic_strategy = SemanticReuseStrategy::default();
let _hybrid_strategy = HybridReuseStrategy::default();
let _length_strategy = LengthAwareReuseStrategy::default();
let _adaptive_strategy = AdaptiveReuseStrategy::default();
}
#[tokio::test]
async fn test_provider_ext_trait() {
let provider = MockHiddenStateProvider::new("test-model", 4, 128);
let mut cache = HiddenStateCache::with_defaults();
let states = provider
.extract_and_cache("test input", &mut cache)
.await
.unwrap();
assert_eq!(states.model_id, "test-model");
assert!(cache.contains("test input"));
let states2 = provider
.extract_and_cache("test input", &mut cache)
.await
.unwrap();
assert_eq!(states2.model_id, states.model_id);
}
#[test]
fn test_model_hidden_states_operations() {
let mut states = ModelHiddenStates::new("test", 4, 64);
states.sequence_length = 10;
for i in 0..4 {
let hidden = HiddenStateTensor::zeros(
TensorShape::new(vec![1, 10, 64]),
DType::F32,
Device::Cpu,
);
states.add_layer(LayerHiddenState::new(i, hidden));
}
assert!(states.get_layer(0).is_some());
assert!(states.get_layer(5).is_none());
assert!(states.last_hidden_state().is_some());
let prefix = states.prefix_states(5).unwrap();
assert_eq!(prefix.sequence_length, 5);
assert!(states.total_size_bytes() > 0);
}
#[test]
fn test_kv_cache_operations() {
let mut kv = KVCache::new(0, 8, 64, 512);
assert_eq!(kv.current_length(), 512);
kv.clear();
assert_eq!(kv.current_length(), 0);
let _ = kv.size_bytes(); }
}