Skip to main content

mlua_lspec/
framework.rs

1use mlua::prelude::*;
2
3use crate::types::{TestResult, TestSummary};
4
5const LUST_LUA: &str = include_str!("../lua/lust.lua");
6
7/// Register the lust test framework into the given Lua VM.
8///
9/// After this call, `lust` is available as a global table with
10/// `describe`, `it`, `expect`, `before`, `after`, `spy`, and
11/// `get_results`.
12pub fn register(lua: &Lua) -> LuaResult<()> {
13    let lust: LuaTable = lua.load(LUST_LUA).set_name("lust.lua").eval()?;
14    lua.globals().set("lust", lust)?;
15    Ok(())
16}
17
18/// Collect structured test results from the lust framework.
19///
20/// Call this after all `describe`/`it` blocks have executed.
21pub fn collect_results(lua: &Lua) -> LuaResult<TestSummary> {
22    let lust: LuaTable = lua.globals().get("lust")?;
23    let get_results: LuaFunction = lust.get("get_results")?;
24    let results: LuaTable = get_results.call(())?;
25
26    let passed: usize = results.get("passed")?;
27    let failed: usize = results.get("failed")?;
28    let total: usize = results.get("total")?;
29
30    let tests_table: LuaTable = results.get("tests")?;
31    let mut tests = Vec::with_capacity(total);
32
33    for pair in tests_table.pairs::<usize, LuaTable>() {
34        let (_, t) = pair?;
35        tests.push(TestResult {
36            suite: t.get::<String>("suite")?,
37            name: t.get::<String>("name")?,
38            passed: t.get::<bool>("passed")?,
39            error: t.get::<Option<String>>("error")?,
40        });
41    }
42
43    Ok(TestSummary {
44        passed,
45        failed,
46        total,
47        tests,
48    })
49}
50
51/// Prepend entries to Lua's `package.path` so that `require` can find
52/// modules in project-specific directories.
53///
54/// For each directory in `search_paths`, two patterns are added:
55/// `<dir>/?.lua` and `<dir>/?/init.lua`.
56fn prepend_search_paths(lua: &Lua, search_paths: &[&str]) -> Result<(), String> {
57    if search_paths.is_empty() {
58        return Ok(());
59    }
60    let package: LuaTable = lua
61        .globals()
62        .get("package")
63        .map_err(|e| format!("Failed to get package table: {e}"))?;
64    let current: String = package
65        .get("path")
66        .map_err(|e| format!("Failed to get package.path: {e}"))?;
67
68    let mut prefix = String::new();
69    for dir in search_paths {
70        let dir = dir.trim_end_matches('/');
71        prefix.push_str(dir);
72        prefix.push_str("/?.lua;");
73        prefix.push_str(dir);
74        prefix.push_str("/?/init.lua;");
75    }
76    prefix.push_str(&current);
77
78    package
79        .set("path", prefix)
80        .map_err(|e| format!("Failed to set package.path: {e}"))?;
81    Ok(())
82}
83
84/// Run Lua test code with the lust framework pre-loaded.
85///
86/// Creates a fresh Lua VM, registers lust and test doubles,
87/// executes `code`, and returns the structured test summary.
88///
89/// `search_paths` is prepended to `package.path` so that
90/// `require` can resolve project-specific modules.  Pass `&[]`
91/// when no extra paths are needed.
92///
93/// Lua's `print` is replaced with a no-op to prevent stdout
94/// pollution.  Callers who need console output should use
95/// [`register`] on their own `Lua` instance where `print`
96/// remains intact.
97pub fn run_tests(
98    code: &str,
99    chunk_name: &str,
100    search_paths: &[&str],
101) -> Result<TestSummary, String> {
102    let lua = Lua::new();
103
104    register(&lua).map_err(|e| format!("Failed to register test framework: {e}"))?;
105    crate::doubles::register(&lua).map_err(|e| format!("Failed to register test doubles: {e}"))?;
106
107    prepend_search_paths(&lua, search_paths)?;
108
109    // Suppress lust's print() output so that callers using stdio
110    // transports (e.g. MCP servers) are not polluted.
111    lua.globals()
112        .set(
113            "print",
114            lua.create_function(|_, _: mlua::Variadic<LuaValue>| Ok(()))
115                .map_err(|e| format!("Failed to override print: {e}"))?,
116        )
117        .map_err(|e| format!("Failed to override print: {e}"))?;
118
119    lua.load(code)
120        .set_name(chunk_name)
121        .exec()
122        .map_err(|e| format!("Test execution error: {e}"))?;
123
124    collect_results(&lua).map_err(|e| format!("Failed to collect results: {e}"))
125}