rust_query/
async_db.rs

1use std::{
2    future,
3    sync::{Arc, Mutex},
4    task::{Poll, Waker},
5};
6
7use crate::{Database, Transaction, migrate::Schema};
8
9/// This is an async wrapper for [Database].
10///
11/// You can easily achieve the same thing with `tokio::task::spawn_blocking`,
12/// but this wrapper is a little bit more efficient while also being runtime agnostic.
13pub struct DatabaseAsync<S> {
14    inner: Arc<Database<S>>,
15}
16
17impl<S> Clone for DatabaseAsync<S> {
18    fn clone(&self) -> Self {
19        Self {
20            inner: self.inner.clone(),
21        }
22    }
23}
24
25impl<S: 'static + Send + Sync + Schema> DatabaseAsync<S> {
26    /// Create an async wrapper for the [Database].
27    ///
28    /// The database is wrapped in an [Arc] as it needs to be shared with any thread
29    /// executing a transaction. These threads can live longer than the future that
30    /// started the transaction.
31    ///
32    /// By accepting an [Arc], you can keep your own clone of the [Arc] and use
33    /// the database synchronously and asynchronously at the same time!
34    pub fn new(db: Arc<Database<S>>) -> Self {
35        DatabaseAsync { inner: db }
36    }
37
38    /// This is a lot like [Database::transaction], the only difference is that the async function
39    /// does not block the runtime and requires the closure to be `'static`.
40    /// The static requirement is because the future may be canceled, but the transaction can not
41    /// be canceled.
42    pub async fn transaction<R: 'static + Send>(
43        &self,
44        f: impl 'static + Send + FnOnce(&'static Transaction<S>) -> R,
45    ) -> R {
46        let db = self.inner.clone();
47        async_run(move || db.transaction_local(f)).await
48    }
49
50    /// This is a lot like [Database::transaction_mut], the only difference is that the async function
51    /// does not block the runtime and requires the closure to be `'static`.
52    /// The static requirement is because the future may be canceled, but the transaction can not
53    /// be canceled.
54    pub async fn transaction_mut<O: 'static + Send, E: 'static + Send>(
55        &self,
56        f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
57    ) -> Result<O, E> {
58        let db = self.inner.clone();
59        async_run(move || db.transaction_mut_local(f)).await
60    }
61
62    /// This is a lot like [Database::transaction_mut_ok], the only difference is that the async function
63    /// does not block the runtime and requires the closure to be `'static`.
64    /// The static requirement is because the future may be canceled, but the transaction can not
65    /// be canceled.
66    pub async fn transaction_mut_ok<R: 'static + Send>(
67        &self,
68        f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> R,
69    ) -> R {
70        self.transaction_mut(|txn| Ok::<R, std::convert::Infallible>(f(txn)))
71            .await
72            .unwrap()
73    }
74}
75
76async fn async_run<R: 'static + Send>(f: impl 'static + Send + FnOnce() -> R) -> R {
77    pub struct WakeOnDrop {
78        waker: Mutex<Waker>,
79    }
80
81    impl Drop for WakeOnDrop {
82        fn drop(&mut self) {
83            self.waker.lock().unwrap().wake_by_ref();
84        }
85    }
86
87    // Initally we use a noop waker, because we will override it anyway.
88    let wake_on_drop = Arc::new(WakeOnDrop {
89        waker: Mutex::new(Waker::noop().clone()),
90    });
91    let weak = Arc::downgrade(&wake_on_drop);
92
93    let handle = std::thread::spawn(move || {
94        // waker will be called when thread finishes, even with panic.
95        let _wake_on_drop = wake_on_drop;
96        f()
97    });
98
99    // asynchonously wait for the thread to finish
100    future::poll_fn(|cx| {
101        if let Some(wake_on_drop) = weak.upgrade() {
102            wake_on_drop.waker.lock().unwrap().clone_from(cx.waker());
103            Poll::Pending
104        } else {
105            Poll::Ready(())
106        }
107    })
108    .await;
109
110    // we know that the thread is finished, so we block on it
111    match handle.join() {
112        Ok(val) => val,
113        Err(err) => std::panic::resume_unwind(err),
114    }
115}