misty_vm/
async_task.rs

1use std::{
2    any::TypeId,
3    collections::HashMap,
4    marker::PhantomData,
5    sync::{atomic::AtomicU64, Arc, RwLock, Weak},
6};
7
8use futures::future::{BoxFuture, LocalBoxFuture};
9
10use crate::{
11    client::{
12        AsMistyClientHandle, AsReadonlyMistyClientHandle, MistyClientAccessor, MistyClientHandle,
13        MistyClientInner, MistyReadonlyClientHandle,
14    },
15    utils::PhantomUnsync,
16};
17
18pub trait IAsyncTaskRuntimeAdapter {
19    fn spawn(&self, future: BoxFuture<'static, ()>) -> u64;
20    fn spawn_local(&self, future: LocalBoxFuture<'static, ()>) -> u64;
21    fn try_abort(&self, task_id: u64);
22}
23
24pub struct MistyAsyncTaskContext {
25    pub(crate) inner: Weak<MistyClientInner>,
26}
27
28pub struct MistyClientAsyncHandleGuard {
29    inner: Option<Arc<MistyClientInner>>,
30    _unsync_marker: PhantomUnsync,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34struct MistyAsyncTask {
35    id: u64,
36    host_task_id: u64,
37}
38
39fn alloc_task_id() -> u64 {
40    static ALLOCATED: AtomicU64 = AtomicU64::new(1);
41    let id = ALLOCATED.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
42    id
43}
44
45#[derive(Debug, Default)]
46struct InternalMistyAsyncTaskPool {
47    async_tasks: HashMap<u64, MistyAsyncTask>,
48}
49
50#[derive(Debug, Clone, Default)]
51struct BoxedMistyAsyncTaskPool {
52    pool: Arc<RwLock<InternalMistyAsyncTaskPool>>,
53}
54
55#[derive(Debug)]
56struct MistyAsyncTaskPool<T> {
57    pool: Arc<RwLock<InternalMistyAsyncTaskPool>>,
58    _marker: PhantomData<T>,
59}
60
61type InternalMistyAsyncTaskPools = HashMap<TypeId, BoxedMistyAsyncTaskPool>;
62
63#[derive(Debug)]
64pub(crate) struct MistyAsyncTaskPools {
65    pools: Arc<RwLock<InternalMistyAsyncTaskPools>>,
66}
67
68struct MistyAsyncTaskPoolSpawnCleanupGuard<T>
69where
70    T: MistyAsyncTaskTrait,
71{
72    task_id: u64,
73    marker: PhantomData<T>,
74    pools: Weak<RwLock<InternalMistyAsyncTaskPools>>,
75}
76
77impl<T> Drop for MistyAsyncTaskPoolSpawnCleanupGuard<T>
78where
79    T: MistyAsyncTaskTrait,
80{
81    fn drop(&mut self) {
82        if let Some(pools) = self.pools.upgrade() {
83            let tid = std::any::TypeId::of::<T>();
84            let pool = pools.write().unwrap();
85            let pool = pool.get(&tid);
86            if let Some(pool) = pool {
87                let mut pool = pool.pool.write().unwrap();
88                pool.async_tasks.remove(&self.task_id);
89            }
90        }
91    }
92}
93
94impl MistyAsyncTaskPools {
95    pub fn new() -> Self {
96        Self {
97            pools: Default::default(),
98        }
99    }
100
101    fn get<T: 'static>(&self) -> MistyAsyncTaskPool<T> {
102        let pool = {
103            let mut pools = self.pools.write().unwrap();
104            let pool = pools
105                .entry(std::any::TypeId::of::<T>())
106                .or_default()
107                .clone();
108            pool
109        };
110
111        MistyAsyncTaskPool {
112            pool: pool.pool.clone(),
113            _marker: Default::default(),
114        }
115    }
116
117    pub(crate) fn reset(&self, rt: &dyn IAsyncTaskRuntimeAdapter) {
118        let mut pools = self.pools.write().unwrap();
119
120        for (_, pool) in pools.iter() {
121            let mut pool = pool.pool.write().unwrap();
122            for (_, task) in pool.async_tasks.iter() {
123                rt.try_abort(task.host_task_id);
124            }
125            pool.async_tasks.clear();
126        }
127        pools.clear();
128    }
129}
130
131impl<T> MistyAsyncTaskPool<T>
132where
133    T: MistyAsyncTaskTrait,
134{
135    pub fn spawn<R, E>(
136        &self,
137        handle: MistyReadonlyClientHandle,
138        future_fn: impl (FnOnce(MistyAsyncTaskContext) -> R) + Send + 'static,
139    ) where
140        R: std::future::Future<Output = Result<(), E>> + Send + 'static,
141        E: std::fmt::Display,
142    {
143        let inner = handle.inner.clone();
144        let cloned_inner = inner.clone();
145        let task_id = alloc_task_id();
146
147        let host_task_id = inner.async_task_runtime.spawn(Box::pin(async move {
148            let inner = cloned_inner;
149            let _guard = MistyAsyncTaskPoolSpawnCleanupGuard::<T> {
150                task_id,
151                marker: Default::default(),
152                pools: Arc::downgrade(&inner.async_task_pools.pools),
153            };
154
155            let ctx = MistyAsyncTaskContext::new(Arc::downgrade(&inner));
156            let res = future_fn(ctx).await;
157            if res.is_err() {
158                let e = res.unwrap_err();
159                tracing::error!("spawn error: {}", e);
160            }
161        }));
162
163        let task = MistyAsyncTask {
164            id: task_id,
165            host_task_id,
166        };
167        {
168            let mut pool = self.pool.write().unwrap();
169            pool.async_tasks.insert(task_id, task);
170        }
171    }
172
173    pub fn spawn_local<R, E>(
174        &self,
175        handle: MistyReadonlyClientHandle,
176        future_fn: impl (FnOnce(MistyAsyncTaskContext) -> R) + 'static,
177    ) where
178        R: std::future::Future<Output = Result<(), E>> + 'static,
179        E: std::fmt::Display,
180    {
181        let inner = handle.inner.clone();
182        let cloned_inner = inner.clone();
183        let task_id = alloc_task_id();
184
185        let host_task_id = inner.async_task_runtime.spawn_local(Box::pin(async move {
186            let inner = cloned_inner;
187            let _guard = MistyAsyncTaskPoolSpawnCleanupGuard::<T> {
188                task_id,
189                marker: Default::default(),
190                pools: Arc::downgrade(&inner.async_task_pools.pools),
191            };
192
193            let ctx = MistyAsyncTaskContext::new(Arc::downgrade(&inner));
194            let res = future_fn(ctx).await;
195            if res.is_err() {
196                let e = res.unwrap_err();
197                tracing::error!("spawn error: {}", e);
198            }
199        }));
200
201        let task = MistyAsyncTask {
202            id: task_id,
203            host_task_id,
204        };
205        {
206            let mut pool = self.pool.write().unwrap();
207            pool.async_tasks.insert(task_id, task);
208        }
209    }
210
211    pub fn cancel_all(&self, rt: &dyn IAsyncTaskRuntimeAdapter) {
212        let mut pool = self.pool.write().unwrap();
213
214        for (_, task) in pool.async_tasks.iter() {
215            rt.try_abort(task.host_task_id);
216        }
217        pool.async_tasks.clear();
218    }
219}
220
221impl MistyClientAsyncHandleGuard {
222    pub fn handle(&self) -> MistyReadonlyClientHandle {
223        // SAFETY: spawned task will be aborted when client destroyed
224        let inner = self.inner.as_ref().unwrap();
225        MistyReadonlyClientHandle { inner }
226    }
227}
228
229impl MistyAsyncTaskContext {
230    fn new(inner: Weak<MistyClientInner>) -> Self {
231        Self { inner }
232    }
233
234    pub fn handle(&self) -> MistyClientAsyncHandleGuard {
235        let inner = self.inner.upgrade();
236        MistyClientAsyncHandleGuard {
237            inner,
238            _unsync_marker: Default::default(),
239        }
240    }
241
242    pub fn accessor(&self) -> MistyClientAccessor {
243        MistyClientAccessor {
244            inner: self.inner.clone(),
245        }
246    }
247
248    pub fn schedule<E>(
249        &self,
250        handler: impl FnOnce(MistyClientHandle) -> Result<(), E> + Send + Sync + 'static,
251    ) where
252        E: std::fmt::Display,
253    {
254        let client = self.inner.upgrade();
255        if client.is_none() {
256            return;
257        }
258        let inner = client.unwrap();
259
260        if inner.is_destroyed() {
261            tracing::warn!("schedule but client is destroyed");
262            return;
263        }
264        inner
265            .schedule_manager
266            .enqueue(&inner.signal_emitter, handler);
267    }
268}
269
270pub trait MistyAsyncTaskTrait: Sized + Send + Sync + 'static {
271    fn spawn_once<'a, T, E>(
272        cx: impl AsMistyClientHandle<'a>,
273        future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + Send + Sync + 'static,
274    ) where
275        T: std::future::Future<Output = Result<(), E>> + Send + 'static,
276        E: std::fmt::Display,
277    {
278        let inner = cx.handle().inner;
279        let pool = inner.async_task_pools.get::<Self>();
280        pool.cancel_all(inner.async_task_runtime.as_ref());
281        pool.spawn(cx.readonly_handle().clone(), future_fn);
282    }
283
284    fn spawn<'a, T, E>(
285        cx: impl AsMistyClientHandle<'a>,
286        future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + Send + Sync + 'static,
287    ) where
288        T: std::future::Future<Output = Result<(), E>> + Send + 'static,
289        E: std::fmt::Display,
290    {
291        let pool = cx.handle().inner.async_task_pools.get::<Self>();
292        pool.spawn(cx.readonly_handle(), future_fn);
293    }
294
295    fn spawn_local_once<'a, T, E>(
296        cx: impl AsMistyClientHandle<'a>,
297        future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + 'static,
298    ) where
299        T: std::future::Future<Output = Result<(), E>> + 'static,
300        E: std::fmt::Display,
301    {
302        let inner = cx.handle().inner;
303        let pool = inner.async_task_pools.get::<Self>();
304        pool.cancel_all(inner.async_task_runtime.as_ref());
305        pool.spawn_local(cx.readonly_handle().clone(), future_fn);
306    }
307
308    fn spawn_local<'a, T, E>(
309        cx: impl AsMistyClientHandle<'a>,
310        future_fn: impl (FnOnce(MistyAsyncTaskContext) -> T) + 'static,
311    ) where
312        T: std::future::Future<Output = Result<(), E>> + 'static,
313        E: std::fmt::Display,
314    {
315        let pool = cx.handle().inner.async_task_pools.get::<Self>();
316        pool.spawn_local(cx.readonly_handle(), future_fn);
317    }
318
319    fn cancel_all<'a>(cx: impl AsMistyClientHandle<'a>) {
320        let inner = cx.handle().inner;
321        let pool = inner.async_task_pools.get::<Self>();
322        pool.cancel_all(inner.async_task_runtime.as_ref());
323    }
324}