1#![allow(clippy::disallowed_types)]
2
3mod park_group;
4mod task;
5
6use std::cell::{Cell, UnsafeCell};
7use std::future::Future;
8use std::marker::PhantomData;
9use std::panic::{AssertUnwindSafe, Location};
10use std::pin::Pin;
11use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
12use std::sync::{Arc, OnceLock, Weak};
13use std::task::{Context, Poll};
14use std::time::{Duration, Instant};
15
16use crossbeam_channel::{Receiver, Sender};
17use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue};
18use crossbeam_utils::CachePadded;
19use park_group::ParkGroup;
20use parking_lot::Mutex;
21use polars_utils::relaxed_cell::RelaxedCell;
22use polars_utils::with_drop::WithDrop;
23use rand::rngs::SmallRng;
24use rand::{Rng, SeedableRng};
25use slotmap::SlotMap;
26use task::{Cancellable, DynTask, Runnable};
27
28thread_local! {
29 pub static ALLOW_RAYON_THREADS: Cell<bool> = const { Cell::new(true) };
30 pub static THREAD_SPAWNED_BY_POLARS_EXECUTOR: Cell<bool> = const { Cell::new(false) };
31
32 static TLS_THREAD_ID: Cell<usize> = const { Cell::new(usize::MAX) };
34}
35
36pub fn is_scheduling_polars_executor_thread() -> bool {
38 TLS_THREAD_ID.get() != usize::MAX
39}
40
41static TRACK_METRICS: RelaxedCell<bool> = RelaxedCell::new_bool(false);
42
43pub fn track_task_metrics(should_track: bool) {
44 TRACK_METRICS.store(should_track);
45}
46
47static GLOBAL_SCHEDULER: OnceLock<Executor> = OnceLock::new();
48
49slotmap::new_key_type! {
50 struct TaskKey;
51}
52
53#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
55pub enum TaskPriority {
56 Low,
57 High,
58}
59
60struct ScopedTaskMetadata {
62 task_key: TaskKey,
63 completed_tasks: Weak<Mutex<Vec<TaskKey>>>,
64}
65
66#[derive(Default)]
67#[repr(align(128))]
68pub struct TaskMetrics {
69 pub total_polls: RelaxedCell<u64>,
70 pub total_stolen_polls: RelaxedCell<u64>,
71 pub total_poll_time_ns: RelaxedCell<u64>,
72 pub max_poll_time_ns: RelaxedCell<u64>,
73 pub done: RelaxedCell<bool>,
74}
75
76struct TaskMetadata {
77 spawn_location: &'static Location<'static>,
78 priority: TaskPriority,
79 freshly_spawned: AtomicBool,
80 scoped: Option<ScopedTaskMetadata>,
81 metrics: Option<Arc<TaskMetrics>>,
82}
83
84impl Drop for TaskMetadata {
85 fn drop(&mut self) {
86 if let Some(metrics) = self.metrics.as_ref() {
87 metrics.done.store(true);
88 }
89
90 if let Some(scoped) = &self.scoped {
91 if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {
92 completed_tasks.lock().push(scoped.task_key);
93 }
94 }
95 }
96}
97
98pub struct JoinHandle<T>(Arc<dyn DynTask<T, TaskMetadata>>);
99pub struct CancelHandle(Weak<dyn Cancellable>);
100
101impl<T> JoinHandle<T> {
102 pub fn metrics(&self) -> Option<&Arc<TaskMetrics>> {
103 self.0.metadata().metrics.as_ref()
104 }
105
106 #[allow(unused)]
107 pub fn spawn_location(&self) -> &'static Location<'static> {
108 self.0.metadata().spawn_location
109 }
110
111 pub fn cancel_handle(&self) -> CancelHandle {
112 let coerce: Weak<dyn DynTask<T, TaskMetadata>> = Arc::downgrade(&self.0);
113 CancelHandle(coerce)
114 }
115}
116
117impl<T> Future for JoinHandle<T> {
118 type Output = T;
119
120 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
121 self.0.poll_join(ctx)
122 }
123}
124
125impl CancelHandle {
126 pub fn cancel(&self) {
127 if let Some(t) = self.0.upgrade() {
128 t.cancel();
129 }
130 }
131}
132
133pub struct AbortOnDropHandle<T> {
134 join_handle: JoinHandle<T>,
135 cancel_handle: CancelHandle,
136}
137
138impl<T> AbortOnDropHandle<T> {
139 pub fn new(join_handle: JoinHandle<T>) -> Self {
140 let cancel_handle = join_handle.cancel_handle();
141 Self {
142 join_handle,
143 cancel_handle,
144 }
145 }
146}
147
148impl<T> Future for AbortOnDropHandle<T> {
149 type Output = T;
150
151 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152 Pin::new(&mut self.join_handle).poll(cx)
153 }
154}
155
156impl<T> Drop for AbortOnDropHandle<T> {
157 fn drop(&mut self) {
158 self.cancel_handle.cancel();
159 }
160}
161
162type ReadyTask = Arc<dyn Runnable<TaskMetadata>>;
164
165struct ThreadLocalTaskList {
167 high_prio_tasks_stealer: Stealer<ReadyTask>,
169
170 high_prio_tasks: WorkQueue<ReadyTask>,
172 local_slot: UnsafeCell<Option<ReadyTask>>,
173}
174
175unsafe impl Sync for ThreadLocalTaskList {}
176
177struct Executor {
178 park_group: ParkGroup,
179 thread_task_lists: Vec<CachePadded<ThreadLocalTaskList>>,
180 global_high_prio_task_queue: Injector<ReadyTask>,
181 global_low_prio_task_queue: Injector<ReadyTask>,
182 thread_id_send: Sender<Arc<AtomicUsize>>,
183 thread_id_recv: Receiver<Arc<AtomicUsize>>,
184 thread_name_idx: AtomicUsize,
185 num_runners_without_identity: AtomicUsize,
186}
187
188impl Executor {
189 fn schedule_task(&self, task: ReadyTask) {
190 let thread = TLS_THREAD_ID.get();
191 let meta = task.metadata();
192 let opt_ttl = self.thread_task_lists.get(thread);
193
194 let mut use_global_queue = opt_ttl.is_none();
195 if meta.freshly_spawned.load(Ordering::Relaxed) {
196 use_global_queue = true;
197 meta.freshly_spawned.store(false, Ordering::Relaxed);
198 }
199
200 if use_global_queue {
201 if meta.priority == TaskPriority::High {
203 self.global_high_prio_task_queue.push(task);
204 } else {
205 self.global_low_prio_task_queue.push(task);
206 }
207 self.park_group.unpark_one();
208 } else {
209 let ttl = opt_ttl.unwrap();
210 let slot = unsafe { &mut *ttl.local_slot.get() };
212
213 if meta.priority == TaskPriority::High {
214 let Some(task) = slot.replace(task) else {
216 return;
219 };
220
221 ttl.high_prio_tasks.push(task);
222 self.park_group.unpark_one();
223 } else {
224 if ttl.high_prio_tasks.is_empty() && slot.is_none() {
227 *slot = Some(task);
228 } else {
229 self.global_low_prio_task_queue.push(task);
230 self.park_group.unpark_one();
231 }
232 }
233 }
234 }
235
236 fn try_steal_task<R: Rng>(&self, thread: usize, rng: &mut R) -> Option<ReadyTask> {
237 loop {
239 match self.global_high_prio_task_queue.steal() {
240 Steal::Empty => break,
241 Steal::Success(task) => return Some(task),
242 Steal::Retry => std::hint::spin_loop(),
243 }
244 }
245
246 loop {
247 match self.global_low_prio_task_queue.steal() {
248 Steal::Empty => break,
249 Steal::Success(task) => return Some(task),
250 Steal::Retry => std::hint::spin_loop(),
251 }
252 }
253
254 let ttl = &self.thread_task_lists[thread];
256 for _ in 0..4 {
257 let mut retry = true;
258 while retry {
259 retry = false;
260
261 for idx in random_permutation(self.thread_task_lists.len() as u32, rng) {
262 let foreign_ttl = &self.thread_task_lists[idx as usize];
263 match foreign_ttl
264 .high_prio_tasks_stealer
265 .steal_batch_and_pop(&ttl.high_prio_tasks)
266 {
267 Steal::Empty => {},
268 Steal::Success(task) => return Some(task),
269 Steal::Retry => retry = true,
270 }
271 }
272
273 std::hint::spin_loop()
274 }
275 }
276
277 None
278 }
279
280 fn runner(&self, initial_thread_id: Option<usize>) {
281 TLS_THREAD_ID.set(initial_thread_id.unwrap_or(usize::MAX));
282 ALLOW_RAYON_THREADS.set(false);
283 THREAD_SPAWNED_BY_POLARS_EXECUTOR.set(true);
284
285 let mut rng = SmallRng::from_rng(&mut rand::rng());
286 let mut worker = self.park_group.new_worker();
287
288 loop {
289 let mut thread_id = TLS_THREAD_ID.get();
291 if thread_id == usize::MAX {
292 if let Some(tid) = self.acquire_thread_identity() {
293 TLS_THREAD_ID.set(tid);
294 thread_id = tid;
295 } else {
296 return;
297 }
298 }
299
300 let ttl = &self.thread_task_lists[thread_id];
301 let mut local = true;
302 let task = (|| {
303 if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } {
305 return Some(task);
306 }
307
308 if let Some(task) = ttl.high_prio_tasks.pop() {
310 return Some(task);
311 }
312
313 local = false;
315 if let Some(task) = self.try_steal_task(thread_id, &mut rng) {
316 return Some(task);
317 }
318
319 let park = worker.prepare_park();
321 if let Some(task) = self.try_steal_task(thread_id, &mut rng) {
322 return Some(task);
323 }
324
325 park.park();
326 None
327 })();
328
329 if let Some(task) = task {
330 worker.recruit_next();
331 if let Some(metrics) = task.metadata().metrics.clone() {
332 let start = Instant::now();
333 task.run();
334 let elapsed_ns = start.elapsed().as_nanos() as u64;
335 metrics.total_polls.fetch_add(1);
336 if !local {
337 metrics.total_stolen_polls.fetch_add(1);
338 }
339 metrics.total_poll_time_ns.fetch_add(elapsed_ns);
340 metrics.max_poll_time_ns.fetch_max(elapsed_ns);
341 } else {
342 task.run();
343 }
344 }
345 }
346 }
347
348 fn spawn_runner_without_identity(&self) {
349 self.num_runners_without_identity
350 .fetch_add(1, Ordering::AcqRel);
351 let t = self.thread_name_idx.fetch_add(1, Ordering::Relaxed);
352 std::thread::Builder::new()
353 .name(format!("async-executor-{t}"))
354 .spawn(move || Self::global().runner(None))
355 .unwrap();
356 }
357
358 fn acquire_thread_identity(&self) -> Option<usize> {
359 loop {
360 match self.thread_id_recv.recv_timeout(Duration::from_secs(10)) {
361 Ok(tid_msg) => {
362 let thread_id = tid_msg.swap(usize::MAX, Ordering::AcqRel);
363 if thread_id != usize::MAX {
364 let num_left = self
366 .num_runners_without_identity
367 .fetch_sub(1, Ordering::AcqRel)
368 - 1;
369 if num_left == 0 && !self.thread_id_recv.is_empty() {
370 self.spawn_runner_without_identity();
371 }
372 return Some(thread_id);
373 }
374 },
375 Err(_) => {
376 self.num_runners_without_identity
378 .fetch_sub(1, Ordering::AcqRel);
379 if self.thread_id_recv.is_empty() {
380 return None;
381 }
382 self.num_runners_without_identity
383 .fetch_add(1, Ordering::AcqRel);
384 },
385 }
386 }
387 }
388
389 fn ensure_runner_without_identity_exists(&self) {
390 if self
391 .num_runners_without_identity
392 .fetch_add(0, Ordering::AcqRel)
393 == 0
394 {
395 self.spawn_runner_without_identity();
396 }
397 }
398
399 fn global() -> &'static Executor {
400 GLOBAL_SCHEDULER.get_or_init(|| {
401 let n_threads = polars_config::config().max_threads();
402 let thread_task_lists = (0..n_threads)
403 .map(|t| {
404 std::thread::Builder::new()
405 .name(format!("async-executor-{t}"))
406 .spawn(move || Self::global().runner(Some(t)))
407 .unwrap();
408
409 let high_prio_tasks = WorkQueue::new_lifo();
410 CachePadded::new(ThreadLocalTaskList {
411 high_prio_tasks_stealer: high_prio_tasks.stealer(),
412 high_prio_tasks,
413 local_slot: UnsafeCell::new(None),
414 })
415 })
416 .collect();
417 let (thread_id_send, thread_id_recv) = crossbeam_channel::unbounded();
418 Self {
419 park_group: ParkGroup::new(),
420 thread_task_lists,
421 global_high_prio_task_queue: Injector::new(),
422 global_low_prio_task_queue: Injector::new(),
423 thread_id_send,
424 thread_id_recv,
425 thread_name_idx: AtomicUsize::new(n_threads),
426 num_runners_without_identity: AtomicUsize::new(0),
427 }
428 })
429 }
430}
431
432pub struct TaskScope<'scope, 'env: 'scope> {
433 cancel_handles: Mutex<SlotMap<TaskKey, CancelHandle>>,
438 completed_tasks: Arc<Mutex<Vec<TaskKey>>>,
439
440 scope: PhantomData<&'scope mut &'scope ()>,
442 env: PhantomData<&'env mut &'env ()>,
443}
444
445impl<'scope> TaskScope<'scope, '_> {
446 fn destroy(&self) {
448 for (_, t) in self.cancel_handles.lock().drain() {
450 t.cancel();
451 }
452 }
453
454 fn clear_completed_tasks(&self) {
455 let mut cancel_handles = self.cancel_handles.lock();
456 for t in self.completed_tasks.lock().drain(..) {
457 cancel_handles.remove(t);
458 }
459 }
460
461 #[track_caller]
462 pub fn spawn_task<F: Future + Send + 'scope>(
463 &self,
464 priority: TaskPriority,
465 fut: F,
466 ) -> JoinHandle<F::Output>
467 where
468 <F as Future>::Output: Send + 'static,
469 {
470 let spawn_location = Location::caller();
471 self.clear_completed_tasks();
472
473 let mut runnable = None;
474 let mut join_handle = None;
475 self.cancel_handles.lock().insert_with_key(|task_key| {
476 let metrics = TRACK_METRICS.load().then(Arc::default);
477 let dyn_task = unsafe {
478 let executor = Executor::global();
480 let on_wake = move |task| executor.schedule_task(task);
481 task::spawn_with_lifetime(
482 fut,
483 on_wake,
484 TaskMetadata {
485 spawn_location,
486 priority,
487 freshly_spawned: AtomicBool::new(true),
488 scoped: Some(ScopedTaskMetadata {
489 task_key,
490 completed_tasks: Arc::downgrade(&self.completed_tasks),
491 }),
492 metrics,
493 },
494 )
495 };
496 runnable = Some(Arc::clone(&dyn_task));
497 let jh = JoinHandle(dyn_task);
498 let cancel_handle = jh.cancel_handle();
499 join_handle = Some(jh);
500 cancel_handle
501 });
502 runnable.unwrap().schedule();
503 join_handle.unwrap()
504 }
505}
506
507pub fn task_scope<'env, F, T>(f: F) -> T
508where
509 F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T,
510{
511 let scope = TaskScope {
515 cancel_handles: Mutex::default(),
516 completed_tasks: Arc::new(Mutex::default()),
517 scope: PhantomData,
518 env: PhantomData,
519 };
520
521 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));
522
523 scope.destroy();
525
526 match result {
527 Err(e) => std::panic::resume_unwind(e),
528 Ok(result) => result,
529 }
530}
531
532#[track_caller]
533pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>
534where
535 <F as Future>::Output: Send + 'static,
536{
537 let spawn_location = Location::caller();
538 let executor = Executor::global();
539 let on_wake = move |task| executor.schedule_task(task);
540 let metrics = TRACK_METRICS.load().then(Arc::default);
541 let dyn_task = task::spawn(
542 fut,
543 on_wake,
544 TaskMetadata {
545 spawn_location,
546 priority,
547 freshly_spawned: AtomicBool::new(true),
548 scoped: None,
549 metrics,
550 },
551 );
552 Arc::clone(&dyn_task).schedule();
553 JoinHandle(dyn_task)
554}
555
556pub fn block_in_place<R, F: FnOnce() -> R>(f: F) -> R {
561 let thread_id = TLS_THREAD_ID.replace(usize::MAX);
562 if thread_id == usize::MAX {
563 return f();
564 }
565
566 let executor = Executor::global();
568 let msg = Arc::new(AtomicUsize::new(thread_id));
569 executor.thread_id_send.send(msg.clone()).unwrap();
570 executor.ensure_runner_without_identity_exists(); let _restore_identity = WithDrop::new(msg, |msg| {
575 let thread_id = msg.swap(usize::MAX, Ordering::AcqRel);
576 if thread_id != usize::MAX {
577 TLS_THREAD_ID.set(thread_id);
578 } else {
579 executor
580 .num_runners_without_identity
581 .fetch_add(1, Ordering::AcqRel);
582 }
583 });
584
585 f()
586}
587
588fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {
589 let modulus = len.next_power_of_two();
590 let halfwidth = modulus.trailing_zeros() / 2;
591 let mask = modulus - 1;
592 let displace_zero = rng.random::<u32>();
593 let odd1 = rng.random::<u32>() | 1;
594 let odd2 = rng.random::<u32>() | 1;
595 let uniform_first = ((rng.random::<u32>() as u64 * len as u64) >> 32) as u32;
596
597 (0..modulus)
598 .map(move |mut i| {
599 i = i.wrapping_add(displace_zero);
601 i = i.wrapping_mul(odd1);
602 i ^= (i & mask) >> halfwidth;
603 i = i.wrapping_mul(odd2);
604 i & mask
605 })
606 .filter(move |i| *i < len)
607 .map(move |mut i| {
608 i += uniform_first;
609 if i >= len {
610 i -= len;
611 }
612 i
613 })
614}