Skip to main content

mlua_lspec/
doubles.rs

1//! Spy / stub / mock test doubles implemented as Lua `UserData`.
2//!
3//! Provides [`register`] to inject a `test_doubles` global table into
4//! a Lua VM with the following factory functions:
5//!
6//! ```lua
7//! -- spy: transparently records calls while delegating to the original
8//! local s = test_doubles.spy(function(x) return x * 2 end)
9//! s(5)
10//! assert(s:call_count() == 1)
11//! assert(s:was_called_with(5))
12//!
13//! -- stub: returns fixed values without calling any original
14//! local st = test_doubles.stub()
15//! st:returns(42)
16//! assert(st() == 42)
17//!
18//! -- spy_on: replaces a table method with a spy, supports revert()
19//! local obj = { greet = function(name) return "hello " .. name end }
20//! local s = test_doubles.spy_on(obj, "greet")
21//! obj.greet("world")
22//! assert(s:call_count() == 1)
23//! s:revert()  -- restore the original method
24//! ```
25
26use std::sync::{Arc, Mutex};
27
28use mlua::prelude::*;
29
30// ── Value comparison ────────────────────────────────────────────
31
32/// Compare two Lua values for equality (primitive types only).
33///
34/// Uses exact comparison semantics consistent with Lua 5.4's `==`
35/// operator and the lust framework's `eq()` helper (which defaults
36/// to `eps = 0`).
37///
38/// For Integer↔Number cross-type comparison, a round-trip check
39/// guards against precision loss when `i64` values exceed 2^53
40/// (the limit of exact integer representation in `f64`).
41///
42/// Tables, functions, and userdata are compared by identity
43/// (always `false` here — use `call_args()` for complex assertions).
44fn values_match(a: &LuaValue, b: &LuaValue) -> bool {
45    use LuaValue::*;
46    match (a, b) {
47        (Nil, Nil) => true,
48        (Boolean(a), Boolean(b)) => a == b,
49        (Integer(a), Integer(b)) => a == b,
50        (Number(a), Number(b)) => a == b,
51        (Integer(a), Number(b)) => {
52            let f = *a as f64;
53            f == *b && f as i64 == *a
54        }
55        (Number(a), Integer(b)) => {
56            let f = *b as f64;
57            *a == f && *a as i64 == *b
58        }
59        (String(a), String(b)) => a.as_bytes() == b.as_bytes(),
60        _ => false,
61    }
62}
63
64// ── Shared state ────────────────────────────────────────────────
65
66struct RevertInfo {
67    target: LuaTable,
68    key: String,
69    original: LuaFunction,
70}
71
72struct DoubleState {
73    /// Recorded calls — each entry is the argument list of one invocation.
74    calls: Vec<Vec<LuaValue>>,
75    /// When set, `__call` returns these values instead of calling through.
76    return_values: Option<Vec<LuaValue>>,
77    /// The original function (for call-through spies).
78    original: Option<LuaFunction>,
79    /// Whether `__call` should delegate to `original`.
80    call_through: bool,
81    /// Present only for `spy_on` — stores what to restore on `revert()`.
82    revert_info: Option<RevertInfo>,
83}
84
85// ── UserData ────────────────────────────────────────────────────
86
87/// A test double (spy or stub) exposed to Lua as `UserData`.
88///
89/// Created via the `test_doubles.spy()`, `test_doubles.stub()`, or
90/// `test_doubles.spy_on()` factory functions.
91pub(crate) struct LuaDouble {
92    state: Arc<Mutex<DoubleState>>,
93}
94
95impl LuaDouble {
96    fn new_spy(original: Option<LuaFunction>) -> Self {
97        Self {
98            state: Arc::new(Mutex::new(DoubleState {
99                calls: Vec::new(),
100                return_values: None,
101                original,
102                call_through: true,
103                revert_info: None,
104            })),
105        }
106    }
107
108    fn new_stub() -> Self {
109        Self {
110            state: Arc::new(Mutex::new(DoubleState {
111                calls: Vec::new(),
112                return_values: None,
113                original: None,
114                call_through: false,
115                revert_info: None,
116            })),
117        }
118    }
119
120    fn new_table_spy(original: LuaFunction, target: LuaTable, key: String) -> Self {
121        let call_fn = original.clone();
122        Self {
123            state: Arc::new(Mutex::new(DoubleState {
124                calls: Vec::new(),
125                return_values: None,
126                original: Some(call_fn),
127                call_through: true,
128                revert_info: Some(RevertInfo {
129                    target,
130                    key,
131                    original,
132                }),
133            })),
134        }
135    }
136
137    fn lock(&self) -> LuaResult<std::sync::MutexGuard<'_, DoubleState>> {
138        self.state
139            .lock()
140            .map_err(|e| LuaError::runtime(format!("spy state poisoned: {e}")))
141    }
142}
143
144impl LuaUserData for LuaDouble {
145    fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
146        // ── __call: record args, then delegate or stub ──────────
147
148        methods.add_meta_method(LuaMetaMethod::Call, |_, this, args: LuaMultiValue| {
149            let args_vec: Vec<LuaValue> = args.into_vec();
150
151            let mut state = this.lock()?;
152            state.calls.push(args_vec.clone());
153
154            // Stub: return fixed values.
155            if let Some(ref rv) = state.return_values {
156                return Ok(LuaMultiValue::from_vec(rv.clone()));
157            }
158
159            // Call-through: delegate to original.
160            if state.call_through {
161                if let Some(ref original) = state.original {
162                    let f = original.clone();
163                    // Drop the lock before calling into Lua to avoid
164                    // deadlock if the called function re-enters the spy.
165                    drop(state);
166                    return f.call(LuaMultiValue::from_vec(args_vec));
167                }
168            }
169
170            // No original and no stub values — return nothing.
171            Ok(LuaMultiValue::new())
172        });
173
174        // ── __len: call count (lust-compatible #spy) ────────────
175
176        methods.add_meta_method(LuaMetaMethod::Len, |_, this, ()| {
177            Ok(this.lock()?.calls.len())
178        });
179
180        // ── Inspection methods ──────────────────────────────────
181
182        methods.add_method("call_count", |_, this, ()| Ok(this.lock()?.calls.len()));
183
184        methods.add_method("call_args", |lua, this, n: usize| {
185            let state = this.lock()?;
186            let idx = n
187                .checked_sub(1)
188                .ok_or_else(|| LuaError::runtime("call index must be >= 1"))?;
189            let call = state
190                .calls
191                .get(idx)
192                .ok_or_else(|| LuaError::runtime(format!("no call at index {n}")))?;
193            let table = lua.create_table()?;
194            for (i, arg) in call.iter().enumerate() {
195                table.set(i + 1, arg.clone())?;
196            }
197            Ok(table)
198        });
199
200        methods.add_method("was_called_with", |_, this, args: LuaMultiValue| {
201            let expected: Vec<LuaValue> = args.into_vec();
202            let state = this.lock()?;
203            for call in &state.calls {
204                if call.len() == expected.len()
205                    && call
206                        .iter()
207                        .zip(expected.iter())
208                        .all(|(a, b)| values_match(a, b))
209                {
210                    return Ok(true);
211                }
212            }
213            Ok(false)
214        });
215
216        // ── Mutation methods ────────────────────────────────────
217
218        methods.add_method("returns", |_, this, args: LuaMultiValue| {
219            let mut state = this.lock()?;
220            state.return_values = Some(args.into_vec());
221            state.call_through = false;
222            Ok(())
223        });
224
225        // Clear recorded call history.
226        //
227        // Only resets `calls`; `return_values` set via `returns()` are
228        // preserved.  This mirrors common test-double conventions
229        // (e.g. Sinon.js `resetHistory`) where resetting history and
230        // resetting behaviour are separate operations.
231        methods.add_method("reset", |_, this, ()| {
232            let mut state = this.lock()?;
233            state.calls.clear();
234            Ok(())
235        });
236
237        methods.add_method("revert", |_, this, ()| {
238            let state = this.lock()?;
239            if let Some(ref info) = state.revert_info {
240                info.target.set(info.key.as_str(), info.original.clone())?;
241            }
242            Ok(())
243        });
244    }
245}
246
247// ── Registration ────────────────────────────────────────────────
248
249/// Register the `test_doubles` global table into the given Lua VM.
250///
251/// Provides `test_doubles.spy(fn)`, `test_doubles.stub()`, and
252/// `test_doubles.spy_on(table, key)`.
253pub fn register(lua: &Lua) -> LuaResult<()> {
254    let doubles = lua.create_table()?;
255
256    // test_doubles.spy(fn?) → LuaDouble
257    doubles.set(
258        "spy",
259        lua.create_function(|_, func: Option<LuaFunction>| Ok(LuaDouble::new_spy(func)))?,
260    )?;
261
262    // test_doubles.stub() → LuaDouble
263    doubles.set(
264        "stub",
265        lua.create_function(|_, ()| Ok(LuaDouble::new_stub()))?,
266    )?;
267
268    // test_doubles.spy_on(table, key) → LuaDouble
269    //
270    // Replaces `table[key]` with a spy that calls through to the
271    // original.  Call `spy:revert()` to restore.
272    doubles.set(
273        "spy_on",
274        lua.create_function(|lua, (table, key): (LuaTable, String)| {
275            let original: LuaFunction = table.get(key.as_str())?;
276            let spy = LuaDouble::new_table_spy(original, table.clone(), key.clone());
277            let ud = lua.create_userdata(spy)?;
278            table.set(key.as_str(), ud.clone())?;
279            Ok(ud)
280        })?,
281    )?;
282
283    lua.globals().set("test_doubles", doubles)?;
284    Ok(())
285}