Skip to main content

irontide_engine/
blocking_spawner.rs

1//! Bounded blocking-task spawner that uses `block_in_place` instead of
2//! `spawn_blocking`, eliminating per-call `JoinHandle` heap allocations.
3
4use std::sync::Arc;
5
6use tokio::sync::Semaphore;
7
8/// Executes blocking closures on the current thread (via [`tokio::task::block_in_place`])
9/// with bounded concurrency from a semaphore.
10///
11/// On a `CurrentThread` runtime, `block_in_place` panics, so the spawner falls
12/// back to calling the closure directly.
13#[derive(Clone, Debug)]
14pub struct BlockingSpawner {
15    allow_block_in_place: bool,
16    semaphore: Arc<Semaphore>,
17}
18
19impl BlockingSpawner {
20    /// Create a new spawner with the given concurrency limit.
21    ///
22    /// Detects the current tokio runtime flavor to decide whether
23    /// `block_in_place` is safe.
24    #[must_use]
25    pub fn new(max_blocking: usize) -> Self {
26        let flavor = tokio::runtime::Handle::current().runtime_flavor();
27        let allow_block_in_place = matches!(flavor, tokio::runtime::RuntimeFlavor::MultiThread);
28
29        Self {
30            allow_block_in_place,
31            semaphore: Arc::new(Semaphore::new(max_blocking)),
32        }
33    }
34
35    /// Run a blocking closure, waiting for a semaphore permit first.
36    ///
37    /// On multi-thread runtimes this uses `block_in_place`; on single-thread
38    /// runtimes it calls `f` directly.
39    pub(crate) async fn block_in_place<F, R>(&self, f: F) -> R
40    where
41        F: FnOnce() -> R,
42    {
43        // acquire_owned so the permit lives across the blocking call
44        let _permit = self
45            .semaphore
46            .acquire()
47            .await
48            .expect("BlockingSpawner semaphore closed");
49
50        if self.allow_block_in_place {
51            tokio::task::block_in_place(f)
52        } else {
53            f()
54        }
55    }
56
57    /// Synchronous variant for non-async contexts (e.g. deferred write paths).
58    ///
59    /// Does **not** acquire the semaphore — intended for fallback paths where
60    /// blocking is unavoidable and already bounded by the caller.
61    pub(crate) fn block_in_place_sync<F, R>(&self, f: F) -> R
62    where
63        F: FnOnce() -> R,
64    {
65        if self.allow_block_in_place {
66            tokio::task::block_in_place(f)
67        } else {
68            f()
69        }
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    use std::sync::atomic::{AtomicUsize, Ordering};
78    use std::time::Duration;
79
80    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
81    async fn blocking_spawner_limits_concurrency() {
82        let spawner = BlockingSpawner::new(2);
83        let concurrent = Arc::new(AtomicUsize::new(0));
84        let max_observed = Arc::new(AtomicUsize::new(0));
85
86        let mut handles = Vec::new();
87        for _ in 0..4 {
88            let s = spawner.clone();
89            let c = Arc::clone(&concurrent);
90            let m = Arc::clone(&max_observed);
91            handles.push(tokio::spawn(async move {
92                s.block_in_place(|| {
93                    let prev = c.fetch_add(1, Ordering::SeqCst);
94                    // Update max observed concurrency
95                    let current = prev + 1;
96                    m.fetch_max(current, Ordering::SeqCst);
97                    std::thread::sleep(Duration::from_millis(50));
98                    c.fetch_sub(1, Ordering::SeqCst);
99                })
100                .await;
101            }));
102        }
103
104        for h in handles {
105            h.await.unwrap();
106        }
107
108        let max = max_observed.load(Ordering::SeqCst);
109        assert!(
110            max <= 2,
111            "expected at most 2 concurrent ops, observed {max}"
112        );
113    }
114
115    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
116    async fn blocking_spawner_semaphore_backpressure() {
117        let spawner = BlockingSpawner::new(1);
118        let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
119
120        let s1 = spawner.clone();
121        let o1 = Arc::clone(&order);
122        let h1 = tokio::spawn(async move {
123            s1.block_in_place(|| {
124                o1.lock().push("first-start");
125                std::thread::sleep(Duration::from_millis(80));
126                o1.lock().push("first-end");
127            })
128            .await;
129        });
130
131        // Give h1 a moment to acquire the permit
132        tokio::time::sleep(Duration::from_millis(10)).await;
133
134        let s2 = spawner.clone();
135        let o2 = Arc::clone(&order);
136        let h2 = tokio::spawn(async move {
137            s2.block_in_place(|| {
138                o2.lock().push("second-start");
139            })
140            .await;
141        });
142
143        h1.await.unwrap();
144        h2.await.unwrap();
145
146        let log = order.lock();
147        // first-end must come before second-start (serialized by semaphore)
148        let first_end = log.iter().position(|s| *s == "first-end").unwrap();
149        let second_start = log.iter().position(|s| *s == "second-start").unwrap();
150        assert!(
151            first_end < second_start,
152            "expected first-end before second-start, got: {log:?}"
153        );
154    }
155
156    #[test]
157    fn blocking_spawner_single_threaded_runtime() {
158        let rt = tokio::runtime::Builder::new_current_thread()
159            .enable_all()
160            .build()
161            .unwrap();
162
163        rt.block_on(async {
164            let spawner = BlockingSpawner::new(2);
165            // Must not panic on CurrentThread runtime
166            let result = spawner.block_in_place(|| 42).await;
167            assert_eq!(result, 42);
168
169            // Sync variant also works
170            let sync_result = spawner.block_in_place_sync(|| 99);
171            assert_eq!(sync_result, 99);
172        });
173    }
174}