effect_rs/stm.rs
1use crate::core::{Cause, Effect, Exit};
2use std::any::Any;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6type STMRun<E, A> = dyn Fn(&mut Journal) -> STMResult<E, A> + Send + Sync;
7
8/// STM Effect
9pub struct STM<E, A> {
10 pub(crate) run: Arc<STMRun<E, A>>,
11}
12
13impl<E, A> Clone for STM<E, A> {
14 fn clone(&self) -> Self {
15 Self {
16 run: self.run.clone(),
17 }
18 }
19}
20
21pub enum STMResult<E, A> {
22 Success(A),
23 Failure(E),
24 Retry,
25}
26
27/// Abstract wrapper around a TRef state to allow locking in the Journal
28#[allow(dead_code)]
29trait AbstractEntry: Send + Sync {
30 fn id(&self) -> u64;
31 fn version(&self) -> u64;
32 // We need to be able to clone the inner Arc<RwLock> but type erased?
33 // Actually, we just need to be able to lock it.
34 // The Journal needs to hold `dyn AbstractEntry`?
35 // If we hold `Arc<dyn AbstractEntry>`, we can't lock easily because RwLock is generic.
36 // Let's invert: TRef<A> implements Entry which holds Arc<RwLock<State<A>>>.
37 // We need a non-generic trait to store in Journal list.
38
39 // Lock for reading version
40 fn read_version(&self) -> u64;
41 // Lock for writing (update value and version)
42 fn commit(&self, value: Box<dyn Any + Send + Sync>);
43
44 fn notify(&self);
45 fn listen(&self) -> Arc<tokio::sync::Notify>;
46}
47
48struct EntryImpl<A> {
49 #[allow(dead_code)]
50 id: u64,
51 inner: Arc<RwLock<TRefState<A>>>,
52}
53
54impl<A> AbstractEntry for EntryImpl<A>
55where
56 A: Send + Sync + 'static,
57{
58 fn id(&self) -> u64 {
59 self.id
60 }
61 fn version(&self) -> u64 {
62 self.inner.read().unwrap().version
63 }
64
65 fn read_version(&self) -> u64 {
66 self.inner.read().unwrap().version
67 }
68
69 fn commit(&self, value: Box<dyn Any + Send + Sync>) {
70 if let Ok(typed) = value.downcast::<A>() {
71 let mut guard = self.inner.write().unwrap();
72 guard.value = *typed;
73 guard.version += 1;
74 guard.notify.notify_waiters();
75 } else {
76 panic!("STM Commit Type Mismatch");
77 }
78 }
79
80 fn notify(&self) {
81 self.inner.read().unwrap().notify.notify_waiters();
82 }
83
84 fn listen(&self) -> Arc<tokio::sync::Notify> {
85 self.inner.read().unwrap().notify.clone()
86 }
87}
88
89/// Transactional Reference
90pub struct TRef<A> {
91 id: u64,
92 // Store type-erased entry for Journal interactions
93 entry: Arc<EntryImpl<A>>,
94}
95
96impl<A> Clone for TRef<A> {
97 fn clone(&self) -> Self {
98 Self {
99 id: self.id,
100 entry: self.entry.clone(),
101 }
102 }
103}
104
105struct TRefState<A> {
106 version: u64,
107 value: A,
108 notify: Arc<tokio::sync::Notify>,
109}
110
111type ReadEntry = (u64, Arc<dyn AbstractEntry>);
112type WriteEntry = (Box<dyn Any + Send + Sync>, Arc<dyn AbstractEntry>);
113
114/// Transaction Journal
115pub struct Journal {
116 // Reads: Map TRefId -> (Version, Entry)
117 reads: HashMap<u64, ReadEntry>,
118 // Writes: Map TRefId -> (NewValue, Entry)
119 writes: HashMap<u64, WriteEntry>,
120}
121
122impl Default for Journal {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl Journal {
129 pub fn new() -> Self {
130 Self {
131 reads: HashMap::new(),
132 writes: HashMap::new(),
133 }
134 }
135}
136
137static TREF_ID_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
138
139impl<A> TRef<A>
140where
141 A: Send + Sync + Clone + 'static,
142{
143 pub fn new(value: A) -> Self {
144 let id = TREF_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145 let notify = Arc::new(tokio::sync::Notify::new());
146 let inner = Arc::new(RwLock::new(TRefState {
147 version: 0,
148 value,
149 notify,
150 }));
151 Self {
152 id,
153 entry: Arc::new(EntryImpl { id, inner }),
154 }
155 }
156
157 pub fn get(&self) -> STM<(), A> {
158 let entry = self.entry.clone();
159 let id = self.id;
160 STM {
161 run: Arc::new(move |journal| {
162 if let Some((val, _)) = journal.writes.get(&id) {
163 if let Some(typed_val) = val.downcast_ref::<A>() {
164 return STMResult::Success(typed_val.clone());
165 } else {
166 panic!("TRef type mismatch");
167 }
168 }
169
170 // Read via Entry
171 // We need to access value directly. Entry trait only exposes generic ops?
172 // We shouldn't lock via trait if we have typed access here.
173 // But we must assume 'entry' is sufficient.
174 // Actually, `self.entry.inner` is accessible here because we are in `get` closure
175 // which captures `self.entry` (EntryImpl<A>).
176 let guard = entry.inner.read().unwrap();
177 journal
178 .reads
179 .entry(id)
180 .or_insert((guard.version, entry.clone())); // Store generic entry
181 STMResult::Success(guard.value.clone())
182 }),
183 }
184 }
185
186 pub fn set(&self, value: A) -> STM<(), ()> {
187 let entry = self.entry.clone();
188 let id = self.id;
189 STM {
190 run: Arc::new(move |journal| {
191 journal
192 .writes
193 .insert(id, (Box::new(value.clone()), entry.clone()));
194 STMResult::Success(())
195 }),
196 }
197 }
198}
199
200// STM Combinators
201
202impl<E, A> STM<E, A>
203where
204 E: 'static,
205 A: 'static,
206{
207 pub fn succeed(a: A) -> STM<E, A>
208 where
209 A: Clone + Send + Sync,
210 {
211 STM {
212 run: Arc::new(move |_| STMResult::Success(a.clone())),
213 }
214 }
215
216 pub fn fail(e: E) -> STM<E, A>
217 where
218 E: Clone + Send + Sync,
219 {
220 STM {
221 run: Arc::new(move |_| STMResult::Failure(e.clone())),
222 }
223 }
224
225 pub fn retry() -> STM<E, A> {
226 STM {
227 run: Arc::new(|_| STMResult::Retry),
228 }
229 }
230
231 pub fn map<B>(self, f: impl Fn(A) -> B + Send + Sync + 'static) -> STM<E, B>
232 where
233 A: Send + Sync,
234 B: Send + Sync,
235 {
236 STM {
237 run: Arc::new(move |journal| match (self.run)(journal) {
238 STMResult::Success(a) => STMResult::Success(f(a)),
239 STMResult::Failure(e) => STMResult::Failure(e),
240 STMResult::Retry => STMResult::Retry,
241 }),
242 }
243 }
244
245 pub fn flat_map<B>(self, f: impl Fn(A) -> STM<E, B> + Send + Sync + 'static) -> STM<E, B>
246 where
247 A: Send + Sync,
248 B: Send + Sync,
249 {
250 STM {
251 run: Arc::new(move |journal| match (self.run)(journal) {
252 STMResult::Success(a) => {
253 let next = f(a);
254 (next.run)(journal)
255 }
256 STMResult::Failure(e) => STMResult::Failure(e),
257 STMResult::Retry => STMResult::Retry,
258 }),
259 }
260 }
261
262 pub fn map_error<E2>(self, f: impl Fn(E) -> E2 + Send + Sync + 'static) -> STM<E2, A>
263 where
264 E2: 'static,
265 {
266 STM {
267 run: Arc::new(move |journal| match (self.run)(journal) {
268 STMResult::Success(a) => STMResult::Success(a),
269 STMResult::Failure(e) => STMResult::Failure(f(e)),
270 STMResult::Retry => STMResult::Retry,
271 }),
272 }
273 }
274}
275
276// Convert STM to Effect
277impl<E, A> STM<E, A>
278where
279 E: Send + Sync + Clone + 'static,
280 A: Send + Sync + Clone + 'static,
281{
282 pub fn commit<R>(self) -> Effect<R, E, A>
283 where
284 R: Send + Sync + 'static + Clone,
285 {
286 Effect {
287 inner: Arc::new(move |_, _| {
288 // R is ignored for now
289 let stm = self.clone();
290 Box::pin(async move {
291 loop {
292 let mut journal = Journal::new();
293 let result = (stm.run)(&mut journal);
294
295 match result {
296 STMResult::Success(val) => {
297 // 1. Gather all unique IDs involved (reads + writes)
298 // But for validation we only care about Reads.
299 // For locking, typical STM locks writes.
300 // Optimistic: Validate Reads, Lock Writes.
301
302 // Lock Implementation (Simplified Global ordered lock on TRefs?)
303 // We need to lock ALL modified TRefs to prevent write skews.
304 // Basic Strat: Lock everything involved in arbitrary order (risk deadlock)
305 // Or sort by ID to prevent deadlock.
306
307 let mut write_ids: Vec<u64> =
308 journal.writes.keys().cloned().collect();
309 write_ids.sort(); // Lock in order
310
311 // We need to get the Entries.
312 // Actually, we can just grab entries from map.
313
314 // Lock Phase (Only abstract entries, we can't really "lock" generic traits easily without exposing Mutex)
315 // The AbstractEntry uses RwLock.
316 // We can't hold multiple guards easily in a generic list in implementation.
317 // This is tricky in Rust.
318
319 // Alternative: Global STM Lock (Mutex).
320 // Very slow but safe.
321
322 // Let's implement Optimistic Validation logic without holding locks across all operations if possible?
323 // No, commit must be atomic.
324
325 // Let's use a "Try Commit" approach.
326 // 1. Validate all Reads (check versions == current).
327 // If any fail, RETRY loop.
328 // 2. If valid, try to Lock all Writes.
329 // 3. Re-validate Reads (because between 1 and 2, things changed).
330 // 4. Update.
331
332 // Since `AbstractEntry` hides `RwLock`, we implemented `commit` which locks internally.
333 // But we need to lock multiple at once.
334
335 // Hack: Global STM Mutex to serialize Commits.
336 // Reads can run concurrent. Commits are serialized.
337 // This is a common simplified STM strategy.
338 // Let's do that for now.
339
340 {
341 // GLOBAL_COMMIT_LOCK
342 let _guard = GLOBAL_STM_LOCK.lock().unwrap();
343
344 // Validate Reads
345 let mut valid = true;
346 for (ver, entry) in journal.reads.values() {
347 if entry.read_version() != *ver {
348 valid = false;
349 break;
350 }
351 }
352
353 if !valid {
354 continue; // Retry loop
355 }
356
357 // Commit Writes
358 for (_, (val, entry)) in journal.writes {
359 entry.commit(val);
360 }
361 }
362
363 return Exit::Success(val);
364 }
365 STMResult::Failure(e) => return Exit::Failure(Cause::Fail(e)),
366 STMResult::Retry => {
367 // Collect notifications
368 let notifiers: Vec<Arc<tokio::sync::Notify>> = journal
369 .reads
370 .values()
371 .map(|(_, entry)| entry.listen())
372 .collect();
373
374 if notifiers.is_empty() {
375 // Retry called but no reads? Deadlock/Infinite loop.
376 // Just yield or fail.
377 return Exit::Failure(Cause::Die(Arc::new(
378 "STM Retry with no dependencies".to_string(),
379 )));
380 }
381
382 // Wait for ANY notification
383 // Basic: Wait for first one.
384 // Optimized: Select all.
385 // Tokio Notify doesn't support select_all trivially on a Vec.
386 // We can spawn tasks or just wait on one?
387 // If we wait on just one, we might miss others? No, we just need *one* change to retry.
388 // But if we pick the wrong one (that doesn't change), we sleep forever?
389 // We need to wait until ANY of them fires.
390
391 // We need to wait until ANY of them fires.
392
393 // Pin the futures on the heap so they are Unpin for select_all
394 let futures: Vec<_> =
395 notifiers.iter().map(|n| Box::pin(n.notified())).collect();
396
397 futures::future::select_all(futures).await;
398
399 continue;
400 }
401 }
402 }
403 })
404 }),
405 }
406 }
407}
408
409lazy_static::lazy_static! {
410 static ref GLOBAL_STM_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
411}