1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
use std::{error::Error, mem};

use crate::auxtx::*;
use crate::{
    transaction::{with_tx, Transaction, TX},
    StmDynError, StmDynResult, StmError, StmResult,
};

/// Abandon the transaction and retry after some of the variables read have changed.
pub fn retry<T>() -> StmResult<T> {
    Err(StmError::Retry)
}

/// Retry unless a given condition has been met.
pub fn guard(cond: bool) -> StmResult<()> {
    if cond {
        Ok(())
    } else {
        retry()
    }
}

pub fn abort<T, E: Error + Send + Sync + 'static>(e: E) -> StmDynResult<T> {
    Err(StmDynError::Abort(Box::new(e)))
}

/// Run the first function; if it returns a `Retry`,
/// run the second function; if that too returns `Retry`
/// then combine the values they have read, so that
/// the overall retry will react to any change.
///
/// If they return `Failure` then just return that result,
/// since the transaction can be retried right now.
pub fn or<F, G, T>(f: F, g: G) -> StmResult<T>
where
    F: FnOnce() -> StmResult<T>,
    G: FnOnce() -> StmResult<T>,
{
    let mut snapshot = with_tx(|tx| tx.clone());

    match f() {
        Err(StmError::Retry) => {
            // Restore the original transaction state.
            with_tx(|tx| {
                mem::swap(tx, &mut snapshot);
            });

            match g() {
                retry @ Err(StmError::Retry) =>
                // Add any variable read in the first attempt.
                {
                    with_tx(|tx| {
                        for (id, lvar) in snapshot.log.into_iter() {
                            match tx.log.get(&id) {
                                Some(lvar) if lvar.read => {}
                                _ => {
                                    tx.log.insert(id, lvar);
                                }
                            }
                        }
                    });
                    retry
                }
                other => other,
            }
        }
        other => other,
    }
}

/// Create a new transaction and run `f` until it returns a successful result and
/// can be committed without running into version conflicts.
///
/// Make sure `f` is free of any side effects, because it can be called repeatedly.
pub async fn atomically<F, T>(f: F) -> T
where
    F: Fn() -> StmResult<T>,
{
    atomically_aux(|| NoAux, |_| f()).await
}

/// Like `atomically`, but this version also takes an auxiliary transaction system
/// that gets committed or rolled back together with the STM transaction.
pub async fn atomically_aux<F, T, A, X>(aux: A, f: F) -> T
where
    X: Aux,
    A: Fn() -> X,
    F: Fn(&mut X) -> StmResult<T>,
{
    atomically_or_err_aux(aux, |atx| f(atx).map_err(StmDynError::Control))
        .await
        .expect("Didn't expect `abort`. Use `atomically_or_err` instead.")
}

/// Create a new transaction and run `f` until it returns a successful result and
/// can be committed without running into version conflicts, or until it returns
/// an `Abort` in which case the contained error is returned.
///
/// Make sure `f` is free of any side effects, becuase it can be called repeatedly
/// and also be aborted.
pub async fn atomically_or_err<F, T>(f: F) -> Result<T, Box<dyn Error + Send + Sync>>
where
    F: Fn() -> StmDynResult<T>,
{
    atomically_or_err_aux(|| NoAux, |_| f()).await
}

/// Like `atomically_or_err`, but this version also takes an auxiliary transaction system.
///
/// Aux is passed explicitly to the closure as it's more important to see which methods
/// use it and which don't, because it is ultimately an external dependency that needs to
/// be carefully managed.
///
/// For example the method might need only read-only access, in which case a more
/// permissive transaction can be constructed, than if we need write access to arbitrary
/// data managed by that system.
pub async fn atomically_or_err_aux<F, T, A, X>(
    aux: A,
    f: F,
) -> Result<T, Box<dyn Error + Send + Sync>>
where
    X: Aux,
    A: Fn() -> X,
    F: Fn(&mut X) -> StmDynResult<T>,
{
    loop {
        // Install a new transaction into the thread local context.
        TX.with(|tref| {
            let mut t = tref.borrow_mut();
            if t.is_some() {
                // Nesting is not supported. Use `or` instead.
                panic!("Already running in an atomic transaction!")
            }
            *t = Some(Transaction::new());
        });

        // Create a new auxiliary transaction.
        let mut atx = aux();

        // Run one attempt of the atomic operation.
        let result = f(&mut atx);

        // Take the transaction from the thread local, leaving it empty.
        let tx = TX.with(|tref| tref.borrow_mut().take().unwrap());

        // See if we manage to commit some value.
        if let Some(value) = {
            match result {
                Ok(value) => {
                    if let Some(version) = tx.commit(atx) {
                        tx.notify(version);
                        Some(Ok(value))
                    } else {
                        None
                    }
                }
                Err(err) => {
                    atx.rollback();
                    match err {
                        StmDynError::Control(StmError::Failure) => {
                            // We can retry straight away.
                            None
                        }
                        StmDynError::Control(StmError::Retry) => {
                            // Wait until there's a change.
                            tx.wait().await;
                            None
                        }
                        StmDynError::Abort(e) => {
                            // Don't retry, return the error to the caller.
                            Some(Err(e))
                        }
                    }
                }
            }
        } {
            return value;
        }
    }
}