use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use lazy_static::lazy_static;
use spin::Mutex;
use super::types::{EventMask, EventType, FsEvent, NotifyError, WatchDescriptor, WatchOptions};
pub type WatchCallback = Box<dyn Fn(&FsEvent) + Send + Sync>;
pub struct Watch {
pub descriptor: WatchDescriptor,
pub dataset: String,
pub path: String,
pub mask: EventMask,
pub options: WatchOptions,
callback: Option<WatchCallback>,
pub event_count: u64,
pub active: bool,
}
impl Watch {
pub fn new(
descriptor: WatchDescriptor,
dataset: &str,
path: &str,
mask: EventMask,
options: WatchOptions,
) -> Self {
Self {
descriptor,
dataset: dataset.into(),
path: path.into(),
mask,
options,
callback: None,
event_count: 0,
active: true,
}
}
pub fn with_callback(mut self, callback: WatchCallback) -> Self {
self.callback = Some(callback);
self
}
pub fn matches(&self, event: &FsEvent) -> bool {
if !self.active {
return false;
}
if self.dataset != event.dataset {
return false;
}
if !self.mask.contains(event.event_type) {
return false;
}
if self.options.recursive {
event.path.starts_with(&self.path)
|| event.path == self.path
|| (event.path.starts_with(&self.path)
&& event.path.as_bytes().get(self.path.len()) == Some(&b'/'))
} else {
if event.path == self.path {
!self.options.exclude_self
} else if let Some(parent) = event.path.rsplit_once('/') {
parent.0 == self.path
} else {
false
}
}
}
pub fn invoke(&mut self, event: &FsEvent) {
if let Some(ref callback) = self.callback {
callback(event);
}
self.event_count += 1;
if self.options.oneshot {
self.active = false;
}
}
pub fn deactivate(&mut self) {
self.active = false;
}
}
lazy_static! {
static ref WATCHES: Mutex<WatchRegistry> = Mutex::new(WatchRegistry::new());
}
pub struct WatchRegistry {
watches: BTreeMap<u64, Watch>,
next_id: u64,
max_watches: usize,
}
impl WatchRegistry {
pub fn new() -> Self {
Self {
watches: BTreeMap::new(),
next_id: 1,
max_watches: 65536, }
}
pub fn set_max_watches(&mut self, max: usize) {
self.max_watches = max;
}
pub fn add(
&mut self,
dataset: &str,
path: &str,
mask: EventMask,
options: WatchOptions,
callback: Option<WatchCallback>,
) -> Result<WatchDescriptor, NotifyError> {
if self.watches.len() >= self.max_watches {
return Err(NotifyError::TooManyWatches);
}
let id = self.next_id;
self.next_id += 1;
let descriptor = WatchDescriptor::new(id);
let mut watch = Watch::new(descriptor, dataset, path, mask, options);
if let Some(cb) = callback {
watch = watch.with_callback(cb);
}
self.watches.insert(id, watch);
Ok(descriptor)
}
pub fn remove(&mut self, descriptor: WatchDescriptor) -> Result<(), NotifyError> {
if self.watches.remove(&descriptor.id()).is_none() {
return Err(NotifyError::WatchNotFound(descriptor.id()));
}
Ok(())
}
pub fn get(&self, descriptor: WatchDescriptor) -> Option<&Watch> {
self.watches.get(&descriptor.id())
}
pub fn get_mut(&mut self, descriptor: WatchDescriptor) -> Option<&mut Watch> {
self.watches.get_mut(&descriptor.id())
}
pub fn matching_watches(&self, event: &FsEvent) -> Vec<WatchDescriptor> {
self.watches
.values()
.filter(|w| w.matches(event))
.map(|w| w.descriptor)
.collect()
}
pub fn count(&self) -> usize {
self.watches.len()
}
pub fn active_count(&self) -> usize {
self.watches.values().filter(|w| w.active).count()
}
pub fn list_by_dataset(&self, dataset: &str) -> Vec<WatchDescriptor> {
self.watches
.values()
.filter(|w| w.dataset == dataset)
.map(|w| w.descriptor)
.collect()
}
pub fn remove_by_dataset(&mut self, dataset: &str) -> usize {
let to_remove: Vec<u64> = self
.watches
.iter()
.filter(|(_, w)| w.dataset == dataset)
.map(|(id, _)| *id)
.collect();
let count = to_remove.len();
for id in to_remove {
self.watches.remove(&id);
}
count
}
pub fn cleanup_inactive(&mut self) -> usize {
let to_remove: Vec<u64> = self
.watches
.iter()
.filter(|(_, w)| !w.active)
.map(|(id, _)| *id)
.collect();
let count = to_remove.len();
for id in to_remove {
self.watches.remove(&id);
}
count
}
pub fn clear(&mut self) {
self.watches.clear();
}
}
impl Default for WatchRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn add_watch(
dataset: &str,
path: &str,
mask: EventMask,
options: WatchOptions,
callback: Option<WatchCallback>,
) -> Result<WatchDescriptor, NotifyError> {
let mut registry = WATCHES.lock();
registry.add(dataset, path, mask, options, callback)
}
pub fn watch(
dataset: &str,
path: &str,
events: &[EventType],
callback: WatchCallback,
) -> Result<WatchDescriptor, NotifyError> {
let mask = EventMask::from_events(events);
add_watch(dataset, path, mask, WatchOptions::default(), Some(callback))
}
pub fn watch_recursive(
dataset: &str,
path: &str,
events: &[EventType],
callback: WatchCallback,
) -> Result<WatchDescriptor, NotifyError> {
let mask = EventMask::from_events(events);
add_watch(
dataset,
path,
mask,
WatchOptions::recursive(),
Some(callback),
)
}
pub fn remove_watch(descriptor: WatchDescriptor) -> Result<(), NotifyError> {
let mut registry = WATCHES.lock();
registry.remove(descriptor)
}
pub fn get_watch_info(descriptor: WatchDescriptor) -> Option<WatchInfo> {
let registry = WATCHES.lock();
registry.get(descriptor).map(|w| WatchInfo {
descriptor: w.descriptor,
dataset: w.dataset.clone(),
path: w.path.clone(),
mask: w.mask,
recursive: w.options.recursive,
active: w.active,
event_count: w.event_count,
})
}
pub fn watch_count() -> usize {
let registry = WATCHES.lock();
registry.count()
}
#[derive(Debug, Clone)]
pub struct WatchInfo {
pub descriptor: WatchDescriptor,
pub dataset: String,
pub path: String,
pub mask: EventMask,
pub recursive: bool,
pub active: bool,
pub event_count: u64,
}
pub(crate) fn dispatch_to_watches(event: &FsEvent) {
let mut registry = WATCHES.lock();
let matching: Vec<u64> = registry
.watches
.iter()
.filter(|(_, w)| w.matches(event))
.map(|(id, _)| *id)
.collect();
for id in matching {
if let Some(watch) = registry.watches.get_mut(&id) {
watch.invoke(event);
}
}
registry.cleanup_inactive();
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::sync::Arc;
use core::sync::atomic::{AtomicU64, Ordering};
#[test]
fn test_watch_matches_exact() {
let watch = Watch::new(
WatchDescriptor::new(1),
"tank/data",
"/path/to",
EventMask::ALL,
WatchOptions::default(),
);
let event = FsEvent::new(EventType::Create, "tank/data", "/path/to/file.txt");
assert!(watch.matches(&event));
let event = FsEvent::new(EventType::Create, "tank/data", "/path/to");
assert!(watch.matches(&event));
let event = FsEvent::new(EventType::Create, "other/pool", "/path/to/file.txt");
assert!(!watch.matches(&event));
let event = FsEvent::new(EventType::Create, "tank/data", "/path/to/sub/file.txt");
assert!(!watch.matches(&event));
}
#[test]
fn test_watch_matches_recursive() {
let watch = Watch::new(
WatchDescriptor::new(1),
"tank/data",
"/path/to",
EventMask::ALL,
WatchOptions::recursive(),
);
let event = FsEvent::new(EventType::Create, "tank/data", "/path/to/sub/deep/file.txt");
assert!(watch.matches(&event));
}
#[test]
fn test_watch_mask_filter() {
let watch = Watch::new(
WatchDescriptor::new(1),
"tank",
"/path",
EventMask::from_events(&[EventType::Create, EventType::Delete]),
WatchOptions::default(),
);
let create = FsEvent::new(EventType::Create, "tank", "/path/file.txt");
assert!(watch.matches(&create));
let modify = FsEvent::new(EventType::Modify, "tank", "/path/file.txt");
assert!(!watch.matches(&modify));
}
#[test]
fn test_watch_oneshot() {
let mut watch = Watch::new(
WatchDescriptor::new(1),
"tank",
"/path",
EventMask::ALL,
WatchOptions::oneshot(),
);
let event = FsEvent::new(EventType::Create, "tank", "/path/file.txt");
assert!(watch.matches(&event));
watch.invoke(&event);
assert!(!watch.active);
assert!(!watch.matches(&event)); }
#[test]
fn test_registry_add_remove() {
let mut registry = WatchRegistry::new();
let wd = registry
.add(
"tank",
"/path",
EventMask::ALL,
WatchOptions::default(),
None,
)
.unwrap();
assert_eq!(registry.count(), 1);
registry.remove(wd).unwrap();
assert_eq!(registry.count(), 0);
}
#[test]
fn test_registry_max_watches() {
let mut registry = WatchRegistry::new();
registry.set_max_watches(2);
registry
.add(
"tank",
"/path1",
EventMask::ALL,
WatchOptions::default(),
None,
)
.unwrap();
registry
.add(
"tank",
"/path2",
EventMask::ALL,
WatchOptions::default(),
None,
)
.unwrap();
let result = registry.add(
"tank",
"/path3",
EventMask::ALL,
WatchOptions::default(),
None,
);
assert!(matches!(result, Err(NotifyError::TooManyWatches)));
}
#[test]
fn test_registry_matching_watches() {
let mut registry = WatchRegistry::new();
let wd1 = registry
.add(
"tank",
"/path1",
EventMask::from_events(&[EventType::Create]),
WatchOptions::default(),
None,
)
.unwrap();
let wd2 = registry
.add(
"tank",
"/path2",
EventMask::from_events(&[EventType::Create]),
WatchOptions::default(),
None,
)
.unwrap();
let event = FsEvent::new(EventType::Create, "tank", "/path1/file.txt");
let matches = registry.matching_watches(&event);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0], wd1);
}
#[test]
fn test_registry_remove_by_dataset() {
let mut registry = WatchRegistry::new();
registry
.add(
"tank",
"/path1",
EventMask::ALL,
WatchOptions::default(),
None,
)
.unwrap();
registry
.add(
"tank",
"/path2",
EventMask::ALL,
WatchOptions::default(),
None,
)
.unwrap();
registry
.add(
"other",
"/path3",
EventMask::ALL,
WatchOptions::default(),
None,
)
.unwrap();
let removed = registry.remove_by_dataset("tank");
assert_eq!(removed, 2);
assert_eq!(registry.count(), 1);
}
#[test]
fn test_watch_with_callback() {
let counter = Arc::new(AtomicU64::new(0));
let counter_clone = counter.clone();
let callback: WatchCallback = Box::new(move |_event| {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
let mut watch = Watch::new(
WatchDescriptor::new(1),
"tank",
"/path",
EventMask::ALL,
WatchOptions::default(),
)
.with_callback(callback);
let event = FsEvent::new(EventType::Create, "tank", "/path/file.txt");
watch.invoke(&event);
watch.invoke(&event);
assert_eq!(counter.load(Ordering::SeqCst), 2);
assert_eq!(watch.event_count, 2);
}
}