1use crate::error::{FsError, FsResult};
2use rand::RngCore;
3use std::collections::HashMap;
4use std::fs::{File, OpenOptions};
5use std::io::{Seek, SeekFrom, Write};
6use std::os::unix::fs::FileExt;
7use std::sync::Mutex;
8
9pub trait BlockStore: Send + Sync {
12 fn block_size(&self) -> usize;
14
15 fn total_blocks(&self) -> u64;
17
18 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>>;
20
21 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()>;
23
24 fn sync(&self) -> FsResult<()> {
26 Ok(())
27 }
28}
29
30pub struct MemoryBlockStore {
32 block_size: usize,
33 total_blocks: u64,
34 blocks: Mutex<HashMap<u64, Vec<u8>>>,
35}
36
37impl MemoryBlockStore {
38 pub fn new(block_size: usize, total_blocks: u64) -> Self {
39 Self {
40 block_size,
41 total_blocks,
42 blocks: Mutex::new(HashMap::new()),
43 }
44 }
45}
46
47impl BlockStore for MemoryBlockStore {
48 fn block_size(&self) -> usize {
49 self.block_size
50 }
51
52 fn total_blocks(&self) -> u64 {
53 self.total_blocks
54 }
55
56 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
57 if block_id >= self.total_blocks {
58 return Err(FsError::BlockOutOfRange(block_id));
59 }
60 let blocks = self
61 .blocks
62 .lock()
63 .map_err(|e| FsError::Internal(e.to_string()))?;
64 match blocks.get(&block_id) {
65 Some(data) => Ok(data.clone()),
66 None => {
67 Ok(vec![0u8; self.block_size])
69 }
70 }
71 }
72
73 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
74 if block_id >= self.total_blocks {
75 return Err(FsError::BlockOutOfRange(block_id));
76 }
77 if data.len() != self.block_size {
78 return Err(FsError::BlockSizeMismatch {
79 expected: self.block_size,
80 got: data.len(),
81 });
82 }
83 let mut blocks = self
84 .blocks
85 .lock()
86 .map_err(|e| FsError::Internal(e.to_string()))?;
87 blocks.insert(block_id, data.to_vec());
88 Ok(())
89 }
90}
91
92pub struct DiskBlockStore {
97 file: File,
98 block_size: usize,
99 total_blocks: u64,
100}
101
102impl DiskBlockStore {
103 pub fn open(path: &str, block_size: usize, total_blocks: u64) -> FsResult<Self> {
108 let file = OpenOptions::new()
109 .read(true)
110 .write(true)
111 .open(path)
112 .map_err(|e| FsError::Internal(format!("open {path}: {e}")))?;
113
114 let file_len = file
115 .metadata()
116 .map_err(|e| FsError::Internal(format!("stat {path}: {e}")))?
117 .len();
118
119 let total_blocks = if total_blocks == 0 {
120 file_len / block_size as u64
121 } else {
122 total_blocks
123 };
124
125 let required = total_blocks * block_size as u64;
126 if file_len < required {
127 return Err(FsError::Internal(format!(
128 "file too small: {file_len} bytes, need {required}"
129 )));
130 }
131
132 Ok(Self {
133 file,
134 block_size,
135 total_blocks,
136 })
137 }
138
139 pub fn create(path: &str, block_size: usize, total_blocks: u64) -> FsResult<Self> {
144 let mut file = OpenOptions::new()
145 .read(true)
146 .write(true)
147 .create_new(true)
148 .open(path)
149 .map_err(|e| FsError::Internal(format!("create {path}: {e}")))?;
150
151 let mut rng = rand::thread_rng();
153 let mut buf = vec![0u8; block_size];
154 for _ in 0..total_blocks {
155 rng.fill_bytes(&mut buf);
156 file.write_all(&buf)
157 .map_err(|e| FsError::Internal(format!("write {path}: {e}")))?;
158 }
159 file.sync_all()
160 .map_err(|e| FsError::Internal(format!("sync {path}: {e}")))?;
161
162 Ok(Self {
163 file,
164 block_size,
165 total_blocks,
166 })
167 }
168}
169
170impl BlockStore for DiskBlockStore {
171 fn block_size(&self) -> usize {
172 self.block_size
173 }
174
175 fn total_blocks(&self) -> u64 {
176 self.total_blocks
177 }
178
179 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
180 if block_id >= self.total_blocks {
181 return Err(FsError::BlockOutOfRange(block_id));
182 }
183 let offset = block_id * self.block_size as u64;
184 let mut buf = vec![0u8; self.block_size];
185 self.file
186 .read_exact_at(&mut buf, offset)
187 .map_err(|e| FsError::Internal(format!("read block {block_id}: {e}")))?;
188 Ok(buf)
189 }
190
191 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
192 if block_id >= self.total_blocks {
193 return Err(FsError::BlockOutOfRange(block_id));
194 }
195 if data.len() != self.block_size {
196 return Err(FsError::BlockSizeMismatch {
197 expected: self.block_size,
198 got: data.len(),
199 });
200 }
201 let offset = block_id * self.block_size as u64;
202 self.file
203 .write_all_at(data, offset)
204 .map_err(|e| FsError::Internal(format!("write block {block_id}: {e}")))?;
205 Ok(())
206 }
207
208 fn sync(&self) -> FsResult<()> {
209 self.file
210 .sync_all()
211 .map_err(|e| FsError::Internal(format!("fsync: {e}")))
212 }
213}
214
215pub struct DeviceBlockStore {
226 file: File,
227 block_size: usize,
228 total_blocks: u64,
229}
230
231impl DeviceBlockStore {
232 pub fn open(path: &str, block_size: usize, total_blocks: u64) -> FsResult<Self> {
236 let mut file = OpenOptions::new()
237 .read(true)
238 .write(true)
239 .open(path)
240 .map_err(|e| FsError::Internal(format!("open device {path}: {e}")))?;
241
242 let device_size = file
243 .seek(SeekFrom::End(0))
244 .map_err(|e| FsError::Internal(format!("seek device {path}: {e}")))?;
245
246 let total_blocks = if total_blocks == 0 {
247 device_size / block_size as u64
248 } else {
249 total_blocks
250 };
251
252 let required = total_blocks * block_size as u64;
253 if device_size < required {
254 return Err(FsError::Internal(format!(
255 "device too small: {device_size} bytes, need {required}"
256 )));
257 }
258
259 Ok(Self {
260 file,
261 block_size,
262 total_blocks,
263 })
264 }
265
266 pub fn initialize(path: &str, block_size: usize, total_blocks: u64) -> FsResult<Self> {
274 let mut file = OpenOptions::new()
275 .read(true)
276 .write(true)
277 .open(path)
278 .map_err(|e| FsError::Internal(format!("open device {path}: {e}")))?;
279
280 let device_size = file
281 .seek(SeekFrom::End(0))
282 .map_err(|e| FsError::Internal(format!("seek device {path}: {e}")))?;
283
284 let total_blocks = if total_blocks == 0 {
285 device_size / block_size as u64
286 } else {
287 total_blocks
288 };
289
290 let required = total_blocks * block_size as u64;
291 if device_size < required {
292 return Err(FsError::Internal(format!(
293 "device too small: {device_size} bytes, need {required}"
294 )));
295 }
296
297 file.seek(SeekFrom::Start(0))
299 .map_err(|e| FsError::Internal(format!("seek device {path}: {e}")))?;
300
301 let mut rng = rand::thread_rng();
302 let mut buf = vec![0u8; block_size];
303 for _ in 0..total_blocks {
304 rng.fill_bytes(&mut buf);
305 file.write_all(&buf)
306 .map_err(|e| FsError::Internal(format!("write device {path}: {e}")))?;
307 }
308 file.sync_all()
309 .map_err(|e| FsError::Internal(format!("sync device {path}: {e}")))?;
310
311 Ok(Self {
312 file,
313 block_size,
314 total_blocks,
315 })
316 }
317}
318
319impl BlockStore for DeviceBlockStore {
320 fn block_size(&self) -> usize {
321 self.block_size
322 }
323
324 fn total_blocks(&self) -> u64 {
325 self.total_blocks
326 }
327
328 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
329 if block_id >= self.total_blocks {
330 return Err(FsError::BlockOutOfRange(block_id));
331 }
332 let offset = block_id * self.block_size as u64;
333 let mut buf = vec![0u8; self.block_size];
334 self.file
335 .read_exact_at(&mut buf, offset)
336 .map_err(|e| FsError::Internal(format!("read block {block_id}: {e}")))?;
337 Ok(buf)
338 }
339
340 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
341 if block_id >= self.total_blocks {
342 return Err(FsError::BlockOutOfRange(block_id));
343 }
344 if data.len() != self.block_size {
345 return Err(FsError::BlockSizeMismatch {
346 expected: self.block_size,
347 got: data.len(),
348 });
349 }
350 let offset = block_id * self.block_size as u64;
351 self.file
352 .write_all_at(data, offset)
353 .map_err(|e| FsError::Internal(format!("write block {block_id}: {e}")))?;
354 Ok(())
355 }
356
357 fn sync(&self) -> FsResult<()> {
358 self.file
359 .sync_all()
360 .map_err(|e| FsError::Internal(format!("fsync: {e}")))
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_memory_block_store_roundtrip() {
370 let store = MemoryBlockStore::new(64, 10);
371 let data = vec![0xAB; 64];
372 store.write_block(0, &data).unwrap();
373 let read = store.read_block(0).unwrap();
374 assert_eq!(read, data);
375 }
376
377 #[test]
378 fn test_unwritten_block_returns_zeroes() {
379 let store = MemoryBlockStore::new(64, 10);
380 let read = store.read_block(5).unwrap();
381 assert_eq!(read, vec![0u8; 64]);
382 }
383
384 #[test]
385 fn test_out_of_range_read() {
386 let store = MemoryBlockStore::new(64, 10);
387 assert!(store.read_block(10).is_err());
388 }
389
390 #[test]
391 fn test_block_size_mismatch() {
392 let store = MemoryBlockStore::new(64, 10);
393 assert!(store.write_block(0, &[0u8; 32]).is_err());
394 }
395
396 #[test]
397 fn test_disk_block_store_roundtrip() {
398 let dir = std::env::temp_dir();
399 let path = dir.join(format!("doublecrypt_test_{}.img", std::process::id()));
400 let path_str = path.to_str().unwrap();
401
402 let _ = std::fs::remove_file(&path);
404
405 let store = DiskBlockStore::create(path_str, 512, 16).unwrap();
406 let data = vec![0xAB; 512];
407 store.write_block(0, &data).unwrap();
408 store.sync().unwrap();
409 let read = store.read_block(0).unwrap();
410 assert_eq!(read, data);
411
412 let unwritten = store.read_block(10).unwrap();
414 assert_eq!(unwritten.len(), 512);
415 assert!(unwritten.iter().any(|&b| b != 0));
417
418 assert!(store.read_block(16).is_err());
420 assert!(store.write_block(16, &data).is_err());
421
422 assert!(store.write_block(0, &[0u8; 64]).is_err());
424
425 drop(store);
426 std::fs::remove_file(&path).unwrap();
427 }
428
429 #[test]
430 fn test_disk_block_store_open_existing() {
431 let dir = std::env::temp_dir();
432 let path = dir.join(format!("doublecrypt_test_open_{}.img", std::process::id()));
433 let path_str = path.to_str().unwrap();
434 let _ = std::fs::remove_file(&path);
435
436 {
438 let store = DiskBlockStore::create(path_str, 256, 8).unwrap();
439 let data = vec![0xCD; 256];
440 store.write_block(3, &data).unwrap();
441 store.sync().unwrap();
442 }
443
444 {
446 let store = DiskBlockStore::open(path_str, 256, 8).unwrap();
447 let read = store.read_block(3).unwrap();
448 assert_eq!(read, vec![0xCD; 256]);
449 }
450
451 {
453 let store = DiskBlockStore::open(path_str, 256, 0).unwrap();
454 assert_eq!(store.total_blocks(), 8);
455 }
456
457 std::fs::remove_file(&path).unwrap();
458 }
459
460 #[test]
461 fn test_disk_block_store_file_too_small() {
462 let dir = std::env::temp_dir();
463 let path = dir.join(format!("doublecrypt_test_small_{}.img", std::process::id()));
464 let path_str = path.to_str().unwrap();
465 let _ = std::fs::remove_file(&path);
466
467 std::fs::write(&path, vec![0u8; 100]).unwrap();
469
470 assert!(DiskBlockStore::open(path_str, 256, 8).is_err());
472
473 std::fs::remove_file(&path).unwrap();
474 }
475}