effect_rs/
stm.rs

1use crate::core::{Cause, Effect, Exit};
2use std::any::Any;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6type STMRun<E, A> = dyn Fn(&mut Journal) -> STMResult<E, A> + Send + Sync;
7
8/// STM Effect
9pub struct STM<E, A> {
10    pub(crate) run: Arc<STMRun<E, A>>,
11}
12
13impl<E, A> Clone for STM<E, A> {
14    fn clone(&self) -> Self {
15        Self {
16            run: self.run.clone(),
17        }
18    }
19}
20
21pub enum STMResult<E, A> {
22    Success(A),
23    Failure(E),
24    Retry,
25}
26
27/// Abstract wrapper around a TRef state to allow locking in the Journal
28#[allow(dead_code)]
29trait AbstractEntry: Send + Sync {
30    fn id(&self) -> u64;
31    fn version(&self) -> u64;
32    // We need to be able to clone the inner Arc<RwLock> but type erased?
33    // Actually, we just need to be able to lock it.
34    // The Journal needs to hold `dyn AbstractEntry`?
35    // If we hold `Arc<dyn AbstractEntry>`, we can't lock easily because RwLock is generic.
36    // Let's invert: TRef<A> implements Entry which holds Arc<RwLock<State<A>>>.
37    // We need a non-generic trait to store in Journal list.
38
39    // Lock for reading version
40    fn read_version(&self) -> u64;
41    // Lock for writing (update value and version)
42    fn commit(&self, value: Box<dyn Any + Send + Sync>);
43
44    fn notify(&self);
45    fn listen(&self) -> Arc<tokio::sync::Notify>;
46}
47
48struct EntryImpl<A> {
49    #[allow(dead_code)]
50    id: u64,
51    inner: Arc<RwLock<TRefState<A>>>,
52}
53
54impl<A> AbstractEntry for EntryImpl<A>
55where
56    A: Send + Sync + 'static,
57{
58    fn id(&self) -> u64 {
59        self.id
60    }
61    fn version(&self) -> u64 {
62        self.inner.read().unwrap().version
63    }
64
65    fn read_version(&self) -> u64 {
66        self.inner.read().unwrap().version
67    }
68
69    fn commit(&self, value: Box<dyn Any + Send + Sync>) {
70        if let Ok(typed) = value.downcast::<A>() {
71            let mut guard = self.inner.write().unwrap();
72            guard.value = *typed;
73            guard.version += 1;
74            guard.notify.notify_waiters();
75        } else {
76            panic!("STM Commit Type Mismatch");
77        }
78    }
79
80    fn notify(&self) {
81        self.inner.read().unwrap().notify.notify_waiters();
82    }
83
84    fn listen(&self) -> Arc<tokio::sync::Notify> {
85        self.inner.read().unwrap().notify.clone()
86    }
87}
88
89/// Transactional Reference
90pub struct TRef<A> {
91    id: u64,
92    // Store type-erased entry for Journal interactions
93    entry: Arc<EntryImpl<A>>,
94}
95
96impl<A> Clone for TRef<A> {
97    fn clone(&self) -> Self {
98        Self {
99            id: self.id,
100            entry: self.entry.clone(),
101        }
102    }
103}
104
105struct TRefState<A> {
106    version: u64,
107    value: A,
108    notify: Arc<tokio::sync::Notify>,
109}
110
111type ReadEntry = (u64, Arc<dyn AbstractEntry>);
112type WriteEntry = (Box<dyn Any + Send + Sync>, Arc<dyn AbstractEntry>);
113
114/// Transaction Journal
115pub struct Journal {
116    // Reads: Map TRefId -> (Version, Entry)
117    reads: HashMap<u64, ReadEntry>,
118    // Writes: Map TRefId -> (NewValue, Entry)
119    writes: HashMap<u64, WriteEntry>,
120}
121
122impl Default for Journal {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl Journal {
129    pub fn new() -> Self {
130        Self {
131            reads: HashMap::new(),
132            writes: HashMap::new(),
133        }
134    }
135}
136
137static TREF_ID_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
138
139impl<A> TRef<A>
140where
141    A: Send + Sync + Clone + 'static,
142{
143    pub fn new(value: A) -> Self {
144        let id = TREF_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145        let notify = Arc::new(tokio::sync::Notify::new());
146        let inner = Arc::new(RwLock::new(TRefState {
147            version: 0,
148            value,
149            notify,
150        }));
151        Self {
152            id,
153            entry: Arc::new(EntryImpl { id, inner }),
154        }
155    }
156
157    pub fn get(&self) -> STM<(), A> {
158        let entry = self.entry.clone();
159        let id = self.id;
160        STM {
161            run: Arc::new(move |journal| {
162                if let Some((val, _)) = journal.writes.get(&id) {
163                    if let Some(typed_val) = val.downcast_ref::<A>() {
164                        return STMResult::Success(typed_val.clone());
165                    } else {
166                        panic!("TRef type mismatch");
167                    }
168                }
169
170                // Read via Entry
171                // We need to access value directly. Entry trait only exposes generic ops?
172                // We shouldn't lock via trait if we have typed access here.
173                // But we must assume 'entry' is sufficient.
174                // Actually, `self.entry.inner` is accessible here because we are in `get` closure
175                // which captures `self.entry` (EntryImpl<A>).
176                let guard = entry.inner.read().unwrap();
177                journal
178                    .reads
179                    .entry(id)
180                    .or_insert((guard.version, entry.clone())); // Store generic entry
181                STMResult::Success(guard.value.clone())
182            }),
183        }
184    }
185
186    pub fn set(&self, value: A) -> STM<(), ()> {
187        let entry = self.entry.clone();
188        let id = self.id;
189        STM {
190            run: Arc::new(move |journal| {
191                journal
192                    .writes
193                    .insert(id, (Box::new(value.clone()), entry.clone()));
194                STMResult::Success(())
195            }),
196        }
197    }
198}
199
200// STM Combinators
201
202impl<E, A> STM<E, A>
203where
204    E: 'static,
205    A: 'static,
206{
207    pub fn succeed(a: A) -> STM<E, A>
208    where
209        A: Clone + Send + Sync,
210    {
211        STM {
212            run: Arc::new(move |_| STMResult::Success(a.clone())),
213        }
214    }
215
216    pub fn fail(e: E) -> STM<E, A>
217    where
218        E: Clone + Send + Sync,
219    {
220        STM {
221            run: Arc::new(move |_| STMResult::Failure(e.clone())),
222        }
223    }
224
225    pub fn retry() -> STM<E, A> {
226        STM {
227            run: Arc::new(|_| STMResult::Retry),
228        }
229    }
230
231    pub fn map<B>(self, f: impl Fn(A) -> B + Send + Sync + 'static) -> STM<E, B>
232    where
233        A: Send + Sync,
234        B: Send + Sync,
235    {
236        STM {
237            run: Arc::new(move |journal| match (self.run)(journal) {
238                STMResult::Success(a) => STMResult::Success(f(a)),
239                STMResult::Failure(e) => STMResult::Failure(e),
240                STMResult::Retry => STMResult::Retry,
241            }),
242        }
243    }
244
245    pub fn flat_map<B>(self, f: impl Fn(A) -> STM<E, B> + Send + Sync + 'static) -> STM<E, B>
246    where
247        A: Send + Sync,
248        B: Send + Sync,
249    {
250        STM {
251            run: Arc::new(move |journal| match (self.run)(journal) {
252                STMResult::Success(a) => {
253                    let next = f(a);
254                    (next.run)(journal)
255                }
256                STMResult::Failure(e) => STMResult::Failure(e),
257                STMResult::Retry => STMResult::Retry,
258            }),
259        }
260    }
261
262    pub fn map_error<E2>(self, f: impl Fn(E) -> E2 + Send + Sync + 'static) -> STM<E2, A>
263    where
264        E2: 'static,
265    {
266        STM {
267            run: Arc::new(move |journal| match (self.run)(journal) {
268                STMResult::Success(a) => STMResult::Success(a),
269                STMResult::Failure(e) => STMResult::Failure(f(e)),
270                STMResult::Retry => STMResult::Retry,
271            }),
272        }
273    }
274}
275
276// Convert STM to Effect
277impl<E, A> STM<E, A>
278where
279    E: Send + Sync + Clone + 'static,
280    A: Send + Sync + Clone + 'static,
281{
282    pub fn commit<R>(self) -> Effect<R, E, A>
283    where
284        R: Send + Sync + 'static + Clone,
285    {
286        Effect {
287            inner: Arc::new(move |_, _| {
288                // R is ignored for now
289                let stm = self.clone();
290                Box::pin(async move {
291                    loop {
292                        let mut journal = Journal::new();
293                        let result = (stm.run)(&mut journal);
294
295                        match result {
296                            STMResult::Success(val) => {
297                                // 1. Gather all unique IDs involved (reads + writes)
298                                // But for validation we only care about Reads.
299                                // For locking, typical STM locks writes.
300                                // Optimistic: Validate Reads, Lock Writes.
301
302                                // Lock Implementation (Simplified Global ordered lock on TRefs?)
303                                // We need to lock ALL modified TRefs to prevent write skews.
304                                // Basic Strat: Lock everything involved in arbitrary order (risk deadlock)
305                                // Or sort by ID to prevent deadlock.
306
307                                let mut write_ids: Vec<u64> =
308                                    journal.writes.keys().cloned().collect();
309                                write_ids.sort(); // Lock in order
310
311                                // We need to get the Entries.
312                                // Actually, we can just grab entries from map.
313
314                                // Lock Phase (Only abstract entries, we can't really "lock" generic traits easily without exposing Mutex)
315                                // The AbstractEntry uses RwLock.
316                                // We can't hold multiple guards easily in a generic list in implementation.
317                                // This is tricky in Rust.
318
319                                // Alternative: Global STM Lock (Mutex).
320                                // Very slow but safe.
321
322                                // Let's implement Optimistic Validation logic without holding locks across all operations if possible?
323                                // No, commit must be atomic.
324
325                                // Let's use a "Try Commit" approach.
326                                // 1. Validate all Reads (check versions == current).
327                                //    If any fail, RETRY loop.
328                                // 2. If valid, try to Lock all Writes.
329                                // 3. Re-validate Reads (because between 1 and 2, things changed).
330                                // 4. Update.
331
332                                // Since `AbstractEntry` hides `RwLock`, we implemented `commit` which locks internally.
333                                // But we need to lock multiple at once.
334
335                                // Hack: Global STM Mutex to serialize Commits.
336                                // Reads can run concurrent. Commits are serialized.
337                                // This is a common simplified STM strategy.
338                                // Let's do that for now.
339
340                                {
341                                    // GLOBAL_COMMIT_LOCK
342                                    let _guard = GLOBAL_STM_LOCK.lock().unwrap();
343
344                                    // Validate Reads
345                                    let mut valid = true;
346                                    for (ver, entry) in journal.reads.values() {
347                                        if entry.read_version() != *ver {
348                                            valid = false;
349                                            break;
350                                        }
351                                    }
352
353                                    if !valid {
354                                        continue; // Retry loop
355                                    }
356
357                                    // Commit Writes
358                                    for (_, (val, entry)) in journal.writes {
359                                        entry.commit(val);
360                                    }
361                                }
362
363                                return Exit::Success(val);
364                            }
365                            STMResult::Failure(e) => return Exit::Failure(Cause::Fail(e)),
366                            STMResult::Retry => {
367                                // Collect notifications
368                                let notifiers: Vec<Arc<tokio::sync::Notify>> = journal
369                                    .reads
370                                    .values()
371                                    .map(|(_, entry)| entry.listen())
372                                    .collect();
373
374                                if notifiers.is_empty() {
375                                    // Retry called but no reads? Deadlock/Infinite loop.
376                                    // Just yield or fail.
377                                    return Exit::Failure(Cause::Die(Arc::new(
378                                        "STM Retry with no dependencies".to_string(),
379                                    )));
380                                }
381
382                                // Wait for ANY notification
383                                // Basic: Wait for first one.
384                                // Optimized: Select all.
385                                // Tokio Notify doesn't support select_all trivially on a Vec.
386                                // We can spawn tasks or just wait on one?
387                                // If we wait on just one, we might miss others? No, we just need *one* change to retry.
388                                // But if we pick the wrong one (that doesn't change), we sleep forever?
389                                // We need to wait until ANY of them fires.
390
391                                // We need to wait until ANY of them fires.
392
393                                // Pin the futures on the heap so they are Unpin for select_all
394                                let futures: Vec<_> =
395                                    notifiers.iter().map(|n| Box::pin(n.notified())).collect();
396
397                                futures::future::select_all(futures).await;
398
399                                continue;
400                            }
401                        }
402                    }
403                })
404            }),
405        }
406    }
407}
408
409lazy_static::lazy_static! {
410    static ref GLOBAL_STM_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
411}