local reflect = require("lshape.reflect")
local M = {}
local function pascal_case(name)
local out = {}
local start = 1
while true do
local us = name:find("[_-]", start)
local token
if us then
token = name:sub(start, us - 1)
start = us + 1
else
token = name:sub(start)
end
if #token > 0 then
out[#out + 1] = token:sub(1, 1):upper() .. token:sub(2)
end
if not us then break end
end
return table.concat(out)
end
local type_of
local function has_top_level_pipe(s)
local depth = 0
for i = 1, #s do
local c = s:sub(i, i)
if c == "{" or c == "(" then
depth = depth + 1
elseif c == "}" or c == ")" then
depth = depth - 1
elseif c == "|" and depth == 0 then
return true
end
end
return false
end
local function inline_shape_type(node, class_prefix)
local entries = reflect.fields(node)
if #entries == 0 then return "table" end
local parts = {}
for i = 1, #entries do
local e = entries[i]
local suffix = e.optional and "?" or ""
parts[i] = string.format("%s%s: %s", e.name, suffix, type_of(e.type, class_prefix))
end
return "{ " .. table.concat(parts, ", ") .. " }"
end
type_of = function(node, class_prefix)
class_prefix = class_prefix or ""
local kind = rawget(node, "kind")
if kind == "prim" then
return node.prim
elseif kind == "any" then
return "any"
elseif kind == "optional" then
return type_of(rawget(node, "inner"), class_prefix)
elseif kind == "described" then
return type_of(rawget(node, "inner"), class_prefix)
elseif kind == "array_of" then
local elem = rawget(node, "elem")
local had_optional = false
while true do
local k = rawget(elem, "kind")
if k == "optional" then
had_optional = true
elem = rawget(elem, "inner")
elseif k == "described" then
elem = rawget(elem, "inner")
else
break
end
end
local inner_type = type_of(elem, class_prefix)
if had_optional then
return "(" .. inner_type .. "|nil)[]"
end
if has_top_level_pipe(inner_type) then
return "(" .. inner_type .. ")[]"
end
return inner_type .. "[]"
elseif kind == "discriminated" then
local variants = rawget(node, "variants")
local keys = {}
for k in pairs(variants) do keys[#keys + 1] = k end
table.sort(keys)
local parts = {}
for i = 1, #keys do
parts[i] = inline_shape_type(variants[keys[i]], class_prefix)
end
return table.concat(parts, "|")
elseif kind == "map_of" then
local k = type_of(rawget(node, "key"), class_prefix)
local v = type_of(rawget(node, "val"), class_prefix)
return "table<" .. k .. ", " .. v .. ">"
elseif kind == "shape" then
return inline_shape_type(node, class_prefix)
elseif kind == "ref" then
return class_prefix .. pascal_case(rawget(node, "name"))
elseif kind == "any_of" then
local variants = rawget(node, "variants")
local parts = {}
for i = 1, #variants do
parts[i] = type_of(variants[i], class_prefix)
end
return table.concat(parts, "|")
elseif kind == "pattern" then
return "string"
elseif kind == "one_of" then
local vs = rawget(node, "values")
local parts = {}
for i = 1, #vs do
local v = vs[i]
if type(v) == "string" then
parts[i] = string.format("%q", v)
else
parts[i] = tostring(v)
end
end
return table.concat(parts, "|")
else
error("lshape.luacats: unknown kind '" .. tostring(kind) .. "'", 2)
end
end
function M.class_for(class_name, schema, class_prefix)
if type(class_name) ~= "string" or class_name == "" then
error("lshape.luacats.class_for: class_name must be non-empty string", 2)
end
if type(schema) ~= "table" or rawget(schema, "kind") ~= "shape" then
error("lshape.luacats.class_for: schema must be kind='shape'", 2)
end
class_prefix = class_prefix or ""
local lines = { "---@class " .. class_name }
local entries = reflect.fields(schema)
for i = 1, #entries do
local e = entries[i]
local suffix = e.optional and "?" or ""
local ty = type_of(e.type, class_prefix)
local line = string.format("---@field %s%s %s", e.name, suffix, ty)
if e.doc then
line = line .. " @" .. e.doc
end
lines[#lines + 1] = line
end
return table.concat(lines, "\n") .. "\n"
end
function M.gen(shapes_table, class_prefix)
if shapes_table ~= nil and type(shapes_table) ~= "table" then
error("lshape.luacats.gen: shapes_table must be table or nil", 2)
end
class_prefix = class_prefix or ""
local names = {}
if shapes_table ~= nil then
for name, schema in pairs(shapes_table) do
if type(schema) == "table" and rawget(schema, "kind") == "shape" then
names[#names + 1] = name
end
end
end
table.sort(names)
local out = { "---@meta" }
for i = 1, #names do
local name = names[i]
local class_name = class_prefix .. pascal_case(name)
out[#out + 1] = ""
out[#out + 1] = M.class_for(class_name, shapes_table[name], class_prefix):gsub("\n$", "")
end
return table.concat(out, "\n") .. "\n"
end
M._internal = {
pascal_case = pascal_case,
type_of = type_of,
inline_shape_type = inline_shape_type,
}
return M