1use crate::innerlude::Effect;
2use crate::innerlude::ScopeOrder;
3use crate::innerlude::{remove_future, spawn, Runtime};
4use crate::scope_context::ScopeStatus;
5use crate::scope_context::SuspenseLocation;
6use crate::ScopeId;
7use futures_util::task::ArcWake;
8use slotmap::DefaultKey;
9use std::marker::PhantomData;
10use std::sync::Arc;
11use std::task::Waker;
12use std::{cell::Cell, future::Future};
13use std::{cell::RefCell, rc::Rc};
14use std::{pin::Pin, task::Poll};
15
16#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
20#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
21pub struct Task {
22 pub(crate) id: TaskId,
23 unsend: PhantomData<*const ()>,
25}
26
27pub(crate) type TaskId = slotmap::DefaultKey;
28
29impl Task {
30 pub(crate) const fn from_id(id: slotmap::DefaultKey) -> Self {
32 Self {
33 id,
34 unsend: PhantomData,
35 }
36 }
37
38 pub fn new(task: impl Future<Output = ()> + 'static) -> Self {
48 spawn(task)
49 }
50
51 pub fn cancel(self) {
53 remove_future(self);
54 }
55
56 pub fn pause(&self) {
58 self.set_active(false);
59 }
60
61 pub fn resume(&self) {
63 self.set_active(true);
64 }
65
66 pub fn paused(&self) -> bool {
68 Runtime::with(|rt| {
69 if let Some(task) = rt.tasks.borrow().get(self.id) {
70 !task.active.get()
71 } else {
72 false
73 }
74 })
75 }
76
77 #[track_caller]
79 pub fn wake(&self) {
80 Runtime::with(|rt| {
81 _ = rt
82 .sender
83 .unbounded_send(SchedulerMsg::TaskNotified(self.id))
84 })
85 }
86
87 #[track_caller]
89 pub fn poll_now(&self) -> Poll<()> {
90 Runtime::with(|rt| rt.handle_task_wakeup(*self))
91 }
92
93 #[track_caller]
95 pub fn set_active(&self, active: bool) {
96 Runtime::with(|rt| {
97 if let Some(task) = rt.tasks.borrow().get(self.id) {
98 let was_active = task.active.replace(active);
99 if !was_active && active {
100 _ = rt
101 .sender
102 .unbounded_send(SchedulerMsg::TaskNotified(self.id));
103 }
104 }
105 })
106 }
107}
108
109impl Runtime {
110 pub fn spawn_isomorphic(
137 &self,
138 scope: ScopeId,
139 task: impl Future<Output = ()> + 'static,
140 ) -> Task {
141 self.spawn_task_of_type(scope, task, TaskType::Isomorphic)
142 }
143
144 pub fn spawn(&self, scope: ScopeId, task: impl Future<Output = ()> + 'static) -> Task {
154 self.spawn_task_of_type(scope, task, TaskType::ClientOnly)
155 }
156
157 fn spawn_task_of_type(
158 &self,
159 scope: ScopeId,
160 task: impl Future<Output = ()> + 'static,
161 ty: TaskType,
162 ) -> Task {
163 self.spawn_task_of_type_inner(scope, Box::pin(task), ty)
164 }
165
166 fn spawn_task_of_type_inner(
168 &self,
169 scope: ScopeId,
170 pinned_task: Pin<Box<dyn Future<Output = ()>>>,
171 ty: TaskType,
172 ) -> Task {
173 let (task, task_id) = {
175 let mut tasks = self.tasks.borrow_mut();
176
177 let mut task_id = Task::from_id(DefaultKey::default());
178 let mut local_task = None;
179 tasks.insert_with_key(|key| {
180 task_id = Task::from_id(key);
181
182 let new_task = Rc::new(LocalTask {
183 scope,
184 active: Cell::new(true),
185 parent: self.current_task(),
186 task: RefCell::new(pinned_task),
187 waker: futures_util::task::waker(Arc::new(LocalTaskHandle {
188 id: task_id.id,
189 tx: self.sender.clone(),
190 })),
191 ty: RefCell::new(ty),
192 });
193
194 local_task = Some(new_task.clone());
195
196 new_task
197 });
198
199 (local_task.unwrap(), task_id)
200 };
201
202 debug_assert!(self.tasks.try_borrow_mut().is_ok());
204 debug_assert!(task.task.try_borrow_mut().is_ok());
205
206 self.sender
207 .unbounded_send(SchedulerMsg::TaskNotified(task_id.id))
208 .expect("Scheduler should exist");
209
210 task_id
211 }
212
213 pub(crate) fn queue_effect(&self, id: ScopeId, f: impl FnOnce() + 'static) {
215 let effect = Box::new(f) as Box<dyn FnOnce() + 'static>;
216 let Some(scope) = self.try_get_state(id) else {
217 return;
218 };
219 let mut status = scope.status.borrow_mut();
220 match &mut *status {
221 ScopeStatus::Mounted => {
222 self.queue_effect_on_mounted_scope(id, effect);
223 }
224 ScopeStatus::Unmounted { effects_queued, .. } => {
225 effects_queued.push(effect);
226 }
227 }
228 }
229
230 pub(crate) fn queue_effect_on_mounted_scope(
232 &self,
233 id: ScopeId,
234 f: Box<dyn FnOnce() + 'static>,
235 ) {
236 let mut effects = self.pending_effects.borrow_mut();
238 let height = self.get_state(id).height();
239 let scope_order = ScopeOrder::new(height, id);
240 match effects.get(&scope_order) {
241 Some(effects) => effects.push_back(f),
242 None => {
243 effects.insert(Effect::new(scope_order, f));
244 }
245 }
246 }
247
248 pub fn current_task(&self) -> Option<Task> {
250 self.current_task.get()
251 }
252
253 pub fn parent_task(&self, task: Task) -> Option<Task> {
255 self.tasks.borrow().get(task.id)?.parent
256 }
257
258 pub(crate) fn task_scope(&self, task: Task) -> Option<ScopeId> {
259 self.tasks.borrow().get(task.id).map(|t| t.scope)
260 }
261
262 #[track_caller]
263 pub(crate) fn handle_task_wakeup(&self, id: Task) -> Poll<()> {
264 #[cfg(debug_assertions)]
265 {
266 Runtime::current();
268 }
269
270 let task = self.tasks.borrow().get(id.id).cloned();
271
272 let Some(task) = task else {
274 return Poll::Ready(());
275 };
276
277 if !task.active.get() {
279 return Poll::Pending;
280 }
281
282 let mut cx = std::task::Context::from_waker(&task.waker);
283
284 let poll_result = self.with_scope_on_stack(task.scope, || {
286 self.current_task.set(Some(id));
287
288 let poll_result = task.task.borrow_mut().as_mut().poll(&mut cx);
289
290 if poll_result.is_ready() {
291 self.get_state(task.scope)
293 .spawned_tasks
294 .borrow_mut()
295 .remove(&id);
296
297 self.remove_task(id);
298 }
299
300 poll_result
301 });
302 self.current_task.set(None);
303
304 poll_result
305 }
306
307 pub(crate) fn remove_task(&self, id: Task) -> Option<Rc<LocalTask>> {
311 let task = self.tasks.borrow_mut().remove(id.id);
313
314 if let Some(task) = &task {
315 if let TaskType::Suspended { boundary } = &*task.ty.borrow() {
317 self.suspended_tasks.set(self.suspended_tasks.get() - 1);
318 if let SuspenseLocation::UnderSuspense(boundary) = boundary {
319 boundary.remove_suspended_task(id);
320 }
321 }
322
323 if let Some(scope) = self.try_get_state(task.scope) {
325 let order = ScopeOrder::new(scope.height(), scope.id);
326 if let Some(dirty_tasks) = self.dirty_tasks.borrow_mut().get(&order) {
327 dirty_tasks.remove(id);
328 }
329 }
330 }
331
332 task
333 }
334
335 pub(crate) fn task_runs_during_suspense(&self, task: Task) -> bool {
337 let borrow = self.tasks.borrow();
338 let task: Option<&LocalTask> = borrow.get(task.id).map(|t| &**t);
339 matches!(task, Some(LocalTask { ty, .. }) if ty.borrow().runs_during_suspense())
340 }
341}
342
343pub(crate) struct LocalTask {
345 scope: ScopeId,
346 parent: Option<Task>,
347 task: RefCell<Pin<Box<dyn Future<Output = ()> + 'static>>>,
348 waker: Waker,
349 ty: RefCell<TaskType>,
350 active: Cell<bool>,
351}
352
353impl LocalTask {
354 pub(crate) fn suspend(&self, boundary: SuspenseLocation) -> bool {
356 let old_type = self.ty.replace(TaskType::Suspended { boundary });
358 matches!(old_type, TaskType::Suspended { .. })
359 }
360}
361
362#[derive(Clone)]
363enum TaskType {
364 ClientOnly,
365 Suspended { boundary: SuspenseLocation },
366 Isomorphic,
367}
368
369impl TaskType {
370 fn runs_during_suspense(&self) -> bool {
371 matches!(self, TaskType::Isomorphic | TaskType::Suspended { .. })
372 }
373}
374
375#[derive(Debug)]
379pub(crate) enum SchedulerMsg {
380 #[allow(unused)]
382 AllDirty,
383
384 Immediate(ScopeId),
386
387 TaskNotified(slotmap::DefaultKey),
389
390 EffectQueued,
392}
393
394struct LocalTaskHandle {
395 id: slotmap::DefaultKey,
396 tx: futures_channel::mpsc::UnboundedSender<SchedulerMsg>,
397}
398
399impl ArcWake for LocalTaskHandle {
400 fn wake_by_ref(arc_self: &Arc<Self>) {
401 _ = arc_self
402 .tx
403 .unbounded_send(SchedulerMsg::TaskNotified(arc_self.id));
404 }
405}