1use lru::LruCache;
7use parking_lot::RwLock;
8use std::num::NonZeroUsize;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13#[repr(align(64))]
18#[derive(Debug, Clone)]
19pub struct AlignedVector {
20 data: Vec<f32>,
21}
22
23impl AlignedVector {
24 pub fn new(data: Vec<f32>) -> Self {
26 Self { data }
27 }
28
29 pub fn zeros(len: usize) -> Self {
31 Self {
32 data: vec![0.0; len],
33 }
34 }
35
36 pub fn as_slice(&self) -> &[f32] {
38 &self.data
39 }
40
41 pub fn as_mut_slice(&mut self) -> &mut [f32] {
43 &mut self.data
44 }
45
46 pub fn len(&self) -> usize {
48 self.data.len()
49 }
50
51 pub fn is_empty(&self) -> bool {
53 self.data.is_empty()
54 }
55
56 pub fn into_vec(self) -> Vec<f32> {
58 self.data
59 }
60}
61
62impl From<Vec<f32>> for AlignedVector {
63 fn from(data: Vec<f32>) -> Self {
64 Self::new(data)
65 }
66}
67
68impl AsRef<[f32]> for AlignedVector {
69 fn as_ref(&self) -> &[f32] {
70 &self.data
71 }
72}
73
74impl AsMut<[f32]> for AlignedVector {
75 fn as_mut(&mut self) -> &mut [f32] {
76 &mut self.data
77 }
78}
79
80#[derive(Debug, Clone)]
82struct AccessStats {
83 access_count: u64,
85 last_access: Instant,
87 first_access: Instant,
89 time_in_cache: Duration,
91}
92
93impl AccessStats {
94 fn new() -> Self {
95 let now = Instant::now();
96 Self {
97 access_count: 1,
98 last_access: now,
99 first_access: now,
100 time_in_cache: Duration::from_secs(0),
101 }
102 }
103
104 fn record_access(&mut self) {
105 self.access_count += 1;
106 self.last_access = Instant::now();
107 self.time_in_cache = self.last_access.duration_since(self.first_access);
108 }
109
110 fn access_frequency(&self) -> f64 {
111 if self.time_in_cache.as_secs_f64() > 0.0 {
112 self.access_count as f64 / self.time_in_cache.as_secs_f64()
113 } else {
114 self.access_count as f64
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121struct CachedEmbedding {
122 vector: AlignedVector,
123 stats: AccessStats,
124}
125
126pub struct HotEmbeddingCache {
131 cache: Arc<RwLock<LruCache<String, CachedEmbedding>>>,
133 hits: Arc<AtomicU64>,
135 misses: Arc<AtomicU64>,
136 capacity: usize,
138 prefetch_queue: Arc<RwLock<Vec<String>>>,
140}
141
142impl HotEmbeddingCache {
143 pub fn new(capacity: usize) -> Self {
145 Self {
146 cache: Arc::new(RwLock::new(LruCache::new(
147 NonZeroUsize::new(capacity).unwrap(),
148 ))),
149 hits: Arc::new(AtomicU64::new(0)),
150 misses: Arc::new(AtomicU64::new(0)),
151 capacity,
152 prefetch_queue: Arc::new(RwLock::new(Vec::new())),
153 }
154 }
155
156 pub fn get(&self, key: &str) -> Option<AlignedVector> {
158 let mut cache = self.cache.write();
159 if let Some(entry) = cache.get_mut(key) {
160 entry.stats.record_access();
161 self.hits.fetch_add(1, Ordering::Relaxed);
162 Some(entry.vector.clone())
163 } else {
164 self.misses.fetch_add(1, Ordering::Relaxed);
165 None
166 }
167 }
168
169 pub fn insert(&self, key: String, vector: Vec<f32>) {
171 let aligned = AlignedVector::new(vector);
172 let entry = CachedEmbedding {
173 vector: aligned,
174 stats: AccessStats::new(),
175 };
176 self.cache.write().put(key, entry);
177 }
178
179 pub fn stats(&self) -> HotCacheStats {
181 let hits = self.hits.load(Ordering::Relaxed);
182 let misses = self.misses.load(Ordering::Relaxed);
183 let total = hits + misses;
184 let hit_rate = if total > 0 {
185 hits as f64 / total as f64
186 } else {
187 0.0
188 };
189
190 let cache = self.cache.read();
191 let size = cache.len();
192
193 HotCacheStats {
194 hits,
195 misses,
196 hit_rate,
197 size,
198 capacity: self.capacity,
199 }
200 }
201
202 pub fn clear(&self) {
204 self.cache.write().clear();
205 self.hits.store(0, Ordering::Relaxed);
206 self.misses.store(0, Ordering::Relaxed);
207 }
208
209 pub fn len(&self) -> usize {
211 self.cache.read().len()
212 }
213
214 pub fn is_empty(&self) -> bool {
216 self.cache.read().is_empty()
217 }
218
219 pub fn prefetch(&self, keys: Vec<String>) {
221 let mut queue = self.prefetch_queue.write();
222 queue.extend(keys);
223 }
224
225 pub fn get_hot_keys(&self, top_n: usize) -> Vec<String> {
227 let cache = self.cache.read();
228 let mut entries: Vec<_> = cache
229 .iter()
230 .map(|(k, v)| (k.clone(), v.stats.access_frequency()))
231 .collect();
232
233 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
234 entries.into_iter().take(top_n).map(|(k, _)| k).collect()
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct HotCacheStats {
241 pub hits: u64,
243 pub misses: u64,
245 pub hit_rate: f64,
247 pub size: usize,
249 pub capacity: usize,
251}
252
253pub struct AdaptiveCacheStrategy {
258 target_size: Arc<RwLock<usize>>,
260 min_size: usize,
262 max_size: usize,
264 target_hit_rate: f64,
266 adjustment_factor: f64,
268}
269
270impl AdaptiveCacheStrategy {
271 pub fn new(min_size: usize, max_size: usize, target_hit_rate: f64) -> Self {
273 Self {
274 target_size: Arc::new(RwLock::new((min_size + max_size) / 2)),
275 min_size,
276 max_size,
277 target_hit_rate,
278 adjustment_factor: 1.1, }
280 }
281
282 pub fn adjust(&self, current_hit_rate: f64) -> usize {
284 let mut target = self.target_size.write();
285
286 if current_hit_rate < self.target_hit_rate {
287 let new_size =
289 (*target as f64 * self.adjustment_factor).min(self.max_size as f64) as usize;
290 *target = new_size;
291 } else if current_hit_rate > self.target_hit_rate + 0.05 {
292 let new_size =
294 (*target as f64 / self.adjustment_factor).max(self.min_size as f64) as usize;
295 *target = new_size;
296 }
297
298 *target
299 }
300
301 pub fn target_size(&self) -> usize {
303 *self.target_size.read()
304 }
305
306 pub fn reset(&self) {
308 *self.target_size.write() = (self.min_size + self.max_size) / 2;
309 }
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum InvalidationPolicy {
315 TTL(Duration),
317 Event,
319 Never,
321}
322
323pub struct CacheInvalidator {
325 policy: InvalidationPolicy,
327 last_invalidation: Arc<RwLock<Instant>>,
329}
330
331impl CacheInvalidator {
332 pub fn new(policy: InvalidationPolicy) -> Self {
334 Self {
335 policy,
336 last_invalidation: Arc::new(RwLock::new(Instant::now())),
337 }
338 }
339
340 pub fn should_invalidate(&self) -> bool {
342 match self.policy {
343 InvalidationPolicy::TTL(ttl) => {
344 let elapsed = self.last_invalidation.read().elapsed();
345 elapsed >= ttl
346 }
347 InvalidationPolicy::Event => false, InvalidationPolicy::Never => false,
349 }
350 }
351
352 pub fn invalidate(&self) {
354 *self.last_invalidation.write() = Instant::now();
355 }
356
357 pub fn time_since_invalidation(&self) -> Duration {
359 self.last_invalidation.read().elapsed()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_aligned_vector_creation() {
369 let data = vec![1.0, 2.0, 3.0, 4.0];
370 let aligned = AlignedVector::new(data.clone());
371
372 assert_eq!(aligned.len(), 4);
373 assert_eq!(aligned.as_slice(), &data[..]);
374 }
375
376 #[test]
377 fn test_aligned_vector_alignment() {
378 let aligned = AlignedVector::zeros(100);
379
380 assert_eq!(
383 std::mem::align_of::<AlignedVector>(),
384 64,
385 "AlignedVector struct should be aligned to 64 bytes"
386 );
387
388 let ptr = aligned.as_slice().as_ptr() as usize;
390 assert_eq!(
391 ptr % std::mem::align_of::<f32>(),
392 0,
393 "Data pointer should be properly aligned for f32"
394 );
395 }
396
397 #[test]
398 fn test_hot_cache_basic() {
399 let cache = HotEmbeddingCache::new(10);
400
401 cache.insert("key1".to_string(), vec![1.0, 2.0, 3.0]);
403 cache.insert("key2".to_string(), vec![4.0, 5.0, 6.0]);
404
405 let vec1 = cache.get("key1").unwrap();
407 assert_eq!(vec1.as_slice(), &[1.0, 2.0, 3.0]);
408
409 let vec2 = cache.get("key2").unwrap();
410 assert_eq!(vec2.as_slice(), &[4.0, 5.0, 6.0]);
411
412 assert!(cache.get("key3").is_none());
414 }
415
416 #[test]
417 fn test_hot_cache_stats() {
418 let cache = HotEmbeddingCache::new(10);
419
420 cache.insert("key1".to_string(), vec![1.0, 2.0, 3.0]);
421
422 cache.get("key1");
424 cache.get("key2");
426
427 let stats = cache.stats();
428 assert_eq!(stats.hits, 1);
429 assert_eq!(stats.misses, 1);
430 assert_eq!(stats.hit_rate, 0.5);
431 }
432
433 #[test]
434 fn test_hot_cache_lru() {
435 let cache = HotEmbeddingCache::new(2);
436
437 cache.insert("key1".to_string(), vec![1.0]);
438 cache.insert("key2".to_string(), vec![2.0]);
439 cache.insert("key3".to_string(), vec![3.0]); assert!(cache.get("key1").is_none());
442 assert!(cache.get("key2").is_some());
443 assert!(cache.get("key3").is_some());
444 }
445
446 #[test]
447 fn test_adaptive_strategy() {
448 let strategy = AdaptiveCacheStrategy::new(100, 1000, 0.8);
449
450 let initial_size = strategy.target_size();
451 assert_eq!(initial_size, 550); let new_size = strategy.adjust(0.5);
455 assert!(new_size > initial_size);
456
457 strategy.reset();
459 let new_size = strategy.adjust(0.95);
460 assert!(new_size < initial_size);
461 }
462
463 #[test]
464 fn test_cache_invalidator_ttl() {
465 let invalidator = CacheInvalidator::new(InvalidationPolicy::TTL(Duration::from_millis(10)));
466
467 assert!(!invalidator.should_invalidate());
468
469 std::thread::sleep(Duration::from_millis(15));
470
471 assert!(invalidator.should_invalidate());
472 }
473
474 #[test]
475 fn test_cache_invalidator_never() {
476 let invalidator = CacheInvalidator::new(InvalidationPolicy::Never);
477
478 std::thread::sleep(Duration::from_millis(10));
479
480 assert!(!invalidator.should_invalidate());
481 }
482
483 #[test]
484 fn test_hot_keys_tracking() {
485 let cache = HotEmbeddingCache::new(10);
486
487 cache.insert("key1".to_string(), vec![1.0]);
488 cache.insert("key2".to_string(), vec![2.0]);
489 cache.insert("key3".to_string(), vec![3.0]);
490
491 std::thread::sleep(Duration::from_millis(1));
493
494 for _ in 0..5 {
496 cache.get("key1");
497 }
498
499 for _ in 0..2 {
501 cache.get("key2");
502 }
503
504 let hot_keys = cache.get_hot_keys(2);
505 assert_eq!(hot_keys.len(), 2);
506 assert!(hot_keys.contains(&"key1".to_string()));
508 assert!(hot_keys.contains(&"key2".to_string()));
509 }
510}