cyfs_util/util/
reenter_caller.rs

1use std::cmp::Eq;
2use std::collections::{hash_map::Entry, HashMap};
3use std::future::Future;
4use std::hash::Hash;
5use std::sync::{Arc, Mutex};
6
7struct ReenterCaller<T>
8where
9    T: Send + 'static + Clone,
10{
11    result: Option<T>,
12}
13
14impl<T> ReenterCaller<T>
15where
16    T: Send + 'static + Clone,
17{
18    pub fn new() -> Self {
19        Self { result: None }
20    }
21}
22
23#[derive(Clone)]
24pub struct ReenterCallManager<K, T>
25where
26    K: Hash + Eq + ToOwned<Owned = K>,
27    T: Send + 'static + Clone,
28{
29    call_list: Arc<Mutex<HashMap<K, Arc<async_std::sync::Mutex<ReenterCaller<T>>>>>>,
30}
31
32impl<K, T> ReenterCallManager<K, T>
33where
34    K: Hash + Eq + ToOwned<Owned = K>,
35    T: Send + 'static + Clone,
36{
37    pub fn new() -> Self {
38        Self {
39            call_list: Arc::new(Mutex::new(HashMap::new())),
40        }
41    }
42}
43
44impl<K, T> ReenterCallManager<K, T>
45where
46    K: Hash + Eq + ToOwned<Owned = K>,
47    T: Send + 'static + Clone,
48{
49    pub async fn call<F>(&self, key: &K, future: F) -> T
50    where
51        F: Future<Output = T> + Send + 'static,
52    {
53        // debug!("will reenter call: key={}", key);
54
55        let caller = {
56            let mut list = self.call_list.lock().unwrap();
57            match list.entry(key.to_owned()) {
58                Entry::Occupied(o) => o.get().clone(),
59                Entry::Vacant(v) => {
60                    let caller = ReenterCaller::new();
61                    let item = Arc::new(async_std::sync::Mutex::new(caller));
62                    v.insert(item.clone());
63                    item
64                }
65            }
66        };
67
68        // 这里必须使用异步锁,来保证调用中不重入
69        let mut item = caller.lock().await;
70
71        // 第一个进来的result一定为空,需要锁住并执行目标闭包
72        if item.result.is_none() {
73            let ret = future.await;
74
75            // 移除
76            {
77                let mut list = self.call_list.lock().unwrap();
78                list.remove(key);
79            }
80
81            // 如果引用计数>1, 说明有重入的操作在等待,需要缓存闭包的返回值
82            let ref_count = Arc::strong_count(&caller);
83            if ref_count > 1 {
84                assert!(item.result.is_none());
85                item.result = Some(ret.clone());
86            }
87            ret
88        } else {
89            assert!(item.result.is_some());
90            item.result.as_ref().unwrap().clone()
91        }
92    }
93}
94
95#[cfg(test)]
96mod test {
97    use super::*;
98    use cyfs_base::*;
99    use std::sync::atomic::{AtomicU32, Ordering};
100
101    #[derive(Clone)]
102    struct TestReneterCaller {
103        caller_manager: ReenterCallManager<String, BuckyResult<u32>>,
104        next_value: Arc<AtomicU32>,
105    }
106
107    impl TestReneterCaller {
108        pub fn new() -> Self {
109            Self {
110                caller_manager: ReenterCallManager::new(),
111                next_value: Arc::new(AtomicU32::new(0)),
112            }
113        }
114
115        pub async fn call(&self, key: &str) -> BuckyResult<u32> {
116            let this = self.clone();
117            let owned_key = key.to_owned();
118            self.caller_manager
119                .call(&key.to_owned(), async move {
120                    println!(
121                        "will exec call... key={}, next={:?}",
122                        owned_key, this.next_value
123                    );
124                    async_std::task::sleep(std::time::Duration::from_secs(5)).await;
125                    println!(
126                        "end exec call... key={}, next={:?}",
127                        owned_key, this.next_value
128                    );
129
130                    let v = this.next_value.fetch_add(1, Ordering::SeqCst);
131                    Ok(v)
132                })
133                .await
134        }
135    }
136
137    #[async_std::test]
138    async fn test_enter_caller_once() {
139        let tester = TestReneterCaller::new();
140        for i in 0..10 {
141            let tester = tester.clone();
142            async_std::task::spawn(async move {
143                let ret = tester.call("xxxx").await.unwrap();
144                assert_eq!(ret, 0);
145                println!("caller complete: index={}, ret={}", i, ret);
146            });
147        }
148        async_std::task::sleep(std::time::Duration::from_secs(10)).await;
149    }
150    #[async_std::test]
151    async fn test_enter_caller() {
152        let tester = TestReneterCaller::new();
153        for i in 0..100 {
154            let tester = tester.clone();
155            async_std::task::spawn(async move {
156                async_std::task::sleep(std::time::Duration::from_secs(i)).await;
157                let ret = tester.call("xxxx").await.unwrap();
158                println!("caller complete: index={}, ret={}", i, ret);
159            });
160        }
161        async_std::task::sleep(std::time::Duration::from_secs(100)).await;
162    }
163}