Skip to main content

polars_async/executor/
mod.rs

1#![allow(clippy::disallowed_types)]
2
3mod park_group;
4mod task;
5
6use std::cell::{Cell, UnsafeCell};
7use std::future::Future;
8use std::marker::PhantomData;
9use std::panic::{AssertUnwindSafe, Location};
10use std::pin::Pin;
11use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
12use std::sync::{Arc, OnceLock, Weak};
13use std::task::{Context, Poll};
14use std::time::{Duration, Instant};
15
16use crossbeam_channel::{Receiver, Sender};
17use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue};
18use crossbeam_utils::CachePadded;
19use park_group::ParkGroup;
20use parking_lot::Mutex;
21use polars_utils::relaxed_cell::RelaxedCell;
22use polars_utils::with_drop::WithDrop;
23use rand::rngs::SmallRng;
24use rand::{Rng, SeedableRng};
25use slotmap::SlotMap;
26use task::{Cancellable, DynTask, Runnable};
27
28thread_local! {
29    pub static ALLOW_RAYON_THREADS: Cell<bool> = const { Cell::new(true) };
30    pub static THREAD_SPAWNED_BY_POLARS_EXECUTOR: Cell<bool> = const { Cell::new(false) };
31
32    /// Used to store which executor thread this is.
33    static TLS_THREAD_ID: Cell<usize> = const { Cell::new(usize::MAX) };
34}
35
36/// Returns whether this thread is actively used for scheduling tasks.
37pub fn is_scheduling_polars_executor_thread() -> bool {
38    TLS_THREAD_ID.get() != usize::MAX
39}
40
41static TRACK_METRICS: RelaxedCell<bool> = RelaxedCell::new_bool(false);
42
43pub fn track_task_metrics(should_track: bool) {
44    TRACK_METRICS.store(should_track);
45}
46
47static GLOBAL_SCHEDULER: OnceLock<Executor> = OnceLock::new();
48
49slotmap::new_key_type! {
50    struct TaskKey;
51}
52
53/// High priority tasks are scheduled preferentially over low priority tasks.
54#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
55pub enum TaskPriority {
56    Low,
57    High,
58}
59
60/// Metadata associated with a task to help schedule it and clean it up.
61struct ScopedTaskMetadata {
62    task_key: TaskKey,
63    completed_tasks: Weak<Mutex<Vec<TaskKey>>>,
64}
65
66#[derive(Default)]
67#[repr(align(128))]
68pub struct TaskMetrics {
69    pub total_polls: RelaxedCell<u64>,
70    pub total_stolen_polls: RelaxedCell<u64>,
71    pub total_poll_time_ns: RelaxedCell<u64>,
72    pub max_poll_time_ns: RelaxedCell<u64>,
73    pub done: RelaxedCell<bool>,
74}
75
76struct TaskMetadata {
77    spawn_location: &'static Location<'static>,
78    priority: TaskPriority,
79    freshly_spawned: AtomicBool,
80    scoped: Option<ScopedTaskMetadata>,
81    metrics: Option<Arc<TaskMetrics>>,
82}
83
84impl Drop for TaskMetadata {
85    fn drop(&mut self) {
86        if let Some(metrics) = self.metrics.as_ref() {
87            metrics.done.store(true);
88        }
89
90        if let Some(scoped) = &self.scoped {
91            if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {
92                completed_tasks.lock().push(scoped.task_key);
93            }
94        }
95    }
96}
97
98pub struct JoinHandle<T>(Arc<dyn DynTask<T, TaskMetadata>>);
99pub struct CancelHandle(Weak<dyn Cancellable>);
100
101impl<T> JoinHandle<T> {
102    pub fn metrics(&self) -> Option<&Arc<TaskMetrics>> {
103        self.0.metadata().metrics.as_ref()
104    }
105
106    #[allow(unused)]
107    pub fn spawn_location(&self) -> &'static Location<'static> {
108        self.0.metadata().spawn_location
109    }
110
111    pub fn cancel_handle(&self) -> CancelHandle {
112        let coerce: Weak<dyn DynTask<T, TaskMetadata>> = Arc::downgrade(&self.0);
113        CancelHandle(coerce)
114    }
115}
116
117impl<T> Future for JoinHandle<T> {
118    type Output = T;
119
120    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
121        self.0.poll_join(ctx)
122    }
123}
124
125impl CancelHandle {
126    pub fn cancel(&self) {
127        if let Some(t) = self.0.upgrade() {
128            t.cancel();
129        }
130    }
131}
132
133pub struct AbortOnDropHandle<T> {
134    join_handle: JoinHandle<T>,
135    cancel_handle: CancelHandle,
136}
137
138impl<T> AbortOnDropHandle<T> {
139    pub fn new(join_handle: JoinHandle<T>) -> Self {
140        let cancel_handle = join_handle.cancel_handle();
141        Self {
142            join_handle,
143            cancel_handle,
144        }
145    }
146}
147
148impl<T> Future for AbortOnDropHandle<T> {
149    type Output = T;
150
151    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152        Pin::new(&mut self.join_handle).poll(cx)
153    }
154}
155
156impl<T> Drop for AbortOnDropHandle<T> {
157    fn drop(&mut self) {
158        self.cancel_handle.cancel();
159    }
160}
161
162/// A task ready to run.
163type ReadyTask = Arc<dyn Runnable<TaskMetadata>>;
164
165/// A per-thread task list.
166struct ThreadLocalTaskList {
167    // May be used from any thread.
168    high_prio_tasks_stealer: Stealer<ReadyTask>,
169
170    // SAFETY: these may only be used on the thread this task list belongs to.
171    high_prio_tasks: WorkQueue<ReadyTask>,
172    local_slot: UnsafeCell<Option<ReadyTask>>,
173}
174
175unsafe impl Sync for ThreadLocalTaskList {}
176
177struct Executor {
178    park_group: ParkGroup,
179    thread_task_lists: Vec<CachePadded<ThreadLocalTaskList>>,
180    global_high_prio_task_queue: Injector<ReadyTask>,
181    global_low_prio_task_queue: Injector<ReadyTask>,
182    thread_id_send: Sender<Arc<AtomicUsize>>,
183    thread_id_recv: Receiver<Arc<AtomicUsize>>,
184    thread_name_idx: AtomicUsize,
185    num_runners_without_identity: AtomicUsize,
186}
187
188impl Executor {
189    fn schedule_task(&self, task: ReadyTask) {
190        let thread = TLS_THREAD_ID.get();
191        let meta = task.metadata();
192        let opt_ttl = self.thread_task_lists.get(thread);
193
194        let mut use_global_queue = opt_ttl.is_none();
195        if meta.freshly_spawned.load(Ordering::Relaxed) {
196            use_global_queue = true;
197            meta.freshly_spawned.store(false, Ordering::Relaxed);
198        }
199
200        if use_global_queue {
201            // Scheduled from an unknown thread, add to global queue.
202            if meta.priority == TaskPriority::High {
203                self.global_high_prio_task_queue.push(task);
204            } else {
205                self.global_low_prio_task_queue.push(task);
206            }
207            self.park_group.unpark_one();
208        } else {
209            let ttl = opt_ttl.unwrap();
210            // SAFETY: this slot may only be accessed from the local thread, which we are.
211            let slot = unsafe { &mut *ttl.local_slot.get() };
212
213            if meta.priority == TaskPriority::High {
214                // Insert new task into thread local slot, taking out the old task.
215                let Some(task) = slot.replace(task) else {
216                    // We pushed a task into our local slot which was empty. Since
217                    // we are already awake, no need to notify anyone.
218                    return;
219                };
220
221                ttl.high_prio_tasks.push(task);
222                self.park_group.unpark_one();
223            } else {
224                // Optimization: while this is a low priority task we have no
225                // high priority tasks on this thread so we'll execute this one.
226                if ttl.high_prio_tasks.is_empty() && slot.is_none() {
227                    *slot = Some(task);
228                } else {
229                    self.global_low_prio_task_queue.push(task);
230                    self.park_group.unpark_one();
231                }
232            }
233        }
234    }
235
236    fn try_steal_task<R: Rng>(&self, thread: usize, rng: &mut R) -> Option<ReadyTask> {
237        // Try to get a global task.
238        loop {
239            match self.global_high_prio_task_queue.steal() {
240                Steal::Empty => break,
241                Steal::Success(task) => return Some(task),
242                Steal::Retry => std::hint::spin_loop(),
243            }
244        }
245
246        loop {
247            match self.global_low_prio_task_queue.steal() {
248                Steal::Empty => break,
249                Steal::Success(task) => return Some(task),
250                Steal::Retry => std::hint::spin_loop(),
251            }
252        }
253
254        // Try to steal tasks.
255        let ttl = &self.thread_task_lists[thread];
256        for _ in 0..4 {
257            let mut retry = true;
258            while retry {
259                retry = false;
260
261                for idx in random_permutation(self.thread_task_lists.len() as u32, rng) {
262                    let foreign_ttl = &self.thread_task_lists[idx as usize];
263                    match foreign_ttl
264                        .high_prio_tasks_stealer
265                        .steal_batch_and_pop(&ttl.high_prio_tasks)
266                    {
267                        Steal::Empty => {},
268                        Steal::Success(task) => return Some(task),
269                        Steal::Retry => retry = true,
270                    }
271                }
272
273                std::hint::spin_loop()
274            }
275        }
276
277        None
278    }
279
280    fn runner(&self, initial_thread_id: Option<usize>) {
281        TLS_THREAD_ID.set(initial_thread_id.unwrap_or(usize::MAX));
282        ALLOW_RAYON_THREADS.set(false);
283        THREAD_SPAWNED_BY_POLARS_EXECUTOR.set(true);
284
285        let mut rng = SmallRng::from_rng(&mut rand::rng());
286        let mut worker = self.park_group.new_worker();
287
288        loop {
289            // If we're a runner without an assigned thread id, get one.
290            let mut thread_id = TLS_THREAD_ID.get();
291            if thread_id == usize::MAX {
292                if let Some(tid) = self.acquire_thread_identity() {
293                    TLS_THREAD_ID.set(tid);
294                    thread_id = tid;
295                } else {
296                    return;
297                }
298            }
299
300            let ttl = &self.thread_task_lists[thread_id];
301            let mut local = true;
302            let task = (|| {
303                // Try to get a task from LIFO slot.
304                if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } {
305                    return Some(task);
306                }
307
308                // Try to get a local high-priority task.
309                if let Some(task) = ttl.high_prio_tasks.pop() {
310                    return Some(task);
311                }
312
313                // Try to steal a task.
314                local = false;
315                if let Some(task) = self.try_steal_task(thread_id, &mut rng) {
316                    return Some(task);
317                }
318
319                // Prepare to park, then try one more steal attempt.
320                let park = worker.prepare_park();
321                if let Some(task) = self.try_steal_task(thread_id, &mut rng) {
322                    return Some(task);
323                }
324
325                park.park();
326                None
327            })();
328
329            if let Some(task) = task {
330                worker.recruit_next();
331                if let Some(metrics) = task.metadata().metrics.clone() {
332                    let start = Instant::now();
333                    task.run();
334                    let elapsed_ns = start.elapsed().as_nanos() as u64;
335                    metrics.total_polls.fetch_add(1);
336                    if !local {
337                        metrics.total_stolen_polls.fetch_add(1);
338                    }
339                    metrics.total_poll_time_ns.fetch_add(elapsed_ns);
340                    metrics.max_poll_time_ns.fetch_max(elapsed_ns);
341                } else {
342                    task.run();
343                }
344            }
345        }
346    }
347
348    fn spawn_runner_without_identity(&self) {
349        self.num_runners_without_identity
350            .fetch_add(1, Ordering::AcqRel);
351        let t = self.thread_name_idx.fetch_add(1, Ordering::Relaxed);
352        std::thread::Builder::new()
353            .name(format!("async-executor-{t}"))
354            .spawn(move || Self::global().runner(None))
355            .unwrap();
356    }
357
358    fn acquire_thread_identity(&self) -> Option<usize> {
359        loop {
360            match self.thread_id_recv.recv_timeout(Duration::from_secs(10)) {
361                Ok(tid_msg) => {
362                    let thread_id = tid_msg.swap(usize::MAX, Ordering::AcqRel);
363                    if thread_id != usize::MAX {
364                        // Important: we check queue again after reducing count.
365                        let num_left = self
366                            .num_runners_without_identity
367                            .fetch_sub(1, Ordering::AcqRel)
368                            - 1;
369                        if num_left == 0 && !self.thread_id_recv.is_empty() {
370                            self.spawn_runner_without_identity();
371                        }
372                        return Some(thread_id);
373                    }
374                },
375                Err(_) => {
376                    // Important: we check queue again after reducing count.
377                    self.num_runners_without_identity
378                        .fetch_sub(1, Ordering::AcqRel);
379                    if self.thread_id_recv.is_empty() {
380                        return None;
381                    }
382                    self.num_runners_without_identity
383                        .fetch_add(1, Ordering::AcqRel);
384                },
385            }
386        }
387    }
388
389    fn ensure_runner_without_identity_exists(&self) {
390        if self
391            .num_runners_without_identity
392            .fetch_add(0, Ordering::AcqRel)
393            == 0
394        {
395            self.spawn_runner_without_identity();
396        }
397    }
398
399    fn global() -> &'static Executor {
400        GLOBAL_SCHEDULER.get_or_init(|| {
401            let n_threads = polars_config::config().max_threads();
402            let thread_task_lists = (0..n_threads)
403                .map(|t| {
404                    std::thread::Builder::new()
405                        .name(format!("async-executor-{t}"))
406                        .spawn(move || Self::global().runner(Some(t)))
407                        .unwrap();
408
409                    let high_prio_tasks = WorkQueue::new_lifo();
410                    CachePadded::new(ThreadLocalTaskList {
411                        high_prio_tasks_stealer: high_prio_tasks.stealer(),
412                        high_prio_tasks,
413                        local_slot: UnsafeCell::new(None),
414                    })
415                })
416                .collect();
417            let (thread_id_send, thread_id_recv) = crossbeam_channel::unbounded();
418            Self {
419                park_group: ParkGroup::new(),
420                thread_task_lists,
421                global_high_prio_task_queue: Injector::new(),
422                global_low_prio_task_queue: Injector::new(),
423                thread_id_send,
424                thread_id_recv,
425                thread_name_idx: AtomicUsize::new(n_threads),
426                num_runners_without_identity: AtomicUsize::new(0),
427            }
428        })
429    }
430}
431
432pub struct TaskScope<'scope, 'env: 'scope> {
433    // Keep track of in-progress tasks so we can forcibly cancel them
434    // when the scope ends, to ensure the lifetimes are respected.
435    // Tasks add their own key to completed_tasks when done so we can
436    // reclaim the memory used by the cancel_handles.
437    cancel_handles: Mutex<SlotMap<TaskKey, CancelHandle>>,
438    completed_tasks: Arc<Mutex<Vec<TaskKey>>>,
439
440    // Copied from std::thread::scope. Necessary to prevent unsoundness.
441    scope: PhantomData<&'scope mut &'scope ()>,
442    env: PhantomData<&'env mut &'env ()>,
443}
444
445impl<'scope> TaskScope<'scope, '_> {
446    // Not Drop because that extends lifetimes.
447    fn destroy(&self) {
448        // Make sure all tasks are cancelled.
449        for (_, t) in self.cancel_handles.lock().drain() {
450            t.cancel();
451        }
452    }
453
454    fn clear_completed_tasks(&self) {
455        let mut cancel_handles = self.cancel_handles.lock();
456        for t in self.completed_tasks.lock().drain(..) {
457            cancel_handles.remove(t);
458        }
459    }
460
461    #[track_caller]
462    pub fn spawn_task<F: Future + Send + 'scope>(
463        &self,
464        priority: TaskPriority,
465        fut: F,
466    ) -> JoinHandle<F::Output>
467    where
468        <F as Future>::Output: Send + 'static,
469    {
470        let spawn_location = Location::caller();
471        self.clear_completed_tasks();
472
473        let mut runnable = None;
474        let mut join_handle = None;
475        self.cancel_handles.lock().insert_with_key(|task_key| {
476            let metrics = TRACK_METRICS.load().then(Arc::default);
477            let dyn_task = unsafe {
478                // SAFETY: we make sure to cancel this task before 'scope ends.
479                let executor = Executor::global();
480                let on_wake = move |task| executor.schedule_task(task);
481                task::spawn_with_lifetime(
482                    fut,
483                    on_wake,
484                    TaskMetadata {
485                        spawn_location,
486                        priority,
487                        freshly_spawned: AtomicBool::new(true),
488                        scoped: Some(ScopedTaskMetadata {
489                            task_key,
490                            completed_tasks: Arc::downgrade(&self.completed_tasks),
491                        }),
492                        metrics,
493                    },
494                )
495            };
496            runnable = Some(Arc::clone(&dyn_task));
497            let jh = JoinHandle(dyn_task);
498            let cancel_handle = jh.cancel_handle();
499            join_handle = Some(jh);
500            cancel_handle
501        });
502        runnable.unwrap().schedule();
503        join_handle.unwrap()
504    }
505}
506
507pub fn task_scope<'env, F, T>(f: F) -> T
508where
509    F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T,
510{
511    // By having this local variable inaccessible to anyone we guarantee
512    // that either abort is called killing the entire process, or that this
513    // executor is properly destroyed.
514    let scope = TaskScope {
515        cancel_handles: Mutex::default(),
516        completed_tasks: Arc::new(Mutex::default()),
517        scope: PhantomData,
518        env: PhantomData,
519    };
520
521    let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));
522
523    // Make sure all tasks are properly destroyed.
524    scope.destroy();
525
526    match result {
527        Err(e) => std::panic::resume_unwind(e),
528        Ok(result) => result,
529    }
530}
531
532#[track_caller]
533pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>
534where
535    <F as Future>::Output: Send + 'static,
536{
537    let spawn_location = Location::caller();
538    let executor = Executor::global();
539    let on_wake = move |task| executor.schedule_task(task);
540    let metrics = TRACK_METRICS.load().then(Arc::default);
541    let dyn_task = task::spawn(
542        fut,
543        on_wake,
544        TaskMetadata {
545            spawn_location,
546            priority,
547            freshly_spawned: AtomicBool::new(true),
548            scoped: None,
549            metrics,
550        },
551    );
552    Arc::clone(&dyn_task).schedule();
553    JoinHandle(dyn_task)
554}
555
556/// Runs the given function on this thread while allowing another thread to take
557/// over this thread's task execution duties.
558///
559/// Simply directly calls f() if this thread is not an async executor thread.
560pub fn block_in_place<R, F: FnOnce() -> R>(f: F) -> R {
561    let thread_id = TLS_THREAD_ID.replace(usize::MAX);
562    if thread_id == usize::MAX {
563        return f();
564    }
565
566    // Send off our thread id to another runner, we just become an ordinary thread.
567    let executor = Executor::global();
568    let msg = Arc::new(AtomicUsize::new(thread_id));
569    executor.thread_id_send.send(msg.clone()).unwrap();
570    executor.ensure_runner_without_identity_exists(); // Important: *after* sending in channel.
571
572    // Try to steal our thread id back afterwards, even if f panics. If we can't
573    // steal our thread id back we become a runner without identity.
574    let _restore_identity = WithDrop::new(msg, |msg| {
575        let thread_id = msg.swap(usize::MAX, Ordering::AcqRel);
576        if thread_id != usize::MAX {
577            TLS_THREAD_ID.set(thread_id);
578        } else {
579            executor
580                .num_runners_without_identity
581                .fetch_add(1, Ordering::AcqRel);
582        }
583    });
584
585    f()
586}
587
588fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {
589    let modulus = len.next_power_of_two();
590    let halfwidth = modulus.trailing_zeros() / 2;
591    let mask = modulus - 1;
592    let displace_zero = rng.random::<u32>();
593    let odd1 = rng.random::<u32>() | 1;
594    let odd2 = rng.random::<u32>() | 1;
595    let uniform_first = ((rng.random::<u32>() as u64 * len as u64) >> 32) as u32;
596
597    (0..modulus)
598        .map(move |mut i| {
599            // Invertible permutation on [0, modulus).
600            i = i.wrapping_add(displace_zero);
601            i = i.wrapping_mul(odd1);
602            i ^= (i & mask) >> halfwidth;
603            i = i.wrapping_mul(odd2);
604            i & mask
605        })
606        .filter(move |i| *i < len)
607        .map(move |mut i| {
608            i += uniform_first;
609            if i >= len {
610                i -= len;
611            }
612            i
613        })
614}