cyfs_util/util/
reenter_caller.rs1use std::cmp::Eq;
2use std::collections::{hash_map::Entry, HashMap};
3use std::future::Future;
4use std::hash::Hash;
5use std::sync::{Arc, Mutex};
6
7struct ReenterCaller<T>
8where
9 T: Send + 'static + Clone,
10{
11 result: Option<T>,
12}
13
14impl<T> ReenterCaller<T>
15where
16 T: Send + 'static + Clone,
17{
18 pub fn new() -> Self {
19 Self { result: None }
20 }
21}
22
23#[derive(Clone)]
24pub struct ReenterCallManager<K, T>
25where
26 K: Hash + Eq + ToOwned<Owned = K>,
27 T: Send + 'static + Clone,
28{
29 call_list: Arc<Mutex<HashMap<K, Arc<async_std::sync::Mutex<ReenterCaller<T>>>>>>,
30}
31
32impl<K, T> ReenterCallManager<K, T>
33where
34 K: Hash + Eq + ToOwned<Owned = K>,
35 T: Send + 'static + Clone,
36{
37 pub fn new() -> Self {
38 Self {
39 call_list: Arc::new(Mutex::new(HashMap::new())),
40 }
41 }
42}
43
44impl<K, T> ReenterCallManager<K, T>
45where
46 K: Hash + Eq + ToOwned<Owned = K>,
47 T: Send + 'static + Clone,
48{
49 pub async fn call<F>(&self, key: &K, future: F) -> T
50 where
51 F: Future<Output = T> + Send + 'static,
52 {
53 let caller = {
56 let mut list = self.call_list.lock().unwrap();
57 match list.entry(key.to_owned()) {
58 Entry::Occupied(o) => o.get().clone(),
59 Entry::Vacant(v) => {
60 let caller = ReenterCaller::new();
61 let item = Arc::new(async_std::sync::Mutex::new(caller));
62 v.insert(item.clone());
63 item
64 }
65 }
66 };
67
68 let mut item = caller.lock().await;
70
71 if item.result.is_none() {
73 let ret = future.await;
74
75 {
77 let mut list = self.call_list.lock().unwrap();
78 list.remove(key);
79 }
80
81 let ref_count = Arc::strong_count(&caller);
83 if ref_count > 1 {
84 assert!(item.result.is_none());
85 item.result = Some(ret.clone());
86 }
87 ret
88 } else {
89 assert!(item.result.is_some());
90 item.result.as_ref().unwrap().clone()
91 }
92 }
93}
94
95#[cfg(test)]
96mod test {
97 use super::*;
98 use cyfs_base::*;
99 use std::sync::atomic::{AtomicU32, Ordering};
100
101 #[derive(Clone)]
102 struct TestReneterCaller {
103 caller_manager: ReenterCallManager<String, BuckyResult<u32>>,
104 next_value: Arc<AtomicU32>,
105 }
106
107 impl TestReneterCaller {
108 pub fn new() -> Self {
109 Self {
110 caller_manager: ReenterCallManager::new(),
111 next_value: Arc::new(AtomicU32::new(0)),
112 }
113 }
114
115 pub async fn call(&self, key: &str) -> BuckyResult<u32> {
116 let this = self.clone();
117 let owned_key = key.to_owned();
118 self.caller_manager
119 .call(&key.to_owned(), async move {
120 println!(
121 "will exec call... key={}, next={:?}",
122 owned_key, this.next_value
123 );
124 async_std::task::sleep(std::time::Duration::from_secs(5)).await;
125 println!(
126 "end exec call... key={}, next={:?}",
127 owned_key, this.next_value
128 );
129
130 let v = this.next_value.fetch_add(1, Ordering::SeqCst);
131 Ok(v)
132 })
133 .await
134 }
135 }
136
137 #[async_std::test]
138 async fn test_enter_caller_once() {
139 let tester = TestReneterCaller::new();
140 for i in 0..10 {
141 let tester = tester.clone();
142 async_std::task::spawn(async move {
143 let ret = tester.call("xxxx").await.unwrap();
144 assert_eq!(ret, 0);
145 println!("caller complete: index={}, ret={}", i, ret);
146 });
147 }
148 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
149 }
150 #[async_std::test]
151 async fn test_enter_caller() {
152 let tester = TestReneterCaller::new();
153 for i in 0..100 {
154 let tester = tester.clone();
155 async_std::task::spawn(async move {
156 async_std::task::sleep(std::time::Duration::from_secs(i)).await;
157 let ret = tester.call("xxxx").await.unwrap();
158 println!("caller complete: index={}, ret={}", i, ret);
159 });
160 }
161 async_std::task::sleep(std::time::Duration::from_secs(100)).await;
162 }
163}