1use mlua::{Lua, Result, Table, Value};
28
29pub const LSHAPE_SOURCES: &[(&str, &str)] = &[
35 ("lshape", include_str!("../lua/lshape/init.lua")),
36 ("lshape.t", include_str!("../lua/lshape/t.lua")),
37 ("lshape.check", include_str!("../lua/lshape/check.lua")),
38 ("lshape.reflect", include_str!("../lua/lshape/reflect.lua")),
39 ("lshape.luacats", include_str!("../lua/lshape/luacats.lua")),
40];
41
42pub fn install(lua: &Lua) -> Result<()> {
49 let package: Table = lua.globals().get("package")?;
50 let preload: Table = package.get("preload")?;
51
52 for (name, src) in LSHAPE_SOURCES {
53 let name_owned = (*name).to_owned();
54 let src_owned = (*src).to_owned();
55 let loader = lua.create_function(move |lua, ()| -> Result<Value> {
56 let chunk = lua
57 .load(&src_owned)
58 .set_name(&format!("@{}", name_owned));
59 chunk.eval::<Value>()
60 })?;
61 preload.set(*name, loader)?;
62 }
63
64 Ok(())
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
72 fn install_and_require_lshape() {
73 let lua = Lua::new();
74 install(&lua).unwrap();
75 lua.load(
76 r#"
77 local lshape = require("lshape")
78 assert(type(lshape) == "table")
79 assert(type(lshape.t) == "table")
80 assert(type(lshape.check) == "table")
81 assert(type(lshape.reflect) == "table")
82 assert(type(lshape.luacats) == "table")
83 "#,
84 )
85 .exec()
86 .unwrap();
87 }
88
89 #[test]
90 fn check_validates_shape() {
91 let lua = Lua::new();
92 install(&lua).unwrap();
93 lua.load(
94 r#"
95 local lshape = require("lshape")
96 local T = lshape.t
97 local Voted = T.shape({ answer = T.string })
98 local ok, _ = lshape.check.check({ answer = "42" }, Voted)
99 assert(ok)
100 local ok2, why = lshape.check.check({ answer = 42 }, Voted)
101 assert(not ok2)
102 assert(why:find("shape violation"))
103 "#,
104 )
105 .exec()
106 .unwrap();
107 }
108
109 #[test]
110 fn vendored_version_matches() {
111 let lua = Lua::new();
112 install(&lua).unwrap();
113 lua.load(
114 r#"
115 local lshape = require("lshape")
116 assert(lshape._VERSION == "0.1.0",
117 "expected lshape._VERSION == '0.1.0', got " .. tostring(lshape._VERSION))
118
119 local T = lshape.t
120 -- v0.1.0 surface smoke: any_of, pattern, partial, literal
121 local U = T.any_of({ T.string, T.number })
122 assert(lshape.check.check("x", U))
123 assert(lshape.check.check(1, U))
124 assert(not (lshape.check.check(true, U)))
125
126 local Slug = T.pattern("^[a-z]+$")
127 assert(lshape.check.check("abc", Slug))
128 assert(not (lshape.check.check("ABC", Slug)))
129
130 local P = T.partial({ a = T.string, b = T.number })
131 assert(lshape.check.check({}, P))
132 assert(lshape.check.check({ a = "x" }, P))
133
134 local L = T.literal("yes")
135 assert(lshape.check.check("yes", L))
136 assert(not (lshape.check.check("no", L)))
137 "#,
138 )
139 .exec()
140 .unwrap();
141 }
142}