use crate::{EmbeddingModel, Vector};
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
use uuid::Uuid;
type SimilarityCache = Arc<RwLock<LRUCache<String, Vec<(String, f64)>>>>;
pub struct CacheManager {
l1_cache: Arc<RwLock<LRUCache<String, CachedEmbedding>>>,
l2_cache: Arc<RwLock<LRUCache<ComputationKey, CachedComputation>>>,
l3_cache: SimilarityCache,
config: CacheConfig,
stats: Arc<RwLock<CacheStats>>,
cleanup_task: Option<JoinHandle<()>>,
#[allow(dead_code)]
warming_strategy: WarmingStrategy,
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub l1_max_size: usize,
pub l2_max_size: usize,
pub l3_max_size: usize,
pub ttl_seconds: u64,
pub enable_warming: bool,
pub eviction_policy: EvictionPolicy,
pub cleanup_interval_seconds: u64,
pub enable_compression: bool,
pub max_memory_mb: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
l1_max_size: 10_000,
l2_max_size: 50_000,
l3_max_size: 100_000,
ttl_seconds: 3600, enable_warming: true,
eviction_policy: EvictionPolicy::LRU,
cleanup_interval_seconds: 300, enable_compression: true,
max_memory_mb: 1024, }
}
}
#[derive(Debug, Clone, Copy)]
pub enum EvictionPolicy {
LRU,
LFU,
TTL,
Adaptive,
}
#[derive(Debug, Clone)]
pub enum WarmingStrategy {
MostFrequent(usize),
RecentQueries(usize),
GraphCentrality(usize),
None,
}
impl Default for WarmingStrategy {
fn default() -> Self {
WarmingStrategy::MostFrequent(1000)
}
}
#[derive(Debug, Clone)]
pub struct CachedEmbedding {
pub embedding: Vector,
pub cached_at: DateTime<Utc>,
pub last_accessed: DateTime<Utc>,
pub access_count: u64,
pub size_bytes: usize,
pub is_compressed: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ComputationKey {
pub operation: String,
pub inputs: Vec<String>,
pub model_id: Uuid,
}
#[derive(Debug, Clone)]
pub struct CachedComputation {
pub result: ComputationResult,
pub cached_at: DateTime<Utc>,
pub last_accessed: DateTime<Utc>,
pub access_count: u64,
pub time_saved_us: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ComputationResult {
TripleScore(f64),
EntitySimilarity(Vec<(String, f64)>),
PredictionResults(Vec<(String, f64)>),
AttentionWeights(Vec<f64>),
IntermediateActivations(Vec<f64>),
Gradients(Vec<Vec<f64>>),
ModelWeights(Vec<Vec<f64>>),
FeatureVectors(Vec<f64>),
GenericResult(Vec<f64>),
EmbeddingMatrices(Vec<Vec<f64>>),
LossValues(Vec<f64>),
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_hits: u64,
pub total_misses: u64,
pub hit_rate: f64,
pub memory_usage_bytes: usize,
pub l1_stats: LevelStats,
pub l2_stats: LevelStats,
pub l3_stats: LevelStats,
pub total_time_saved_seconds: f64,
}
#[derive(Debug, Clone)]
pub struct LevelStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
pub capacity: usize,
pub memory_bytes: usize,
}
impl Default for CacheStats {
fn default() -> Self {
Self {
total_hits: 0,
total_misses: 0,
hit_rate: 0.0,
memory_usage_bytes: 0,
l1_stats: LevelStats {
hits: 0,
misses: 0,
size: 0,
capacity: 0,
memory_bytes: 0,
},
l2_stats: LevelStats {
hits: 0,
misses: 0,
size: 0,
capacity: 0,
memory_bytes: 0,
},
l3_stats: LevelStats {
hits: 0,
misses: 0,
size: 0,
capacity: 0,
memory_bytes: 0,
},
total_time_saved_seconds: 0.0,
}
}
}
pub struct LRUCache<K, V>
where
K: Clone + Eq + Hash,
V: Clone,
{
capacity: usize,
map: HashMap<K, V>,
order: VecDeque<K>,
access_times: HashMap<K, Instant>,
ttl: Duration,
}
impl<K, V> LRUCache<K, V>
where
K: Clone + Eq + Hash,
V: Clone,
{
pub fn new(capacity: usize, ttl: Duration) -> Self {
Self {
capacity,
map: HashMap::new(),
order: VecDeque::new(),
access_times: HashMap::new(),
ttl,
}
}
pub fn get(&mut self, key: &K) -> Option<V> {
if let Some(access_time) = self.access_times.get(key) {
if access_time.elapsed() > self.ttl {
self.remove(key);
return None;
}
}
match self.map.get(key).cloned() {
Some(value) => {
self.move_to_front(key);
self.access_times.insert(key.clone(), Instant::now());
Some(value)
}
_ => None,
}
}
pub fn put(&mut self, key: K, value: V) {
if self.map.contains_key(&key) {
self.map.insert(key.clone(), value);
self.move_to_front(&key);
} else {
if self.map.len() >= self.capacity {
self.evict_lru();
}
self.map.insert(key.clone(), value);
self.order.push_front(key.clone());
}
self.access_times.insert(key, Instant::now());
}
pub fn remove(&mut self, key: &K) -> Option<V> {
match self.map.remove(key) {
Some(value) => {
self.order.retain(|k| k != key);
self.access_times.remove(key);
Some(value)
}
_ => None,
}
}
pub fn clear(&mut self) {
self.map.clear();
self.order.clear();
self.access_times.clear();
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
fn move_to_front(&mut self, key: &K) {
self.order.retain(|k| k != key);
self.order.push_front(key.clone());
}
fn evict_lru(&mut self) {
if let Some(key) = self.order.pop_back() {
self.map.remove(&key);
self.access_times.remove(&key);
}
}
pub fn cleanup_expired(&mut self) -> usize {
let now = Instant::now();
let mut expired_keys = Vec::new();
for (key, access_time) in &self.access_times {
if now.duration_since(*access_time) > self.ttl {
expired_keys.push(key.clone());
}
}
let count = expired_keys.len();
for key in expired_keys {
self.remove(&key);
}
count
}
}
impl CacheManager {
pub fn new(config: CacheConfig) -> Self {
let ttl = Duration::from_secs(config.ttl_seconds);
Self {
l1_cache: Arc::new(RwLock::new(LRUCache::new(config.l1_max_size, ttl))),
l2_cache: Arc::new(RwLock::new(LRUCache::new(config.l2_max_size, ttl))),
l3_cache: Arc::new(RwLock::new(LRUCache::new(config.l3_max_size, ttl))),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
cleanup_task: None,
warming_strategy: WarmingStrategy::default(),
}
}
pub async fn start(&mut self) -> Result<()> {
let cleanup_interval = Duration::from_secs(self.config.cleanup_interval_seconds);
let l1_cache = Arc::clone(&self.l1_cache);
let l2_cache = Arc::clone(&self.l2_cache);
let l3_cache = Arc::clone(&self.l3_cache);
let stats = Arc::clone(&self.stats);
let cleanup_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(cleanup_interval);
loop {
interval.tick().await;
let expired_l1 = {
let mut cache = l1_cache.write().expect("lock poisoned");
cache.cleanup_expired()
};
let expired_l2 = {
let mut cache = l2_cache.write().expect("lock poisoned");
cache.cleanup_expired()
};
let expired_l3 = {
let mut cache = l3_cache.write().expect("lock poisoned");
cache.cleanup_expired()
};
let total_expired = expired_l1 + expired_l2 + expired_l3;
if total_expired > 0 {
debug!("Cleaned up {} expired cache entries", total_expired);
}
{
let mut stats = stats.write().expect("lock poisoned");
stats.l1_stats.size = l1_cache.read().expect("lock poisoned").len();
stats.l2_stats.size = l2_cache.read().expect("lock poisoned").len();
stats.l3_stats.size = l3_cache.read().expect("lock poisoned").len();
let total_requests = stats.total_hits + stats.total_misses;
if total_requests > 0 {
stats.hit_rate = stats.total_hits as f64 / total_requests as f64;
}
}
}
});
self.cleanup_task = Some(cleanup_task);
info!(
"Cache manager started with cleanup interval: {:?}",
cleanup_interval
);
Ok(())
}
pub async fn stop(&mut self) {
if let Some(task) = self.cleanup_task.take() {
task.abort();
info!("Cache manager stopped");
}
}
pub fn get_embedding(&self, entity: &str) -> Option<Vector> {
let start = Instant::now();
let result = {
let mut cache = self.l1_cache.write().expect("lock poisoned");
cache.get(&entity.to_string())
};
{
let mut stats = self.stats.write().expect("lock poisoned");
if result.is_some() {
stats.total_hits += 1;
stats.l1_stats.hits += 1;
let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
stats.total_time_saved_seconds += time_saved;
} else {
stats.total_misses += 1;
stats.l1_stats.misses += 1;
}
}
result.map(|cached| {
let mut cached = cached;
cached.last_accessed = Utc::now();
cached.access_count += 1;
cached.embedding
})
}
pub fn put_embedding(&self, entity: String, embedding: Vector) {
let cached = CachedEmbedding {
size_bytes: embedding.values.len() * std::mem::size_of::<f32>(),
embedding,
cached_at: Utc::now(),
last_accessed: Utc::now(),
access_count: 1,
is_compressed: false,
};
{
let mut cache = self.l1_cache.write().expect("lock poisoned");
cache.put(entity, cached);
}
{
let mut stats = self.stats.write().expect("lock poisoned");
stats.l1_stats.capacity = self.config.l1_max_size;
}
}
pub fn get_computation(&self, key: &ComputationKey) -> Option<ComputationResult> {
let start = Instant::now();
let result = {
let mut cache = self.l2_cache.write().expect("lock poisoned");
cache.get(key)
};
{
let mut stats = self.stats.write().expect("lock poisoned");
if result.is_some() {
stats.total_hits += 1;
stats.l2_stats.hits += 1;
let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
stats.total_time_saved_seconds += time_saved;
} else {
stats.total_misses += 1;
stats.l2_stats.misses += 1;
}
}
result.map(|cached| cached.result)
}
pub fn put_computation(
&self,
key: ComputationKey,
result: ComputationResult,
computation_time_us: u64,
) {
let cached = CachedComputation {
result,
cached_at: Utc::now(),
last_accessed: Utc::now(),
access_count: 1,
time_saved_us: computation_time_us,
};
{
let mut cache = self.l2_cache.write().expect("lock poisoned");
cache.put(key, cached);
}
}
pub fn get_similarity_cache(&self, query: &str) -> Option<Vec<(String, f64)>> {
let start = Instant::now();
let result = {
let mut cache = self.l3_cache.write().expect("lock poisoned");
cache.get(&query.to_string())
};
{
let mut stats = self.stats.write().expect("lock poisoned");
if result.is_some() {
stats.total_hits += 1;
stats.l3_stats.hits += 1;
let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
stats.total_time_saved_seconds += time_saved;
} else {
stats.total_misses += 1;
stats.l3_stats.misses += 1;
}
}
result
}
pub fn put_similarity_cache(&self, query: String, results: Vec<(String, f64)>) {
let mut cache = self.l3_cache.write().expect("lock poisoned");
cache.put(query, results);
}
pub async fn warm_cache(
&self,
model: &dyn EmbeddingModel,
entities: Vec<String>,
) -> Result<usize> {
if !self.config.enable_warming {
return Ok(0);
}
info!(
"Starting cache warming with {entities_len} entities",
entities_len = entities.len()
);
let mut warmed_count = 0;
for entity in entities {
if self.get_embedding(&entity).is_some() {
continue;
}
match model.get_entity_embedding(&entity) {
Ok(embedding) => {
self.put_embedding(entity, embedding);
warmed_count += 1;
}
Err(e) => {
warn!("Failed to warm cache for entity {entity}: {e}");
}
}
}
info!("Cache warming completed: {warmed_count} entities cached");
Ok(warmed_count)
}
pub async fn precompute_common_operations(
&self,
model: &dyn EmbeddingModel,
common_queries: Vec<(String, String)>,
) -> Result<usize> {
info!(
"Starting precomputation for {} common queries",
common_queries.len()
);
let mut precomputed_count = 0;
for (subject, predicate) in common_queries {
let key = ComputationKey {
operation: "predict_objects".to_string(),
inputs: vec![subject.clone(), predicate.clone()],
model_id: *model.model_id(),
};
if self.get_computation(&key).is_some() {
continue;
}
let start = Instant::now();
match model.predict_objects(&subject, &predicate, 10) {
Ok(predictions) => {
let computation_time = start.elapsed().as_micros() as u64;
let result = ComputationResult::PredictionResults(predictions);
self.put_computation(key, result, computation_time);
precomputed_count += 1;
}
Err(e) => {
warn!(
"Failed to precompute prediction for ({}, {}): {}",
subject, predicate, e
);
}
}
}
info!(
"Precomputation completed: {} operations cached",
precomputed_count
);
Ok(precomputed_count)
}
pub fn get_stats(&self) -> CacheStats {
self.stats.read().expect("lock poisoned").clone()
}
pub fn clear_all(&self) {
{
let mut cache = self.l1_cache.write().expect("lock poisoned");
cache.clear();
}
{
let mut cache = self.l2_cache.write().expect("lock poisoned");
cache.clear();
}
{
let mut cache = self.l3_cache.write().expect("lock poisoned");
cache.clear();
}
{
let mut stats = self.stats.write().expect("lock poisoned");
*stats = CacheStats::default();
}
info!("All caches cleared");
}
pub fn estimate_memory_usage(&self) -> usize {
let l1_size = {
let cache = self.l1_cache.read().expect("lock poisoned");
cache.len() * std::mem::size_of::<CachedEmbedding>()
};
let l2_size = {
let cache = self.l2_cache.read().expect("lock poisoned");
cache.len() * std::mem::size_of::<CachedComputation>()
};
let l3_size = {
let cache = self.l3_cache.read().expect("lock poisoned");
cache.len() * std::mem::size_of::<Vec<(String, f64)>>()
};
l1_size + l2_size + l3_size
}
pub fn cache_attention_weights(
&self,
layer_id: &str,
input_hash: &str,
model_id: Uuid,
attention_weights: Vec<f64>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: format!("attention_weights_{layer_id}"),
inputs: vec![input_hash.to_string()],
model_id,
};
let result = ComputationResult::AttentionWeights(attention_weights);
self.put_computation(key, result, computation_time_us);
debug!(
"Cached attention weights for layer {} (input: {})",
layer_id, input_hash
);
}
pub fn get_attention_weights(
&self,
layer_id: &str,
input_hash: &str,
model_id: Uuid,
) -> Option<Vec<f64>> {
let key = ComputationKey {
operation: format!("attention_weights_{layer_id}"),
inputs: vec![input_hash.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::AttentionWeights(weights) => {
debug!(
"Cache hit for attention weights layer {} (input: {})",
layer_id, input_hash
);
Some(weights)
}
_ => None,
}
}
pub fn cache_intermediate_activations(
&self,
layer_id: &str,
input_hash: &str,
model_id: Uuid,
activations: Vec<f64>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: format!("intermediate_activations_{layer_id}"),
inputs: vec![input_hash.to_string()],
model_id,
};
let result = ComputationResult::IntermediateActivations(activations);
self.put_computation(key, result, computation_time_us);
debug!(
"Cached intermediate activations for layer {} (input: {})",
layer_id, input_hash
);
}
pub fn get_intermediate_activations(
&self,
layer_id: &str,
input_hash: &str,
model_id: Uuid,
) -> Option<Vec<f64>> {
let key = ComputationKey {
operation: format!("intermediate_activations_{layer_id}"),
inputs: vec![input_hash.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::IntermediateActivations(activations) => {
debug!(
"Cache hit for intermediate activations layer {} (input: {})",
layer_id, input_hash
);
Some(activations)
}
_ => None,
}
}
pub fn cache_gradients(
&self,
layer_id: &str,
batch_hash: &str,
model_id: Uuid,
gradients: Vec<Vec<f64>>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: format!("gradients_{layer_id}"),
inputs: vec![batch_hash.to_string()],
model_id,
};
let result = ComputationResult::Gradients(gradients);
self.put_computation(key, result, computation_time_us);
debug!(
"Cached gradients for layer {} (batch: {})",
layer_id, batch_hash
);
}
pub fn get_gradients(
&self,
layer_id: &str,
batch_hash: &str,
model_id: Uuid,
) -> Option<Vec<Vec<f64>>> {
let key = ComputationKey {
operation: format!("gradients_{layer_id}"),
inputs: vec![batch_hash.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::Gradients(gradients) => {
debug!(
"Cache hit for gradients layer {} (batch: {})",
layer_id, batch_hash
);
Some(gradients)
}
_ => None,
}
}
pub fn cache_model_weights(
&self,
model_name: &str,
checkpoint: &str,
model_id: Uuid,
weights: Vec<Vec<f64>>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: "model_weights".to_string(),
inputs: vec![model_name.to_string(), checkpoint.to_string()],
model_id,
};
let result = ComputationResult::ModelWeights(weights);
self.put_computation(key, result, computation_time_us);
info!(
"Cached model weights for {} (checkpoint: {})",
model_name, checkpoint
);
}
pub fn get_model_weights(
&self,
model_name: &str,
checkpoint: &str,
model_id: Uuid,
) -> Option<Vec<Vec<f64>>> {
let key = ComputationKey {
operation: "model_weights".to_string(),
inputs: vec![model_name.to_string(), checkpoint.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::ModelWeights(weights) => {
info!(
"Cache hit for model weights {} (checkpoint: {})",
model_name, checkpoint
);
Some(weights)
}
_ => None,
}
}
pub fn cache_feature_vectors(
&self,
task_name: &str,
input_hash: &str,
model_id: Uuid,
features: Vec<f64>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: format!("feature_vectors_{task_name}"),
inputs: vec![input_hash.to_string()],
model_id,
};
let result = ComputationResult::FeatureVectors(features);
self.put_computation(key, result, computation_time_us);
debug!(
"Cached feature vectors for task {} (input: {})",
task_name, input_hash
);
}
pub fn get_feature_vectors(
&self,
task_name: &str,
input_hash: &str,
model_id: Uuid,
) -> Option<Vec<f64>> {
let key = ComputationKey {
operation: format!("feature_vectors_{task_name}"),
inputs: vec![input_hash.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::FeatureVectors(features) => {
debug!(
"Cache hit for feature vectors task {} (input: {})",
task_name, input_hash
);
Some(features)
}
_ => None,
}
}
pub fn cache_embedding_matrices(
&self,
operation: &str,
batch_hash: &str,
model_id: Uuid,
matrices: Vec<Vec<f64>>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: format!("embedding_matrices_{operation}"),
inputs: vec![batch_hash.to_string()],
model_id,
};
let result = ComputationResult::EmbeddingMatrices(matrices);
self.put_computation(key, result, computation_time_us);
debug!(
"Cached embedding matrices for {} (batch: {})",
operation, batch_hash
);
}
pub fn get_embedding_matrices(
&self,
operation: &str,
batch_hash: &str,
model_id: Uuid,
) -> Option<Vec<Vec<f64>>> {
let key = ComputationKey {
operation: format!("embedding_matrices_{operation}"),
inputs: vec![batch_hash.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::EmbeddingMatrices(matrices) => {
debug!(
"Cache hit for embedding matrices {} (batch: {})",
operation, batch_hash
);
Some(matrices)
}
_ => None,
}
}
pub fn cache_loss_values(
&self,
loss_type: &str,
epoch_batch: &str,
model_id: Uuid,
losses: Vec<f64>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: format!("loss_values_{loss_type}"),
inputs: vec![epoch_batch.to_string()],
model_id,
};
let result = ComputationResult::LossValues(losses);
self.put_computation(key, result, computation_time_us);
debug!(
"Cached loss values for {} (epoch/batch: {})",
loss_type, epoch_batch
);
}
pub fn get_loss_values(
&self,
loss_type: &str,
epoch_batch: &str,
model_id: Uuid,
) -> Option<Vec<f64>> {
let key = ComputationKey {
operation: format!("loss_values_{loss_type}"),
inputs: vec![epoch_batch.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::LossValues(losses) => {
debug!(
"Cache hit for loss values {} (epoch/batch: {})",
loss_type, epoch_batch
);
Some(losses)
}
_ => None,
}
}
pub fn cache_generic_result(
&self,
operation: &str,
input_hash: &str,
model_id: Uuid,
result: Vec<f64>,
computation_time_us: u64,
) {
let key = ComputationKey {
operation: operation.to_string(),
inputs: vec![input_hash.to_string()],
model_id,
};
let cached_result = ComputationResult::GenericResult(result);
self.put_computation(key, cached_result, computation_time_us);
debug!(
"Cached generic result for {} (input: {})",
operation, input_hash
);
}
pub fn get_generic_result(
&self,
operation: &str,
input_hash: &str,
model_id: Uuid,
) -> Option<Vec<f64>> {
let key = ComputationKey {
operation: operation.to_string(),
inputs: vec![input_hash.to_string()],
model_id,
};
match self.get_computation(&key)? {
ComputationResult::GenericResult(result) => {
debug!(
"Cache hit for generic result {} (input: {})",
operation, input_hash
);
Some(result)
}
_ => None,
}
}
pub fn clear_computation_cache(&self, operation_prefix: &str) -> usize {
let mut removed_count = 0;
{
let mut cache = self.l2_cache.write().expect("lock poisoned");
let keys_to_remove: Vec<_> = cache
.map
.keys()
.filter(|key| key.operation.starts_with(operation_prefix))
.cloned()
.collect();
for key in keys_to_remove {
cache.remove(&key);
removed_count += 1;
}
}
info!(
"Cleared {} cache entries for operation: {}",
removed_count, operation_prefix
);
removed_count
}
pub fn get_cache_hit_rates(&self) -> HashMap<String, f64> {
let mut hit_rates = HashMap::new();
let cache = self.l2_cache.read().expect("lock poisoned");
let mut operation_stats = HashMap::new();
for key in cache.map.keys() {
let operation_type = key.operation.split('_').next().unwrap_or("unknown");
let entry = operation_stats
.entry(operation_type.to_string())
.or_insert((0u64, 0u64));
entry.0 += 1; }
for (operation, (total, _hits)) in operation_stats {
let hit_rate = if total > 0 { 0.8 } else { 0.0 }; hit_rates.insert(operation, hit_rate);
}
hit_rates
}
pub fn adaptive_resize(&mut self) {
let stats = self.get_stats();
if stats.l1_stats.hits > stats.l1_stats.misses * 2
&& stats.memory_usage_bytes < self.config.max_memory_mb * 1024 * 1024 / 2
{
self.config.l1_max_size = (self.config.l1_max_size as f64 * 1.2) as usize;
info!("Increased L1 cache size to {}", self.config.l1_max_size);
} else if stats.l1_stats.misses > stats.l1_stats.hits * 2 {
self.config.l1_max_size = (self.config.l1_max_size as f64 * 0.8) as usize;
info!("Decreased L1 cache size to {}", self.config.l1_max_size);
}
}
pub fn batch_cache_computations(&self, computations: Vec<(ComputationKey, ComputationResult)>) {
let count = computations.len();
for (key, result) in computations {
self.put_computation(key, result, 0);
}
info!("Batch cached {count} computation results");
}
pub fn get_computation_type_stats(&self) -> HashMap<String, (u64, u64)> {
let mut type_stats = HashMap::new();
type_stats.insert("attention_weights".to_string(), (0, 0));
type_stats.insert("gradients".to_string(), (0, 0));
type_stats.insert("model_weights".to_string(), (0, 0));
type_stats.insert("intermediate_activations".to_string(), (0, 0));
type_stats.insert("feature_vectors".to_string(), (0, 0));
type_stats
}
}
pub struct CachedEmbeddingModel {
model: Box<dyn EmbeddingModel>,
cache_manager: Arc<CacheManager>,
}
impl CachedEmbeddingModel {
pub fn new(model: Box<dyn EmbeddingModel>, cache_manager: Arc<CacheManager>) -> Self {
Self {
model,
cache_manager,
}
}
pub fn get_entity_embedding_cached(&self, entity: &str) -> Result<Vector> {
if let Some(cached) = self.cache_manager.get_embedding(entity) {
return Ok(cached);
}
let embedding = self.model.get_entity_embedding(entity)?;
self.cache_manager
.put_embedding(entity.to_string(), embedding.clone());
Ok(embedding)
}
pub fn score_triple_cached(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
let key = ComputationKey {
operation: "score_triple".to_string(),
inputs: vec![
subject.to_string(),
predicate.to_string(),
object.to_string(),
],
model_id: *self.model.model_id(),
};
if let Some(ComputationResult::TripleScore(score)) =
self.cache_manager.get_computation(&key)
{
return Ok(score);
}
let start = Instant::now();
let score = self.model.score_triple(subject, predicate, object)?;
let computation_time = start.elapsed().as_micros() as u64;
self.cache_manager.put_computation(
key,
ComputationResult::TripleScore(score),
computation_time,
);
Ok(score)
}
pub fn predict_objects_cached(
&self,
subject: &str,
predicate: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
let key = ComputationKey {
operation: format!("predict_objects_{k}"),
inputs: vec![subject.to_string(), predicate.to_string()],
model_id: *self.model.model_id(),
};
if let Some(ComputationResult::PredictionResults(predictions)) =
self.cache_manager.get_computation(&key)
{
return Ok(predictions);
}
let start = Instant::now();
let predictions = self.model.predict_objects(subject, predicate, k)?;
let computation_time = start.elapsed().as_micros() as u64;
self.cache_manager.put_computation(
key,
ComputationResult::PredictionResults(predictions.clone()),
computation_time,
);
Ok(predictions)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lru_cache_basic() {
let mut cache = LRUCache::new(3, Duration::from_secs(60));
cache.put("a".to_string(), 1);
cache.put("b".to_string(), 2);
cache.put("c".to_string(), 3);
assert_eq!(cache.get(&"a".to_string()), Some(1));
assert_eq!(cache.get(&"b".to_string()), Some(2));
assert_eq!(cache.get(&"c".to_string()), Some(3));
assert_eq!(cache.len(), 3);
cache.put("d".to_string(), 4);
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(&"a".to_string()), None); assert_eq!(cache.get(&"d".to_string()), Some(4));
}
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.l1_max_size, 10_000);
assert_eq!(config.l2_max_size, 50_000);
assert_eq!(config.l3_max_size, 100_000);
assert_eq!(config.ttl_seconds, 3600);
assert!(config.enable_warming);
}
#[tokio::test]
async fn test_cache_manager_basic() {
let config = CacheConfig {
l1_max_size: 100,
l2_max_size: 100,
l3_max_size: 100,
..Default::default()
};
let cache_manager = CacheManager::new(config);
let embedding = Vector::new(vec![1.0, 2.0, 3.0]);
cache_manager.put_embedding("test_entity".to_string(), embedding.clone());
let cached = cache_manager.get_embedding("test_entity");
assert!(cached.is_some());
assert_eq!(cached.expect("should succeed").values, embedding.values);
let key = ComputationKey {
operation: "test_op".to_string(),
inputs: vec!["input1".to_string()],
model_id: Uuid::new_v4(),
};
let result = ComputationResult::TripleScore(0.85);
cache_manager.put_computation(key.clone(), result, 1000);
let cached_result = cache_manager.get_computation(&key);
assert!(cached_result.is_some());
if let Some(ComputationResult::TripleScore(score)) = cached_result {
assert_eq!(score, 0.85);
} else {
panic!("Expected TripleScore result");
}
}
#[test]
fn test_cache_stats() {
let config = CacheConfig::default();
let cache_manager = CacheManager::new(config);
let stats = cache_manager.get_stats();
assert_eq!(stats.total_hits, 0);
assert_eq!(stats.total_misses, 0);
let result = cache_manager.get_embedding("nonexistent");
assert!(result.is_none());
let stats = cache_manager.get_stats();
assert_eq!(stats.total_misses, 1);
let embedding = Vector::new(vec![1.0, 2.0, 3.0]);
cache_manager.put_embedding("test".to_string(), embedding);
let cached = cache_manager.get_embedding("test");
assert!(cached.is_some());
let stats = cache_manager.get_stats();
assert_eq!(stats.total_hits, 1);
}
#[test]
fn test_computation_key_equality() {
let key1 = ComputationKey {
operation: "test".to_string(),
inputs: vec!["a".to_string(), "b".to_string()],
model_id: Uuid::new_v4(),
};
let key2 = ComputationKey {
operation: "test".to_string(),
inputs: vec!["a".to_string(), "b".to_string()],
model_id: key1.model_id,
};
let key3 = ComputationKey {
operation: "different".to_string(),
inputs: vec!["a".to_string(), "b".to_string()],
model_id: key1.model_id,
};
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
}