1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Instant;
12
13use tokio::fs;
14use tokio::io::AsyncWriteExt;
15use tokio::sync::RwLock;
16use tracing::{debug, info, warn};
17
18use apiary_core::error::ApiaryError;
19use apiary_core::storage::StorageBackend;
20use apiary_core::Result;
21
22#[derive(Debug, Clone)]
24pub struct CacheEntry {
25 pub storage_key: String,
27
28 pub local_path: PathBuf,
30
31 pub size: u64,
33
34 pub last_accessed: Instant,
36}
37
38pub struct CellCache {
49 cache_dir: PathBuf,
51
52 max_size: u64,
54
55 current_size: Arc<AtomicU64>,
57
58 entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
60
61 storage: Arc<dyn StorageBackend>,
63}
64
65impl CellCache {
66 pub async fn new(
78 cache_dir: PathBuf,
79 max_size: u64,
80 storage: Arc<dyn StorageBackend>,
81 ) -> Result<Self> {
82 fs::create_dir_all(&cache_dir)
84 .await
85 .map_err(|e| ApiaryError::Storage {
86 message: format!("Failed to create cache directory: {:?}", cache_dir),
87 source: Some(Box::new(e)),
88 })?;
89
90 info!(
91 cache_dir = ?cache_dir,
92 max_size_mb = max_size / (1024 * 1024),
93 "Cell cache initialized"
94 );
95
96 Ok(Self {
97 cache_dir,
98 max_size,
99 current_size: Arc::new(AtomicU64::new(0)),
100 entries: Arc::new(RwLock::new(HashMap::new())),
101 storage,
102 })
103 }
104
105 pub async fn get(&self, storage_key: &str) -> Result<PathBuf> {
121 {
123 let entries = self.entries.read().await;
124 if let Some(entry) = entries.get(storage_key) {
125 let path = entry.local_path.clone();
126 drop(entries); let mut entries_write = self.entries.write().await;
130 if let Some(entry) = entries_write.get_mut(storage_key) {
131 entry.last_accessed = Instant::now();
132 }
133
134 debug!(storage_key, "Cache hit");
135 return Ok(path);
136 }
137 }
138
139 debug!(storage_key, "Cache miss - fetching from storage");
141
142 let sanitized = storage_key.replace('/', "_");
144 let local_path = self.cache_dir.join(&sanitized);
145
146 let data = self.storage.get(storage_key).await?;
148 let size = data.len() as u64;
149
150 let mut file = fs::File::create(&local_path)
152 .await
153 .map_err(|e| ApiaryError::Storage {
154 message: format!("Failed to create cache file: {:?}", local_path),
155 source: Some(Box::new(e)),
156 })?;
157
158 file.write_all(&data)
159 .await
160 .map_err(|e| ApiaryError::Storage {
161 message: format!("Failed to write cache file: {:?}", local_path),
162 source: Some(Box::new(e)),
163 })?;
164
165 {
167 let mut entries = self.entries.write().await;
168 entries.insert(
169 storage_key.to_string(),
170 CacheEntry {
171 storage_key: storage_key.to_string(),
172 local_path: local_path.clone(),
173 size,
174 last_accessed: Instant::now(),
175 },
176 );
177 }
178
179 self.current_size.fetch_add(size, Ordering::SeqCst);
181
182 self.evict_if_needed().await?;
184
185 debug!(storage_key, size, "Cell cached");
186 Ok(local_path)
187 }
188
189 async fn evict_if_needed(&self) -> Result<()> {
200 let current = self.current_size.load(Ordering::SeqCst);
201 if current <= self.max_size {
202 return Ok(());
203 }
204
205 info!(
206 current_mb = current / (1024 * 1024),
207 max_mb = self.max_size / (1024 * 1024),
208 "Cache size exceeded - starting LRU eviction"
209 );
210
211 let mut entries = self.entries.write().await;
212
213 let mut sorted: Vec<_> = entries.values().cloned().collect();
216 sorted.sort_by_key(|e| e.last_accessed);
217
218 let mut freed = 0u64;
220 for entry in sorted {
221 if self.current_size.load(Ordering::SeqCst) <= self.max_size {
222 break;
223 }
224
225 entries.remove(&entry.storage_key);
227
228 if let Err(e) = fs::remove_file(&entry.local_path).await {
230 warn!(
231 path = ?entry.local_path,
232 error = %e,
233 "Failed to delete evicted cache file"
234 );
235 } else {
236 freed += entry.size;
237 self.current_size.fetch_sub(entry.size, Ordering::SeqCst);
238 debug!(
239 storage_key = entry.storage_key,
240 size = entry.size,
241 "Evicted cache entry"
242 );
243 }
244 }
245
246 info!(
247 freed_mb = freed / (1024 * 1024),
248 remaining_mb = self.current_size.load(Ordering::SeqCst) / (1024 * 1024),
249 "Cache eviction complete"
250 );
251
252 Ok(())
253 }
254
255 pub fn size(&self) -> u64 {
257 self.current_size.load(Ordering::SeqCst)
258 }
259
260 pub async fn list_cached_cells(&self) -> HashMap<String, u64> {
264 let entries = self.entries.read().await;
265 entries
266 .iter()
267 .map(|(key, entry)| (key.clone(), entry.size))
268 .collect()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use apiary_storage::local::LocalBackend;
276 use tempfile::TempDir;
277
278 #[tokio::test]
279 async fn test_cache_miss_fetches_from_storage() {
280 let storage_dir = TempDir::new().unwrap();
281 let cache_dir = TempDir::new().unwrap();
282
283 let storage = Arc::new(LocalBackend::new(storage_dir.path()).await.unwrap());
284
285 let test_data = b"test cell data";
287 storage
288 .put("cells/test.parquet", test_data.as_slice().into())
289 .await
290 .unwrap();
291
292 let cache = CellCache::new(
294 cache_dir.path().to_path_buf(),
295 10 * 1024 * 1024, storage.clone(),
297 )
298 .await
299 .unwrap();
300
301 let path = cache.get("cells/test.parquet").await.unwrap();
303 assert!(path.exists());
304
305 let cached_data = fs::read(&path).await.unwrap();
307 assert_eq!(cached_data, test_data);
308
309 assert_eq!(cache.size(), test_data.len() as u64);
311 }
312
313 #[tokio::test]
314 async fn test_cache_hit_returns_cached_path() {
315 let storage_dir = TempDir::new().unwrap();
316 let cache_dir = TempDir::new().unwrap();
317
318 let storage = Arc::new(LocalBackend::new(storage_dir.path()).await.unwrap());
319 storage
320 .put("cells/test.parquet", b"data".as_slice().into())
321 .await
322 .unwrap();
323
324 let cache = CellCache::new(
325 cache_dir.path().to_path_buf(),
326 10 * 1024 * 1024,
327 storage.clone(),
328 )
329 .await
330 .unwrap();
331
332 let path1 = cache.get("cells/test.parquet").await.unwrap();
334
335 let path2 = cache.get("cells/test.parquet").await.unwrap();
337 assert_eq!(path1, path2);
338 }
339
340 #[tokio::test]
341 async fn test_lru_eviction() {
342 let storage_dir = TempDir::new().unwrap();
343 let cache_dir = TempDir::new().unwrap();
344
345 let storage = Arc::new(LocalBackend::new(storage_dir.path()).await.unwrap());
346
347 let data1 = vec![1u8; 1024]; let data2 = vec![2u8; 1024]; let data3 = vec![3u8; 1024]; storage
353 .put("cells/cell1.parquet", data1.into())
354 .await
355 .unwrap();
356 storage
357 .put("cells/cell2.parquet", data2.into())
358 .await
359 .unwrap();
360 storage
361 .put("cells/cell3.parquet", data3.into())
362 .await
363 .unwrap();
364
365 let cache = CellCache::new(
367 cache_dir.path().to_path_buf(),
368 2500, storage.clone(),
370 )
371 .await
372 .unwrap();
373
374 let path1 = cache.get("cells/cell1.parquet").await.unwrap();
376 assert!(path1.exists());
377
378 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
380
381 let path2 = cache.get("cells/cell2.parquet").await.unwrap();
383 assert!(path2.exists());
384
385 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
387
388 let path3 = cache.get("cells/cell3.parquet").await.unwrap();
390 assert!(path3.exists());
391
392 assert!(!path1.exists());
394
395 assert!(path2.exists());
397 assert!(path3.exists());
398
399 assert!(cache.size() <= 2500);
401 }
402
403 #[tokio::test]
404 async fn test_list_cached_cells() {
405 let storage_dir = TempDir::new().unwrap();
406 let cache_dir = TempDir::new().unwrap();
407
408 let storage = Arc::new(LocalBackend::new(storage_dir.path()).await.unwrap());
409 storage
410 .put("cells/cell1.parquet", b"data1".as_slice().into())
411 .await
412 .unwrap();
413 storage
414 .put("cells/cell2.parquet", b"data22".as_slice().into())
415 .await
416 .unwrap();
417
418 let cache = CellCache::new(
419 cache_dir.path().to_path_buf(),
420 10 * 1024 * 1024,
421 storage.clone(),
422 )
423 .await
424 .unwrap();
425
426 cache.get("cells/cell1.parquet").await.unwrap();
428 cache.get("cells/cell2.parquet").await.unwrap();
429
430 let cached = cache.list_cached_cells().await;
432 assert_eq!(cached.len(), 2);
433 assert_eq!(cached.get("cells/cell1.parquet"), Some(&5u64));
434 assert_eq!(cached.get("cells/cell2.parquet"), Some(&6u64));
435 }
436}