1use crate::innerlude::Effect;
2use crate::innerlude::ScopeOrder;
3use crate::innerlude::{remove_future, spawn, Runtime};
4use crate::scope_context::ScopeStatus;
5use crate::scope_context::SuspenseLocation;
6use crate::ScopeId;
7use futures_util::task::ArcWake;
8use slotmap::DefaultKey;
9use std::marker::PhantomData;
10use std::panic;
11use std::sync::Arc;
12use std::task::Waker;
13use std::{cell::Cell, future::Future};
14use std::{cell::RefCell, rc::Rc};
15use std::{pin::Pin, task::Poll};
16
17#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
21#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
22pub struct Task {
23 pub(crate) id: slotmap::DefaultKey,
24 unsend: PhantomData<*const ()>,
26}
27
28impl Task {
29 pub(crate) const fn from_id(id: slotmap::DefaultKey) -> Self {
31 Self {
32 id,
33 unsend: PhantomData,
34 }
35 }
36
37 pub fn new(task: impl Future<Output = ()> + 'static) -> Self {
47 spawn(task)
48 }
49
50 pub fn cancel(self) {
54 remove_future(self);
55 }
56
57 pub fn pause(&self) {
59 self.set_active(false);
60 }
61
62 pub fn resume(&self) {
64 self.set_active(true);
65 }
66
67 pub fn paused(&self) -> bool {
69 Runtime::with(|rt| {
70 if let Some(task) = rt.tasks.borrow().get(self.id) {
71 !task.active.get()
72 } else {
73 false
74 }
75 })
76 .unwrap_or_default()
77 }
78
79 #[track_caller]
81 pub fn wake(&self) {
82 Runtime::with(|rt| {
83 _ = rt
84 .sender
85 .unbounded_send(SchedulerMsg::TaskNotified(self.id))
86 })
87 .unwrap_or_else(|e| panic!("{}", e))
88 }
89
90 #[track_caller]
92 pub fn poll_now(&self) -> Poll<()> {
93 Runtime::with(|rt| rt.handle_task_wakeup(*self)).unwrap_or_else(|e| panic!("{}", e))
94 }
95
96 #[track_caller]
98 pub fn set_active(&self, active: bool) {
99 Runtime::with(|rt| {
100 if let Some(task) = rt.tasks.borrow().get(self.id) {
101 let was_active = task.active.replace(active);
102 if !was_active && active {
103 _ = rt
104 .sender
105 .unbounded_send(SchedulerMsg::TaskNotified(self.id));
106 }
107 }
108 })
109 .unwrap_or_else(|e| panic!("{}", e))
110 }
111}
112
113impl Runtime {
114 pub fn spawn_isomorphic(
140 &self,
141 scope: ScopeId,
142 task: impl Future<Output = ()> + 'static,
143 ) -> Task {
144 self.spawn_task_of_type(scope, task, TaskType::Isomorphic)
145 }
146
147 pub fn spawn(&self, scope: ScopeId, task: impl Future<Output = ()> + 'static) -> Task {
157 self.spawn_task_of_type(scope, task, TaskType::ClientOnly)
158 }
159
160 fn spawn_task_of_type(
161 &self,
162 scope: ScopeId,
163 task: impl Future<Output = ()> + 'static,
164 ty: TaskType,
165 ) -> Task {
166 self.spawn_task_of_type_inner(scope, Box::pin(task), ty)
167 }
168
169 fn spawn_task_of_type_inner(
171 &self,
172 scope: ScopeId,
173 pinned_task: Pin<Box<dyn Future<Output = ()>>>,
174 ty: TaskType,
175 ) -> Task {
176 let (task, task_id) = {
178 let mut tasks = self.tasks.borrow_mut();
179
180 let mut task_id = Task::from_id(DefaultKey::default());
181 let mut local_task = None;
182 tasks.insert_with_key(|key| {
183 task_id = Task::from_id(key);
184
185 let new_task = Rc::new(LocalTask {
186 scope,
187 active: Cell::new(true),
188 parent: self.current_task(),
189 task: RefCell::new(pinned_task),
190 waker: futures_util::task::waker(Arc::new(LocalTaskHandle {
191 id: task_id.id,
192 tx: self.sender.clone(),
193 })),
194 ty: RefCell::new(ty),
195 });
196
197 local_task = Some(new_task.clone());
198
199 new_task
200 });
201
202 (local_task.unwrap(), task_id)
203 };
204
205 debug_assert!(self.tasks.try_borrow_mut().is_ok());
207 debug_assert!(task.task.try_borrow_mut().is_ok());
208
209 self.sender
210 .unbounded_send(SchedulerMsg::TaskNotified(task_id.id))
211 .expect("Scheduler should exist");
212
213 task_id
214 }
215
216 pub(crate) fn queue_effect(&self, id: ScopeId, f: impl FnOnce() + 'static) {
218 let effect = Box::new(f) as Box<dyn FnOnce() + 'static>;
219 let Some(scope) = self.get_state(id) else {
220 return;
221 };
222 let mut status = scope.status.borrow_mut();
223 match &mut *status {
224 ScopeStatus::Mounted => {
225 self.queue_effect_on_mounted_scope(id, effect);
226 }
227 ScopeStatus::Unmounted { effects_queued, .. } => {
228 effects_queued.push(effect);
229 }
230 }
231 }
232
233 pub(crate) fn queue_effect_on_mounted_scope(
235 &self,
236 id: ScopeId,
237 f: Box<dyn FnOnce() + 'static>,
238 ) {
239 let mut effects = self.pending_effects.borrow_mut();
241 let scope_order = ScopeOrder::new(id.height(), id);
242 match effects.get(&scope_order) {
243 Some(effects) => effects.push_back(f),
244 None => {
245 effects.insert(Effect::new(scope_order, f));
246 }
247 }
248 }
249
250 pub fn current_task(&self) -> Option<Task> {
252 self.current_task.get()
253 }
254
255 pub fn parent_task(&self, task: Task) -> Option<Task> {
257 self.tasks.borrow().get(task.id)?.parent
258 }
259
260 pub(crate) fn task_scope(&self, task: Task) -> Option<ScopeId> {
261 self.tasks.borrow().get(task.id).map(|t| t.scope)
262 }
263
264 #[track_caller]
265 pub(crate) fn handle_task_wakeup(&self, id: Task) -> Poll<()> {
266 #[cfg(debug_assertions)]
267 {
268 Runtime::current().unwrap_or_else(|e| panic!("{}", e));
270 }
271
272 let task = self.tasks.borrow().get(id.id).cloned();
273
274 let Some(task) = task else {
276 return Poll::Ready(());
277 };
278
279 if !task.active.get() {
281 return Poll::Pending;
282 }
283
284 let mut cx = std::task::Context::from_waker(&task.waker);
285
286 let poll_result = self.with_scope_on_stack(task.scope, || {
288 self.current_task.set(Some(id));
289
290 let poll_result = task.task.borrow_mut().as_mut().poll(&mut cx);
291
292 if poll_result.is_ready() {
293 self.get_state(task.scope)
295 .unwrap()
296 .spawned_tasks
297 .borrow_mut()
298 .remove(&id);
299
300 self.remove_task(id);
301 }
302
303 poll_result
304 });
305 self.current_task.set(None);
306
307 poll_result
308 }
309
310 pub(crate) fn remove_task(&self, id: Task) -> Option<Rc<LocalTask>> {
314 let task = self.tasks.borrow_mut().remove(id.id);
316
317 if let Some(task) = &task {
318 if let TaskType::Suspended { boundary } = &*task.ty.borrow() {
320 self.suspended_tasks.set(self.suspended_tasks.get() - 1);
321 if let SuspenseLocation::UnderSuspense(boundary) = boundary {
322 boundary.remove_suspended_task(id);
323 }
324 }
325
326 if let Some(scope) = self.get_state(task.scope) {
328 let order = ScopeOrder::new(scope.height(), scope.id);
329 if let Some(dirty_tasks) = self.dirty_tasks.borrow_mut().get(&order) {
330 dirty_tasks.remove(id);
331 }
332 }
333 }
334
335 task
336 }
337
338 pub(crate) fn task_runs_during_suspense(&self, task: Task) -> bool {
340 let borrow = self.tasks.borrow();
341 let task: Option<&LocalTask> = borrow.get(task.id).map(|t| &**t);
342 matches!(task, Some(LocalTask { ty, .. }) if ty.borrow().runs_during_suspense())
343 }
344}
345
346pub(crate) struct LocalTask {
348 scope: ScopeId,
349 parent: Option<Task>,
350 task: RefCell<Pin<Box<dyn Future<Output = ()> + 'static>>>,
351 waker: Waker,
352 ty: RefCell<TaskType>,
353 active: Cell<bool>,
354}
355
356impl LocalTask {
357 pub(crate) fn suspend(&self, boundary: SuspenseLocation) -> bool {
359 let old_type = self.ty.replace(TaskType::Suspended { boundary });
361 matches!(old_type, TaskType::Suspended { .. })
362 }
363}
364
365#[derive(Clone)]
366enum TaskType {
367 ClientOnly,
368 Suspended { boundary: SuspenseLocation },
369 Isomorphic,
370}
371
372impl TaskType {
373 fn runs_during_suspense(&self) -> bool {
374 matches!(self, TaskType::Isomorphic | TaskType::Suspended { .. })
375 }
376}
377
378#[derive(Debug)]
382pub(crate) enum SchedulerMsg {
383 AllDirty,
385
386 Immediate(ScopeId),
388
389 TaskNotified(slotmap::DefaultKey),
391
392 EffectQueued,
394}
395
396struct LocalTaskHandle {
397 id: slotmap::DefaultKey,
398 tx: futures_channel::mpsc::UnboundedSender<SchedulerMsg>,
399}
400
401impl ArcWake for LocalTaskHandle {
402 fn wake_by_ref(arc_self: &Arc<Self>) {
403 _ = arc_self
404 .tx
405 .unbounded_send(SchedulerMsg::TaskNotified(arc_self.id));
406 }
407}