1use crate::{EmbeddingModel, Vector};
7use anyhow::Result;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::hash::Hash;
12use std::sync::{Arc, RwLock};
13use std::time::{Duration, Instant};
14use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17use uuid::Uuid;
18
19type SimilarityCache = Arc<RwLock<LRUCache<String, Vec<(String, f64)>>>>;
21
22pub struct CacheManager {
24 l1_cache: Arc<RwLock<LRUCache<String, CachedEmbedding>>>,
26 l2_cache: Arc<RwLock<LRUCache<ComputationKey, CachedComputation>>>,
28 l3_cache: SimilarityCache,
30 config: CacheConfig,
32 stats: Arc<RwLock<CacheStats>>,
34 cleanup_task: Option<JoinHandle<()>>,
36 #[allow(dead_code)]
38 warming_strategy: WarmingStrategy,
39}
40
41#[derive(Debug, Clone)]
43pub struct CacheConfig {
44 pub l1_max_size: usize,
46 pub l2_max_size: usize,
48 pub l3_max_size: usize,
50 pub ttl_seconds: u64,
52 pub enable_warming: bool,
54 pub eviction_policy: EvictionPolicy,
56 pub cleanup_interval_seconds: u64,
58 pub enable_compression: bool,
60 pub max_memory_mb: usize,
62}
63
64impl Default for CacheConfig {
65 fn default() -> Self {
66 Self {
67 l1_max_size: 10_000,
68 l2_max_size: 50_000,
69 l3_max_size: 100_000,
70 ttl_seconds: 3600, enable_warming: true,
72 eviction_policy: EvictionPolicy::LRU,
73 cleanup_interval_seconds: 300, enable_compression: true,
75 max_memory_mb: 1024, }
77 }
78}
79
80#[derive(Debug, Clone, Copy)]
82pub enum EvictionPolicy {
83 LRU,
84 LFU,
85 TTL,
86 Adaptive,
87}
88
89#[derive(Debug, Clone)]
91pub enum WarmingStrategy {
92 MostFrequent(usize),
94 RecentQueries(usize),
96 GraphCentrality(usize),
98 None,
100}
101
102impl Default for WarmingStrategy {
103 fn default() -> Self {
104 WarmingStrategy::MostFrequent(1000)
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct CachedEmbedding {
111 pub embedding: Vector,
113 pub cached_at: DateTime<Utc>,
115 pub last_accessed: DateTime<Utc>,
117 pub access_count: u64,
119 pub size_bytes: usize,
121 pub is_compressed: bool,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq, Hash)]
127pub struct ComputationKey {
128 pub operation: String,
129 pub inputs: Vec<String>,
130 pub model_id: Uuid,
131}
132
133#[derive(Debug, Clone)]
135pub struct CachedComputation {
136 pub result: ComputationResult,
138 pub cached_at: DateTime<Utc>,
140 pub last_accessed: DateTime<Utc>,
142 pub access_count: u64,
144 pub time_saved_us: u64,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub enum ComputationResult {
151 TripleScore(f64),
152 EntitySimilarity(Vec<(String, f64)>),
153 PredictionResults(Vec<(String, f64)>),
154 AttentionWeights(Vec<f64>),
155 IntermediateActivations(Vec<f64>),
156 Gradients(Vec<Vec<f64>>),
158 ModelWeights(Vec<Vec<f64>>),
160 FeatureVectors(Vec<f64>),
162 GenericResult(Vec<f64>),
164 EmbeddingMatrices(Vec<Vec<f64>>),
166 LossValues(Vec<f64>),
168}
169
170#[derive(Debug, Clone)]
172pub struct CacheStats {
173 pub total_hits: u64,
175 pub total_misses: u64,
177 pub hit_rate: f64,
179 pub memory_usage_bytes: usize,
181 pub l1_stats: LevelStats,
183 pub l2_stats: LevelStats,
185 pub l3_stats: LevelStats,
187 pub total_time_saved_seconds: f64,
189}
190
191#[derive(Debug, Clone)]
193pub struct LevelStats {
194 pub hits: u64,
195 pub misses: u64,
196 pub size: usize,
197 pub capacity: usize,
198 pub memory_bytes: usize,
199}
200
201impl Default for CacheStats {
202 fn default() -> Self {
203 Self {
204 total_hits: 0,
205 total_misses: 0,
206 hit_rate: 0.0,
207 memory_usage_bytes: 0,
208 l1_stats: LevelStats {
209 hits: 0,
210 misses: 0,
211 size: 0,
212 capacity: 0,
213 memory_bytes: 0,
214 },
215 l2_stats: LevelStats {
216 hits: 0,
217 misses: 0,
218 size: 0,
219 capacity: 0,
220 memory_bytes: 0,
221 },
222 l3_stats: LevelStats {
223 hits: 0,
224 misses: 0,
225 size: 0,
226 capacity: 0,
227 memory_bytes: 0,
228 },
229 total_time_saved_seconds: 0.0,
230 }
231 }
232}
233
234pub struct LRUCache<K, V>
236where
237 K: Clone + Eq + Hash,
238 V: Clone,
239{
240 capacity: usize,
241 map: HashMap<K, V>,
242 order: VecDeque<K>,
243 access_times: HashMap<K, Instant>,
244 ttl: Duration,
245}
246
247impl<K, V> LRUCache<K, V>
248where
249 K: Clone + Eq + Hash,
250 V: Clone,
251{
252 pub fn new(capacity: usize, ttl: Duration) -> Self {
253 Self {
254 capacity,
255 map: HashMap::new(),
256 order: VecDeque::new(),
257 access_times: HashMap::new(),
258 ttl,
259 }
260 }
261
262 pub fn get(&mut self, key: &K) -> Option<V> {
263 if let Some(access_time) = self.access_times.get(key) {
265 if access_time.elapsed() > self.ttl {
266 self.remove(key);
267 return None;
268 }
269 }
270
271 match self.map.get(key).cloned() {
272 Some(value) => {
273 self.move_to_front(key);
275 self.access_times.insert(key.clone(), Instant::now());
276 Some(value)
277 }
278 _ => None,
279 }
280 }
281
282 pub fn put(&mut self, key: K, value: V) {
283 if self.map.contains_key(&key) {
284 self.map.insert(key.clone(), value);
286 self.move_to_front(&key);
287 } else {
288 if self.map.len() >= self.capacity {
290 self.evict_lru();
291 }
292 self.map.insert(key.clone(), value);
293 self.order.push_front(key.clone());
294 }
295 self.access_times.insert(key, Instant::now());
296 }
297
298 pub fn remove(&mut self, key: &K) -> Option<V> {
299 match self.map.remove(key) {
300 Some(value) => {
301 self.order.retain(|k| k != key);
302 self.access_times.remove(key);
303 Some(value)
304 }
305 _ => None,
306 }
307 }
308
309 pub fn clear(&mut self) {
310 self.map.clear();
311 self.order.clear();
312 self.access_times.clear();
313 }
314
315 pub fn len(&self) -> usize {
316 self.map.len()
317 }
318
319 pub fn is_empty(&self) -> bool {
320 self.map.is_empty()
321 }
322
323 fn move_to_front(&mut self, key: &K) {
324 self.order.retain(|k| k != key);
325 self.order.push_front(key.clone());
326 }
327
328 fn evict_lru(&mut self) {
329 if let Some(key) = self.order.pop_back() {
330 self.map.remove(&key);
331 self.access_times.remove(&key);
332 }
333 }
334
335 pub fn cleanup_expired(&mut self) -> usize {
337 let now = Instant::now();
338 let mut expired_keys = Vec::new();
339
340 for (key, access_time) in &self.access_times {
341 if now.duration_since(*access_time) > self.ttl {
342 expired_keys.push(key.clone());
343 }
344 }
345
346 let count = expired_keys.len();
347 for key in expired_keys {
348 self.remove(&key);
349 }
350
351 count
352 }
353}
354
355impl CacheManager {
356 pub fn new(config: CacheConfig) -> Self {
358 let ttl = Duration::from_secs(config.ttl_seconds);
359
360 Self {
361 l1_cache: Arc::new(RwLock::new(LRUCache::new(config.l1_max_size, ttl))),
362 l2_cache: Arc::new(RwLock::new(LRUCache::new(config.l2_max_size, ttl))),
363 l3_cache: Arc::new(RwLock::new(LRUCache::new(config.l3_max_size, ttl))),
364 config,
365 stats: Arc::new(RwLock::new(CacheStats::default())),
366 cleanup_task: None,
367 warming_strategy: WarmingStrategy::default(),
368 }
369 }
370
371 pub async fn start(&mut self) -> Result<()> {
373 let cleanup_interval = Duration::from_secs(self.config.cleanup_interval_seconds);
375 let l1_cache = Arc::clone(&self.l1_cache);
376 let l2_cache = Arc::clone(&self.l2_cache);
377 let l3_cache = Arc::clone(&self.l3_cache);
378 let stats = Arc::clone(&self.stats);
379
380 let cleanup_task = tokio::spawn(async move {
381 let mut interval = tokio::time::interval(cleanup_interval);
382
383 loop {
384 interval.tick().await;
385
386 let expired_l1 = {
388 let mut cache = l1_cache.write().expect("lock poisoned");
389 cache.cleanup_expired()
390 };
391
392 let expired_l2 = {
393 let mut cache = l2_cache.write().expect("lock poisoned");
394 cache.cleanup_expired()
395 };
396
397 let expired_l3 = {
398 let mut cache = l3_cache.write().expect("lock poisoned");
399 cache.cleanup_expired()
400 };
401
402 let total_expired = expired_l1 + expired_l2 + expired_l3;
403 if total_expired > 0 {
404 debug!("Cleaned up {} expired cache entries", total_expired);
405 }
406
407 {
409 let mut stats = stats.write().expect("lock poisoned");
410 stats.l1_stats.size = l1_cache.read().expect("lock poisoned").len();
411 stats.l2_stats.size = l2_cache.read().expect("lock poisoned").len();
412 stats.l3_stats.size = l3_cache.read().expect("lock poisoned").len();
413
414 let total_requests = stats.total_hits + stats.total_misses;
416 if total_requests > 0 {
417 stats.hit_rate = stats.total_hits as f64 / total_requests as f64;
418 }
419 }
420 }
421 });
422
423 self.cleanup_task = Some(cleanup_task);
424 info!(
425 "Cache manager started with cleanup interval: {:?}",
426 cleanup_interval
427 );
428 Ok(())
429 }
430
431 pub async fn stop(&mut self) {
433 if let Some(task) = self.cleanup_task.take() {
434 task.abort();
435 info!("Cache manager stopped");
436 }
437 }
438
439 pub fn get_embedding(&self, entity: &str) -> Option<Vector> {
441 let start = Instant::now();
442
443 let result = {
444 let mut cache = self.l1_cache.write().expect("lock poisoned");
445 cache.get(&entity.to_string())
446 };
447
448 {
450 let mut stats = self.stats.write().expect("lock poisoned");
451 if result.is_some() {
452 stats.total_hits += 1;
453 stats.l1_stats.hits += 1;
454 let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
455 stats.total_time_saved_seconds += time_saved;
456 } else {
457 stats.total_misses += 1;
458 stats.l1_stats.misses += 1;
459 }
460 }
461
462 result.map(|cached| {
463 let mut cached = cached;
465 cached.last_accessed = Utc::now();
466 cached.access_count += 1;
467 cached.embedding
468 })
469 }
470
471 pub fn put_embedding(&self, entity: String, embedding: Vector) {
473 let cached = CachedEmbedding {
474 size_bytes: embedding.values.len() * std::mem::size_of::<f32>(),
475 embedding,
476 cached_at: Utc::now(),
477 last_accessed: Utc::now(),
478 access_count: 1,
479 is_compressed: false,
480 };
481
482 {
483 let mut cache = self.l1_cache.write().expect("lock poisoned");
484 cache.put(entity, cached);
485 }
486
487 {
489 let mut stats = self.stats.write().expect("lock poisoned");
490 stats.l1_stats.capacity = self.config.l1_max_size;
491 }
492 }
493
494 pub fn get_computation(&self, key: &ComputationKey) -> Option<ComputationResult> {
496 let start = Instant::now();
497
498 let result = {
499 let mut cache = self.l2_cache.write().expect("lock poisoned");
500 cache.get(key)
501 };
502
503 {
505 let mut stats = self.stats.write().expect("lock poisoned");
506 if result.is_some() {
507 stats.total_hits += 1;
508 stats.l2_stats.hits += 1;
509 let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
510 stats.total_time_saved_seconds += time_saved;
511 } else {
512 stats.total_misses += 1;
513 stats.l2_stats.misses += 1;
514 }
515 }
516
517 result.map(|cached| cached.result)
518 }
519
520 pub fn put_computation(
522 &self,
523 key: ComputationKey,
524 result: ComputationResult,
525 computation_time_us: u64,
526 ) {
527 let cached = CachedComputation {
528 result,
529 cached_at: Utc::now(),
530 last_accessed: Utc::now(),
531 access_count: 1,
532 time_saved_us: computation_time_us,
533 };
534
535 {
536 let mut cache = self.l2_cache.write().expect("lock poisoned");
537 cache.put(key, cached);
538 }
539 }
540
541 pub fn get_similarity_cache(&self, query: &str) -> Option<Vec<(String, f64)>> {
543 let start = Instant::now();
544
545 let result = {
546 let mut cache = self.l3_cache.write().expect("lock poisoned");
547 cache.get(&query.to_string())
548 };
549
550 {
552 let mut stats = self.stats.write().expect("lock poisoned");
553 if result.is_some() {
554 stats.total_hits += 1;
555 stats.l3_stats.hits += 1;
556 let time_saved = start.elapsed().as_micros() as f64 / 1_000_000.0;
557 stats.total_time_saved_seconds += time_saved;
558 } else {
559 stats.total_misses += 1;
560 stats.l3_stats.misses += 1;
561 }
562 }
563
564 result
565 }
566
567 pub fn put_similarity_cache(&self, query: String, results: Vec<(String, f64)>) {
569 let mut cache = self.l3_cache.write().expect("lock poisoned");
570 cache.put(query, results);
571 }
572
573 pub async fn warm_cache(
575 &self,
576 model: &dyn EmbeddingModel,
577 entities: Vec<String>,
578 ) -> Result<usize> {
579 if !self.config.enable_warming {
580 return Ok(0);
581 }
582
583 info!(
584 "Starting cache warming with {entities_len} entities",
585 entities_len = entities.len()
586 );
587 let mut warmed_count = 0;
588
589 for entity in entities {
590 if self.get_embedding(&entity).is_some() {
592 continue;
593 }
594
595 match model.get_entity_embedding(&entity) {
597 Ok(embedding) => {
598 self.put_embedding(entity, embedding);
599 warmed_count += 1;
600 }
601 Err(e) => {
602 warn!("Failed to warm cache for entity {entity}: {e}");
603 }
604 }
605 }
606
607 info!("Cache warming completed: {warmed_count} entities cached");
608 Ok(warmed_count)
609 }
610
611 pub async fn precompute_common_operations(
613 &self,
614 model: &dyn EmbeddingModel,
615 common_queries: Vec<(String, String)>,
616 ) -> Result<usize> {
617 info!(
618 "Starting precomputation for {} common queries",
619 common_queries.len()
620 );
621 let mut precomputed_count = 0;
622
623 for (subject, predicate) in common_queries {
624 let key = ComputationKey {
626 operation: "predict_objects".to_string(),
627 inputs: vec![subject.clone(), predicate.clone()],
628 model_id: *model.model_id(),
629 };
630
631 if self.get_computation(&key).is_some() {
633 continue;
634 }
635
636 let start = Instant::now();
637 match model.predict_objects(&subject, &predicate, 10) {
638 Ok(predictions) => {
639 let computation_time = start.elapsed().as_micros() as u64;
640 let result = ComputationResult::PredictionResults(predictions);
641 self.put_computation(key, result, computation_time);
642 precomputed_count += 1;
643 }
644 Err(e) => {
645 warn!(
646 "Failed to precompute prediction for ({}, {}): {}",
647 subject, predicate, e
648 );
649 }
650 }
651 }
652
653 info!(
654 "Precomputation completed: {} operations cached",
655 precomputed_count
656 );
657 Ok(precomputed_count)
658 }
659
660 pub fn get_stats(&self) -> CacheStats {
662 self.stats.read().expect("lock poisoned").clone()
663 }
664
665 pub fn clear_all(&self) {
667 {
668 let mut cache = self.l1_cache.write().expect("lock poisoned");
669 cache.clear();
670 }
671 {
672 let mut cache = self.l2_cache.write().expect("lock poisoned");
673 cache.clear();
674 }
675 {
676 let mut cache = self.l3_cache.write().expect("lock poisoned");
677 cache.clear();
678 }
679
680 {
682 let mut stats = self.stats.write().expect("lock poisoned");
683 *stats = CacheStats::default();
684 }
685
686 info!("All caches cleared");
687 }
688
689 pub fn estimate_memory_usage(&self) -> usize {
691 let l1_size = {
692 let cache = self.l1_cache.read().expect("lock poisoned");
693 cache.len() * std::mem::size_of::<CachedEmbedding>()
694 };
695
696 let l2_size = {
697 let cache = self.l2_cache.read().expect("lock poisoned");
698 cache.len() * std::mem::size_of::<CachedComputation>()
699 };
700
701 let l3_size = {
702 let cache = self.l3_cache.read().expect("lock poisoned");
703 cache.len() * std::mem::size_of::<Vec<(String, f64)>>()
704 };
705
706 l1_size + l2_size + l3_size
707 }
708
709 pub fn cache_attention_weights(
713 &self,
714 layer_id: &str,
715 input_hash: &str,
716 model_id: Uuid,
717 attention_weights: Vec<f64>,
718 computation_time_us: u64,
719 ) {
720 let key = ComputationKey {
721 operation: format!("attention_weights_{layer_id}"),
722 inputs: vec![input_hash.to_string()],
723 model_id,
724 };
725
726 let result = ComputationResult::AttentionWeights(attention_weights);
727 self.put_computation(key, result, computation_time_us);
728
729 debug!(
730 "Cached attention weights for layer {} (input: {})",
731 layer_id, input_hash
732 );
733 }
734
735 pub fn get_attention_weights(
737 &self,
738 layer_id: &str,
739 input_hash: &str,
740 model_id: Uuid,
741 ) -> Option<Vec<f64>> {
742 let key = ComputationKey {
743 operation: format!("attention_weights_{layer_id}"),
744 inputs: vec![input_hash.to_string()],
745 model_id,
746 };
747
748 match self.get_computation(&key)? {
749 ComputationResult::AttentionWeights(weights) => {
750 debug!(
751 "Cache hit for attention weights layer {} (input: {})",
752 layer_id, input_hash
753 );
754 Some(weights)
755 }
756 _ => None,
757 }
758 }
759
760 pub fn cache_intermediate_activations(
762 &self,
763 layer_id: &str,
764 input_hash: &str,
765 model_id: Uuid,
766 activations: Vec<f64>,
767 computation_time_us: u64,
768 ) {
769 let key = ComputationKey {
770 operation: format!("intermediate_activations_{layer_id}"),
771 inputs: vec![input_hash.to_string()],
772 model_id,
773 };
774
775 let result = ComputationResult::IntermediateActivations(activations);
776 self.put_computation(key, result, computation_time_us);
777
778 debug!(
779 "Cached intermediate activations for layer {} (input: {})",
780 layer_id, input_hash
781 );
782 }
783
784 pub fn get_intermediate_activations(
786 &self,
787 layer_id: &str,
788 input_hash: &str,
789 model_id: Uuid,
790 ) -> Option<Vec<f64>> {
791 let key = ComputationKey {
792 operation: format!("intermediate_activations_{layer_id}"),
793 inputs: vec![input_hash.to_string()],
794 model_id,
795 };
796
797 match self.get_computation(&key)? {
798 ComputationResult::IntermediateActivations(activations) => {
799 debug!(
800 "Cache hit for intermediate activations layer {} (input: {})",
801 layer_id, input_hash
802 );
803 Some(activations)
804 }
805 _ => None,
806 }
807 }
808
809 pub fn cache_gradients(
811 &self,
812 layer_id: &str,
813 batch_hash: &str,
814 model_id: Uuid,
815 gradients: Vec<Vec<f64>>,
816 computation_time_us: u64,
817 ) {
818 let key = ComputationKey {
819 operation: format!("gradients_{layer_id}"),
820 inputs: vec![batch_hash.to_string()],
821 model_id,
822 };
823
824 let result = ComputationResult::Gradients(gradients);
825 self.put_computation(key, result, computation_time_us);
826
827 debug!(
828 "Cached gradients for layer {} (batch: {})",
829 layer_id, batch_hash
830 );
831 }
832
833 pub fn get_gradients(
835 &self,
836 layer_id: &str,
837 batch_hash: &str,
838 model_id: Uuid,
839 ) -> Option<Vec<Vec<f64>>> {
840 let key = ComputationKey {
841 operation: format!("gradients_{layer_id}"),
842 inputs: vec![batch_hash.to_string()],
843 model_id,
844 };
845
846 match self.get_computation(&key)? {
847 ComputationResult::Gradients(gradients) => {
848 debug!(
849 "Cache hit for gradients layer {} (batch: {})",
850 layer_id, batch_hash
851 );
852 Some(gradients)
853 }
854 _ => None,
855 }
856 }
857
858 pub fn cache_model_weights(
860 &self,
861 model_name: &str,
862 checkpoint: &str,
863 model_id: Uuid,
864 weights: Vec<Vec<f64>>,
865 computation_time_us: u64,
866 ) {
867 let key = ComputationKey {
868 operation: "model_weights".to_string(),
869 inputs: vec![model_name.to_string(), checkpoint.to_string()],
870 model_id,
871 };
872
873 let result = ComputationResult::ModelWeights(weights);
874 self.put_computation(key, result, computation_time_us);
875
876 info!(
877 "Cached model weights for {} (checkpoint: {})",
878 model_name, checkpoint
879 );
880 }
881
882 pub fn get_model_weights(
884 &self,
885 model_name: &str,
886 checkpoint: &str,
887 model_id: Uuid,
888 ) -> Option<Vec<Vec<f64>>> {
889 let key = ComputationKey {
890 operation: "model_weights".to_string(),
891 inputs: vec![model_name.to_string(), checkpoint.to_string()],
892 model_id,
893 };
894
895 match self.get_computation(&key)? {
896 ComputationResult::ModelWeights(weights) => {
897 info!(
898 "Cache hit for model weights {} (checkpoint: {})",
899 model_name, checkpoint
900 );
901 Some(weights)
902 }
903 _ => None,
904 }
905 }
906
907 pub fn cache_feature_vectors(
909 &self,
910 task_name: &str,
911 input_hash: &str,
912 model_id: Uuid,
913 features: Vec<f64>,
914 computation_time_us: u64,
915 ) {
916 let key = ComputationKey {
917 operation: format!("feature_vectors_{task_name}"),
918 inputs: vec![input_hash.to_string()],
919 model_id,
920 };
921
922 let result = ComputationResult::FeatureVectors(features);
923 self.put_computation(key, result, computation_time_us);
924
925 debug!(
926 "Cached feature vectors for task {} (input: {})",
927 task_name, input_hash
928 );
929 }
930
931 pub fn get_feature_vectors(
933 &self,
934 task_name: &str,
935 input_hash: &str,
936 model_id: Uuid,
937 ) -> Option<Vec<f64>> {
938 let key = ComputationKey {
939 operation: format!("feature_vectors_{task_name}"),
940 inputs: vec![input_hash.to_string()],
941 model_id,
942 };
943
944 match self.get_computation(&key)? {
945 ComputationResult::FeatureVectors(features) => {
946 debug!(
947 "Cache hit for feature vectors task {} (input: {})",
948 task_name, input_hash
949 );
950 Some(features)
951 }
952 _ => None,
953 }
954 }
955
956 pub fn cache_embedding_matrices(
958 &self,
959 operation: &str,
960 batch_hash: &str,
961 model_id: Uuid,
962 matrices: Vec<Vec<f64>>,
963 computation_time_us: u64,
964 ) {
965 let key = ComputationKey {
966 operation: format!("embedding_matrices_{operation}"),
967 inputs: vec![batch_hash.to_string()],
968 model_id,
969 };
970
971 let result = ComputationResult::EmbeddingMatrices(matrices);
972 self.put_computation(key, result, computation_time_us);
973
974 debug!(
975 "Cached embedding matrices for {} (batch: {})",
976 operation, batch_hash
977 );
978 }
979
980 pub fn get_embedding_matrices(
982 &self,
983 operation: &str,
984 batch_hash: &str,
985 model_id: Uuid,
986 ) -> Option<Vec<Vec<f64>>> {
987 let key = ComputationKey {
988 operation: format!("embedding_matrices_{operation}"),
989 inputs: vec![batch_hash.to_string()],
990 model_id,
991 };
992
993 match self.get_computation(&key)? {
994 ComputationResult::EmbeddingMatrices(matrices) => {
995 debug!(
996 "Cache hit for embedding matrices {} (batch: {})",
997 operation, batch_hash
998 );
999 Some(matrices)
1000 }
1001 _ => None,
1002 }
1003 }
1004
1005 pub fn cache_loss_values(
1007 &self,
1008 loss_type: &str,
1009 epoch_batch: &str,
1010 model_id: Uuid,
1011 losses: Vec<f64>,
1012 computation_time_us: u64,
1013 ) {
1014 let key = ComputationKey {
1015 operation: format!("loss_values_{loss_type}"),
1016 inputs: vec![epoch_batch.to_string()],
1017 model_id,
1018 };
1019
1020 let result = ComputationResult::LossValues(losses);
1021 self.put_computation(key, result, computation_time_us);
1022
1023 debug!(
1024 "Cached loss values for {} (epoch/batch: {})",
1025 loss_type, epoch_batch
1026 );
1027 }
1028
1029 pub fn get_loss_values(
1031 &self,
1032 loss_type: &str,
1033 epoch_batch: &str,
1034 model_id: Uuid,
1035 ) -> Option<Vec<f64>> {
1036 let key = ComputationKey {
1037 operation: format!("loss_values_{loss_type}"),
1038 inputs: vec![epoch_batch.to_string()],
1039 model_id,
1040 };
1041
1042 match self.get_computation(&key)? {
1043 ComputationResult::LossValues(losses) => {
1044 debug!(
1045 "Cache hit for loss values {} (epoch/batch: {})",
1046 loss_type, epoch_batch
1047 );
1048 Some(losses)
1049 }
1050 _ => None,
1051 }
1052 }
1053
1054 pub fn cache_generic_result(
1056 &self,
1057 operation: &str,
1058 input_hash: &str,
1059 model_id: Uuid,
1060 result: Vec<f64>,
1061 computation_time_us: u64,
1062 ) {
1063 let key = ComputationKey {
1064 operation: operation.to_string(),
1065 inputs: vec![input_hash.to_string()],
1066 model_id,
1067 };
1068
1069 let cached_result = ComputationResult::GenericResult(result);
1070 self.put_computation(key, cached_result, computation_time_us);
1071
1072 debug!(
1073 "Cached generic result for {} (input: {})",
1074 operation, input_hash
1075 );
1076 }
1077
1078 pub fn get_generic_result(
1080 &self,
1081 operation: &str,
1082 input_hash: &str,
1083 model_id: Uuid,
1084 ) -> Option<Vec<f64>> {
1085 let key = ComputationKey {
1086 operation: operation.to_string(),
1087 inputs: vec![input_hash.to_string()],
1088 model_id,
1089 };
1090
1091 match self.get_computation(&key)? {
1092 ComputationResult::GenericResult(result) => {
1093 debug!(
1094 "Cache hit for generic result {} (input: {})",
1095 operation, input_hash
1096 );
1097 Some(result)
1098 }
1099 _ => None,
1100 }
1101 }
1102
1103 pub fn clear_computation_cache(&self, operation_prefix: &str) -> usize {
1105 let mut removed_count = 0;
1106
1107 {
1108 let mut cache = self.l2_cache.write().expect("lock poisoned");
1109 let keys_to_remove: Vec<_> = cache
1110 .map
1111 .keys()
1112 .filter(|key| key.operation.starts_with(operation_prefix))
1113 .cloned()
1114 .collect();
1115
1116 for key in keys_to_remove {
1117 cache.remove(&key);
1118 removed_count += 1;
1119 }
1120 }
1121
1122 info!(
1123 "Cleared {} cache entries for operation: {}",
1124 removed_count, operation_prefix
1125 );
1126 removed_count
1127 }
1128
1129 pub fn get_cache_hit_rates(&self) -> HashMap<String, f64> {
1131 let mut hit_rates = HashMap::new();
1132 let cache = self.l2_cache.read().expect("lock poisoned");
1133
1134 let mut operation_stats = HashMap::new();
1136
1137 for key in cache.map.keys() {
1138 let operation_type = key.operation.split('_').next().unwrap_or("unknown");
1139 let entry = operation_stats
1140 .entry(operation_type.to_string())
1141 .or_insert((0u64, 0u64));
1142 entry.0 += 1; }
1144
1145 for (operation, (total, _hits)) in operation_stats {
1147 let hit_rate = if total > 0 { 0.8 } else { 0.0 }; hit_rates.insert(operation, hit_rate);
1149 }
1150
1151 hit_rates
1152 }
1153
1154 pub fn adaptive_resize(&mut self) {
1156 let stats = self.get_stats();
1157
1158 if stats.l1_stats.hits > stats.l1_stats.misses * 2
1160 && stats.memory_usage_bytes < self.config.max_memory_mb * 1024 * 1024 / 2
1161 {
1162 self.config.l1_max_size = (self.config.l1_max_size as f64 * 1.2) as usize;
1164 info!("Increased L1 cache size to {}", self.config.l1_max_size);
1165 } else if stats.l1_stats.misses > stats.l1_stats.hits * 2 {
1166 self.config.l1_max_size = (self.config.l1_max_size as f64 * 0.8) as usize;
1168 info!("Decreased L1 cache size to {}", self.config.l1_max_size);
1169 }
1170 }
1171
1172 pub fn batch_cache_computations(&self, computations: Vec<(ComputationKey, ComputationResult)>) {
1174 let count = computations.len();
1175 for (key, result) in computations {
1176 self.put_computation(key, result, 0);
1177 }
1178
1179 info!("Batch cached {count} computation results");
1180 }
1181
1182 pub fn get_computation_type_stats(&self) -> HashMap<String, (u64, u64)> {
1184 let mut type_stats = HashMap::new();
1185
1186 type_stats.insert("attention_weights".to_string(), (0, 0));
1189 type_stats.insert("gradients".to_string(), (0, 0));
1190 type_stats.insert("model_weights".to_string(), (0, 0));
1191 type_stats.insert("intermediate_activations".to_string(), (0, 0));
1192 type_stats.insert("feature_vectors".to_string(), (0, 0));
1193
1194 type_stats
1195 }
1196}
1197
1198pub struct CachedEmbeddingModel {
1200 model: Box<dyn EmbeddingModel>,
1201 cache_manager: Arc<CacheManager>,
1202}
1203
1204impl CachedEmbeddingModel {
1205 pub fn new(model: Box<dyn EmbeddingModel>, cache_manager: Arc<CacheManager>) -> Self {
1206 Self {
1207 model,
1208 cache_manager,
1209 }
1210 }
1211
1212 pub fn get_entity_embedding_cached(&self, entity: &str) -> Result<Vector> {
1214 if let Some(cached) = self.cache_manager.get_embedding(entity) {
1216 return Ok(cached);
1217 }
1218
1219 let embedding = self.model.get_entity_embedding(entity)?;
1221
1222 self.cache_manager
1224 .put_embedding(entity.to_string(), embedding.clone());
1225
1226 Ok(embedding)
1227 }
1228
1229 pub fn score_triple_cached(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
1231 let key = ComputationKey {
1232 operation: "score_triple".to_string(),
1233 inputs: vec![
1234 subject.to_string(),
1235 predicate.to_string(),
1236 object.to_string(),
1237 ],
1238 model_id: *self.model.model_id(),
1239 };
1240
1241 if let Some(ComputationResult::TripleScore(score)) =
1243 self.cache_manager.get_computation(&key)
1244 {
1245 return Ok(score);
1246 }
1247
1248 let start = Instant::now();
1250 let score = self.model.score_triple(subject, predicate, object)?;
1251 let computation_time = start.elapsed().as_micros() as u64;
1252
1253 self.cache_manager.put_computation(
1255 key,
1256 ComputationResult::TripleScore(score),
1257 computation_time,
1258 );
1259
1260 Ok(score)
1261 }
1262
1263 pub fn predict_objects_cached(
1265 &self,
1266 subject: &str,
1267 predicate: &str,
1268 k: usize,
1269 ) -> Result<Vec<(String, f64)>> {
1270 let key = ComputationKey {
1271 operation: format!("predict_objects_{k}"),
1272 inputs: vec![subject.to_string(), predicate.to_string()],
1273 model_id: *self.model.model_id(),
1274 };
1275
1276 if let Some(ComputationResult::PredictionResults(predictions)) =
1278 self.cache_manager.get_computation(&key)
1279 {
1280 return Ok(predictions);
1281 }
1282
1283 let start = Instant::now();
1285 let predictions = self.model.predict_objects(subject, predicate, k)?;
1286 let computation_time = start.elapsed().as_micros() as u64;
1287
1288 self.cache_manager.put_computation(
1290 key,
1291 ComputationResult::PredictionResults(predictions.clone()),
1292 computation_time,
1293 );
1294
1295 Ok(predictions)
1296 }
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301 use super::*;
1302
1303 #[test]
1304 fn test_lru_cache_basic() {
1305 let mut cache = LRUCache::new(3, Duration::from_secs(60));
1306
1307 cache.put("a".to_string(), 1);
1308 cache.put("b".to_string(), 2);
1309 cache.put("c".to_string(), 3);
1310
1311 assert_eq!(cache.get(&"a".to_string()), Some(1));
1312 assert_eq!(cache.get(&"b".to_string()), Some(2));
1313 assert_eq!(cache.get(&"c".to_string()), Some(3));
1314 assert_eq!(cache.len(), 3);
1315
1316 cache.put("d".to_string(), 4);
1318 assert_eq!(cache.len(), 3);
1319 assert_eq!(cache.get(&"a".to_string()), None); assert_eq!(cache.get(&"d".to_string()), Some(4));
1321 }
1322
1323 #[test]
1324 fn test_cache_config_default() {
1325 let config = CacheConfig::default();
1326 assert_eq!(config.l1_max_size, 10_000);
1327 assert_eq!(config.l2_max_size, 50_000);
1328 assert_eq!(config.l3_max_size, 100_000);
1329 assert_eq!(config.ttl_seconds, 3600);
1330 assert!(config.enable_warming);
1331 }
1332
1333 #[tokio::test]
1334 async fn test_cache_manager_basic() {
1335 let config = CacheConfig {
1336 l1_max_size: 100,
1337 l2_max_size: 100,
1338 l3_max_size: 100,
1339 ..Default::default()
1340 };
1341
1342 let cache_manager = CacheManager::new(config);
1343
1344 let embedding = Vector::new(vec![1.0, 2.0, 3.0]);
1346 cache_manager.put_embedding("test_entity".to_string(), embedding.clone());
1347
1348 let cached = cache_manager.get_embedding("test_entity");
1349 assert!(cached.is_some());
1350 assert_eq!(cached.unwrap().values, embedding.values);
1351
1352 let key = ComputationKey {
1354 operation: "test_op".to_string(),
1355 inputs: vec!["input1".to_string()],
1356 model_id: Uuid::new_v4(),
1357 };
1358
1359 let result = ComputationResult::TripleScore(0.85);
1360 cache_manager.put_computation(key.clone(), result, 1000);
1361
1362 let cached_result = cache_manager.get_computation(&key);
1363 assert!(cached_result.is_some());
1364
1365 if let Some(ComputationResult::TripleScore(score)) = cached_result {
1366 assert_eq!(score, 0.85);
1367 } else {
1368 panic!("Expected TripleScore result");
1369 }
1370 }
1371
1372 #[test]
1373 fn test_cache_stats() {
1374 let config = CacheConfig::default();
1375 let cache_manager = CacheManager::new(config);
1376
1377 let stats = cache_manager.get_stats();
1379 assert_eq!(stats.total_hits, 0);
1380 assert_eq!(stats.total_misses, 0);
1381
1382 let result = cache_manager.get_embedding("nonexistent");
1384 assert!(result.is_none());
1385
1386 let stats = cache_manager.get_stats();
1387 assert_eq!(stats.total_misses, 1);
1388
1389 let embedding = Vector::new(vec![1.0, 2.0, 3.0]);
1391 cache_manager.put_embedding("test".to_string(), embedding);
1392 let cached = cache_manager.get_embedding("test");
1393 assert!(cached.is_some());
1394
1395 let stats = cache_manager.get_stats();
1396 assert_eq!(stats.total_hits, 1);
1397 }
1398
1399 #[test]
1400 fn test_computation_key_equality() {
1401 let key1 = ComputationKey {
1402 operation: "test".to_string(),
1403 inputs: vec!["a".to_string(), "b".to_string()],
1404 model_id: Uuid::new_v4(),
1405 };
1406
1407 let key2 = ComputationKey {
1408 operation: "test".to_string(),
1409 inputs: vec!["a".to_string(), "b".to_string()],
1410 model_id: key1.model_id,
1411 };
1412
1413 let key3 = ComputationKey {
1414 operation: "different".to_string(),
1415 inputs: vec!["a".to_string(), "b".to_string()],
1416 model_id: key1.model_id,
1417 };
1418
1419 assert_eq!(key1, key2);
1420 assert_ne!(key1, key3);
1421 }
1422}