1use bytes::Bytes;
7use memmap2::Mmap;
8use parking_lot::RwLock;
9use std::fs::File;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use crate::{Error, Result};
13
14#[derive(Clone)]
16pub struct MmapReader {
17 inner: Arc<MmapReaderInner>,
18}
19
20struct MmapReaderInner {
21 path: PathBuf,
22 mmap: RwLock<Option<Mmap>>,
23}
24
25impl MmapReader {
26 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
28 let path = path.as_ref().to_path_buf();
29 let file = File::open(&path)?;
30
31 let mmap = unsafe { Mmap::map(&file)? };
33
34 Ok(Self {
35 inner: Arc::new(MmapReaderInner {
36 path,
37 mmap: RwLock::new(Some(mmap)),
38 }),
39 })
40 }
41
42 pub fn read(&self, offset: u64, len: usize) -> Result<Bytes> {
44 let mmap_guard = self.inner.mmap.read();
45 let mmap = mmap_guard.as_ref()
46 .ok_or_else(|| Error::Internal("Mmap has been closed".to_string()))?;
47
48 let offset = offset as usize;
49 let end = offset.checked_add(len)
50 .ok_or_else(|| Error::InvalidArgument("Read would overflow".to_string()))?;
51
52 if end > mmap.len() {
53 return Err(Error::InvalidArgument(format!(
54 "Read beyond file bounds: {} + {} > {}",
55 offset, len, mmap.len()
56 )));
57 }
58
59 Ok(Bytes::copy_from_slice(&mmap[offset..end]))
60 }
61
62 pub fn read_block(&self, offset: u64) -> Result<Bytes> {
64 self.read(offset, crate::layout::BLOCK_SIZE)
65 }
66
67 pub fn len(&self) -> usize {
69 let mmap_guard = self.inner.mmap.read();
70 mmap_guard.as_ref().map(|m| m.len()).unwrap_or(0)
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.len() == 0
76 }
77
78 pub fn path(&self) -> &Path {
80 &self.inner.path
81 }
82
83 pub fn close(&self) {
85 let mut mmap_guard = self.inner.mmap.write();
86 *mmap_guard = None;
87 }
88}
89
90pub struct MmapPool {
94 readers: RwLock<std::collections::HashMap<PathBuf, MmapReader>>,
95}
96
97impl MmapPool {
98 pub fn new() -> Self {
100 Self {
101 readers: RwLock::new(std::collections::HashMap::new()),
102 }
103 }
104
105 pub fn get_or_open(&self, path: impl AsRef<Path>) -> Result<MmapReader> {
107 let path = path.as_ref();
108
109 {
111 let readers = self.readers.read();
112 if let Some(reader) = readers.get(path) {
113 return Ok(reader.clone());
114 }
115 }
116
117 let reader = MmapReader::open(path)?;
119
120 let mut readers = self.readers.write();
121 readers.insert(path.to_path_buf(), reader.clone());
122
123 Ok(reader)
124 }
125
126 pub fn remove(&self, path: impl AsRef<Path>) {
128 let mut readers = self.readers.write();
129 if let Some(reader) = readers.remove(path.as_ref()) {
130 reader.close();
131 }
132 }
133
134 pub fn clear(&self) {
136 let mut readers = self.readers.write();
137 for (_, reader) in readers.drain() {
138 reader.close();
139 }
140 }
141
142 pub fn len(&self) -> usize {
144 let readers = self.readers.read();
145 readers.len()
146 }
147
148 pub fn is_empty(&self) -> bool {
150 self.len() == 0
151 }
152}
153
154impl Default for MmapPool {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use tempfile::NamedTempFile;
164 use std::io::Write;
165
166 #[test]
167 fn test_mmap_reader_basic() {
168 let mut tmp = NamedTempFile::new().unwrap();
169 tmp.write_all(b"hello world").unwrap();
170 tmp.flush().unwrap();
171
172 let reader = MmapReader::open(tmp.path()).unwrap();
173
174 let data = reader.read(0, 5).unwrap();
175 assert_eq!(&data[..], b"hello");
176
177 let data = reader.read(6, 5).unwrap();
178 assert_eq!(&data[..], b"world");
179
180 assert_eq!(reader.len(), 11);
181 assert!(!reader.is_empty());
182 }
183
184 #[test]
185 fn test_mmap_reader_out_of_bounds() {
186 let mut tmp = NamedTempFile::new().unwrap();
187 tmp.write_all(b"test").unwrap();
188 tmp.flush().unwrap();
189
190 let reader = MmapReader::open(tmp.path()).unwrap();
191
192 let result = reader.read(0, 100);
194 assert!(matches!(result, Err(Error::InvalidArgument(_))));
195
196 let result = reader.read(100, 1);
198 assert!(matches!(result, Err(Error::InvalidArgument(_))));
199 }
200
201 #[test]
202 fn test_mmap_reader_close() {
203 let mut tmp = NamedTempFile::new().unwrap();
204 tmp.write_all(b"data").unwrap();
205 tmp.flush().unwrap();
206
207 let reader = MmapReader::open(tmp.path()).unwrap();
208
209 let data = reader.read(0, 4).unwrap();
210 assert_eq!(&data[..], b"data");
211
212 reader.close();
213
214 let result = reader.read(0, 1);
216 assert!(matches!(result, Err(Error::Internal(_))));
217 }
218
219 #[test]
220 fn test_mmap_pool_basic() {
221 let mut tmp = NamedTempFile::new().unwrap();
222 tmp.write_all(b"pool test").unwrap();
223 tmp.flush().unwrap();
224
225 let pool = MmapPool::new();
226 assert!(pool.is_empty());
227
228 let reader1 = pool.get_or_open(tmp.path()).unwrap();
230 assert_eq!(pool.len(), 1);
231
232 let reader2 = pool.get_or_open(tmp.path()).unwrap();
234 assert_eq!(pool.len(), 1);
235
236 let data1 = reader1.read(0, 4).unwrap();
238 let data2 = reader2.read(0, 4).unwrap();
239 assert_eq!(data1, data2);
240 }
241
242 #[test]
243 fn test_mmap_pool_remove() {
244 let mut tmp = NamedTempFile::new().unwrap();
245 tmp.write_all(b"test").unwrap();
246 tmp.flush().unwrap();
247
248 let pool = MmapPool::new();
249 let _reader = pool.get_or_open(tmp.path()).unwrap();
250
251 assert_eq!(pool.len(), 1);
252 pool.remove(tmp.path());
253 assert_eq!(pool.len(), 0);
254 }
255
256 #[test]
257 fn test_mmap_pool_clear() {
258 let mut tmp1 = NamedTempFile::new().unwrap();
259 tmp1.write_all(b"file1").unwrap();
260 tmp1.flush().unwrap();
261
262 let mut tmp2 = NamedTempFile::new().unwrap();
263 tmp2.write_all(b"file2").unwrap();
264 tmp2.flush().unwrap();
265
266 let pool = MmapPool::new();
267 pool.get_or_open(tmp1.path()).unwrap();
268 pool.get_or_open(tmp2.path()).unwrap();
269
270 assert_eq!(pool.len(), 2);
271 pool.clear();
272 assert_eq!(pool.len(), 0);
273 }
274
275 #[test]
276 fn test_mmap_read_block() {
277 let mut tmp = NamedTempFile::new().unwrap();
278 let data = vec![0xAB; crate::layout::BLOCK_SIZE];
279 tmp.write_all(&data).unwrap();
280 tmp.flush().unwrap();
281
282 let reader = MmapReader::open(tmp.path()).unwrap();
283 let block = reader.read_block(0).unwrap();
284
285 assert_eq!(block.len(), crate::layout::BLOCK_SIZE);
286 assert_eq!(&block[..], &data[..]);
287 }
288}