reflex/cache/l2/
loader.rs

1use crate::storage::mmap::{AlignedMmapBuilder, MmapFileHandle};
2use crate::storage::{CacheEntry, StorageError, StorageWriter};
3
4/// Loads cached entries (typically from disk) given a storage key.
5pub trait StorageLoader: Send + Sync {
6    /// Loads a cache entry for `tenant_id`, returning `None` on missing/mismatch.
7    fn load(
8        &self,
9        storage_key: &str,
10        tenant_id: u64,
11    ) -> impl std::future::Future<Output = Option<CacheEntry>> + Send;
12}
13
14#[cfg(any(test, feature = "mock"))]
15#[derive(Default, Clone)]
16/// In-memory [`StorageLoader`] used by tests and examples.
17pub struct MockStorageLoader {
18    entries: std::sync::Arc<std::sync::RwLock<std::collections::HashMap<String, CacheEntry>>>,
19}
20
21#[cfg(any(test, feature = "mock"))]
22impl MockStorageLoader {
23    /// Creates an empty mock loader.
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    /// Inserts an entry under `key`.
29    pub fn insert(&self, key: &str, entry: CacheEntry) {
30        self.entries
31            .write()
32            .expect("lock poisoned")
33            .insert(key.to_string(), entry);
34    }
35
36    /// Returns the number of stored entries.
37    pub fn len(&self) -> usize {
38        self.entries.read().expect("lock poisoned").len()
39    }
40
41    /// Returns `true` if empty.
42    pub fn is_empty(&self) -> bool {
43        self.entries.read().expect("lock poisoned").is_empty()
44    }
45}
46
47#[cfg(any(test, feature = "mock"))]
48impl StorageLoader for MockStorageLoader {
49    async fn load(&self, storage_key: &str, _tenant_id: u64) -> Option<CacheEntry> {
50        self.entries
51            .read()
52            .expect("lock poisoned")
53            .get(storage_key)
54            .cloned()
55    }
56}
57
58#[cfg(any(test, feature = "mock"))]
59impl StorageWriter for MockStorageLoader {
60    fn write(
61        &self,
62        key: &str,
63        data: &[u8],
64    ) -> Result<MmapFileHandle, crate::storage::StorageError> {
65        use std::io::Write;
66
67        // Create a temp file and write data to it
68        let mut temp_file = tempfile::NamedTempFile::new()
69            .map_err(|e| crate::storage::StorageError::WriteFailed(e.to_string()))?;
70        temp_file
71            .write_all(data)
72            .map_err(|e| crate::storage::StorageError::WriteFailed(e.to_string()))?;
73        temp_file
74            .flush()
75            .map_err(|e| crate::storage::StorageError::WriteFailed(e.to_string()))?;
76
77        // Also store the entry in our mock storage for later retrieval
78        if let Ok(entry) = rkyv::from_bytes::<CacheEntry, rkyv::rancor::Error>(data) {
79            self.insert(key, entry);
80        }
81
82        // Open as mmap handle
83        MmapFileHandle::open(temp_file.path())
84            .map_err(|e| crate::storage::StorageError::WriteFailed(e.to_string()))
85    }
86}
87
88#[derive(Debug, Clone)]
89/// NVMe-backed storage loader that reads `rkyv`-serialized entries via mmap.
90pub struct NvmeStorageLoader {
91    storage_path: std::path::PathBuf,
92}
93
94impl NvmeStorageLoader {
95    /// Creates a loader rooted at `storage_path`.
96    pub fn new(storage_path: std::path::PathBuf) -> Self {
97        Self { storage_path }
98    }
99
100    /// Returns the root storage path.
101    pub fn storage_path(&self) -> &std::path::Path {
102        &self.storage_path
103    }
104}
105
106fn sanitize_storage_key(storage_key: &str) -> Option<std::path::PathBuf> {
107    use std::path::{Component, Path};
108
109    if storage_key.is_empty() {
110        return None;
111    }
112
113    let p = Path::new(storage_key);
114    let mut out = std::path::PathBuf::new();
115
116    for c in p.components() {
117        match c {
118            Component::Normal(seg) => out.push(seg),
119            Component::CurDir => continue,
120            Component::ParentDir | Component::RootDir | Component::Prefix(_) => return None,
121        }
122    }
123
124    if out.as_os_str().is_empty() {
125        None
126    } else {
127        Some(out)
128    }
129}
130
131impl StorageWriter for NvmeStorageLoader {
132    fn write(&self, key: &str, data: &[u8]) -> Result<MmapFileHandle, StorageError> {
133        let rel = sanitize_storage_key(key).ok_or_else(|| {
134            StorageError::Io(format!("Invalid storage key (path traversal?): {}", key))
135        })?;
136        let file_path = self.storage_path.join(rel);
137
138        if let Some(parent) = file_path.parent() {
139            std::fs::create_dir_all(parent)
140                .map_err(|e| StorageError::Io(format!("Failed to create directory: {}", e)))?;
141        }
142
143        let builder = AlignedMmapBuilder::new(file_path);
144        builder
145            .write_readonly(data)
146            .map_err(|e| StorageError::WriteFailed(format!("Failed to write file: {}", e)))
147    }
148}
149
150impl StorageLoader for NvmeStorageLoader {
151    async fn load(&self, storage_key: &str, tenant_id: u64) -> Option<CacheEntry> {
152        use crate::storage::mmap::MmapFileHandle;
153        use rkyv::from_bytes;
154        use rkyv::rancor::Error;
155
156        let storage_path = self.storage_path.clone();
157        let storage_key = storage_key.to_string();
158
159        tokio::task::spawn_blocking(move || {
160            let rel = match sanitize_storage_key(&storage_key) {
161                Some(r) => r,
162                None => {
163                    tracing::warn!(
164                        storage_key = %storage_key,
165                        "Rejected invalid storage_key (path traversal?)"
166                    );
167                    return None;
168                }
169            };
170
171            let file_path = storage_path.join(rel);
172            let handle = match MmapFileHandle::open(&file_path) {
173                Ok(h) => h,
174                Err(_) => return None,
175            };
176            let bytes = handle.as_slice();
177
178            let entry: CacheEntry = match from_bytes::<CacheEntry, Error>(bytes) {
179                Ok(e) => e,
180                Err(e) => {
181                    tracing::warn!(
182                        "Failed to deserialize cache entry at {:?}: {}",
183                        file_path,
184                        e
185                    );
186                    return None;
187                }
188            };
189
190            if entry.tenant_id != tenant_id {
191                tracing::warn!(
192                    "Tenant ID mismatch for key {}: expected {}, found {}",
193                    storage_key,
194                    tenant_id,
195                    entry.tenant_id
196                );
197                return None;
198            }
199
200            Some(entry)
201        })
202        .await
203        .ok()
204        .flatten()
205    }
206}