local M = {}
M.default_registry = nil
local function resolve_default_registry()
local reg = M.default_registry
if reg == nil then
error(
"lshape.check: default registry not bound. "
.. "Set M.default_registry = <{name → schema}> at host "
.. "load time, or pass opts.registry explicitly at the "
.. "call site.", 3)
end
if type(reg) ~= "table" then
error(
"lshape.check: M.default_registry must be a plain table "
.. "of {name = schema} (closures NG); got " .. type(reg), 3)
end
return reg
end
local function lua_type_of(v) return type(v) end
local check_node
local handlers = {}
handlers.any = function(_value, _schema, _path)
return true
end
handlers.prim = function(value, schema, path)
local expected = schema.prim
local got = lua_type_of(value)
if got ~= expected then
return false, string.format(
"shape violation at %s: expected %s, got %s",
path, expected, got)
end
return true
end
handlers.optional = function(value, schema, path, ctx)
if value == nil then return true end
return check_node(value, schema.inner, path, ctx)
end
handlers.described = function(value, schema, path, ctx)
return check_node(value, schema.inner, path, ctx)
end
handlers.array_of = function(value, schema, path, ctx)
if type(value) ~= "table" then
return false, string.format(
"shape violation at %s: expected table (array), got %s",
path, type(value))
end
for i = 1, #value do
local item = value[i]
local sub_path = path .. "[" .. i .. "]"
local ok, reason = check_node(item, schema.elem, sub_path, ctx)
if not ok then return false, reason end
end
return true
end
handlers.shape = function(value, schema, path, ctx)
if type(value) ~= "table" then
return false, string.format(
"shape violation at %s: expected table, got %s",
path, type(value))
end
local fields = schema.fields
local names = {}
for name in pairs(fields) do names[#names + 1] = name end
table.sort(names)
for i = 1, #names do
local name = names[i]
local sub_schema = fields[name]
local sub_path = (path == "$") and ("$." .. name) or (path .. "." .. name)
local sub_val = value[name]
local ok, reason = check_node(sub_val, sub_schema, sub_path, ctx)
if not ok then return false, reason end
end
if schema.open == false then
local extra = {}
for name in pairs(value) do
if type(name) == "string" and fields[name] == nil then
extra[#extra + 1] = name
end
end
table.sort(extra)
if extra[1] ~= nil then
local name = extra[1]
local sub_path = (path == "$") and ("$." .. name) or (path .. "." .. name)
return false, string.format(
"shape violation at %s: unexpected field", sub_path)
end
end
return true
end
handlers.discriminated = function(value, schema, path, ctx)
if type(value) ~= "table" then
return false, string.format(
"shape violation at %s: expected table, got %s",
path, type(value))
end
local tag = schema.tag
local tag_val = value[tag]
if tag_val == nil then
return false, string.format(
"shape violation at %s: missing discriminant field '%s'",
path, tag)
end
local variant = schema.variants[tag_val]
if variant == nil then
local keys = {}
for k in pairs(schema.variants) do keys[#keys + 1] = k end
table.sort(keys)
local parts = {}
for i = 1, #keys do parts[i] = string.format("%q", keys[i]) end
return false, string.format(
"shape violation at %s: discriminant '%s' = %q not in [%s]",
path, tag, tostring(tag_val), table.concat(parts, ", "))
end
return handlers.shape(value, variant, path, ctx)
end
handlers.any_of = function(value, schema, path, ctx)
local variants = schema.variants
local reasons = {}
for i = 1, #variants do
local ok, reason = check_node(value, variants[i], path, ctx)
if ok then return true end
reasons[i] = string.format(" variant %d: %s", i, reason or "(no reason)")
end
return false, string.format(
"shape violation at %s: no variant matched\n%s",
path, table.concat(reasons, "\n"))
end
handlers.map_of = function(value, schema, path, ctx)
if type(value) ~= "table" then
return false, string.format(
"shape violation at %s: expected table (map), got %s",
path, type(value))
end
for k, v in pairs(value) do
local key_path = path .. "[key=" .. tostring(k) .. "]"
local ok, reason = check_node(k, schema.key, key_path, ctx)
if not ok then return false, reason end
local val_path = path .. "[" .. tostring(k) .. "]"
ok, reason = check_node(v, schema.val, val_path, ctx)
if not ok then return false, reason end
end
return true
end
handlers.ref = function(value, schema, path, ctx)
local name = schema.name
local registry = ctx.registry
if registry == nil then
registry = resolve_default_registry()
ctx.registry = registry
end
local resolved = rawget(registry, name)
if resolved == nil or type(resolved) ~= "table"
or rawget(resolved, "kind") == nil then
return false, string.format(
"shape violation at %s: unresolved ref '%s'", path, name)
end
return check_node(value, resolved, path, ctx)
end
handlers.pattern = function(value, schema, path, _ctx)
if type(value) ~= "string" then
return false, string.format(
"shape violation at %s: expected string (pattern), got %s",
path, type(value))
end
if value:match(schema.pattern) == nil then
return false, string.format(
"shape violation at %s: string %q does not match pattern %q",
path, value, schema.pattern)
end
return true
end
handlers.one_of = function(value, schema, path, _ctx)
local vs = schema.values
for i = 1, #vs do
if value == vs[i] then return true end
end
local parts = {}
for i = 1, #vs do
local v = vs[i]
if type(v) == "string" then
parts[i] = string.format("%q", v)
else
parts[i] = tostring(v)
end
end
return false, string.format(
"shape violation at %s: expected one of [%s], got %s",
path, table.concat(parts, ", "), tostring(value))
end
local MAX_CHECK_DEPTH = 256
check_node = function(value, schema, path, ctx)
if schema == nil then return true end
ctx.depth = (ctx.depth or 0) + 1
if ctx.depth > MAX_CHECK_DEPTH then
error(string.format(
"lshape.check: recursion depth exceeded at %s " ..
"(> %d; cycle in schema or value?)",
path, MAX_CHECK_DEPTH), 2)
end
local kind = rawget(schema, "kind")
if kind == nil then
error("lshape.check: schema missing 'kind' field", 2)
end
local h = handlers[kind]
if h == nil then
error("lshape.check: unknown kind '" .. tostring(kind) .. "'", 2)
end
local ok, reason = h(value, schema, path, ctx)
ctx.depth = ctx.depth - 1
return ok, reason
end
local function build_ctx(opts)
if opts == nil then
return { registry = nil }
end
if type(opts) ~= "table" then
error("lshape.check: opts must be a table or nil (got "
.. type(opts) .. ")", 3)
end
local registry = opts.registry
if registry ~= nil and type(registry) ~= "table" then
error("lshape.check: opts.registry must be a plain table "
.. "of {name = schema} (closures NG); got " .. type(registry), 3)
end
return { registry = registry }
end
function M.check(value, schema, opts)
if schema == nil then return true end
local ctx = build_ctx(opts)
return check_node(value, schema, "$", ctx)
end
local function compose_msg(reason, ctx_hint)
if ctx_hint == nil or ctx_hint == "" then return reason end
return reason .. " (ctx: " .. tostring(ctx_hint) .. ")"
end
function M.assert(value, schema_or_name, ctx_hint, opts)
local schema
if schema_or_name == nil then
error("lshape.assert: schema_or_name must not be nil. "
.. "Use S.check(v, nil) for silent pass, or pass a schema / "
.. "name / \"any\" explicitly.", 2)
elseif type(schema_or_name) == "string" then
if schema_or_name == "any" then return value end
local registry = (opts and opts.registry) or resolve_default_registry()
if type(registry) ~= "table" then
error("lshape.assert: opts.registry must be a plain table "
.. "of {name = schema} (closures NG); got " .. type(registry), 2)
end
schema = rawget(registry, schema_or_name)
if schema == nil or type(schema) ~= "table" or rawget(schema, "kind") == nil then
error("lshape.assert: unknown shape name '" .. schema_or_name .. "'", 2)
end
elseif type(schema_or_name) == "table" then
schema = schema_or_name
else
error("lshape.assert: schema_or_name must be nil, string, or table (got "
.. type(schema_or_name) .. ")", 2)
end
local ok, reason = M.check(value, schema, opts)
if not ok then
error(compose_msg(reason, ctx_hint), 2)
end
return value
end
function M.assert_dev(value, schema_or_name, ctx_hint, opts)
if not M.is_dev_mode() then return value end
return M.assert(value, schema_or_name, ctx_hint, opts)
end
M.dev_env_var = "LSHAPE_CHECK"
function M.is_dev_mode()
return os.getenv(M.dev_env_var) == "1"
end
M._internal = {
handlers = handlers,
check_node = check_node,
compose_msg = compose_msg,
build_ctx = build_ctx,
resolve_default_registry = resolve_default_registry,
}
return M