irontide_engine/
blocking_spawner.rs1use std::sync::Arc;
5
6use tokio::sync::Semaphore;
7
8#[derive(Clone, Debug)]
14pub struct BlockingSpawner {
15 allow_block_in_place: bool,
16 semaphore: Arc<Semaphore>,
17}
18
19impl BlockingSpawner {
20 #[must_use]
25 pub fn new(max_blocking: usize) -> Self {
26 let flavor = tokio::runtime::Handle::current().runtime_flavor();
27 let allow_block_in_place = matches!(flavor, tokio::runtime::RuntimeFlavor::MultiThread);
28
29 Self {
30 allow_block_in_place,
31 semaphore: Arc::new(Semaphore::new(max_blocking)),
32 }
33 }
34
35 pub(crate) async fn block_in_place<F, R>(&self, f: F) -> R
40 where
41 F: FnOnce() -> R,
42 {
43 let _permit = self
45 .semaphore
46 .acquire()
47 .await
48 .expect("BlockingSpawner semaphore closed");
49
50 if self.allow_block_in_place {
51 tokio::task::block_in_place(f)
52 } else {
53 f()
54 }
55 }
56
57 pub(crate) fn block_in_place_sync<F, R>(&self, f: F) -> R
62 where
63 F: FnOnce() -> R,
64 {
65 if self.allow_block_in_place {
66 tokio::task::block_in_place(f)
67 } else {
68 f()
69 }
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 use std::sync::atomic::{AtomicUsize, Ordering};
78 use std::time::Duration;
79
80 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
81 async fn blocking_spawner_limits_concurrency() {
82 let spawner = BlockingSpawner::new(2);
83 let concurrent = Arc::new(AtomicUsize::new(0));
84 let max_observed = Arc::new(AtomicUsize::new(0));
85
86 let mut handles = Vec::new();
87 for _ in 0..4 {
88 let s = spawner.clone();
89 let c = Arc::clone(&concurrent);
90 let m = Arc::clone(&max_observed);
91 handles.push(tokio::spawn(async move {
92 s.block_in_place(|| {
93 let prev = c.fetch_add(1, Ordering::SeqCst);
94 let current = prev + 1;
96 m.fetch_max(current, Ordering::SeqCst);
97 std::thread::sleep(Duration::from_millis(50));
98 c.fetch_sub(1, Ordering::SeqCst);
99 })
100 .await;
101 }));
102 }
103
104 for h in handles {
105 h.await.unwrap();
106 }
107
108 let max = max_observed.load(Ordering::SeqCst);
109 assert!(
110 max <= 2,
111 "expected at most 2 concurrent ops, observed {max}"
112 );
113 }
114
115 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
116 async fn blocking_spawner_semaphore_backpressure() {
117 let spawner = BlockingSpawner::new(1);
118 let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
119
120 let s1 = spawner.clone();
121 let o1 = Arc::clone(&order);
122 let h1 = tokio::spawn(async move {
123 s1.block_in_place(|| {
124 o1.lock().push("first-start");
125 std::thread::sleep(Duration::from_millis(80));
126 o1.lock().push("first-end");
127 })
128 .await;
129 });
130
131 tokio::time::sleep(Duration::from_millis(10)).await;
133
134 let s2 = spawner.clone();
135 let o2 = Arc::clone(&order);
136 let h2 = tokio::spawn(async move {
137 s2.block_in_place(|| {
138 o2.lock().push("second-start");
139 })
140 .await;
141 });
142
143 h1.await.unwrap();
144 h2.await.unwrap();
145
146 let log = order.lock();
147 let first_end = log.iter().position(|s| *s == "first-end").unwrap();
149 let second_start = log.iter().position(|s| *s == "second-start").unwrap();
150 assert!(
151 first_end < second_start,
152 "expected first-end before second-start, got: {log:?}"
153 );
154 }
155
156 #[test]
157 fn blocking_spawner_single_threaded_runtime() {
158 let rt = tokio::runtime::Builder::new_current_thread()
159 .enable_all()
160 .build()
161 .unwrap();
162
163 rt.block_on(async {
164 let spawner = BlockingSpawner::new(2);
165 let result = spawner.block_in_place(|| 42).await;
167 assert_eq!(result, 42);
168
169 let sync_result = spawner.block_in_place_sync(|| 99);
171 assert_eq!(sync_result, 99);
172 });
173 }
174}