thread_cell/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3
4use std::thread;
5
6/// A message to run
7type Run<T> = Box<dyn FnOnce(&mut T) + Send>;
8
9/// Messages sent to the manager thread
10enum ThreadCellMessage<T> {
11    Run(Run<T>),
12    GetSessionSync(crossbeam::channel::Sender<ThreadCellSession<T>>),
13    #[cfg(feature = "tokio")]
14    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
15    GetSessionAsync(tokio::sync::oneshot::Sender<ThreadCellSession<T>>),
16}
17
18static SESSION_ERROR_MESSAGE: &str = "ThreadCell thread has panicked or was dropped";
19
20/// A session with exclusive access to the resource held by the thread.
21/// While held, this is the only way to access the resource. It is possible to create a "deadlock"
22/// if a `ThreadCellSession` is requested while one is already held.
23pub struct ThreadCellSession<T> {
24    sender: crossbeam::channel::Sender<Run<T>>,
25}
26
27impl<T> ThreadCellSession<T> {
28    pub fn run_blocking<F, R>(&self, f: F) -> R
29    where
30        F: FnOnce(&mut T) -> R + Send + 'static,
31        R: Send + 'static,
32    {
33        let (tx, rx) = crossbeam::channel::bounded(1);
34        self.sender
35            .send(Box::new(move |resource| {
36                let res = f(resource);
37                tx.send(res).unwrap();
38            }))
39            .expect(SESSION_ERROR_MESSAGE);
40        rx.recv().expect(SESSION_ERROR_MESSAGE)
41    }
42
43    #[cfg(feature = "tokio")]
44    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
45    pub async fn run<F, R>(&self, f: F) -> R
46    where
47        F: FnOnce(&mut T) -> R + Send + 'static,
48        R: Send + 'static,
49    {
50        let (tx, rx) = tokio::sync::oneshot::channel();
51        self.sender
52            .send(Box::new(move |resource| {
53                let res = f(resource);
54                // Outer receiver is waiting
55                tx.send(res).ok().unwrap();
56            }))
57            .expect(SESSION_ERROR_MESSAGE);
58        rx.await.expect(SESSION_ERROR_MESSAGE)
59    }
60}
61
62static THREAD_CELL_ERROR_MESSAGE: &str = "ThreadCell thread has panicked";
63
64/// A cell that holds a value bound to a single thread. Thus `T` can be non-`Send` and/or non-`Sync`,
65/// but `ThreadCell<T>` is always `Send`/`Sync`. Access is provided through message passing, so no
66/// internal locking is used. But a lock-like `ThreadCellSession` can be acquired to gain exclusive
67/// access to the underlying resource while held.
68pub struct ThreadCell<T: 'static> {
69    sender: crossbeam::channel::Sender<ThreadCellMessage<T>>,
70}
71
72impl<T: 'static> Clone for ThreadCell<T> {
73    fn clone(&self) -> Self {
74        Self {
75            sender: self.sender.clone(),
76        }
77    }
78}
79
80impl<T: Send> ThreadCell<T> {
81    /// Creates new
82    pub fn new(resource: T) -> Self {
83        let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
84
85        thread::spawn(move || {
86            sync_handle(rx, resource);
87        });
88
89        Self { sender: tx }
90    }
91}
92
93impl<T> ThreadCell<T> {
94    /// Creates a new when `T` is not `Send` but a function to create `T` is
95    pub fn new_with<F: FnOnce() -> T + Send + 'static>(resource_fn: F) -> Self {
96        let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
97
98        thread::spawn(move || {
99            let resource = resource_fn();
100            sync_handle(rx, resource);
101        });
102
103        Self { sender: tx }
104    }
105
106    pub fn run_blocking<F, R>(&self, f: F) -> R
107    where
108        F: FnOnce(&mut T) -> R + Send + 'static,
109        R: Send + 'static,
110    {
111        let (tx, rx) = crossbeam::channel::bounded(1);
112        self.sender
113            .send(ThreadCellMessage::Run(Box::new(move |resource| {
114                let res = f(resource);
115                // Outer receiver is waiting
116                tx.send(res).ok().unwrap();
117            })))
118            .expect(THREAD_CELL_ERROR_MESSAGE);
119        rx.recv().expect(THREAD_CELL_ERROR_MESSAGE)
120    }
121
122    #[cfg(feature = "tokio")]
123    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
124    pub async fn run<F, R>(&self, f: F) -> R
125    where
126        F: FnOnce(&mut T) -> R + Send + 'static,
127        R: Send + 'static,
128    {
129        let (tx, rx) = tokio::sync::oneshot::channel();
130        self.sender
131            .send(ThreadCellMessage::Run(Box::new(move |resource| {
132                let res = f(resource);
133                // Outer receiver is waiting
134                tx.send(res).ok().unwrap();
135            })))
136            .expect(THREAD_CELL_ERROR_MESSAGE);
137        rx.await.expect(THREAD_CELL_ERROR_MESSAGE)
138    }
139
140    pub fn session_blocking(&self) -> ThreadCellSession<T> {
141        let (tx, rx) = crossbeam::channel::bounded(1);
142        self.sender
143            .send(ThreadCellMessage::GetSessionSync(tx))
144            .expect(THREAD_CELL_ERROR_MESSAGE);
145        rx.recv().expect(THREAD_CELL_ERROR_MESSAGE)
146    }
147
148    #[cfg(feature = "tokio")]
149    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
150    pub async fn session(&self) -> ThreadCellSession<T> {
151        let (tx, rx) = tokio::sync::oneshot::channel();
152        self.sender
153            .send(ThreadCellMessage::GetSessionAsync(tx))
154            .expect(THREAD_CELL_ERROR_MESSAGE);
155        rx.await.expect(THREAD_CELL_ERROR_MESSAGE)
156    }
157}
158
159impl<T: Send> ThreadCell<T> {
160    /// Set the resource in a blocking manner
161    pub fn set_blocking(&self, new_value: T) {
162        self.run_blocking(|res| *res = new_value);
163    }
164
165    /// Set the resource in an async manner
166    #[cfg(feature = "tokio")]
167    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
168    pub async fn set(&self, new_value: T) {
169        self.run(|res| *res = new_value).await;
170    }
171
172    /// Set the resource in a blocking manner, returning the old value
173    pub fn replace_blocking(&self, new_value: T) -> T {
174        self.run_blocking(|res| std::mem::replace(res, new_value))
175    }
176
177    /// Set the resource in an async manner, returning the old value
178    #[cfg(feature = "tokio")]
179    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
180    pub async fn replace(&self, new_value: T) -> T {
181        self.run(|res| std::mem::replace(res, new_value)).await
182    }
183}
184
185impl<T: Send + Default> ThreadCell<T> {
186    pub fn take_blocking(&self) -> T {
187        self.run_blocking(|res| std::mem::take(res))
188    }
189
190    #[cfg(feature = "tokio")]
191    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
192    pub async fn take(&self) -> T {
193        self.run(|res| std::mem::take(res)).await
194    }
195}
196
197impl<T: Send + Clone> ThreadCell<T> {
198    /// Get a clone of the resource in a blocking manner
199    pub fn get_blocking(&self) -> T {
200        self.run_blocking(|res| res.clone())
201    }
202
203    /// Get a clone of the resource in an async manner
204    #[cfg(feature = "tokio")]
205    #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
206    pub async fn get(&self) -> T {
207        self.run(|res| res.clone()).await
208    }
209}
210
211#[cfg(feature = "tokio")]
212thread_local! {
213    static RUNTIME: std::cell::OnceCell<tokio::runtime::Runtime>  = const { std::cell::OnceCell::new() };
214}
215
216/// Run a future to completion on the current [`ThreadCell`].
217/// This should only ever be called from the top level closure of
218/// [`ThreadCell::run_blocking`], 
219/// [`ThreadCell::run`],
220/// [`ThreadCellSession::run_blocking`],
221/// or [`ThreadCellSession::run`].
222/// Will panic if called nested.
223/// ```rust
224/// # #[cfg(feature = "tokio")] {
225/// use thread_cell::ThreadCell;
226/// use thread_cell::run_local;
227///
228/// struct Counter {
229///     value: usize,
230/// }
231///
232/// let cell = ThreadCell::new(Counter { value: 0 });
233///
234/// let result = cell.run_blocking(|counter| {
235///     // Increment synchronously
236///     counter.value += 1;
237///
238///     // Run an async block on the `ThreadCell`s thread
239///     run_local(async {
240///         tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
241///         counter.value += 1;
242///         counter.value
243///     })
244/// });
245///
246/// assert_eq!(result, 2);
247/// # }
248/// ```
249#[cfg(feature = "tokio")]
250#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
251pub fn run_local<F: Future>(future: F) -> F::Output {
252    RUNTIME.with(|cell| {
253        let rt = cell.get_or_init(|| {
254            tokio::runtime::Builder::new_current_thread()
255                .enable_all()
256                .build()
257                .unwrap()
258        });
259        rt.block_on(future)
260    })
261}
262
263const GET_SESSION_RESPONSE_ERROR_MESSAGE: &str =
264    "A get session request should always be waiting for a response";
265
266fn sync_handle<T>(rx: crossbeam::channel::Receiver<ThreadCellMessage<T>>, mut resource: T) {
267    // #[cfg(feature = "tokio")]
268    // let rt = tokio::runtime::Builder::new_current_thread()
269    //     .enable_all()
270    //     .build()
271    //     .unwrap();
272    // #[cfg(feature = "tokio")]
273    // let guard = rt.enter();
274    while let Ok(msg) = rx.recv() {
275        match msg {
276            ThreadCellMessage::Run(f) => f(&mut resource),
277            ThreadCellMessage::GetSessionSync(responder) => {
278                let (stx, srx) = crossbeam::channel::unbounded::<Run<T>>();
279                responder
280                    .send(ThreadCellSession { sender: stx })
281                    .ok()
282                    .expect(GET_SESSION_RESPONSE_ERROR_MESSAGE);
283                while let Ok(f) = srx.recv() {
284                    f(&mut resource);
285                }
286            }
287            #[cfg(feature = "tokio")]
288            ThreadCellMessage::GetSessionAsync(responder) => {
289                let (stx, srx) = crossbeam::channel::unbounded::<Run<T>>();
290                responder
291                    .send(ThreadCellSession { sender: stx })
292                    .ok()
293                    .expect(GET_SESSION_RESPONSE_ERROR_MESSAGE);
294                while let Ok(f) = srx.recv() {
295                    f(&mut resource);
296                }
297            }
298        }
299    }
300    // #[cfg(feature = "tokio")]
301    // drop(guard);
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use std::rc::Rc;
308    use std::sync::Arc;
309    use std::sync::atomic::{AtomicUsize, Ordering};
310
311    #[derive(Default)]
312    struct TestResource {
313        counter: usize,
314    }
315
316    impl TestResource {
317        fn increment(&mut self) -> usize {
318            self.counter += 1;
319            self.counter
320        }
321    }
322
323    #[test]
324    fn basic_run_blocking_works() {
325        let cell = ThreadCell::new(TestResource::default());
326        let value = cell.run_blocking(|res| {
327            res.increment();
328            res.increment()
329        });
330        assert_eq!(value, 2);
331
332        let value = cell.run_blocking(|res| res.increment());
333        assert_eq!(value, 3);
334    }
335
336    #[test]
337    fn can_be_sent_to_another_thread() {
338        let cell = ThreadCell::new(TestResource::default());
339        let handle = std::thread::spawn(move || cell.run_blocking(|res| res.increment()));
340        let result = handle.join().unwrap();
341        assert_eq!(result, 1);
342    }
343
344    #[cfg(feature = "tokio")]
345    #[tokio::test(flavor = "current_thread")]
346    async fn async_run_works() {
347        let cell = ThreadCell::new(TestResource::default());
348        let result = cell.run(|res| res.increment()).await;
349        assert_eq!(result, 1);
350    }
351
352    #[test]
353    fn session_blocking_gives_mutable_access() {
354        let cell = ThreadCell::new(TestResource::default());
355        let lock = cell.session_blocking();
356        let value = lock.run_blocking(|res| {
357            res.increment();
358            res.increment()
359        });
360        assert_eq!(value, 2);
361    }
362
363    #[cfg(feature = "tokio")]
364    #[tokio::test(flavor = "current_thread")]
365    async fn async_session_works() {
366        let cell = ThreadCell::new(TestResource::default());
367        let lock = cell.session().await;
368        let value = lock.run(|res| res.increment()).await;
369        assert_eq!(value, 1);
370    }
371
372    #[test]
373    fn can_hold_non_send_type() {
374        #[derive(Default)]
375        struct NotSend(Rc<()>); // Rc is !Send
376        let cell = ThreadCell::new_with(|| NotSend(Rc::new(())));
377        let count = cell.run_blocking(|res| Rc::strong_count(&res.0));
378        assert_eq!(count, 1);
379    }
380
381    #[test]
382    fn concurrent_run_blocking_requests_are_serialized() {
383        let cell = ThreadCell::new(TestResource::default());
384        let counter = Arc::new(AtomicUsize::new(0));
385
386        let mut handles = Vec::new();
387        for _ in 0..10 {
388            let cell = cell.clone();
389            let counter = counter.clone();
390            handles.push(std::thread::spawn(move || {
391                cell.run_blocking(move |res| {
392                    let val = res.increment();
393                    counter.fetch_add(val, Ordering::SeqCst);
394                });
395            }));
396        }
397
398        for h in handles {
399            h.join().unwrap();
400        }
401
402        // The sum of 1..=10 = 55
403        assert_eq!(counter.load(Ordering::SeqCst), 55);
404    }
405
406    #[test]
407    fn dropping_cell_does_not_panic() {
408        let cell = ThreadCell::new(TestResource::default());
409        drop(cell);
410        // no panic = pass
411    }
412
413    #[cfg(feature = "tokio")]
414    #[tokio::test(flavor = "current_thread")]
415    async fn run_local_works() {
416        let cell = ThreadCell::new(TestResource::default());
417        let mut value_to_move = TestResource::default();
418        let value_to_move_returned = cell.run_blocking(|res| {
419            res.increment();
420            run_local(async move {
421                res.increment();
422                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
423                value_to_move.increment();
424                res.increment();
425                value_to_move.increment();
426                value_to_move
427            })
428        });
429        assert_eq!(value_to_move_returned.counter, 2);
430        let value = cell.run(|res| res.increment()).await;
431        assert_eq!(value, 4);
432    }
433}