1use std::{
2 any::TypeId,
3 collections::HashMap,
4 marker::PhantomData,
5 sync::{atomic::AtomicU64, Arc, RwLock, Weak},
6};
7
8use futures::future::{BoxFuture, LocalBoxFuture};
9
10use crate::{
11 client::{
12 AsMistyClientHandle, AsReadonlyMistyClientHandle, MistyClientAccessor, MistyClientHandle,
13 MistyClientInner, MistyReadonlyClientHandle,
14 },
15 utils::PhantomUnsync,
16};
17
18pub trait IAsyncTaskRuntimeAdapter {
19 fn spawn(&self, future: BoxFuture<'static, ()>) -> u64;
20 fn spawn_local(&self, future: LocalBoxFuture<'static, ()>) -> u64;
21 fn try_abort(&self, task_id: u64);
22}
23
24pub struct MistyAsyncTaskContext {
25 pub(crate) inner: Weak<MistyClientInner>,
26}
27
28pub struct MistyClientAsyncHandleGuard {
29 inner: Option<Arc<MistyClientInner>>,
30 _unsync_marker: PhantomUnsync,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34struct MistyAsyncTask {
35 id: u64,
36 host_task_id: u64,
37}
38
39fn alloc_task_id() -> u64 {
40 static ALLOCATED: AtomicU64 = AtomicU64::new(1);
41 let id = ALLOCATED.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
42 id
43}
44
45#[derive(Debug, Default)]
46struct InternalMistyAsyncTaskPool {
47 async_tasks: HashMap<u64, MistyAsyncTask>,
48}
49
50#[derive(Debug, Clone, Default)]
51struct BoxedMistyAsyncTaskPool {
52 pool: Arc<RwLock<InternalMistyAsyncTaskPool>>,
53}
54
55#[derive(Debug)]
56struct MistyAsyncTaskPool<T> {
57 pool: Arc<RwLock<InternalMistyAsyncTaskPool>>,
58 _marker: PhantomData<T>,
59}
60
61type InternalMistyAsyncTaskPools = HashMap<TypeId, BoxedMistyAsyncTaskPool>;
62
63#[derive(Debug)]
64pub(crate) struct MistyAsyncTaskPools {
65 pools: Arc<RwLock<InternalMistyAsyncTaskPools>>,
66}
67
68struct MistyAsyncTaskPoolSpawnCleanupGuard<T>
69where
70 T: MistyAsyncTaskTrait,
71{
72 task_id: u64,
73 marker: PhantomData<T>,
74 pools: Weak<RwLock<InternalMistyAsyncTaskPools>>,
75}
76
77impl<T> Drop for MistyAsyncTaskPoolSpawnCleanupGuard<T>
78where
79 T: MistyAsyncTaskTrait,
80{
81 fn drop(&mut self) {
82 if let Some(pools) = self.pools.upgrade() {
83 let tid = std::any::TypeId::of::<T>();
84 let pool = pools.write().unwrap();
85 let pool = pool.get(&tid);
86 if let Some(pool) = pool {
87 let mut pool = pool.pool.write().unwrap();
88 pool.async_tasks.remove(&self.task_id);
89 }
90 }
91 }
92}
93
94impl MistyAsyncTaskPools {
95 pub fn new() -> Self {
96 Self {
97 pools: Default::default(),
98 }
99 }
100
101 fn get<T: 'static>(&self) -> MistyAsyncTaskPool<T> {
102 let pool = {
103 let mut pools = self.pools.write().unwrap();
104 let pool = pools
105 .entry(std::any::TypeId::of::<T>())
106 .or_default()
107 .clone();
108 pool
109 };
110
111 MistyAsyncTaskPool {
112 pool: pool.pool.clone(),
113 _marker: Default::default(),
114 }
115 }
116
117 pub(crate) fn reset(&self, rt: &dyn IAsyncTaskRuntimeAdapter) {
118 let mut pools = self.pools.write().unwrap();
119
120 for (_, pool) in pools.iter() {
121 let mut pool = pool.pool.write().unwrap();
122 for (_, task) in pool.async_tasks.iter() {
123 rt.try_abort(task.host_task_id);
124 }
125 pool.async_tasks.clear();
126 }
127 pools.clear();
128 }
129}
130
131impl<T> MistyAsyncTaskPool<T>
132where
133 T: MistyAsyncTaskTrait,
134{
135 pub fn spawn<R, E>(
136 &self,
137 handle: MistyReadonlyClientHandle,
138 future_fn: impl (FnOnce(MistyAsyncTaskContext) -> R) + Send + 'static,
139 ) where
140 R: std::future::Future<Output = Result<(), E>> + Send + 'static,
141 E: std::fmt::Display,
142 {
143 let inner = handle.inner.clone();
144 let cloned_inner = inner.clone();
145 let task_id = alloc_task_id();
146
147 let host_task_id = inner.async_task_runtime.spawn(Box::pin(async move {
148 let inner = cloned_inner;
149 let _guard = MistyAsyncTaskPoolSpawnCleanupGuard::<T> {
150 task_id,
151 marker: Default::default(),
152 pools: Arc::downgrade(&inner.async_task_pools.pools),
153 };
154
155 let ctx = MistyAsyncTaskContext::new(Arc::downgrade(&inner));
156 let res = future_fn(ctx).await;
157 if res.is_err() {
158 let e = res.unwrap_err();
159 tracing::error!("spawn error: {}", e);
160 }
161 }));
162
163 let task = MistyAsyncTask {
164 id: task_id,
165 host_task_id,
166 };
167 {
168 let mut pool = self.pool.write().unwrap();
169 pool.async_tasks.insert(task_id, task);
170 }
171 }
172
173 pub fn spawn_local<R, E>(
174 &self,
175 handle: MistyReadonlyClientHandle,
176 future_fn: impl (FnOnce(MistyAsyncTaskContext) -> R) + 'static,
177 ) where
178 R: std::future::Future<Output = Result<(), E>> + 'static,
179 E: std::fmt::Display,
180 {
181 let inner = handle.inner.clone();
182 let cloned_inner = inner.clone();
183 let task_id = alloc_task_id();
184
185 let host_task_id = inner.async_task_runtime.spawn_local(Box::pin(async move {
186 let inner = cloned_inner;
187 let _guard = MistyAsyncTaskPoolSpawnCleanupGuard::<T> {
188 task_id,
189 marker: Default::default(),
190 pools: Arc::downgrade(&inner.async_task_pools.pools),
191 };
192
193 let ctx = MistyAsyncTaskContext::new(Arc::downgrade(&inner));
194 let res = future_fn(ctx).await;
195 if res.is_err() {
196 let e = res.unwrap_err();
197 tracing::error!("spawn error: {}", e);
198 }
199 }));
200
201 let task = MistyAsyncTask {
202 id: task_id,
203 host_task_id,
204 };
205 {
206 let mut pool = self.pool.write().unwrap();
207 pool.async_tasks.insert(task_id, task);
208 }
209 }
210
211 pub fn cancel_all(&self, rt: &dyn IAsyncTaskRuntimeAdapter) {
212 let mut pool = self.pool.write().unwrap();
213
214 for (_, task) in pool.async_tasks.iter() {
215 rt.try_abort(task.host_task_id);
216 }
217 pool.async_tasks.clear();
218 }
219}
220
221impl MistyClientAsyncHandleGuard {
222 pub fn handle(&self) -> MistyReadonlyClientHandle {
223 let inner = self.inner.as_ref().unwrap();
225 MistyReadonlyClientHandle { inner }
226 }
227}
228
229impl MistyAsyncTaskContext {
230 fn new(inner: Weak<MistyClientInner>) -> Self {
231 Self { inner }
232 }
233
234 pub fn handle(&self) -> MistyClientAsyncHandleGuard {
235 let inner = self.inner.upgrade();
236 MistyClientAsyncHandleGuard {
237 inner,
238 _unsync_marker: Default::default(),
239 }
240 }
241
242 pub fn accessor(&self) -> MistyClientAccessor {
243 MistyClientAccessor {
244 inner: self.inner.clone(),
245 }
246 }
247
248 pub fn schedule<E>(
249 &self,
250 handler: impl FnOnce(MistyClientHandle) -> Result<(), E> + Send + Sync + 'static,
251 ) where
252 E: std::fmt::Display,
253 {
254 let client = self.inner.upgrade();
255 if client.is_none() {
256 return;
257 }
258 let inner = client.unwrap();
259
260 if inner.is_destroyed() {
261 tracing::warn!("schedule but client is destroyed");
262 return;
263 }
264 inner
265 .schedule_manager
266 .enqueue(&inner.signal_emitter, handler);
267 }
268}
269
270pub trait MistyAsyncTaskTrait: Sized + Send + Sync + 'static {
271 fn spawn_once<'a, T, E>(
272 cx: impl AsMistyClientHandle<'a>,
273 future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + Send + Sync + 'static,
274 ) where
275 T: std::future::Future<Output = Result<(), E>> + Send + 'static,
276 E: std::fmt::Display,
277 {
278 let inner = cx.handle().inner;
279 let pool = inner.async_task_pools.get::<Self>();
280 pool.cancel_all(inner.async_task_runtime.as_ref());
281 pool.spawn(cx.readonly_handle().clone(), future_fn);
282 }
283
284 fn spawn<'a, T, E>(
285 cx: impl AsMistyClientHandle<'a>,
286 future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + Send + Sync + 'static,
287 ) where
288 T: std::future::Future<Output = Result<(), E>> + Send + 'static,
289 E: std::fmt::Display,
290 {
291 let pool = cx.handle().inner.async_task_pools.get::<Self>();
292 pool.spawn(cx.readonly_handle(), future_fn);
293 }
294
295 fn spawn_local_once<'a, T, E>(
296 cx: impl AsMistyClientHandle<'a>,
297 future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + 'static,
298 ) where
299 T: std::future::Future<Output = Result<(), E>> + 'static,
300 E: std::fmt::Display,
301 {
302 let inner = cx.handle().inner;
303 let pool = inner.async_task_pools.get::<Self>();
304 pool.cancel_all(inner.async_task_runtime.as_ref());
305 pool.spawn_local(cx.readonly_handle().clone(), future_fn);
306 }
307
308 fn spawn_local<'a, T, E>(
309 cx: impl AsMistyClientHandle<'a>,
310 future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + 'static,
311 ) where
312 T: std::future::Future<Output = Result<(), E>> + 'static,
313 E: std::fmt::Display,
314 {
315 let pool = cx.handle().inner.async_task_pools.get::<Self>();
316 pool.spawn_local(cx.readonly_handle(), future_fn);
317 }
318
319 fn cancel_all<'a>(cx: impl AsMistyClientHandle<'a>) {
320 let inner = cx.handle().inner;
321 let pool = inner.async_task_pools.get::<Self>();
322 pool.cancel_all(inner.async_task_runtime.as_ref());
323 }
324}