oxirs_embed/
caching.rs

1//! Advanced caching and precomputation system for embedding models
2//!
3//! This module provides multi-level caching for embeddings, computation results,
4//! and intelligent precomputation strategies for improved performance.
5
6use crate::{EmbeddingModel, Vector};
7use anyhow::Result;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::hash::Hash;
12use std::sync::{Arc, RwLock};
13use std::time::{Duration, Instant};
14// Removed unused imports
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17use uuid::Uuid;
18
19/// Type alias for similarity cache
20type SimilarityCache = Arc<RwLock<LRUCache<String, Vec<(String, f64)>>>>;
21
22/// Multi-level caching system for embeddings and computations
23pub struct CacheManager {
24    /// L1 Cache: Hot embeddings (fastest access)
25    l1_cache: Arc<RwLock<LRUCache<String, CachedEmbedding>>>,
26    /// L2 Cache: Computation results (intermediate speed)
27    l2_cache: Arc<RwLock<LRUCache<ComputationKey, CachedComputation>>>,
28    /// L3 Cache: Similarity cache (bulk operations)
29    l3_cache: SimilarityCache,
30    /// Cache configuration
31    config: CacheConfig,
32    /// Cache statistics
33    stats: Arc<RwLock<CacheStats>>,
34    /// Background cleanup task
35    cleanup_task: Option<JoinHandle<()>>,
36    /// Cache warming strategy
37    #[allow(dead_code)]
38    warming_strategy: WarmingStrategy,
39}
40
41/// Configuration for the caching system
42#[derive(Debug, Clone)]
43pub struct CacheConfig {
44    /// L1 cache size (number of embeddings)
45    pub l1_max_size: usize,
46    /// L2 cache size (number of computation results)
47    pub l2_max_size: usize,
48    /// L3 cache size (number of similarity results)
49    pub l3_max_size: usize,
50    /// Cache entry TTL in seconds
51    pub ttl_seconds: u64,
52    /// Enable cache warming
53    pub enable_warming: bool,
54    /// Cache eviction policy
55    pub eviction_policy: EvictionPolicy,
56    /// Background cleanup interval
57    pub cleanup_interval_seconds: u64,
58    /// Enable cache compression
59    pub enable_compression: bool,
60    /// Maximum memory usage in MB
61    pub max_memory_mb: usize,
62}
63
64impl Default for CacheConfig {
65    fn default() -> Self {
66        Self {
67            l1_max_size: 10_000,
68            l2_max_size: 50_000,
69            l3_max_size: 100_000,
70            ttl_seconds: 3600, // 1 hour
71            enable_warming: true,
72            eviction_policy: EvictionPolicy::LRU,
73            cleanup_interval_seconds: 300, // 5 minutes
74            enable_compression: true,
75            max_memory_mb: 1024, // 1GB
76        }
77    }
78}
79
80/// Cache eviction policies
81#[derive(Debug, Clone, Copy)]
82pub enum EvictionPolicy {
83    LRU,
84    LFU,
85    TTL,
86    Adaptive,
87}
88
89/// Cache warming strategies
90#[derive(Debug, Clone)]
91pub enum WarmingStrategy {
92    /// Pre-populate with most frequently accessed entities
93    MostFrequent(usize),
94    /// Pre-populate with entities from recent queries
95    RecentQueries(usize),
96    /// Pre-populate with entities based on graph centrality
97    GraphCentrality(usize),
98    /// No warming
99    None,
100}
101
102impl Default for WarmingStrategy {
103    fn default() -> Self {
104        WarmingStrategy::MostFrequent(1000)
105    }
106}
107
108/// Cached embedding with metadata
109#[derive(Debug, Clone)]
110pub struct CachedEmbedding {
111    /// The embedding vector
112    pub embedding: Vector,
113    /// When this was cached
114    pub cached_at: DateTime<Utc>,
115    /// Last access time
116    pub last_accessed: DateTime<Utc>,
117    /// Access count
118    pub access_count: u64,
119    /// Size in bytes
120    pub size_bytes: usize,
121    /// Whether this is compressed
122    pub is_compressed: bool,
123}
124
125/// Key for computation caching
126#[derive(Debug, Clone, PartialEq, Eq, Hash)]
127pub struct ComputationKey {
128    pub operation: String,
129    pub inputs: Vec<String>,
130    pub model_id: Uuid,
131}
132
133/// Cached computation result
134#[derive(Debug, Clone)]
135pub struct CachedComputation {
136    /// The computation result
137    pub result: ComputationResult,
138    /// When this was cached
139    pub cached_at: DateTime<Utc>,
140    /// Last access time
141    pub last_accessed: DateTime<Utc>,
142    /// Access count
143    pub access_count: u64,
144    /// Computation time saved (in microseconds)
145    pub time_saved_us: u64,
146}
147
148/// Types of computation results that can be cached
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub enum ComputationResult {
151    TripleScore(f64),
152    EntitySimilarity(Vec<(String, f64)>),
153    PredictionResults(Vec<(String, f64)>),
154    AttentionWeights(Vec<f64>),
155    IntermediateActivations(Vec<f64>),
156    /// Cached gradients for training optimization
157    Gradients(Vec<Vec<f64>>),
158    /// Cached model weights for quick model switching
159    ModelWeights(Vec<Vec<f64>>),
160    /// Cached feature vectors for downstream tasks
161    FeatureVectors(Vec<f64>),
162    /// Generic computation results for extensibility
163    GenericResult(Vec<f64>),
164    /// Cached embeddings matrices for batch operations
165    EmbeddingMatrices(Vec<Vec<f64>>),
166    /// Cached loss values for training monitoring
167    LossValues(Vec<f64>),
168}
169
170/// Cache statistics
171#[derive(Debug, Clone)]
172pub struct CacheStats {
173    /// Total cache hits
174    pub total_hits: u64,
175    /// Total cache misses
176    pub total_misses: u64,
177    /// Cache hit rate
178    pub hit_rate: f64,
179    /// Total memory usage in bytes
180    pub memory_usage_bytes: usize,
181    /// L1 cache stats
182    pub l1_stats: LevelStats,
183    /// L2 cache stats
184    pub l2_stats: LevelStats,
185    /// L3 cache stats
186    pub l3_stats: LevelStats,
187    /// Time saved by caching (in seconds)
188    pub total_time_saved_seconds: f64,
189}
190
191/// Statistics for a cache level
192#[derive(Debug, Clone)]
193pub struct LevelStats {
194    pub hits: u64,
195    pub misses: u64,
196    pub size: usize,
197    pub capacity: usize,
198    pub memory_bytes: usize,
199}
200
201impl Default for CacheStats {
202    fn default() -> Self {
203        Self {
204            total_hits: 0,
205            total_misses: 0,
206            hit_rate: 0.0,
207            memory_usage_bytes: 0,
208            l1_stats: LevelStats {
209                hits: 0,
210                misses: 0,
211                size: 0,
212                capacity: 0,
213                memory_bytes: 0,
214            },
215            l2_stats: LevelStats {
216                hits: 0,
217                misses: 0,
218                size: 0,
219                capacity: 0,
220                memory_bytes: 0,
221            },
222            l3_stats: LevelStats {
223                hits: 0,
224                misses: 0,
225                size: 0,
226                capacity: 0,
227                memory_bytes: 0,
228            },
229            total_time_saved_seconds: 0.0,
230        }
231    }
232}
233
234/// LRU Cache implementation
235pub struct LRUCache<K, V>
236where
237    K: Clone + Eq + Hash,
238    V: Clone,
239{
240    capacity: usize,
241    map: HashMap<K, V>,
242    order: VecDeque<K>,
243    access_times: HashMap<K, Instant>,
244    ttl: Duration,
245}
246
247impl<K, V> LRUCache<K, V>
248where
249    K: Clone + Eq + Hash,
250    V: Clone,
251{
252    pub fn new(capacity: usize, ttl: Duration) -> Self {
253        Self {
254            capacity,
255            map: HashMap::new(),
256            order: VecDeque::new(),
257            access_times: HashMap::new(),
258            ttl,
259        }
260    }
261
262    pub fn get(&mut self, key: &K) -> Option<V> {
263        // Check TTL
264        if let Some(access_time) = self.access_times.get(key) {
265            if access_time.elapsed() > self.ttl {
266                self.remove(key);
267                return None;
268            }
269        }
270
271        match self.map.get(key).cloned() {
272            Some(value) => {
273                // Move to front
274                self.move_to_front(key);
275                self.access_times.insert(key.clone(), Instant::now());
276                Some(value)
277            }
278            _ => None,
279        }
280    }
281
282    pub fn put(&mut self, key: K, value: V) {
283        if self.map.contains_key(&key) {
284            // Update existing
285            self.map.insert(key.clone(), value);
286            self.move_to_front(&key);
287        } else {
288            // Add new
289            if self.map.len() >= self.capacity {
290                self.evict_lru();
291            }
292            self.map.insert(key.clone(), value);
293            self.order.push_front(key.clone());
294        }
295        self.access_times.insert(key, Instant::now());
296    }
297
298    pub fn remove(&mut self, key: &K) -> Option<V> {
299        match self.map.remove(key) {
300            Some(value) => {
301                self.order.retain(|k| k != key);
302                self.access_times.remove(key);
303                Some(value)
304            }
305            _ => None,
306        }
307    }
308
309    pub fn clear(&mut self) {
310        self.map.clear();
311        self.order.clear();
312        self.access_times.clear();
313    }
314
315    pub fn len(&self) -> usize {
316        self.map.len()
317    }
318
319    pub fn is_empty(&self) -> bool {
320        self.map.is_empty()
321    }
322
323    fn move_to_front(&mut self, key: &K) {
324        self.order.retain(|k| k != key);
325        self.order.push_front(key.clone());
326    }
327
328    fn evict_lru(&mut self) {
329        if let Some(key) = self.order.pop_back() {
330            self.map.remove(&key);
331            self.access_times.remove(&key);
332        }
333    }
334
335    /// Clean up expired entries
336    pub fn cleanup_expired(&mut self) -> usize {
337        let now = Instant::now();
338        let mut expired_keys = Vec::new();
339
340        for (key, access_time) in &self.access_times {
341            if now.duration_since(*access_time) > self.ttl {
342                expired_keys.push(key.clone());
343            }
344        }
345
346        let count = expired_keys.len();
347        for key in expired_keys {
348            self.remove(&key);
349        }
350
351        count
352    }
353}
354
355impl CacheManager {
356    /// Create a new cache manager
357    pub fn new(config: CacheConfig) -> Self {
358        let ttl = Duration::from_secs(config.ttl_seconds);
359
360        Self {
361            l1_cache: Arc::new(RwLock::new(LRUCache::new(config.l1_max_size, ttl))),
362            l2_cache: Arc::new(RwLock::new(LRUCache::new(config.l2_max_size, ttl))),
363            l3_cache: Arc::new(RwLock::new(LRUCache::new(config.l3_max_size, ttl))),
364            config,
365            stats: Arc::new(RwLock::new(CacheStats::default())),
366            cleanup_task: None,
367            warming_strategy: WarmingStrategy::default(),
368        }
369    }
370
371    /// Start the cache manager with background tasks
372    pub async fn start(&mut self) -> Result<()> {
373        // Start cleanup task
374        let cleanup_interval = Duration::from_secs(self.config.cleanup_interval_seconds);
375        let l1_cache = Arc::clone(&self.l1_cache);
376        let l2_cache = Arc::clone(&self.l2_cache);
377        let l3_cache = Arc::clone(&self.l3_cache);
378        let stats = Arc::clone(&self.stats);
379
380        let cleanup_task = tokio::spawn(async move {
381            let mut interval = tokio::time::interval(cleanup_interval);
382
383            loop {
384                interval.tick().await;
385
386                // Cleanup expired entries
387                let expired_l1 = {
388                    let mut cache = l1_cache.write().expect("lock poisoned");
389                    cache.cleanup_expired()
390                };
391
392                let expired_l2 = {
393                    let mut cache = l2_cache.write().expect("lock poisoned");
394                    cache.cleanup_expired()
395                };
396
397                let expired_l3 = {
398                    let mut cache = l3_cache.write().expect("lock poisoned");
399                    cache.cleanup_expired()
400                };
401
402                let total_expired = expired_l1 + expired_l2 + expired_l3;
403                if total_expired > 0 {
404                    debug!("Cleaned up {} expired cache entries", total_expired);
405                }
406
407                // Update stats
408                {
409                    let mut stats = stats.write().expect("lock poisoned");
410                    stats.l1_stats.size = l1_cache.read().expect("lock poisoned").len();
411                    stats.l2_stats.size = l2_cache.read().expect("lock poisoned").len();
412                    stats.l3_stats.size = l3_cache.read().expect("lock poisoned").len();
413
414                    // Update hit rate
415                    let total_requests = stats.total_hits + stats.total_misses;
416                    if total_requests > 0 {
417                        stats.hit_rate = stats.total_hits as f64 / total_requests as f64;
418                    }
419                }
420            }
421        });
422
423        self.cleanup_task = Some(cleanup_task);
424        info!(
425            "Cache manager started with cleanup interval: {:?}",
426            cleanup_interval
427        );
428        Ok(())
429    }
430
431    /// Stop the cache manager
432    pub async fn stop(&mut self) {
433        if let Some(task) = self.cleanup_task.take() {
434            task.abort();
435            info!("Cache manager stopped");
436        }
437    }
438
439    /// Get cached embedding
440    pub fn get_embedding(&self, entity: &str) -> Option<Vector> {
441        let start = Instant::now();
442
443        let result = {
444            let mut cache = self.l1_cache.write().expect("lock poisoned");
445            cache.get(&entity.to_string())
446        };
447
448        // Update stats
449        {
450            let mut stats = self.stats.write().expect("lock poisoned");
451            if result.is_some() {
452                stats.total_hits += 1;
453                stats.l1_stats.hits += 1;
454                let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
455                stats.total_time_saved_seconds += time_saved;
456            } else {
457                stats.total_misses += 1;
458                stats.l1_stats.misses += 1;
459            }
460        }
461
462        result.map(|cached| {
463            // Update access info
464            let mut cached = cached;
465            cached.last_accessed = Utc::now();
466            cached.access_count += 1;
467            cached.embedding
468        })
469    }
470
471    /// Cache an embedding
472    pub fn put_embedding(&self, entity: String, embedding: Vector) {
473        let cached = CachedEmbedding {
474            size_bytes: embedding.values.len() * std::mem::size_of::<f32>(),
475            embedding,
476            cached_at: Utc::now(),
477            last_accessed: Utc::now(),
478            access_count: 1,
479            is_compressed: false,
480        };
481
482        {
483            let mut cache = self.l1_cache.write().expect("lock poisoned");
484            cache.put(entity, cached);
485        }
486
487        // Update capacity stats
488        {
489            let mut stats = self.stats.write().expect("lock poisoned");
490            stats.l1_stats.capacity = self.config.l1_max_size;
491        }
492    }
493
494    /// Get cached computation result
495    pub fn get_computation(&self, key: &ComputationKey) -> Option<ComputationResult> {
496        let start = Instant::now();
497
498        let result = {
499            let mut cache = self.l2_cache.write().expect("lock poisoned");
500            cache.get(key)
501        };
502
503        // Update stats
504        {
505            let mut stats = self.stats.write().expect("lock poisoned");
506            if result.is_some() {
507                stats.total_hits += 1;
508                stats.l2_stats.hits += 1;
509                let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
510                stats.total_time_saved_seconds += time_saved;
511            } else {
512                stats.total_misses += 1;
513                stats.l2_stats.misses += 1;
514            }
515        }
516
517        result.map(|cached| cached.result)
518    }
519
520    /// Cache a computation result
521    pub fn put_computation(
522        &self,
523        key: ComputationKey,
524        result: ComputationResult,
525        computation_time_us: u64,
526    ) {
527        let cached = CachedComputation {
528            result,
529            cached_at: Utc::now(),
530            last_accessed: Utc::now(),
531            access_count: 1,
532            time_saved_us: computation_time_us,
533        };
534
535        {
536            let mut cache = self.l2_cache.write().expect("lock poisoned");
537            cache.put(key, cached);
538        }
539    }
540
541    /// Get cached similarity results
542    pub fn get_similarity_cache(&self, query: &str) -> Option<Vec<(String, f64)>> {
543        let start = Instant::now();
544
545        let result = {
546            let mut cache = self.l3_cache.write().expect("lock poisoned");
547            cache.get(&query.to_string())
548        };
549
550        // Update stats
551        {
552            let mut stats = self.stats.write().expect("lock poisoned");
553            if result.is_some() {
554                stats.total_hits += 1;
555                stats.l3_stats.hits += 1;
556                let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
557                stats.total_time_saved_seconds += time_saved;
558            } else {
559                stats.total_misses += 1;
560                stats.l3_stats.misses += 1;
561            }
562        }
563
564        result
565    }
566
567    /// Cache similarity results
568    pub fn put_similarity_cache(&self, query: String, results: Vec<(String, f64)>) {
569        let mut cache = self.l3_cache.write().expect("lock poisoned");
570        cache.put(query, results);
571    }
572
573    /// Warm up cache with frequently accessed entities
574    pub async fn warm_cache(
575        &self,
576        model: &dyn EmbeddingModel,
577        entities: Vec<String>,
578    ) -> Result<usize> {
579        if !self.config.enable_warming {
580            return Ok(0);
581        }
582
583        info!(
584            "Starting cache warming with {entities_len} entities",
585            entities_len = entities.len()
586        );
587        let mut warmed_count = 0;
588
589        for entity in entities {
590            // Check if already cached
591            if self.get_embedding(&entity).is_some() {
592                continue;
593            }
594
595            // Get embedding and cache it
596            match model.get_entity_embedding(&entity) {
597                Ok(embedding) => {
598                    self.put_embedding(entity, embedding);
599                    warmed_count += 1;
600                }
601                Err(e) => {
602                    warn!("Failed to warm cache for entity {entity}: {e}");
603                }
604            }
605        }
606
607        info!("Cache warming completed: {warmed_count} entities cached");
608        Ok(warmed_count)
609    }
610
611    /// Precompute and cache common operations
612    pub async fn precompute_common_operations(
613        &self,
614        model: &dyn EmbeddingModel,
615        common_queries: Vec<(String, String)>,
616    ) -> Result<usize> {
617        info!(
618            "Starting precomputation for {} common queries",
619            common_queries.len()
620        );
621        let mut precomputed_count = 0;
622
623        for (subject, predicate) in common_queries {
624            // Precompute object predictions
625            let key = ComputationKey {
626                operation: "predict_objects".to_string(),
627                inputs: vec![subject.clone(), predicate.clone()],
628                model_id: *model.model_id(),
629            };
630
631            // Check if already cached
632            if self.get_computation(&key).is_some() {
633                continue;
634            }
635
636            let start = Instant::now();
637            match model.predict_objects(&subject, &predicate, 10) {
638                Ok(predictions) => {
639                    let computation_time = start.elapsed().as_micros() as u64;
640                    let result = ComputationResult::PredictionResults(predictions);
641                    self.put_computation(key, result, computation_time);
642                    precomputed_count += 1;
643                }
644                Err(e) => {
645                    warn!(
646                        "Failed to precompute prediction for ({}, {}): {}",
647                        subject, predicate, e
648                    );
649                }
650            }
651        }
652
653        info!(
654            "Precomputation completed: {} operations cached",
655            precomputed_count
656        );
657        Ok(precomputed_count)
658    }
659
660    /// Get cache statistics
661    pub fn get_stats(&self) -> CacheStats {
662        self.stats.read().expect("lock poisoned").clone()
663    }
664
665    /// Clear all caches
666    pub fn clear_all(&self) {
667        {
668            let mut cache = self.l1_cache.write().expect("lock poisoned");
669            cache.clear();
670        }
671        {
672            let mut cache = self.l2_cache.write().expect("lock poisoned");
673            cache.clear();
674        }
675        {
676            let mut cache = self.l3_cache.write().expect("lock poisoned");
677            cache.clear();
678        }
679
680        // Reset stats
681        {
682            let mut stats = self.stats.write().expect("lock poisoned");
683            *stats = CacheStats::default();
684        }
685
686        info!("All caches cleared");
687    }
688
689    /// Get memory usage estimation
690    pub fn estimate_memory_usage(&self) -> usize {
691        let l1_size = {
692            let cache = self.l1_cache.read().expect("lock poisoned");
693            cache.len() * std::mem::size_of::<CachedEmbedding>()
694        };
695
696        let l2_size = {
697            let cache = self.l2_cache.read().expect("lock poisoned");
698            cache.len() * std::mem::size_of::<CachedComputation>()
699        };
700
701        let l3_size = {
702            let cache = self.l3_cache.read().expect("lock poisoned");
703            cache.len() * std::mem::size_of::<Vec<(String, f64)>>()
704        };
705
706        l1_size + l2_size + l3_size
707    }
708
709    // ====== COMPUTATION CACHE SPECIALIZED METHODS ======
710
711    /// Cache attention weights for a specific layer and input
712    pub fn cache_attention_weights(
713        &self,
714        layer_id: &str,
715        input_hash: &str,
716        model_id: Uuid,
717        attention_weights: Vec<f64>,
718        computation_time_us: u64,
719    ) {
720        let key = ComputationKey {
721            operation: format!("attention_weights_{layer_id}"),
722            inputs: vec![input_hash.to_string()],
723            model_id,
724        };
725
726        let result = ComputationResult::AttentionWeights(attention_weights);
727        self.put_computation(key, result, computation_time_us);
728
729        debug!(
730            "Cached attention weights for layer {} (input: {})",
731            layer_id, input_hash
732        );
733    }
734
735    /// Get cached attention weights
736    pub fn get_attention_weights(
737        &self,
738        layer_id: &str,
739        input_hash: &str,
740        model_id: Uuid,
741    ) -> Option<Vec<f64>> {
742        let key = ComputationKey {
743            operation: format!("attention_weights_{layer_id}"),
744            inputs: vec![input_hash.to_string()],
745            model_id,
746        };
747
748        match self.get_computation(&key)? {
749            ComputationResult::AttentionWeights(weights) => {
750                debug!(
751                    "Cache hit for attention weights layer {} (input: {})",
752                    layer_id, input_hash
753                );
754                Some(weights)
755            }
756            _ => None,
757        }
758    }
759
760    /// Cache intermediate activations for a specific layer and input
761    pub fn cache_intermediate_activations(
762        &self,
763        layer_id: &str,
764        input_hash: &str,
765        model_id: Uuid,
766        activations: Vec<f64>,
767        computation_time_us: u64,
768    ) {
769        let key = ComputationKey {
770            operation: format!("intermediate_activations_{layer_id}"),
771            inputs: vec![input_hash.to_string()],
772            model_id,
773        };
774
775        let result = ComputationResult::IntermediateActivations(activations);
776        self.put_computation(key, result, computation_time_us);
777
778        debug!(
779            "Cached intermediate activations for layer {} (input: {})",
780            layer_id, input_hash
781        );
782    }
783
784    /// Get cached intermediate activations
785    pub fn get_intermediate_activations(
786        &self,
787        layer_id: &str,
788        input_hash: &str,
789        model_id: Uuid,
790    ) -> Option<Vec<f64>> {
791        let key = ComputationKey {
792            operation: format!("intermediate_activations_{layer_id}"),
793            inputs: vec![input_hash.to_string()],
794            model_id,
795        };
796
797        match self.get_computation(&key)? {
798            ComputationResult::IntermediateActivations(activations) => {
799                debug!(
800                    "Cache hit for intermediate activations layer {} (input: {})",
801                    layer_id, input_hash
802                );
803                Some(activations)
804            }
805            _ => None,
806        }
807    }
808
809    /// Cache gradients for training optimization
810    pub fn cache_gradients(
811        &self,
812        layer_id: &str,
813        batch_hash: &str,
814        model_id: Uuid,
815        gradients: Vec<Vec<f64>>,
816        computation_time_us: u64,
817    ) {
818        let key = ComputationKey {
819            operation: format!("gradients_{layer_id}"),
820            inputs: vec![batch_hash.to_string()],
821            model_id,
822        };
823
824        let result = ComputationResult::Gradients(gradients);
825        self.put_computation(key, result, computation_time_us);
826
827        debug!(
828            "Cached gradients for layer {} (batch: {})",
829            layer_id, batch_hash
830        );
831    }
832
833    /// Get cached gradients
834    pub fn get_gradients(
835        &self,
836        layer_id: &str,
837        batch_hash: &str,
838        model_id: Uuid,
839    ) -> Option<Vec<Vec<f64>>> {
840        let key = ComputationKey {
841            operation: format!("gradients_{layer_id}"),
842            inputs: vec![batch_hash.to_string()],
843            model_id,
844        };
845
846        match self.get_computation(&key)? {
847            ComputationResult::Gradients(gradients) => {
848                debug!(
849                    "Cache hit for gradients layer {} (batch: {})",
850                    layer_id, batch_hash
851                );
852                Some(gradients)
853            }
854            _ => None,
855        }
856    }
857
858    /// Cache model weights for quick model switching
859    pub fn cache_model_weights(
860        &self,
861        model_name: &str,
862        checkpoint: &str,
863        model_id: Uuid,
864        weights: Vec<Vec<f64>>,
865        computation_time_us: u64,
866    ) {
867        let key = ComputationKey {
868            operation: "model_weights".to_string(),
869            inputs: vec![model_name.to_string(), checkpoint.to_string()],
870            model_id,
871        };
872
873        let result = ComputationResult::ModelWeights(weights);
874        self.put_computation(key, result, computation_time_us);
875
876        info!(
877            "Cached model weights for {} (checkpoint: {})",
878            model_name, checkpoint
879        );
880    }
881
882    /// Get cached model weights
883    pub fn get_model_weights(
884        &self,
885        model_name: &str,
886        checkpoint: &str,
887        model_id: Uuid,
888    ) -> Option<Vec<Vec<f64>>> {
889        let key = ComputationKey {
890            operation: "model_weights".to_string(),
891            inputs: vec![model_name.to_string(), checkpoint.to_string()],
892            model_id,
893        };
894
895        match self.get_computation(&key)? {
896            ComputationResult::ModelWeights(weights) => {
897                info!(
898                    "Cache hit for model weights {} (checkpoint: {})",
899                    model_name, checkpoint
900                );
901                Some(weights)
902            }
903            _ => None,
904        }
905    }
906
907    /// Cache feature vectors for downstream tasks
908    pub fn cache_feature_vectors(
909        &self,
910        task_name: &str,
911        input_hash: &str,
912        model_id: Uuid,
913        features: Vec<f64>,
914        computation_time_us: u64,
915    ) {
916        let key = ComputationKey {
917            operation: format!("feature_vectors_{task_name}"),
918            inputs: vec![input_hash.to_string()],
919            model_id,
920        };
921
922        let result = ComputationResult::FeatureVectors(features);
923        self.put_computation(key, result, computation_time_us);
924
925        debug!(
926            "Cached feature vectors for task {} (input: {})",
927            task_name, input_hash
928        );
929    }
930
931    /// Get cached feature vectors
932    pub fn get_feature_vectors(
933        &self,
934        task_name: &str,
935        input_hash: &str,
936        model_id: Uuid,
937    ) -> Option<Vec<f64>> {
938        let key = ComputationKey {
939            operation: format!("feature_vectors_{task_name}"),
940            inputs: vec![input_hash.to_string()],
941            model_id,
942        };
943
944        match self.get_computation(&key)? {
945            ComputationResult::FeatureVectors(features) => {
946                debug!(
947                    "Cache hit for feature vectors task {} (input: {})",
948                    task_name, input_hash
949                );
950                Some(features)
951            }
952            _ => None,
953        }
954    }
955
956    /// Cache embedding matrices for batch operations
957    pub fn cache_embedding_matrices(
958        &self,
959        operation: &str,
960        batch_hash: &str,
961        model_id: Uuid,
962        matrices: Vec<Vec<f64>>,
963        computation_time_us: u64,
964    ) {
965        let key = ComputationKey {
966            operation: format!("embedding_matrices_{operation}"),
967            inputs: vec![batch_hash.to_string()],
968            model_id,
969        };
970
971        let result = ComputationResult::EmbeddingMatrices(matrices);
972        self.put_computation(key, result, computation_time_us);
973
974        debug!(
975            "Cached embedding matrices for {} (batch: {})",
976            operation, batch_hash
977        );
978    }
979
980    /// Get cached embedding matrices
981    pub fn get_embedding_matrices(
982        &self,
983        operation: &str,
984        batch_hash: &str,
985        model_id: Uuid,
986    ) -> Option<Vec<Vec<f64>>> {
987        let key = ComputationKey {
988            operation: format!("embedding_matrices_{operation}"),
989            inputs: vec![batch_hash.to_string()],
990            model_id,
991        };
992
993        match self.get_computation(&key)? {
994            ComputationResult::EmbeddingMatrices(matrices) => {
995                debug!(
996                    "Cache hit for embedding matrices {} (batch: {})",
997                    operation, batch_hash
998                );
999                Some(matrices)
1000            }
1001            _ => None,
1002        }
1003    }
1004
1005    /// Cache loss values for training monitoring
1006    pub fn cache_loss_values(
1007        &self,
1008        loss_type: &str,
1009        epoch_batch: &str,
1010        model_id: Uuid,
1011        losses: Vec<f64>,
1012        computation_time_us: u64,
1013    ) {
1014        let key = ComputationKey {
1015            operation: format!("loss_values_{loss_type}"),
1016            inputs: vec![epoch_batch.to_string()],
1017            model_id,
1018        };
1019
1020        let result = ComputationResult::LossValues(losses);
1021        self.put_computation(key, result, computation_time_us);
1022
1023        debug!(
1024            "Cached loss values for {} (epoch/batch: {})",
1025            loss_type, epoch_batch
1026        );
1027    }
1028
1029    /// Get cached loss values
1030    pub fn get_loss_values(
1031        &self,
1032        loss_type: &str,
1033        epoch_batch: &str,
1034        model_id: Uuid,
1035    ) -> Option<Vec<f64>> {
1036        let key = ComputationKey {
1037            operation: format!("loss_values_{loss_type}"),
1038            inputs: vec![epoch_batch.to_string()],
1039            model_id,
1040        };
1041
1042        match self.get_computation(&key)? {
1043            ComputationResult::LossValues(losses) => {
1044                debug!(
1045                    "Cache hit for loss values {} (epoch/batch: {})",
1046                    loss_type, epoch_batch
1047                );
1048                Some(losses)
1049            }
1050            _ => None,
1051        }
1052    }
1053
1054    /// Cache generic computation results for extensibility
1055    pub fn cache_generic_result(
1056        &self,
1057        operation: &str,
1058        input_hash: &str,
1059        model_id: Uuid,
1060        result: Vec<f64>,
1061        computation_time_us: u64,
1062    ) {
1063        let key = ComputationKey {
1064            operation: operation.to_string(),
1065            inputs: vec![input_hash.to_string()],
1066            model_id,
1067        };
1068
1069        let cached_result = ComputationResult::GenericResult(result);
1070        self.put_computation(key, cached_result, computation_time_us);
1071
1072        debug!(
1073            "Cached generic result for {} (input: {})",
1074            operation, input_hash
1075        );
1076    }
1077
1078    /// Get cached generic result
1079    pub fn get_generic_result(
1080        &self,
1081        operation: &str,
1082        input_hash: &str,
1083        model_id: Uuid,
1084    ) -> Option<Vec<f64>> {
1085        let key = ComputationKey {
1086            operation: operation.to_string(),
1087            inputs: vec![input_hash.to_string()],
1088            model_id,
1089        };
1090
1091        match self.get_computation(&key)? {
1092            ComputationResult::GenericResult(result) => {
1093                debug!(
1094                    "Cache hit for generic result {} (input: {})",
1095                    operation, input_hash
1096                );
1097                Some(result)
1098            }
1099            _ => None,
1100        }
1101    }
1102
1103    /// Clear cache for specific computation type
1104    pub fn clear_computation_cache(&self, operation_prefix: &str) -> usize {
1105        let mut removed_count = 0;
1106
1107        {
1108            let mut cache = self.l2_cache.write().expect("lock poisoned");
1109            let keys_to_remove: Vec<_> = cache
1110                .map
1111                .keys()
1112                .filter(|key| key.operation.starts_with(operation_prefix))
1113                .cloned()
1114                .collect();
1115
1116            for key in keys_to_remove {
1117                cache.remove(&key);
1118                removed_count += 1;
1119            }
1120        }
1121
1122        info!(
1123            "Cleared {} cache entries for operation: {}",
1124            removed_count, operation_prefix
1125        );
1126        removed_count
1127    }
1128
1129    /// Get cache hit rates by computation type  
1130    pub fn get_cache_hit_rates(&self) -> HashMap<String, f64> {
1131        let mut hit_rates = HashMap::new();
1132        let cache = self.l2_cache.read().expect("lock poisoned");
1133
1134        // Group by operation type
1135        let mut operation_stats = HashMap::new();
1136
1137        for key in cache.map.keys() {
1138            let operation_type = key.operation.split('_').next().unwrap_or("unknown");
1139            let entry = operation_stats
1140                .entry(operation_type.to_string())
1141                .or_insert((0u64, 0u64));
1142            entry.0 += 1; // Total operations
1143        }
1144
1145        // Calculate hit rates (simplified for now)
1146        for (operation, (total, _hits)) in operation_stats {
1147            let hit_rate = if total > 0 { 0.8 } else { 0.0 }; // Placeholder calculation
1148            hit_rates.insert(operation, hit_rate);
1149        }
1150
1151        hit_rates
1152    }
1153
1154    /// Adaptive cache resizing based on usage patterns
1155    pub fn adaptive_resize(&mut self) {
1156        let stats = self.get_stats();
1157
1158        // Resize based on hit rates and memory usage
1159        if stats.l1_stats.hits > stats.l1_stats.misses * 2
1160            && stats.memory_usage_bytes < self.config.max_memory_mb * 1024 * 1024 / 2
1161        {
1162            // High hit rate and low memory usage - increase L1 cache
1163            self.config.l1_max_size = (self.config.l1_max_size as f64 * 1.2) as usize;
1164            info!("Increased L1 cache size to {}", self.config.l1_max_size);
1165        } else if stats.l1_stats.misses > stats.l1_stats.hits * 2 {
1166            // High miss rate - decrease L1 cache
1167            self.config.l1_max_size = (self.config.l1_max_size as f64 * 0.8) as usize;
1168            info!("Decreased L1 cache size to {}", self.config.l1_max_size);
1169        }
1170    }
1171
1172    /// Batch cache multiple computation results efficiently
1173    pub fn batch_cache_computations(&self, computations: Vec<(ComputationKey, ComputationResult)>) {
1174        let count = computations.len();
1175        for (key, result) in computations {
1176            self.put_computation(key, result, 0);
1177        }
1178
1179        info!("Batch cached {count} computation results");
1180    }
1181
1182    /// Get cache efficiency metrics for different computation types
1183    pub fn get_computation_type_stats(&self) -> HashMap<String, (u64, u64)> {
1184        let mut type_stats = HashMap::new();
1185
1186        // This would require tracking stats per computation type
1187        // For now, return placeholder values - would need to enhance stats tracking
1188        type_stats.insert("attention_weights".to_string(), (0, 0));
1189        type_stats.insert("gradients".to_string(), (0, 0));
1190        type_stats.insert("model_weights".to_string(), (0, 0));
1191        type_stats.insert("intermediate_activations".to_string(), (0, 0));
1192        type_stats.insert("feature_vectors".to_string(), (0, 0));
1193
1194        type_stats
1195    }
1196}
1197
1198/// Cache-aware embedding wrapper
1199pub struct CachedEmbeddingModel {
1200    model: Box<dyn EmbeddingModel>,
1201    cache_manager: Arc<CacheManager>,
1202}
1203
1204impl CachedEmbeddingModel {
1205    pub fn new(model: Box<dyn EmbeddingModel>, cache_manager: Arc<CacheManager>) -> Self {
1206        Self {
1207            model,
1208            cache_manager,
1209        }
1210    }
1211
1212    /// Get entity embedding with caching
1213    pub fn get_entity_embedding_cached(&self, entity: &str) -> Result<Vector> {
1214        // Try cache first
1215        if let Some(cached) = self.cache_manager.get_embedding(entity) {
1216            return Ok(cached);
1217        }
1218
1219        // Cache miss - get from model
1220        let embedding = self.model.get_entity_embedding(entity)?;
1221
1222        // Cache the result
1223        self.cache_manager
1224            .put_embedding(entity.to_string(), embedding.clone());
1225
1226        Ok(embedding)
1227    }
1228
1229    /// Score triple with caching
1230    pub fn score_triple_cached(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
1231        let key = ComputationKey {
1232            operation: "score_triple".to_string(),
1233            inputs: vec![
1234                subject.to_string(),
1235                predicate.to_string(),
1236                object.to_string(),
1237            ],
1238            model_id: *self.model.model_id(),
1239        };
1240
1241        // Try cache first
1242        if let Some(ComputationResult::TripleScore(score)) =
1243            self.cache_manager.get_computation(&key)
1244        {
1245            return Ok(score);
1246        }
1247
1248        // Cache miss - compute from model
1249        let start = Instant::now();
1250        let score = self.model.score_triple(subject, predicate, object)?;
1251        let computation_time = start.elapsed().as_micros() as u64;
1252
1253        // Cache the result
1254        self.cache_manager.put_computation(
1255            key,
1256            ComputationResult::TripleScore(score),
1257            computation_time,
1258        );
1259
1260        Ok(score)
1261    }
1262
1263    /// Predict objects with caching
1264    pub fn predict_objects_cached(
1265        &self,
1266        subject: &str,
1267        predicate: &str,
1268        k: usize,
1269    ) -> Result<Vec<(String, f64)>> {
1270        let key = ComputationKey {
1271            operation: format!("predict_objects_{k}"),
1272            inputs: vec![subject.to_string(), predicate.to_string()],
1273            model_id: *self.model.model_id(),
1274        };
1275
1276        // Try cache first
1277        if let Some(ComputationResult::PredictionResults(predictions)) =
1278            self.cache_manager.get_computation(&key)
1279        {
1280            return Ok(predictions);
1281        }
1282
1283        // Cache miss - compute from model
1284        let start = Instant::now();
1285        let predictions = self.model.predict_objects(subject, predicate, k)?;
1286        let computation_time = start.elapsed().as_micros() as u64;
1287
1288        // Cache the result
1289        self.cache_manager.put_computation(
1290            key,
1291            ComputationResult::PredictionResults(predictions.clone()),
1292            computation_time,
1293        );
1294
1295        Ok(predictions)
1296    }
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301    use super::*;
1302
1303    #[test]
1304    fn test_lru_cache_basic() {
1305        let mut cache = LRUCache::new(3, Duration::from_secs(60));
1306
1307        cache.put("a".to_string(), 1);
1308        cache.put("b".to_string(), 2);
1309        cache.put("c".to_string(), 3);
1310
1311        assert_eq!(cache.get(&"a".to_string()), Some(1));
1312        assert_eq!(cache.get(&"b".to_string()), Some(2));
1313        assert_eq!(cache.get(&"c".to_string()), Some(3));
1314        assert_eq!(cache.len(), 3);
1315
1316        // Add one more - should evict least recently used
1317        cache.put("d".to_string(), 4);
1318        assert_eq!(cache.len(), 3);
1319        assert_eq!(cache.get(&"a".to_string()), None); // Should be evicted
1320        assert_eq!(cache.get(&"d".to_string()), Some(4));
1321    }
1322
1323    #[test]
1324    fn test_cache_config_default() {
1325        let config = CacheConfig::default();
1326        assert_eq!(config.l1_max_size, 10_000);
1327        assert_eq!(config.l2_max_size, 50_000);
1328        assert_eq!(config.l3_max_size, 100_000);
1329        assert_eq!(config.ttl_seconds, 3600);
1330        assert!(config.enable_warming);
1331    }
1332
1333    #[tokio::test]
1334    async fn test_cache_manager_basic() {
1335        let config = CacheConfig {
1336            l1_max_size: 100,
1337            l2_max_size: 100,
1338            l3_max_size: 100,
1339            ..Default::default()
1340        };
1341
1342        let cache_manager = CacheManager::new(config);
1343
1344        // Test embedding caching
1345        let embedding = Vector::new(vec![1.0, 2.0, 3.0]);
1346        cache_manager.put_embedding("test_entity".to_string(), embedding.clone());
1347
1348        let cached = cache_manager.get_embedding("test_entity");
1349        assert!(cached.is_some());
1350        assert_eq!(cached.unwrap().values, embedding.values);
1351
1352        // Test computation caching
1353        let key = ComputationKey {
1354            operation: "test_op".to_string(),
1355            inputs: vec!["input1".to_string()],
1356            model_id: Uuid::new_v4(),
1357        };
1358
1359        let result = ComputationResult::TripleScore(0.85);
1360        cache_manager.put_computation(key.clone(), result, 1000);
1361
1362        let cached_result = cache_manager.get_computation(&key);
1363        assert!(cached_result.is_some());
1364
1365        if let Some(ComputationResult::TripleScore(score)) = cached_result {
1366            assert_eq!(score, 0.85);
1367        } else {
1368            panic!("Expected TripleScore result");
1369        }
1370    }
1371
1372    #[test]
1373    fn test_cache_stats() {
1374        let config = CacheConfig::default();
1375        let cache_manager = CacheManager::new(config);
1376
1377        // Initially empty
1378        let stats = cache_manager.get_stats();
1379        assert_eq!(stats.total_hits, 0);
1380        assert_eq!(stats.total_misses, 0);
1381
1382        // Cache miss
1383        let result = cache_manager.get_embedding("nonexistent");
1384        assert!(result.is_none());
1385
1386        let stats = cache_manager.get_stats();
1387        assert_eq!(stats.total_misses, 1);
1388
1389        // Cache hit
1390        let embedding = Vector::new(vec![1.0, 2.0, 3.0]);
1391        cache_manager.put_embedding("test".to_string(), embedding);
1392        let cached = cache_manager.get_embedding("test");
1393        assert!(cached.is_some());
1394
1395        let stats = cache_manager.get_stats();
1396        assert_eq!(stats.total_hits, 1);
1397    }
1398
1399    #[test]
1400    fn test_computation_key_equality() {
1401        let key1 = ComputationKey {
1402            operation: "test".to_string(),
1403            inputs: vec!["a".to_string(), "b".to_string()],
1404            model_id: Uuid::new_v4(),
1405        };
1406
1407        let key2 = ComputationKey {
1408            operation: "test".to_string(),
1409            inputs: vec!["a".to_string(), "b".to_string()],
1410            model_id: key1.model_id,
1411        };
1412
1413        let key3 = ComputationKey {
1414            operation: "different".to_string(),
1415            inputs: vec!["a".to_string(), "b".to_string()],
1416            model_id: key1.model_id,
1417        };
1418
1419        assert_eq!(key1, key2);
1420        assert_ne!(key1, key3);
1421    }
1422}