Skip to main content

harddrive_party/
shares.rs

1//! Index shared directories
2use crate::{
3    subtree_names::{DIRS, FILES, SHARE_NAMES},
4    wire_messages::{Entry, LsResponse},
5};
6use async_walkdir::WalkDir;
7use futures::stream::StreamExt;
8use log::{debug, info, warn};
9use sled::IVec;
10use std::path::{Path, PathBuf, MAIN_SEPARATOR};
11use thiserror::Error;
12
13/// The maximum number of query results we can store in a single message
14pub const MAX_ENTRIES_PER_MESSAGE: usize = 64;
15
16/// The share index
17#[derive(Clone)]
18pub struct Shares {
19    /// Filepaths mapped to their size in bytes
20    files: sled::Tree,
21    /// Directory paths mapped to their size in bytes
22    dirs: sled::Tree,
23    /// The displayed names of shared directories mapped to their actual path on disk
24    share_names: sled::Tree, // or should this be an in-memory hashmap?
25}
26
27impl Shares {
28    /// Setup share index giving a path to use for persistant storage
29    pub async fn new(db: sled::Db, share_dirs: Vec<String>) -> Result<Self, CreateSharesError> {
30        let files = db.open_tree(FILES)?;
31        let dirs = db.open_tree(DIRS)?;
32        dirs.set_merge_operator(addition_merge);
33        let share_names = db.open_tree(SHARE_NAMES)?;
34
35        let mut shares = Shares {
36            files,
37            dirs,
38            share_names,
39        };
40
41        for share_dir in share_dirs {
42            shares.scan(&share_dir).await?;
43        }
44
45        Ok(shares)
46    }
47
48    /// Index a given directory and return the number of entries added to the database
49    // TODO #16 handle share name collisions
50    pub async fn scan(&mut self, root: &str) -> Result<u32, ScanDirError> {
51        let mut added_entries = 0;
52        let path = PathBuf::from(root);
53        let path_clone = &path.clone();
54
55        // share_name is what we refer to the shared dir by
56        let path_clone_2 = path.clone();
57        let share_name = path_clone_2
58            .file_name()
59            .ok_or(ScanDirError::GetParentError)?
60            .to_str()
61            .ok_or(ScanDirError::OsStringError())?;
62
63        let path_os_str = path.clone().into_os_string();
64        let path_str = path_os_str.to_str().ok_or(ScanDirError::OsStringError())?;
65        self.share_names.insert(share_name, path_str)?;
66
67        // Remove existing entries before beginning
68        if let Err(err) = self.remove_share_dir(share_name) {
69            match err {
70                // Ignore the error if it didn't exist
71                ScanDirError::NoShare => {}
72                _ => return Err(err),
73            }
74        };
75
76        let mut entries = WalkDir::new(path);
77        loop {
78            match entries.next().await {
79                Some(Ok(entry)) => {
80                    let metadata = entry.metadata().await?;
81                    if !metadata.is_dir() {
82                        // Remove the 'path' portion of the entry, and join it with share_name
83                        let ep = entry.path();
84                        let entry_path = ep.strip_prefix(path_clone)?;
85                        let sn = path_clone.file_name().ok_or(ScanDirError::GetParentError)?;
86                        let entry_path_with_share_name = Path::new(sn).join(entry_path);
87                        let filepath = entry_path_with_share_name
88                            .to_str()
89                            .ok_or(ScanDirError::OsStringError())?;
90
91                        let size = metadata.len().to_le_bytes();
92
93                        // For each component of the path, add the size into the directory sizes index
94                        for sub_path in entry_path_with_share_name
95                            .parent()
96                            .ok_or(ScanDirError::GetParentError)?
97                            .ancestors()
98                        {
99                            let sub_path_bytes = sub_path
100                                .to_str()
101                                .ok_or(ScanDirError::OsStringError())?
102                                .as_bytes();
103                            self.dirs.merge(sub_path_bytes, size)?;
104                        }
105                        self.files.insert(filepath.as_bytes(), &size)?;
106                        info!("{:?} {:?}", entry.path(), entry.metadata().await?.is_file());
107                        added_entries += 1;
108                    }
109                }
110                Some(Err(e)) => {
111                    warn!("Error {e}");
112                    return Err(ScanDirError::IOError(e));
113                }
114                None => break,
115            };
116        }
117        Ok(added_entries)
118    }
119
120    /// ls or search query
121    pub fn query(
122        &self,
123        path_option: Option<String>,
124        searchterm: Option<String>,
125        recursive: bool,
126    ) -> Result<Box<dyn Iterator<Item = LsResponse> + Send>, EntryParseError> {
127        let path = path_option.unwrap_or_default();
128
129        // Check that the given subdir / file exists
130        if let Ok(None) = self.dirs.get(&path) {
131            if let Ok(None) = self.files.get(&path) {
132                return Err(EntryParseError::PathNotFound);
133            }
134        }
135
136        let path_len = path.len();
137        let searchterm = searchterm.map(|s| s.to_lowercase());
138        let searchterm_clone = searchterm.clone();
139
140        let dirs_iter = self.dirs.scan_prefix(&path).filter_map(move |kv_result| {
141            kv_filter_map(kv_result, true, recursive, path_len, &searchterm)
142        });
143
144        let files_iter = self.files.scan_prefix(&path).filter_map(move |kv_result| {
145            kv_filter_map(kv_result, false, recursive, path_len, &searchterm_clone)
146        });
147
148        let entries_iter = dirs_iter.chain(files_iter);
149
150        let chunked = Chunker {
151            inner: Box::new(entries_iter),
152            chunk_size: MAX_ENTRIES_PER_MESSAGE,
153        };
154
155        let response_iter = chunked.map(LsResponse::Success);
156
157        Ok(Box::new(response_iter))
158    }
159
160    /// Resolve a path from a request by looking up the absolute path associated with its share name
161    /// component
162    pub fn resolve_path(&self, input_path: String) -> Result<(PathBuf, u64), ResolvePathError> {
163        info!("Resolving path {input_path}");
164
165        let size = match self.files.get(&input_path)? {
166            Some(size_buf) => u64::from_le_bytes(
167                size_buf
168                    .to_vec()
169                    .try_into()
170                    .map_err(|_| ResolvePathError::BadShareName)?,
171            ),
172            None => {
173                return Err(ResolvePathError::FileNotFound);
174            }
175        };
176
177        let input_path_path_buf = PathBuf::from(input_path);
178        let mut input_path_iter = input_path_path_buf.iter();
179        let share_name = input_path_iter
180            .next()
181            .ok_or(ResolvePathError::MissingFirstComponent)?;
182
183        let sub_path: PathBuf = input_path_iter.collect();
184
185        let share_name_bytes = share_name
186            .to_str()
187            .ok_or(ResolvePathError::MissingFirstComponent)?
188            .as_bytes();
189
190        let actual_path_bytes = self
191            .share_names
192            .get(share_name_bytes)?
193            .ok_or(ResolvePathError::BadShareName)?;
194
195        let actual_path = PathBuf::from(std::str::from_utf8(&actual_path_bytes)?);
196        Ok((actual_path.join(sub_path), size))
197    }
198
199    /// Stop sharing a directory by removing related entries from the database
200    pub fn remove_share_dir(&mut self, share_name: &str) -> Result<(), ScanDirError> {
201        // First find the old total size of share dir and subtract it from the "" entry
202        if let Some(existing_size) = self.get_dir_size(share_name) {
203            self.dirs
204                .fetch_and_update("", |root_size_option: Option<&[u8]>| {
205                    let new_size = match root_size_option {
206                        Some(root_size_buf) => match root_size_buf.to_vec().try_into() {
207                            Ok(root_size_arr) => {
208                                let root_size = u64::from_le_bytes(root_size_arr);
209                                root_size - existing_size
210                            }
211                            Err(_) => 0,
212                        },
213                        None => 0,
214                    };
215                    Some(new_size.to_le_bytes().to_vec())
216                })?;
217
218            for (entry, _) in self.dirs.scan_prefix(share_name).flatten() {
219                debug!("Deleting existing entry {entry:?}");
220                self.dirs.remove(entry)?;
221            }
222            for (entry, _) in self.files.scan_prefix(share_name).flatten() {
223                debug!("Deleting existing entry {entry:?}");
224                self.files.remove(entry)?;
225            }
226            Ok(())
227        } else {
228            Err(ScanDirError::NoShare)
229        }
230    }
231
232    pub async fn flush(&self) {
233        let _ = self.files.flush_async().await;
234        let _ = self.dirs.flush_async().await;
235        let _ = self.share_names.flush_async().await;
236    }
237
238    fn get_dir_size(&mut self, dir_name: &str) -> Option<u64> {
239        let existing_ivec = self.dirs.get(dir_name).ok()??;
240        Some(u64::from_le_bytes(existing_ivec.to_vec().try_into().ok()?))
241    }
242}
243
244/// Filter a key/value database entry based on query and if selected convert to a struct
245fn kv_filter_map(
246    kv_result: Result<(IVec, IVec), sled::Error>,
247    is_dir: bool,
248    recursive: bool,
249    path_len: usize,
250    searchterm: &Option<String>,
251) -> Option<Entry> {
252    let (name, size) = kv_result.ok()?;
253    let name = std::str::from_utf8(&name).ok()?;
254
255    if !recursive {
256        // TODO should we use pathbuf for this?
257        let full_suffix = &name[path_len..];
258        let suffix = if full_suffix.starts_with(MAIN_SEPARATOR) {
259            &full_suffix[1..]
260        } else {
261            full_suffix
262        };
263        if suffix.contains(MAIN_SEPARATOR) {
264            return None;
265        }
266    }
267
268    if let Some(search) = searchterm {
269        if !name.to_lowercase().contains(search) {
270            return None;
271        };
272    }
273
274    let size = u64::from_le_bytes(size.to_vec().try_into().ok()?);
275    Some(Entry {
276        name: name.to_string(),
277        size,
278        is_dir,
279    })
280}
281
282/// Turn an iterator into an iterator containing vectors of chunks of a given size
283pub struct Chunker<T> {
284    pub inner: Box<dyn Iterator<Item = T> + Send>,
285    pub chunk_size: usize,
286}
287
288impl<T> Iterator for Chunker<T> {
289    type Item = Vec<T>;
290
291    fn next(&mut self) -> Option<Self::Item> {
292        let mut entries = Vec::new();
293        for e in self.inner.by_ref() {
294            entries.push(e);
295            if entries.len() == self.chunk_size {
296                return Some(entries);
297            }
298        }
299        match entries.len() {
300            0 => None,
301            _ => Some(entries),
302        }
303    }
304}
305
306/// To make cumulative directory sizes by adding the size of their containing files
307fn addition_merge(_key: &[u8], old_value: Option<&[u8]>, merged_bytes: &[u8]) -> Option<Vec<u8>> {
308    let old_size = match old_value {
309        Some(v) => u64::from_le_bytes(v.try_into().unwrap_or([0; 8])),
310        None => 0,
311    };
312    let to_add = u64::from_le_bytes(merged_bytes.try_into().unwrap_or([0; 8]));
313    let new_size = old_size + to_add;
314    Some(new_size.to_le_bytes().to_vec())
315}
316
317/// Error when creating a Shares struct
318#[derive(Error, Debug)]
319pub enum CreateSharesError {
320    #[error(transparent)]
321    IOError(#[from] sled::Error),
322    #[error(transparent)]
323    ScanDirError(#[from] ScanDirError),
324}
325
326/// Error when indexing a dir
327#[derive(Error, Debug)]
328pub enum ScanDirError {
329    #[error(transparent)]
330    IOError(#[from] std::io::Error),
331    #[error("Cannot parse OsString")]
332    OsStringError(),
333    #[error("Unable to merge db record")]
334    DbMergeError(#[from] sled::Error),
335    #[error("Cannot get parent of given dir")]
336    GetParentError,
337    #[error("Got entry which does not appear to be a child of the given directory")]
338    PrefixError(#[from] std::path::StripPrefixError),
339    #[error("Error converting database value to u64")]
340    U64ConversionError,
341    #[error("Share dir does not exist in DB")]
342    NoShare,
343}
344
345/// Error when parsing a Db entry
346#[derive(Error, Debug)]
347pub enum EntryParseError {
348    #[error("Db error")]
349    DbError(#[from] sled::Error),
350    #[error("Error parsing UTF8")]
351    Utf8Error(#[from] std::str::Utf8Error),
352    #[error("Error converting database value to u64")]
353    U64ConversionError(),
354    #[error("Path not found")]
355    PathNotFound,
356}
357
358/// Error when resolving a path from a request
359#[derive(Error, Debug)]
360pub enum ResolvePathError {
361    #[error("Db error")]
362    DbError(#[from] sled::Error),
363    #[error("Cannot get share name")]
364    MissingFirstComponent,
365    #[error("Cannot find share name in db")]
366    BadShareName,
367    #[error("Error parsing UTF8")]
368    Utf8Error(#[from] std::str::Utf8Error),
369    #[error("File does not exist in db")]
370    FileNotFound,
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use tempfile::TempDir;
377
378    fn create_test_entries() -> Vec<Entry> {
379        vec![
380            Entry {
381                name: "".to_string(),
382                size: 17,
383                is_dir: true,
384            },
385            Entry {
386                name: "test-data".to_string(),
387                size: 17,
388                is_dir: true,
389            },
390            Entry {
391                name: "test-data/subdir".to_string(),
392                size: 12,
393                is_dir: true,
394            },
395            Entry {
396                name: "test-data/subdir/subsubdir".to_string(),
397                size: 6,
398                is_dir: true,
399            },
400            Entry {
401                name: "test-data/somefile".to_string(),
402                size: 5,
403                is_dir: false,
404            },
405            Entry {
406                name: "test-data/subdir/anotherfile".to_string(),
407                size: 6,
408                is_dir: false,
409            },
410            Entry {
411                name: "test-data/subdir/subsubdir/yetanotherfile".to_string(),
412                size: 6,
413                is_dir: false,
414            },
415        ]
416    }
417
418    #[tokio::test]
419    async fn share_query() {
420        let storage = TempDir::new().unwrap();
421        let mut db_dir = storage.as_ref().to_owned();
422        db_dir.push("db");
423        let db = sled::open(db_dir).expect("open");
424
425        let mut shares = Shares::new(db.clone(), Vec::new()).await.unwrap();
426        let added = shares.scan("tests/test-data").await.unwrap();
427        assert_eq!(added, 3);
428
429        let mut test_entries = create_test_entries();
430        let responses = shares.query(None, None, true).unwrap();
431        for res in responses {
432            match res {
433                LsResponse::Success(entries) => {
434                    for entry in entries {
435                        let i = test_entries.iter().position(|e| e == &entry).unwrap();
436                        test_entries.remove(i);
437                    }
438                }
439                LsResponse::Err(err) => {
440                    panic!("Got error response {:?}", err);
441                }
442            }
443        }
444        // Make sure we found every entry
445        assert_eq!(test_entries.len(), 0);
446
447        // Try resolving a path name
448        let (resolved, _size) = shares
449            .resolve_path("test-data/subdir/anotherfile".to_string())
450            .unwrap();
451        assert_eq!(
452            resolved,
453            PathBuf::from("tests/test-data/subdir/anotherfile")
454        );
455
456        // Repeat the process with a new shares instance using the same db, to simulate restarting
457        // the program
458        let mut shares_2 = Shares::new(db, Vec::new()).await.unwrap();
459
460        let added = shares_2.scan("tests/test-data").await.unwrap();
461        assert_eq!(added, 3);
462
463        let mut test_entries = create_test_entries();
464        let responses = shares_2.query(None, None, true).unwrap();
465        for res in responses {
466            match res {
467                LsResponse::Success(entries) => {
468                    for entry in entries {
469                        let i = test_entries.iter().position(|e| e == &entry).unwrap();
470                        test_entries.remove(i);
471                    }
472                }
473                LsResponse::Err(err) => {
474                    panic!("Got error response {:?}", err);
475                }
476            }
477        }
478        // Make sure we found every entry
479        assert_eq!(test_entries.len(), 0);
480    }
481}