hook/
lib.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::RwLock;
4use once_cell::sync::Lazy;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7type FilterCallback = Box<dyn Fn(Box<dyn Any>) -> Box<dyn Any> + Send + Sync>;
8
9struct Filter {
10    id: u64,
11    priority: i32,
12    callback: FilterCallback,
13    type_id: TypeId,
14}
15
16static FILTERS: Lazy<RwLock<HashMap<String, Vec<Filter>>>> = Lazy::new(|| {
17    RwLock::new(HashMap::new())
18});
19
20static FILTER_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
21
22/// Registers a filter callback for the given hook name.
23/// Returns an ID that can be used to remove the filter.
24pub fn add_filter<T: 'static + Send + Sync>(
25    hook: &str,
26    priority: i32,
27    callback: impl Fn(T) -> T + 'static + Send + Sync,
28) -> u64 {
29    let id = FILTER_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
30    let filter = Filter {
31        id,
32        priority,
33        callback: Box::new(move |value: Box<dyn Any>| {
34            let value = *value.downcast::<T>().expect("Type mismatch in filter");
35            let new_value = callback(value);
36            Box::new(new_value)
37        }),
38        type_id: TypeId::of::<T>(),
39    };
40
41    let mut filters = FILTERS.write().unwrap();
42    let entry = filters.entry(hook.to_string()).or_insert_with(Vec::new);
43    entry.push(filter);
44    entry.sort_by_key(|f| f.priority);
45
46    id
47}
48
49/// Applies all filter callbacks registered for the given hook to `value`.
50pub fn apply_filters<T: 'static + Send + Sync>(hook: &str, value: T) -> T {
51    let filters = FILTERS.read().unwrap();
52    let filter_list = match filters.get(hook) {
53        Some(list) => list,
54        None => return value,
55    };
56
57    let mut result: Box<dyn Any> = Box::new(value);
58    for filter in filter_list {
59        if filter.type_id == TypeId::of::<T>() {
60            result = (filter.callback)(result);
61        } else {
62            panic!("Type mismatch for filter hook '{}'", hook);
63        }
64    }
65
66    *result.downcast::<T>().expect("Type mismatch in final value")
67}
68
69/// Removes the filter with the specified ID from the given hook.
70/// Returns `true` if a filter was removed.
71pub fn remove_filter(hook: &str, id: u64) -> bool {
72    let mut filters = FILTERS.write().unwrap();
73    if let Some(list) = filters.get_mut(hook) {
74        let orig_len = list.len();
75        list.retain(|f| f.id != id);
76
77        return list.len() != orig_len;
78    }
79
80    false
81}
82
83/// Removes all filters for the given hook.
84pub fn remove_all_filters(hook: &str) {
85    let mut filters = FILTERS.write().unwrap();
86    filters.remove(hook);
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_filters_i32() {
95        let hook = "modify_int";
96        add_filter(hook, 10, |v: i32| v + 1);
97        add_filter(hook, 20, |v: i32| v * 3);
98        let result = apply_filters(hook, 4);
99        assert_eq!(result, 15);
100    }
101
102    #[test]
103    fn test_filters_string() {
104        let hook = "modify_string";
105        add_filter(hook, 10, |s: String| format!("Hello, {}", s));
106        add_filter(hook, 20, |s: String| s.to_uppercase());
107        let result = apply_filters(hook, "world".to_string());
108        assert_eq!(result, "HELLO, WORLD");
109    }
110}