1use crate::innerlude::Effect;
2use crate::innerlude::ScopeOrder;
3use crate::innerlude::{remove_future, spawn, Runtime};
4use crate::scope_context::ScopeStatus;
5use crate::scope_context::SuspenseLocation;
6use crate::ScopeId;
7use futures_util::task::ArcWake;
8use slotmap::DefaultKey;
9use std::marker::PhantomData;
10use std::panic;
11use std::sync::Arc;
12use std::task::Waker;
13use std::{cell::Cell, future::Future};
14use std::{cell::RefCell, rc::Rc};
15use std::{pin::Pin, task::Poll};
16
17#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
21#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
22pub struct Task {
23 pub(crate) id: slotmap::DefaultKey,
24 unsend: PhantomData<*const ()>,
26}
27
28impl Task {
29 pub(crate) const fn from_id(id: slotmap::DefaultKey) -> Self {
31 Self {
32 id,
33 unsend: PhantomData,
34 }
35 }
36
37 pub fn new(task: impl Future<Output = ()> + 'static) -> Self {
47 spawn(task)
48 }
49
50 pub fn cancel(self) {
54 remove_future(self);
55 }
56
57 pub fn pause(&self) {
59 self.set_active(false);
60 }
61
62 pub fn resume(&self) {
64 self.set_active(true);
65 }
66
67 pub fn paused(&self) -> bool {
69 Runtime::with(|rt| {
70 if let Some(task) = rt.tasks.borrow().get(self.id) {
71 !task.active.get()
72 } else {
73 false
74 }
75 })
76 .unwrap_or_default()
77 }
78
79 #[track_caller]
81 pub fn wake(&self) {
82 Runtime::with(|rt| {
83 _ = rt
84 .sender
85 .unbounded_send(SchedulerMsg::TaskNotified(self.id))
86 })
87 .unwrap_or_else(|e| panic!("{}", e))
88 }
89
90 #[track_caller]
92 pub fn poll_now(&self) -> Poll<()> {
93 Runtime::with(|rt| rt.handle_task_wakeup(*self)).unwrap_or_else(|e| panic!("{}", e))
94 }
95
96 #[track_caller]
98 pub fn set_active(&self, active: bool) {
99 Runtime::with(|rt| {
100 if let Some(task) = rt.tasks.borrow().get(self.id) {
101 let was_active = task.active.replace(active);
102 if !was_active && active {
103 _ = rt
104 .sender
105 .unbounded_send(SchedulerMsg::TaskNotified(self.id));
106 }
107 }
108 })
109 .unwrap_or_else(|e| panic!("{}", e))
110 }
111}
112
113impl Runtime {
114 pub fn spawn_isomorphic(
141 &self,
142 scope: ScopeId,
143 task: impl Future<Output = ()> + 'static,
144 ) -> Task {
145 self.spawn_task_of_type(scope, task, TaskType::Isomorphic)
146 }
147
148 pub fn spawn(&self, scope: ScopeId, task: impl Future<Output = ()> + 'static) -> Task {
158 self.spawn_task_of_type(scope, task, TaskType::ClientOnly)
159 }
160
161 fn spawn_task_of_type(
162 &self,
163 scope: ScopeId,
164 task: impl Future<Output = ()> + 'static,
165 ty: TaskType,
166 ) -> Task {
167 self.spawn_task_of_type_inner(scope, Box::pin(task), ty)
168 }
169
170 fn spawn_task_of_type_inner(
172 &self,
173 scope: ScopeId,
174 pinned_task: Pin<Box<dyn Future<Output = ()>>>,
175 ty: TaskType,
176 ) -> Task {
177 let (task, task_id) = {
179 let mut tasks = self.tasks.borrow_mut();
180
181 let mut task_id = Task::from_id(DefaultKey::default());
182 let mut local_task = None;
183 tasks.insert_with_key(|key| {
184 task_id = Task::from_id(key);
185
186 let new_task = Rc::new(LocalTask {
187 scope,
188 active: Cell::new(true),
189 parent: self.current_task(),
190 task: RefCell::new(pinned_task),
191 waker: futures_util::task::waker(Arc::new(LocalTaskHandle {
192 id: task_id.id,
193 tx: self.sender.clone(),
194 })),
195 ty: RefCell::new(ty),
196 });
197
198 local_task = Some(new_task.clone());
199
200 new_task
201 });
202
203 (local_task.unwrap(), task_id)
204 };
205
206 debug_assert!(self.tasks.try_borrow_mut().is_ok());
208 debug_assert!(task.task.try_borrow_mut().is_ok());
209
210 self.sender
211 .unbounded_send(SchedulerMsg::TaskNotified(task_id.id))
212 .expect("Scheduler should exist");
213
214 task_id
215 }
216
217 pub(crate) fn queue_effect(&self, id: ScopeId, f: impl FnOnce() + 'static) {
219 let effect = Box::new(f) as Box<dyn FnOnce() + 'static>;
220 let Some(scope) = self.get_state(id) else {
221 return;
222 };
223 let mut status = scope.status.borrow_mut();
224 match &mut *status {
225 ScopeStatus::Mounted => {
226 self.queue_effect_on_mounted_scope(id, effect);
227 }
228 ScopeStatus::Unmounted { effects_queued, .. } => {
229 effects_queued.push(effect);
230 }
231 }
232 }
233
234 pub(crate) fn queue_effect_on_mounted_scope(
236 &self,
237 id: ScopeId,
238 f: Box<dyn FnOnce() + 'static>,
239 ) {
240 let mut effects = self.pending_effects.borrow_mut();
242 let scope_order = ScopeOrder::new(id.height(), id);
243 match effects.get(&scope_order) {
244 Some(effects) => effects.push_back(f),
245 None => {
246 effects.insert(Effect::new(scope_order, f));
247 }
248 }
249 }
250
251 pub fn current_task(&self) -> Option<Task> {
253 self.current_task.get()
254 }
255
256 pub fn parent_task(&self, task: Task) -> Option<Task> {
258 self.tasks.borrow().get(task.id)?.parent
259 }
260
261 pub(crate) fn task_scope(&self, task: Task) -> Option<ScopeId> {
262 self.tasks.borrow().get(task.id).map(|t| t.scope)
263 }
264
265 #[track_caller]
266 pub(crate) fn handle_task_wakeup(&self, id: Task) -> Poll<()> {
267 #[cfg(debug_assertions)]
268 {
269 Runtime::current().unwrap_or_else(|e| panic!("{}", e));
271 }
272
273 let task = self.tasks.borrow().get(id.id).cloned();
274
275 let Some(task) = task else {
277 return Poll::Ready(());
278 };
279
280 if !task.active.get() {
282 return Poll::Pending;
283 }
284
285 let mut cx = std::task::Context::from_waker(&task.waker);
286
287 let poll_result = self.with_scope_on_stack(task.scope, || {
289 self.current_task.set(Some(id));
290
291 let poll_result = task.task.borrow_mut().as_mut().poll(&mut cx);
292
293 if poll_result.is_ready() {
294 self.get_state(task.scope)
296 .unwrap()
297 .spawned_tasks
298 .borrow_mut()
299 .remove(&id);
300
301 self.remove_task(id);
302 }
303
304 poll_result
305 });
306 self.current_task.set(None);
307
308 poll_result
309 }
310
311 pub(crate) fn remove_task(&self, id: Task) -> Option<Rc<LocalTask>> {
315 let task = self.tasks.borrow_mut().remove(id.id);
317
318 if let Some(task) = &task {
319 if let TaskType::Suspended { boundary } = &*task.ty.borrow() {
321 self.suspended_tasks.set(self.suspended_tasks.get() - 1);
322 if let SuspenseLocation::UnderSuspense(boundary) = boundary {
323 boundary.remove_suspended_task(id);
324 }
325 }
326
327 if let Some(scope) = self.get_state(task.scope) {
329 let order = ScopeOrder::new(scope.height(), scope.id);
330 if let Some(dirty_tasks) = self.dirty_tasks.borrow_mut().get(&order) {
331 dirty_tasks.remove(id);
332 }
333 }
334 }
335
336 task
337 }
338
339 pub(crate) fn task_runs_during_suspense(&self, task: Task) -> bool {
341 let borrow = self.tasks.borrow();
342 let task: Option<&LocalTask> = borrow.get(task.id).map(|t| &**t);
343 matches!(task, Some(LocalTask { ty, .. }) if ty.borrow().runs_during_suspense())
344 }
345}
346
347pub(crate) struct LocalTask {
349 scope: ScopeId,
350 parent: Option<Task>,
351 task: RefCell<Pin<Box<dyn Future<Output = ()> + 'static>>>,
352 waker: Waker,
353 ty: RefCell<TaskType>,
354 active: Cell<bool>,
355}
356
357impl LocalTask {
358 pub(crate) fn suspend(&self, boundary: SuspenseLocation) -> bool {
360 let old_type = self.ty.replace(TaskType::Suspended { boundary });
362 matches!(old_type, TaskType::Suspended { .. })
363 }
364}
365
366#[derive(Clone)]
367enum TaskType {
368 ClientOnly,
369 Suspended { boundary: SuspenseLocation },
370 Isomorphic,
371}
372
373impl TaskType {
374 fn runs_during_suspense(&self) -> bool {
375 matches!(self, TaskType::Isomorphic | TaskType::Suspended { .. })
376 }
377}
378
379#[derive(Debug)]
383pub(crate) enum SchedulerMsg {
384 AllDirty,
386
387 Immediate(ScopeId),
389
390 TaskNotified(slotmap::DefaultKey),
392
393 EffectQueued,
395}
396
397struct LocalTaskHandle {
398 id: slotmap::DefaultKey,
399 tx: futures_channel::mpsc::UnboundedSender<SchedulerMsg>,
400}
401
402impl ArcWake for LocalTaskHandle {
403 fn wake_by_ref(arc_self: &Arc<Self>) {
404 _ = arc_self
405 .tx
406 .unbounded_send(SchedulerMsg::TaskNotified(arc_self.id));
407 }
408}