async_stm/
vars.rs

1use parking_lot::{Mutex, RwLock};
2
3use crate::{transaction::with_tx, version::Version, Stm};
4use std::{
5    any::Any,
6    marker::PhantomData,
7    mem,
8    sync::{
9        atomic::{AtomicU64, Ordering},
10        Arc,
11    },
12};
13
14/// Unique ID for a [TVar].
15pub type ID = u64;
16
17/// The value can be read by many threads, so it has to be tracked by an `Arc`.
18/// Keeping it dynamic because trying to make [LVar] generic turned out to be
19/// a bit of a nightmare.
20type DynValue = Arc<dyn Any + Send + Sync>;
21
22/// A versioned value. It will only be accessed through a transaction and a [TVar].
23#[derive(Clone)]
24pub struct VVar {
25    pub version: Version,
26    pub value: DynValue,
27}
28
29impl VVar {
30    /// Perform a downcast on a var. Returns an `Arc` that tracks when that variable
31    /// will go out of scope. This avoids cloning on reads, if the value needs to be
32    /// mutated then it can be cloned after being read.
33    pub fn downcast<T: Any + Sync + Send>(&self) -> Arc<T> {
34        match self.value.clone().downcast::<T>() {
35            Ok(s) => s,
36            Err(_) => unreachable!("TVar has wrong type"),
37        }
38    }
39}
40
41/// Using a channel to wake up tasks when a [TVar] they read changed.
42///
43/// Sending `true` means the associated [TVar] has been updated.
44/// Sending `false` means there are too many subscribers on a variable
45/// that doesn't seem to be written to. In this case any listeners still
46/// alive can re-subscribe.
47type Signaler = tokio::sync::mpsc::UnboundedSender<bool>;
48
49pub struct WaitQueue {
50    /// Store the last version which was written to avoid race condition where the notification
51    /// happens before the waiters would subscribe and then there's no further event that would
52    /// unpark them, causing them to wait forever or until they time out.
53    ///
54    /// This can happen if the order of events on thread A and B are:
55    /// 1. [atomically] on A returns `Retry`
56    /// 2. `commit` on B updates the versions
57    /// 3. `notify` on B finds nobody in the wait queues
58    /// 4. `wait` on A adds itself to the wait queues
59    ///
60    /// By having `notify` update the `last_written_version` we make sure that `wait` sees it.
61    pub last_written_version: Version,
62
63    /// Signalers for tasks waiting for the [TVar] to get an update.
64    waiting: Vec<Signaler>,
65
66    /// Highest number of transactions waiting we have encountered so far.
67    max_waiting: usize,
68}
69
70impl WaitQueue {
71    pub fn new() -> Self {
72        WaitQueue {
73            last_written_version: Default::default(),
74            waiting: Vec::new(),
75            max_waiting: 1,
76        }
77    }
78
79    /// Register a transaction as waiting for this [TVar] to be updated.
80    pub fn add(&mut self, s: Signaler) {
81        self.prune();
82        self.waiting.push(s)
83    }
84
85    /// Signal to all waiting transactions that this [TVar] has been updated.
86    pub fn notify_all(&mut self, commit_version: Version) {
87        self.last_written_version = commit_version;
88
89        if self.waiting.is_empty() {
90            return;
91        }
92
93        let waiting = mem::take(&mut self.waiting);
94
95        for tx in waiting {
96            let _ = tx.send(true);
97        }
98    }
99
100    /// Whenever we see the length of the waiting queue hit a new record,
101    /// remove any waiting signalers that already had their receiver closed,
102    /// which means some other [TVar] they subscribed to has been updated.
103    ///
104    /// This is to avoid memory leaks in the wait queue of a [TVar] which
105    /// is frequently read but never written to.
106    fn prune(&mut self) {
107        if self.waiting.len() > self.max_waiting {
108            self.waiting.retain(|tx| tx.send(false).is_ok());
109            self.max_waiting = self.max_waiting.max(self.waiting.len());
110        }
111    }
112}
113
114impl Default for WaitQueue {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120/// Sync variable, hold the committed value and the waiting threads.
121pub struct SVar {
122    pub vvar: RwLock<VVar>,
123    pub queue: Mutex<WaitQueue>,
124}
125
126/// A variable in the transaction log that remembers if it has been read and/or written to.
127#[derive(Clone)]
128pub struct LVar {
129    // Hold on the original that we need to commit to.
130    pub svar: Arc<SVar>,
131    // Hold on to the value as it was read or written for MVCC comparison.
132    pub vvar: VVar,
133    /// Remember reads; these are the variables we need to watch if we retry.
134    pub read: bool,
135    /// Remember writes; these are the variables that need to be stored at the
136    /// end of the transaction, but they don't need to be watched if we retry.
137    pub write: bool,
138}
139
140/// [TVar] is our handle to a variable, but reading and writing go through a transaction.
141/// It also tracks which threads are waiting on it.
142#[derive(Clone)]
143pub struct TVar<T> {
144    pub(crate) id: ID,
145    pub(crate) svar: Arc<SVar>,
146    phantom: PhantomData<T>,
147}
148
149impl<T> Default for TVar<T>
150where
151    T: Any + Sync + Send + Clone + Default,
152{
153    fn default() -> Self {
154        Self::new(Default::default())
155    }
156}
157
158impl<T: Any + Sync + Send + Clone> TVar<T> {
159    /// Create a new [TVar]. The initial version is 0, so that if a
160    /// [TVar] is created in the middle of a transaction it will
161    /// not cause any MVCC conflict during the commit.
162    pub fn new(value: T) -> TVar<T> {
163        // This is shared between all [TVar]s.
164        static COUNTER: AtomicU64 = AtomicU64::new(0);
165
166        TVar {
167            id: COUNTER.fetch_add(1, Ordering::Relaxed),
168            svar: Arc::new(SVar {
169                vvar: RwLock::new(VVar {
170                    version: Default::default(),
171                    value: Arc::new(value),
172                }),
173                queue: Mutex::new(WaitQueue::default()),
174            }),
175            phantom: PhantomData,
176        }
177    }
178
179    /// Read the value of the [TVar] as a clone, for subsequent modification. Only call this inside [atomically].
180    pub fn read_clone(&self) -> Stm<T> {
181        with_tx(|tx| tx.read(self).map(|r| r.as_ref().clone()))
182    }
183
184    /// Read the value of the [TVar]. Only call this inside [atomically].
185    pub fn read(&self) -> Stm<Arc<T>> {
186        with_tx(|tx| tx.read(self))
187    }
188
189    /// Replace the value of the [TVar]. Only call this inside [atomically].
190    pub fn write(&self, value: T) -> Stm<()> {
191        with_tx(move |tx| tx.write(self, value))
192    }
193
194    /// Apply an update on the value of the [TVar]. Only call this inside [atomically].
195    pub fn update<F>(&self, f: F) -> Stm<()>
196    where
197        F: FnOnce(T) -> T,
198    {
199        let v = self.read_clone()?;
200        self.write(f(v))
201    }
202
203    /// Apply an update on the value of the [TVar]. Only call this inside [atomically].
204    pub fn update_mut<F>(&self, f: F) -> Stm<()>
205    where
206        F: FnOnce(&mut T),
207    {
208        let mut v = self.read_clone()?;
209        f(&mut v);
210        self.write(v)
211    }
212
213    /// Apply an update on the value of the [TVar] and return a value. Only call this inside [atomically].
214    pub fn modify<F, R>(&self, f: F) -> Stm<R>
215    where
216        F: FnOnce(T) -> (T, R),
217    {
218        let v = self.read_clone()?;
219        let (w, r) = f(v);
220        self.write(w)?;
221        Ok(r)
222    }
223
224    /// Apply an update on the value of the [TVar] and return a value. Only call this inside [atomically].
225    pub fn modify_mut<F, R>(&self, f: F) -> Stm<R>
226    where
227        F: FnOnce(&mut T) -> R,
228    {
229        let mut v = self.read_clone()?;
230        let r = f(&mut v);
231        self.write(v)?;
232        Ok(r)
233    }
234
235    /// Replace the value of the [TVar] and return the previous value. Only call this inside [atomically].
236    pub fn replace(&self, value: T) -> Stm<Arc<T>> {
237        let v = self.read()?;
238        self.write(value)?;
239        Ok(v)
240    }
241}