dynamo_runtime/compute/thread_local.rs
1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Thread-local storage for compute resources
5//!
6//! This module provides thread-local access to compute resources (Rayon pool and semaphore)
7//! for Tokio worker threads. This eliminates the need to pass Runtime or ComputePool
8//! references through async function calls.
9
10use super::ComputePool;
11use std::cell::RefCell;
12use std::sync::Arc;
13use tokio::sync::Semaphore;
14
15thread_local! {
16 /// Thread-local compute context available on Tokio worker threads
17 static COMPUTE_CONTEXT: RefCell<Option<ComputeContext>> = const { RefCell::new(None) };
18}
19
20/// Compute resources available to a Tokio worker thread
21#[derive(Clone)]
22pub struct ComputeContext {
23 /// The Rayon compute pool
24 pub pool: Arc<ComputePool>,
25 /// Semaphore for block_in_place permits
26 pub block_in_place_permits: Arc<Semaphore>,
27}
28
29/// Initialize the thread-local compute context
30///
31/// This should be called from the Tokio runtime's `on_thread_start` callback
32pub fn initialize_context(pool: Arc<ComputePool>, permits: Arc<Semaphore>) {
33 COMPUTE_CONTEXT.with(|ctx| {
34 *ctx.borrow_mut() = Some(ComputeContext {
35 pool,
36 block_in_place_permits: permits,
37 });
38 });
39}
40
41/// Access the thread-local compute context
42///
43/// Returns None if called from a non-worker thread or if context wasn't initialized
44pub fn with_context<F, R>(f: F) -> Option<R>
45where
46 F: FnOnce(&ComputeContext) -> R,
47{
48 COMPUTE_CONTEXT.with(|ctx| ctx.borrow().as_ref().map(f))
49}
50
51/// Try to acquire a block_in_place permit from thread-local context
52///
53/// Returns Ok(permit) if successful, Err if no context or no permits available
54pub fn try_acquire_block_permit() -> Result<tokio::sync::OwnedSemaphorePermit, &'static str> {
55 with_context(|ctx| {
56 ctx.block_in_place_permits
57 .clone()
58 .try_acquire_owned()
59 .map_err(|_| "No permits available")
60 })
61 .ok_or("No compute context on this thread")?
62}
63
64/// Get the compute pool from thread-local context
65///
66/// Returns None if called from a non-worker thread
67pub fn get_pool() -> Option<Arc<ComputePool>> {
68 with_context(|ctx| ctx.pool.clone())
69}
70
71/// Check if the current thread has compute context initialized
72///
73/// Returns true if the thread-local context is initialized with a compute pool
74/// and semaphore permits, meaning the compute macros will offload work.
75/// Returns false if macros would fall back to inline execution.
76pub fn has_compute_context() -> bool {
77 with_context(|_| ()).is_some()
78}
79
80/// Assert that the current thread has compute context initialized
81///
82/// Panics if the thread-local context is not initialized.
83/// Use this to ensure compute macros will offload work rather than run inline.
84pub fn assert_compute_context() {
85 if !has_compute_context() {
86 panic!(
87 "Thread-local compute context not initialized! \
88 Compute macros will fall back to inline execution. \
89 Call Runtime::initialize_thread_local() on worker threads."
90 );
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_uninitialized_context() {
100 // Should return None when context not initialized
101 assert!(get_pool().is_none());
102 assert!(try_acquire_block_permit().is_err());
103 assert!(!has_compute_context());
104 }
105
106 #[test]
107 #[should_panic(expected = "Thread-local compute context not initialized")]
108 fn test_assert_compute_context_panics() {
109 // Should panic when context not initialized
110 assert_compute_context();
111 }
112}