async_stm/vars.rs
1use parking_lot::{Mutex, RwLock};
2
3use crate::{transaction::with_tx, version::Version, Stm};
4use std::{
5 any::Any,
6 marker::PhantomData,
7 mem,
8 sync::{
9 atomic::{AtomicU64, Ordering},
10 Arc,
11 },
12};
13
14/// Unique ID for a [TVar].
15pub type ID = u64;
16
17/// The value can be read by many threads, so it has to be tracked by an `Arc`.
18/// Keeping it dynamic because trying to make [LVar] generic turned out to be
19/// a bit of a nightmare.
20type DynValue = Arc<dyn Any + Send + Sync>;
21
22/// A versioned value. It will only be accessed through a transaction and a [TVar].
23#[derive(Clone)]
24pub struct VVar {
25 pub version: Version,
26 pub value: DynValue,
27}
28
29impl VVar {
30 /// Perform a downcast on a var. Returns an `Arc` that tracks when that variable
31 /// will go out of scope. This avoids cloning on reads, if the value needs to be
32 /// mutated then it can be cloned after being read.
33 pub fn downcast<T: Any + Sync + Send>(&self) -> Arc<T> {
34 match self.value.clone().downcast::<T>() {
35 Ok(s) => s,
36 Err(_) => unreachable!("TVar has wrong type"),
37 }
38 }
39}
40
41/// Using a channel to wake up tasks when a [TVar] they read changed.
42///
43/// Sending `true` means the associated [TVar] has been updated.
44/// Sending `false` means there are too many subscribers on a variable
45/// that doesn't seem to be written to. In this case any listeners still
46/// alive can re-subscribe.
47type Signaler = tokio::sync::mpsc::UnboundedSender<bool>;
48
49pub struct WaitQueue {
50 /// Store the last version which was written to avoid race condition where the notification
51 /// happens before the waiters would subscribe and then there's no further event that would
52 /// unpark them, causing them to wait forever or until they time out.
53 ///
54 /// This can happen if the order of events on thread A and B are:
55 /// 1. [atomically] on A returns `Retry`
56 /// 2. `commit` on B updates the versions
57 /// 3. `notify` on B finds nobody in the wait queues
58 /// 4. `wait` on A adds itself to the wait queues
59 ///
60 /// By having `notify` update the `last_written_version` we make sure that `wait` sees it.
61 pub last_written_version: Version,
62
63 /// Signalers for tasks waiting for the [TVar] to get an update.
64 waiting: Vec<Signaler>,
65
66 /// Highest number of transactions waiting we have encountered so far.
67 max_waiting: usize,
68}
69
70impl WaitQueue {
71 pub fn new() -> Self {
72 WaitQueue {
73 last_written_version: Default::default(),
74 waiting: Vec::new(),
75 max_waiting: 1,
76 }
77 }
78
79 /// Register a transaction as waiting for this [TVar] to be updated.
80 pub fn add(&mut self, s: Signaler) {
81 self.prune();
82 self.waiting.push(s)
83 }
84
85 /// Signal to all waiting transactions that this [TVar] has been updated.
86 pub fn notify_all(&mut self, commit_version: Version) {
87 self.last_written_version = commit_version;
88
89 if self.waiting.is_empty() {
90 return;
91 }
92
93 let waiting = mem::take(&mut self.waiting);
94
95 for tx in waiting {
96 let _ = tx.send(true);
97 }
98 }
99
100 /// Whenever we see the length of the waiting queue hit a new record,
101 /// remove any waiting signalers that already had their receiver closed,
102 /// which means some other [TVar] they subscribed to has been updated.
103 ///
104 /// This is to avoid memory leaks in the wait queue of a [TVar] which
105 /// is frequently read but never written to.
106 fn prune(&mut self) {
107 if self.waiting.len() > self.max_waiting {
108 self.waiting.retain(|tx| tx.send(false).is_ok());
109 self.max_waiting = self.max_waiting.max(self.waiting.len());
110 }
111 }
112}
113
114impl Default for WaitQueue {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120/// Sync variable, hold the committed value and the waiting threads.
121pub struct SVar {
122 pub vvar: RwLock<VVar>,
123 pub queue: Mutex<WaitQueue>,
124}
125
126/// A variable in the transaction log that remembers if it has been read and/or written to.
127#[derive(Clone)]
128pub struct LVar {
129 // Hold on the original that we need to commit to.
130 pub svar: Arc<SVar>,
131 // Hold on to the value as it was read or written for MVCC comparison.
132 pub vvar: VVar,
133 /// Remember reads; these are the variables we need to watch if we retry.
134 pub read: bool,
135 /// Remember writes; these are the variables that need to be stored at the
136 /// end of the transaction, but they don't need to be watched if we retry.
137 pub write: bool,
138}
139
140/// [TVar] is our handle to a variable, but reading and writing go through a transaction.
141/// It also tracks which threads are waiting on it.
142#[derive(Clone)]
143pub struct TVar<T> {
144 pub(crate) id: ID,
145 pub(crate) svar: Arc<SVar>,
146 phantom: PhantomData<T>,
147}
148
149impl<T> Default for TVar<T>
150where
151 T: Any + Sync + Send + Clone + Default,
152{
153 fn default() -> Self {
154 Self::new(Default::default())
155 }
156}
157
158impl<T: Any + Sync + Send + Clone> TVar<T> {
159 /// Create a new [TVar]. The initial version is 0, so that if a
160 /// [TVar] is created in the middle of a transaction it will
161 /// not cause any MVCC conflict during the commit.
162 pub fn new(value: T) -> TVar<T> {
163 // This is shared between all [TVar]s.
164 static COUNTER: AtomicU64 = AtomicU64::new(0);
165
166 TVar {
167 id: COUNTER.fetch_add(1, Ordering::Relaxed),
168 svar: Arc::new(SVar {
169 vvar: RwLock::new(VVar {
170 version: Default::default(),
171 value: Arc::new(value),
172 }),
173 queue: Mutex::new(WaitQueue::default()),
174 }),
175 phantom: PhantomData,
176 }
177 }
178
179 /// Read the value of the [TVar] as a clone, for subsequent modification. Only call this inside [atomically].
180 pub fn read_clone(&self) -> Stm<T> {
181 with_tx(|tx| tx.read(self).map(|r| r.as_ref().clone()))
182 }
183
184 /// Read the value of the [TVar]. Only call this inside [atomically].
185 pub fn read(&self) -> Stm<Arc<T>> {
186 with_tx(|tx| tx.read(self))
187 }
188
189 /// Replace the value of the [TVar]. Only call this inside [atomically].
190 pub fn write(&self, value: T) -> Stm<()> {
191 with_tx(move |tx| tx.write(self, value))
192 }
193
194 /// Apply an update on the value of the [TVar]. Only call this inside [atomically].
195 pub fn update<F>(&self, f: F) -> Stm<()>
196 where
197 F: FnOnce(T) -> T,
198 {
199 let v = self.read_clone()?;
200 self.write(f(v))
201 }
202
203 /// Apply an update on the value of the [TVar]. Only call this inside [atomically].
204 pub fn update_mut<F>(&self, f: F) -> Stm<()>
205 where
206 F: FnOnce(&mut T),
207 {
208 let mut v = self.read_clone()?;
209 f(&mut v);
210 self.write(v)
211 }
212
213 /// Apply an update on the value of the [TVar] and return a value. Only call this inside [atomically].
214 pub fn modify<F, R>(&self, f: F) -> Stm<R>
215 where
216 F: FnOnce(T) -> (T, R),
217 {
218 let v = self.read_clone()?;
219 let (w, r) = f(v);
220 self.write(w)?;
221 Ok(r)
222 }
223
224 /// Apply an update on the value of the [TVar] and return a value. Only call this inside [atomically].
225 pub fn modify_mut<F, R>(&self, f: F) -> Stm<R>
226 where
227 F: FnOnce(&mut T) -> R,
228 {
229 let mut v = self.read_clone()?;
230 let r = f(&mut v);
231 self.write(v)?;
232 Ok(r)
233 }
234
235 /// Replace the value of the [TVar] and return the previous value. Only call this inside [atomically].
236 pub fn replace(&self, value: T) -> Stm<Arc<T>> {
237 let v = self.read()?;
238 self.write(value)?;
239 Ok(v)
240 }
241}