Skip to main content

dynamo_runtime/compute/
thread_local.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Thread-local storage for compute resources
5//!
6//! This module provides thread-local access to compute resources (Rayon pool and semaphore)
7//! for Tokio worker threads. This eliminates the need to pass Runtime or ComputePool
8//! references through async function calls.
9
10use super::ComputePool;
11use std::cell::RefCell;
12use std::sync::Arc;
13use tokio::sync::Semaphore;
14
15thread_local! {
16    /// Thread-local compute context available on Tokio worker threads
17    static COMPUTE_CONTEXT: RefCell<Option<ComputeContext>> = const { RefCell::new(None) };
18}
19
20/// Compute resources available to a Tokio worker thread
21#[derive(Clone)]
22pub struct ComputeContext {
23    /// The Rayon compute pool
24    pub pool: Arc<ComputePool>,
25    /// Semaphore for block_in_place permits
26    pub block_in_place_permits: Arc<Semaphore>,
27}
28
29/// Initialize the thread-local compute context
30///
31/// This should be called from the Tokio runtime's `on_thread_start` callback
32pub fn initialize_context(pool: Arc<ComputePool>, permits: Arc<Semaphore>) {
33    COMPUTE_CONTEXT.with(|ctx| {
34        *ctx.borrow_mut() = Some(ComputeContext {
35            pool,
36            block_in_place_permits: permits,
37        });
38    });
39}
40
41/// Access the thread-local compute context
42///
43/// Returns None if called from a non-worker thread or if context wasn't initialized
44pub fn with_context<F, R>(f: F) -> Option<R>
45where
46    F: FnOnce(&ComputeContext) -> R,
47{
48    COMPUTE_CONTEXT.with(|ctx| ctx.borrow().as_ref().map(f))
49}
50
51/// Try to acquire a block_in_place permit from thread-local context
52///
53/// Returns Ok(permit) if successful, Err if no context or no permits available
54pub fn try_acquire_block_permit() -> Result<tokio::sync::OwnedSemaphorePermit, &'static str> {
55    with_context(|ctx| {
56        ctx.block_in_place_permits
57            .clone()
58            .try_acquire_owned()
59            .map_err(|_| "No permits available")
60    })
61    .ok_or("No compute context on this thread")?
62}
63
64/// Get the compute pool from thread-local context
65///
66/// Returns None if called from a non-worker thread
67pub fn get_pool() -> Option<Arc<ComputePool>> {
68    with_context(|ctx| ctx.pool.clone())
69}
70
71/// Check if the current thread has compute context initialized
72///
73/// Returns true if the thread-local context is initialized with a compute pool
74/// and semaphore permits, meaning the compute macros will offload work.
75/// Returns false if macros would fall back to inline execution.
76pub fn has_compute_context() -> bool {
77    with_context(|_| ()).is_some()
78}
79
80/// Assert that the current thread has compute context initialized
81///
82/// Panics if the thread-local context is not initialized.
83/// Use this to ensure compute macros will offload work rather than run inline.
84pub fn assert_compute_context() {
85    if !has_compute_context() {
86        panic!(
87            "Thread-local compute context not initialized! \
88             Compute macros will fall back to inline execution. \
89             Call Runtime::initialize_thread_local() on worker threads."
90        );
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_uninitialized_context() {
100        // Should return None when context not initialized
101        assert!(get_pool().is_none());
102        assert!(try_acquire_block_permit().is_err());
103        assert!(!has_compute_context());
104    }
105
106    #[test]
107    #[should_panic(expected = "Thread-local compute context not initialized")]
108    fn test_assert_compute_context_panics() {
109        // Should panic when context not initialized
110        assert_compute_context();
111    }
112}