local M = {}
local _stack = {}
local function _snapshot(keys)
local frame = {}
for _, k in ipairs(keys) do
frame[#frame + 1] = { key = k, prev = alc[k], had = (alc[k] ~= nil) }
end
return frame
end
local function _apply(overrides)
local keys = {}
for k, _ in pairs(overrides) do
keys[#keys + 1] = k
end
local frame = _snapshot(keys)
for k, v in pairs(overrides) do
alc[k] = v
end
return frame
end
local function _restore_frame(frame)
for _, entry in ipairs(frame) do
if entry.had then
alc[entry.key] = entry.prev
else
alc[entry.key] = nil
end
end
end
function with_alc(overrides, fn)
if type(overrides) ~= "table" then
error("with_alc: overrides must be a table", 2)
end
if type(fn) ~= "function" then
error("with_alc: fn must be a function", 2)
end
local frame = _apply(overrides)
local ok, result = pcall(fn)
_restore_frame(frame)
if not ok then
error(result, 0)
end
return result
end
function M.install(overrides)
if type(overrides) ~= "table" then
error("alc_mock.install: overrides must be a table", 2)
end
local frame = _apply(overrides)
_stack[#_stack + 1] = frame
end
function M.restore()
local n = #_stack
if n == 0 then
return
end
local frame = _stack[n]
_stack[n] = nil
_restore_frame(frame)
end
function M.restore_all()
while #_stack > 0 do
M.restore()
end
end
function alc.spy(name, default_fn)
if type(name) ~= "string" then
error("alc.spy: name must be a string", 2)
end
local impl = default_fn
if impl == nil then
impl = alc[name]
end
local handle = {
call_count = 0,
calls = {},
reset = function(self)
self.call_count = 0
self.calls = {}
end,
}
local wrapper = function(...)
handle.call_count = handle.call_count + 1
handle.calls[#handle.calls + 1] = { args = { ... } }
if type(impl) == "function" then
return impl(...)
end
return nil
end
local frame = _apply({ [name] = wrapper })
_stack[#_stack + 1] = frame
return handle
end
alc_mock = M
return M