1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3
4use std::thread;
5
6type Run<T> = Box<dyn FnOnce(&mut T) + Send>;
8
9enum ThreadCellMessage<T> {
11 Run(Run<T>),
12 GetSessionSync(crossbeam::channel::Sender<ThreadCellSession<T>>),
13 #[cfg(feature = "tokio")]
14 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
15 GetSessionAsync(tokio::sync::oneshot::Sender<ThreadCellSession<T>>),
16}
17
18static SESSION_ERROR_MESSAGE: &str = "ThreadCell thread has panicked or was dropped";
19
20pub struct ThreadCellSession<T> {
24 sender: crossbeam::channel::Sender<Run<T>>,
25}
26
27impl<T> ThreadCellSession<T> {
28 pub fn run_blocking<F, R>(&self, f: F) -> R
29 where
30 F: FnOnce(&mut T) -> R + Send + 'static,
31 R: Send + 'static,
32 {
33 let (tx, rx) = crossbeam::channel::bounded(1);
34 self.sender
35 .send(Box::new(move |resource| {
36 let res = f(resource);
37 tx.send(res).unwrap();
38 }))
39 .expect(SESSION_ERROR_MESSAGE);
40 rx.recv().expect(SESSION_ERROR_MESSAGE)
41 }
42
43 #[cfg(feature = "tokio")]
44 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
45 pub async fn run<F, R>(&self, f: F) -> R
46 where
47 F: FnOnce(&mut T) -> R + Send + 'static,
48 R: Send + 'static,
49 {
50 let (tx, rx) = tokio::sync::oneshot::channel();
51 self.sender
52 .send(Box::new(move |resource| {
53 let res = f(resource);
54 tx.send(res).ok().unwrap();
56 }))
57 .expect(SESSION_ERROR_MESSAGE);
58 rx.await.expect(SESSION_ERROR_MESSAGE)
59 }
60}
61
62static THREAD_CELL_ERROR_MESSAGE: &str = "ThreadCell thread has panicked";
63
64pub struct ThreadCell<T: 'static> {
69 sender: crossbeam::channel::Sender<ThreadCellMessage<T>>,
70}
71
72impl<T: 'static> Clone for ThreadCell<T> {
73 fn clone(&self) -> Self {
74 Self {
75 sender: self.sender.clone(),
76 }
77 }
78}
79
80impl<T: Send> ThreadCell<T> {
81 pub fn new(resource: T) -> Self {
83 let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
84
85 thread::spawn(move || {
86 sync_handle(rx, resource);
87 });
88
89 Self { sender: tx }
90 }
91}
92
93impl<T> ThreadCell<T> {
94 pub fn new_with<F: FnOnce() -> T + Send + 'static>(resource_fn: F) -> Self {
96 let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
97
98 thread::spawn(move || {
99 let resource = resource_fn();
100 sync_handle(rx, resource);
101 });
102
103 Self { sender: tx }
104 }
105
106 pub fn run_blocking<F, R>(&self, f: F) -> R
107 where
108 F: FnOnce(&mut T) -> R + Send + 'static,
109 R: Send + 'static,
110 {
111 let (tx, rx) = crossbeam::channel::bounded(1);
112 self.sender
113 .send(ThreadCellMessage::Run(Box::new(move |resource| {
114 let res = f(resource);
115 tx.send(res).ok().unwrap();
117 })))
118 .expect(THREAD_CELL_ERROR_MESSAGE);
119 rx.recv().expect(THREAD_CELL_ERROR_MESSAGE)
120 }
121
122 #[cfg(feature = "tokio")]
123 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
124 pub async fn run<F, R>(&self, f: F) -> R
125 where
126 F: FnOnce(&mut T) -> R + Send + 'static,
127 R: Send + 'static,
128 {
129 let (tx, rx) = tokio::sync::oneshot::channel();
130 self.sender
131 .send(ThreadCellMessage::Run(Box::new(move |resource| {
132 let res = f(resource);
133 tx.send(res).ok().unwrap();
135 })))
136 .expect(THREAD_CELL_ERROR_MESSAGE);
137 rx.await.expect(THREAD_CELL_ERROR_MESSAGE)
138 }
139
140 pub fn session_blocking(&self) -> ThreadCellSession<T> {
141 let (tx, rx) = crossbeam::channel::bounded(1);
142 self.sender
143 .send(ThreadCellMessage::GetSessionSync(tx))
144 .expect(THREAD_CELL_ERROR_MESSAGE);
145 rx.recv().expect(THREAD_CELL_ERROR_MESSAGE)
146 }
147
148 #[cfg(feature = "tokio")]
149 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
150 pub async fn session(&self) -> ThreadCellSession<T> {
151 let (tx, rx) = tokio::sync::oneshot::channel();
152 self.sender
153 .send(ThreadCellMessage::GetSessionAsync(tx))
154 .expect(THREAD_CELL_ERROR_MESSAGE);
155 rx.await.expect(THREAD_CELL_ERROR_MESSAGE)
156 }
157}
158
159impl<T: Send> ThreadCell<T> {
160 pub fn set_blocking(&self, new_value: T) {
162 self.run_blocking(|res| *res = new_value);
163 }
164
165 #[cfg(feature = "tokio")]
167 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
168 pub async fn set(&self, new_value: T) {
169 self.run(|res| *res = new_value).await;
170 }
171
172 pub fn replace_blocking(&self, new_value: T) -> T {
174 self.run_blocking(|res| std::mem::replace(res, new_value))
175 }
176
177 #[cfg(feature = "tokio")]
179 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
180 pub async fn replace(&self, new_value: T) -> T {
181 self.run(|res| std::mem::replace(res, new_value)).await
182 }
183}
184
185impl<T: Send + Default> ThreadCell<T> {
186 pub fn take_blocking(&self) -> T {
187 self.run_blocking(|res| std::mem::take(res))
188 }
189
190 #[cfg(feature = "tokio")]
191 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
192 pub async fn take(&self) -> T {
193 self.run(|res| std::mem::take(res)).await
194 }
195}
196
197impl<T: Send + Clone> ThreadCell<T> {
198 pub fn get_blocking(&self) -> T {
200 self.run_blocking(|res| res.clone())
201 }
202
203 #[cfg(feature = "tokio")]
205 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
206 pub async fn get(&self) -> T {
207 self.run(|res| res.clone()).await
208 }
209}
210
211#[cfg(feature = "tokio")]
212thread_local! {
213 static RUNTIME: std::cell::OnceCell<tokio::runtime::Runtime> = const { std::cell::OnceCell::new() };
214}
215
216#[cfg(feature = "tokio")]
250#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
251pub fn run_local<F: Future>(future: F) -> F::Output {
252 RUNTIME.with(|cell| {
253 let rt = cell.get_or_init(|| {
254 tokio::runtime::Builder::new_current_thread()
255 .enable_all()
256 .build()
257 .unwrap()
258 });
259 rt.block_on(future)
260 })
261}
262
263const GET_SESSION_RESPONSE_ERROR_MESSAGE: &str =
264 "A get session request should always be waiting for a response";
265
266fn sync_handle<T>(rx: crossbeam::channel::Receiver<ThreadCellMessage<T>>, mut resource: T) {
267 while let Ok(msg) = rx.recv() {
275 match msg {
276 ThreadCellMessage::Run(f) => f(&mut resource),
277 ThreadCellMessage::GetSessionSync(responder) => {
278 let (stx, srx) = crossbeam::channel::unbounded::<Run<T>>();
279 responder
280 .send(ThreadCellSession { sender: stx })
281 .ok()
282 .expect(GET_SESSION_RESPONSE_ERROR_MESSAGE);
283 while let Ok(f) = srx.recv() {
284 f(&mut resource);
285 }
286 }
287 #[cfg(feature = "tokio")]
288 ThreadCellMessage::GetSessionAsync(responder) => {
289 let (stx, srx) = crossbeam::channel::unbounded::<Run<T>>();
290 responder
291 .send(ThreadCellSession { sender: stx })
292 .ok()
293 .expect(GET_SESSION_RESPONSE_ERROR_MESSAGE);
294 while let Ok(f) = srx.recv() {
295 f(&mut resource);
296 }
297 }
298 }
299 }
300 }
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use std::rc::Rc;
308 use std::sync::Arc;
309 use std::sync::atomic::{AtomicUsize, Ordering};
310
311 #[derive(Default)]
312 struct TestResource {
313 counter: usize,
314 }
315
316 impl TestResource {
317 fn increment(&mut self) -> usize {
318 self.counter += 1;
319 self.counter
320 }
321 }
322
323 #[test]
324 fn basic_run_blocking_works() {
325 let cell = ThreadCell::new(TestResource::default());
326 let value = cell.run_blocking(|res| {
327 res.increment();
328 res.increment()
329 });
330 assert_eq!(value, 2);
331
332 let value = cell.run_blocking(|res| res.increment());
333 assert_eq!(value, 3);
334 }
335
336 #[test]
337 fn can_be_sent_to_another_thread() {
338 let cell = ThreadCell::new(TestResource::default());
339 let handle = std::thread::spawn(move || cell.run_blocking(|res| res.increment()));
340 let result = handle.join().unwrap();
341 assert_eq!(result, 1);
342 }
343
344 #[cfg(feature = "tokio")]
345 #[tokio::test(flavor = "current_thread")]
346 async fn async_run_works() {
347 let cell = ThreadCell::new(TestResource::default());
348 let result = cell.run(|res| res.increment()).await;
349 assert_eq!(result, 1);
350 }
351
352 #[test]
353 fn session_blocking_gives_mutable_access() {
354 let cell = ThreadCell::new(TestResource::default());
355 let lock = cell.session_blocking();
356 let value = lock.run_blocking(|res| {
357 res.increment();
358 res.increment()
359 });
360 assert_eq!(value, 2);
361 }
362
363 #[cfg(feature = "tokio")]
364 #[tokio::test(flavor = "current_thread")]
365 async fn async_session_works() {
366 let cell = ThreadCell::new(TestResource::default());
367 let lock = cell.session().await;
368 let value = lock.run(|res| res.increment()).await;
369 assert_eq!(value, 1);
370 }
371
372 #[test]
373 fn can_hold_non_send_type() {
374 #[derive(Default)]
375 struct NotSend(Rc<()>); let cell = ThreadCell::new_with(|| NotSend(Rc::new(())));
377 let count = cell.run_blocking(|res| Rc::strong_count(&res.0));
378 assert_eq!(count, 1);
379 }
380
381 #[test]
382 fn concurrent_run_blocking_requests_are_serialized() {
383 let cell = ThreadCell::new(TestResource::default());
384 let counter = Arc::new(AtomicUsize::new(0));
385
386 let mut handles = Vec::new();
387 for _ in 0..10 {
388 let cell = cell.clone();
389 let counter = counter.clone();
390 handles.push(std::thread::spawn(move || {
391 cell.run_blocking(move |res| {
392 let val = res.increment();
393 counter.fetch_add(val, Ordering::SeqCst);
394 });
395 }));
396 }
397
398 for h in handles {
399 h.join().unwrap();
400 }
401
402 assert_eq!(counter.load(Ordering::SeqCst), 55);
404 }
405
406 #[test]
407 fn dropping_cell_does_not_panic() {
408 let cell = ThreadCell::new(TestResource::default());
409 drop(cell);
410 }
412
413 #[cfg(feature = "tokio")]
414 #[tokio::test(flavor = "current_thread")]
415 async fn run_local_works() {
416 let cell = ThreadCell::new(TestResource::default());
417 let mut value_to_move = TestResource::default();
418 let value_to_move_returned = cell.run_blocking(|res| {
419 res.increment();
420 run_local(async move {
421 res.increment();
422 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
423 value_to_move.increment();
424 res.increment();
425 value_to_move.increment();
426 value_to_move
427 })
428 });
429 assert_eq!(value_to_move_returned.counter, 2);
430 let value = cell.run(|res| res.increment()).await;
431 assert_eq!(value, 4);
432 }
433}