local derivers = require "derivers"
local pretty_printer = require "pretty-printer"
local traits = require "traits"
local U = require "alicorn-utils"
local _ = require "lua-ext"
local math_floor, select, type = math.floor, select, type
local s = pretty_printer.s
local builtin_integer_value_check
do
local math_type = math.type
if math_type ~= nil then
function builtin_integer_value_check(val)
local val_type = math_type(val)
return val_type == "integer"
end
else
function builtin_integer_value_check(val)
local val_type = type(val)
return val_type == "number" and math_floor(val) == val
end
end
end
local function attempt_traceback(s)
if debug then
return debug.traceback(s)
else
return s
end
end
local function metatable_equality(mt)
if type(mt) ~= "table" then
error(
("trying to define metatable equality to something that isn't a metatable (possible typo?): %s"):format(
attempt_traceback(tostring(mt))
)
)
end
return function(val)
return getmetatable(val) == mt
end
end
local function parse_params_with_types(params_with_types)
local params = {}
local params_types = {}
local odd = true
local i = 1
for _, v in ipairs(params_with_types) do
if odd then
params[i] = v
else
params_types[i] = v
i = i + 1
end
odd = not odd
end
return params, params_types
end
local function validate_params_types(kind, params, params_types)
local at_least_one = false
local params_set = {}
for i, v in ipairs(params) do
at_least_one = true
local param_type = params_types[i]
if type(param_type) ~= "table" or type(param_type.value_check) ~= "function" then
error(
attempt_traceback(
("trying to set a parameter type to something that isn't a type, in constructor %s, parameter %q (possible typo?)"):format(
kind,
v
)
)
)
end
if params_set[v] then
error(("constructor %s must have unique parameter names (%q was given more than once)"):format(kind, v))
end
params_set[v] = true
end
if not at_least_one then
error(("constructor %s must take at least one parameter, or be changed to a unit"):format(kind))
end
end
local function gen_record(self, cons, kind, params_with_types)
local params, params_types = parse_params_with_types(params_with_types)
validate_params_types(kind, params, params_types)
local function build_record(...)
local args = table.pack(...)
local val = {
kind = kind,
_record = {},
}
for i, v in ipairs(params) do
local param = args[i]
local param_type = params_types[i]
if param_type.value_check(param) ~= true then
error(
attempt_traceback(
("wrong argument type passed to constructor %s, parameter %q\nexpected type of parameter %q is: %s\nvalue of parameter %q: (follows)\n%s"):format(
kind,
v,
v,
param_type,
v,
s(param)
)
)
)
end
val._record[v] = param
end
if false then
val["{TRACE}"] = attempt_traceback("", 2)
end
val["{ID}"] = U.debug_id()
setmetatable(val, self)
return val
end
build_record = U.memoize(build_record, false)
local function build_record_freeze_wrapper(...)
local args = { ... }
for i, v in ipairs(params) do
local argi = args[i]
local freeze_impl = traits.freeze:get(params_types[i])
if freeze_impl then
argi = freeze_impl.freeze(params_types[i], argi)
else
print(
("WARNING: while constructing %s, can't freeze param %s (type %s)"):format(
kind,
v,
tostring(params_types[i])
)
)
print("this may lead to suboptimal hash-consing")
end
args[i] = argi
end
return build_record(table.unpack(args, 1, #params))
end
setmetatable(cons, {
__call = function(_, ...)
return build_record_freeze_wrapper(...)
end,
})
local derive_info = {
kind = kind,
params = params,
params_types = params_types,
}
return derive_info
end
local function record_tostring(self)
return ("terms-gen record: %s"):format(self._kind)
end
local function define_record(self, kind, params_with_types)
local derive_info = gen_record(self, self, kind, params_with_types)
getmetatable(self).__tostring = record_tostring
self.value_check = metatable_equality(self)
function self:derive(deriver, ...)
return deriver.record(self, derive_info, ...)
end
self._kind = kind
self.__index = function(t, key)
local method = self.methods[key]
if method then
return method
end
if key == "{TRACE}" or key == "{ID}" then
return t._record[key]
end
if key ~= "name" then
error(attempt_traceback(("use unwrap instead for: %s"):format(key)))
end
if t._record[key] then
return t._record[key]
end
error(("Tried to access nonexistent key: %s"):format(key))
end
self.methods = {
pretty_preprint = pretty_printer.pretty_preprint,
pretty_print = pretty_printer.pretty_print,
default_print = pretty_printer.default_print,
}
self.__newindex = function()
error("records are immutable!")
end
traits.value_name:implement_on(self, {
value_name = function()
return kind
end,
})
self:derive(derivers.eq)
self:derive(derivers.unwrap)
self:derive(derivers.diff)
self:derive(derivers.freeze)
return self
end
local function declare_record(kind, params_with_types)
return define_record({}, kind, params_with_types)
end
local function gen_unit(self, kind)
local val = {
kind = kind,
}
local derive_info = {
kind = kind,
}
setmetatable(val, self)
return val, derive_info
end
local enum_type_mt = {
__tostring = function(self)
return ("terms-gen enum: %s"):format(self._name)
end,
}
local function define_enum(self, name, variants)
setmetatable(self, enum_type_mt)
self.value_check = metatable_equality(self)
local derive_variants = {}
for i, v in ipairs(variants) do
local vname = v[1]
local vparams_with_types = v[2]
local vkind = name .. "." .. vname
if self[vname] then
error(("enum variant %s is defined multiple times"):format(vkind))
end
derive_variants[i] = vname
if vparams_with_types then
local record_cons = {}
local record_info = gen_record(self, record_cons, vkind, vparams_with_types)
self[vname] = record_cons
derive_variants[vname] = {
type = derivers.EnumDeriveInfoVariantKind.Record,
info = record_info,
}
else
local unit_val, unit_info = gen_unit(self, vkind)
self[vname] = unit_val
derive_variants[vname] = {
type = derivers.EnumDeriveInfoVariantKind.Unit,
info = unit_info,
}
end
end
local derive_info = {
name = name,
variants = derive_variants,
}
function self:derive(deriver, ...)
return deriver.enum(self, derive_info, ...)
end
self._name = name
self.__index = function(t, key)
local method = self.methods[key]
if method then
return method
end
if key == "{TRACE}" or key == "{ID}" then
return t._record[key]
end
error(attempt_traceback(("use unwrap instead for: %s"):format(key)))
if t._record[key] then
return t._record[key]
end
error(("Tried to access nonexistent key: %s"):format(key))
end
self.methods = {
pretty_preprint = pretty_printer.pretty_preprint,
pretty_print = pretty_printer.pretty_print,
default_print = pretty_printer.default_print,
}
self.__newindex = function()
error("enums are immutable!")
end
traits.value_name:implement_on(self, {
value_name = function()
return name
end,
})
self:derive(derivers.eq)
self:derive(derivers.is)
self:derive(derivers.unwrap)
self:derive(derivers.as)
self:derive(derivers.diff)
self:derive(derivers.freeze)
return self
end
local function declare_enum(name, variants)
return define_enum({}, name, variants)
end
local function split_delim(s, delim)
local subs = {}
for sub in s:gmatch(("[^%s]+"):format(delim)) do
table.insert(subs, sub)
end
return subs
end
local function define_multi_enum(flex, flex_name, fn_replace, fn_specify, fn_unify, types, names, variants, fn_sub)
local keyed_variants = {}
local flex_variants = {}
for _, k, v in U.table_stable_pairs(types) do
keyed_variants[k] = {}
table.insert(flex_variants, { k, { k, v } })
end
local flex_tags = {}
for _, v in ipairs(variants) do
local vname, vtag = table.unpack(split_delim(v[1], "$"))
local vparams_with_types = v[2]
if vtag == nil then
error(("Missing tag on %s"):format(vname))
end
table.insert(flex_variants, { vname, vparams_with_types })
flex_tags[vname] = vtag
if vtag == "flex" then
for _, k, _ in U.table_stable_pairs(types) do
local fix_variants = {}
for i, ty in ipairs(vparams_with_types) do
if (i % 2) == 0 and fn_replace then
table.insert(fix_variants, fn_replace(k, ty))
else
table.insert(fix_variants, ty)
end
end
table.insert(keyed_variants[k], { vname, fix_variants })
end
else
if keyed_variants[vtag] == nil then
error(("Unknown tag: %s"):format(vtag))
end
table.insert(keyed_variants[vtag], { vname, vparams_with_types })
end
end
for _, k, v in U.table_stable_pairs(types) do
v:define_enum(names[k], keyed_variants[k])
end
if fn_sub then
fn_sub(types)
end
flex:define_enum(flex_name, flex_variants)
local unify_passthrough = function(ok, ...)
return ok, table.unpack(fn_unify(table.pack(...)))
end
for i, pair in ipairs(flex_variants) do
local k = pair[1]
if flex_tags[k] == "flex" then
local vkind = flex_name .. "." .. k
local params, params_types = parse_params_with_types(pair[2])
validate_params_types(vkind, params, params_types)
flex[k] = function(...)
local args = table.pack(...)
for i, v in ipairs(params) do
local param = args[i]
local param_type = params_types[i]
if param_type.value_check(param) ~= true then
error(
attempt_traceback(
("wrong argument type passed to constructor %s, parameter %q\nexpected type of parameter %q is: %s\nvalue of parameter %q: (follows)\n%s"):format(
param.kind,
v,
v,
param_type,
v,
s(param)
)
)
)
end
end
local tag, unified_args = fn_specify(args, params_types)
local subtype = types[tag]
local inner = subtype[k](table.unpack(unified_args))
return flex[tag](inner)
end
elseif flex_tags[k] ~= nil then
local tag = flex_tags[k]
local subtype = types[tag]
local inner = subtype[k]
if not pair[2] then
flex[k] = flex[tag](inner)
else
flex[k] = function(...)
return flex[tag](inner(...))
end
end
end
local derivers = { "is_", "unwrap_", "as_" }
for _, v in ipairs(derivers) do
local tag = flex_tags[k]
local key = v .. k
local unwrapper = {}
for _, k, v in U.table_stable_pairs(types) do
unwrapper[flex_name .. "." .. k] = flex.methods["unwrap_" .. k]
end
if tag == "flex" then
if v == "is_" then
flex.methods[key] = function(self, ...)
local inner = unwrapper[self.kind](self)
return inner[key](inner, ...)
end
elseif v == "unwrap_" then
flex.methods[key] = function(self, ...)
local inner = unwrapper[self.kind](self)
return table.unpack(fn_unify(table.pack(inner[key](inner, ...))))
end
elseif v == "as_" then
flex.methods[key] = function(self, ...)
local inner = unwrapper[self.kind](self)
return unify_passthrough(inner[key](inner, ...))
end
end
elseif tag ~= nil then
local base = flex.methods[key]
if not base then
error("Trying to override nonexistent function " .. key)
end
if v == "is_" or v == "as_" then
flex.methods[key] = function(self, ...)
local ok, inner = flex.methods["as_" .. tag](self)
if not ok then
return false
end
return inner[key](inner, ...)
end
elseif v == "unwrap_" then
flex.methods[key] = function(self, ...)
local inner = flex.methods[v .. tag](self)
return inner[key](inner, ...)
end
end
end
end
end
end
local foreign_type_mt = {
__tostring = function(self)
return ("terms-gen foreign: %s"):format(self.lsp_type)
end,
}
local function define_foreign(self, value_check, lsp_type)
setmetatable(self, foreign_type_mt)
self.value_check = value_check
self.lsp_type = lsp_type
traits.value_name:implement_on(self, {
value_name = function()
return lsp_type
end,
})
return self
end
local function declare_foreign(value_check, lsp_type)
return define_foreign({}, value_check, lsp_type)
end
local function map_set_value(self, key_type, value_type, map, key, value)
if key_type.value_check(key) ~= true then
p("map-set", key_type, value_type)
p(key)
error("wrong key type passed to map:set")
end
if value_type.value_check(value) ~= true then
p("map-set", key_type, value_type)
p(value)
error("wrong value type passed to map:set")
end
local freeze_impl_key = traits.freeze:get(key_type)
if freeze_impl_key then
key = freeze_impl_key.freeze(key_type, key)
else
print(("WARNING: while setting %s, can't freeze key (type %s)"):format(tostring(self), tostring(key_type)))
print("this may lead to suboptimal hash-consing")
end
local freeze_impl_value = traits.freeze:get(value_type)
if freeze_impl_value then
value = freeze_impl_value.freeze(value_type, value)
else
print(("WARNING: while setting %s, can't freeze value (type %s)"):format(tostring(self), tostring(value_type)))
print("this may lead to suboptimal hash-consing")
end
map[key] = value
end
local map_type_mt = {
__call = function(self, ...)
local map = {}
local val = {
_map = map,
is_frozen = false, }
setmetatable(val, self)
local args = table.pack(...)
for i = 1, args.n, 2 do
map_set_value(self, self.key_type, self.value_type, map, args[i], args[i + 1])
end
return val
end,
__eq = function(left, right)
return left.key_type == right.key_type and left.value_type == right.value_type
end,
__tostring = function(self)
return ("terms-gen map key:<%s> val:<%s>"):format(tostring(self.key_type), tostring(self.value_type))
end,
}
local function gen_map_methods(self, key_type, value_type)
return {
set = function(val, key, value)
if val.is_frozen then
error("trying to modify a frozen map")
end
map_set_value(self, key_type, value_type, val._map, key, value)
end,
reset = function(val, key)
if val.is_frozen then
error("trying to modify a frozen map")
end
if key_type.value_check(key) ~= true then
p("map-reset", key_type, value_type)
p(key)
error("wrong key type passed to map:reset")
end
val._map[key] = nil
end,
get = function(val, key)
if key_type.value_check(key) ~= true then
p("map-get", key_type, value_type)
p(key)
error("wrong key type passed to map:get")
end
return val._map[key]
end,
pairs = function(val)
return pairs(val._map)
end,
copy = function(val, onto, conflict)
if not onto then
local map = {}
for key, value in pairs(val._map) do
map[key] = value
end
return self:unchecked_new(map)
end
if not conflict then
error("map:copy onto requires a conflict resolution function")
end
local rt = getmetatable(onto)
if self ~= rt then
error("map:copy must be passed maps of the same type")
end
for k, v in val:pairs() do
local old = onto:get(k)
if old then
onto:set(k, conflict(old, v))
else
onto:set(k, v)
end
end
return onto
end,
union = function(left, right, conflict)
local rt = getmetatable(right)
if self ~= rt then
error("map:union must be passed maps of the same type")
end
local new = left:copy()
right:copy(new, conflict)
return new
end,
pretty_preprint = pretty_printer.pretty_preprint,
pretty_print = pretty_printer.pretty_print,
default_print = pretty_printer.default_print,
}
end
local function map_unchecked_new_fn(self, map)
return setmetatable({
_map = map,
is_frozen = false,
}, self)
end
local function map_new_fn(self, map)
local key_type, value_type = self.key_type, self.value_type
local new_map = {}
for key, value in pairs(map) do
map_set_value(self, key_type, value_type, new_map, key, value)
end
return setmetatable({
_map = new_map,
is_frozen = false,
}, self)
end
local function map_newindex()
error("index-assignment of maps is no longer allowed. use :set()")
end
local function map_pretty_print(self, pp, ...)
return pp:table(self._map, ...)
end
local function map_freeze_helper_2(t, ...)
local frozenval = t(...)
frozenval.is_frozen = true
return frozenval
end
map_freeze_helper_2 = U.memoize(map_freeze_helper_2, false)
local function map_freeze_helper(t, keys, map, ...)
if #keys > 0 then
local key = table.remove(keys)
local val = map[key]
return map_freeze_helper(t, keys, map, key, val, ...)
else
return map_freeze_helper_2(t, ...)
end
end
local function map_freeze(t, val)
if val.is_frozen then
return val
end
local order_impl = traits.order:get(t.key_type)
if not order_impl then
print(("WARNING: can't freeze %s"):format(tostring(t)))
return val
end
local keys = {}
for k in pairs(val._map) do
keys[#keys + 1] = k
end
table.sort(keys, order_impl.compare)
local frozen = map_freeze_helper(t, keys, val._map)
return frozen
end
local function define_map(self, key_type, value_type)
if
type(key_type) ~= "table"
or type(key_type.value_check) ~= "function"
or type(value_type) ~= "table"
or type(value_type.value_check) ~= "function"
then
error("trying to set the key or value type to something that isn't a type (possible typo?)")
end
setmetatable(self, map_type_mt)
self.unchecked_new = map_unchecked_new_fn
self.new = map_new_fn
self.value_check = metatable_equality(self)
self.key_type = key_type
self.value_type = value_type
self.__index = gen_map_methods(self, key_type, value_type)
self.__newindex = map_newindex
self.__pairs = self.__index.pairs
self.__tostring = self.__index.pretty_print
traits.pretty_print:implement_on(self, {
pretty_print = map_pretty_print,
default_print = map_pretty_print,
})
traits.value_name:implement_on(self, {
value_name = function()
return ("MapValue<%s, %s>"):format(
traits.value_name:get(key_type).value_name(),
traits.value_name:get(value_type).value_name()
)
end,
})
traits.freeze:implement_on(self, { freeze = map_freeze })
return self
end
define_map = U.memoize(define_map, false)
local function declare_map(key_type, value_type)
return define_map({}, key_type, value_type)
end
declare_map = U.memoize(declare_map, false)
local set_type_mt = {
__call = function(self, ...)
local val = {
_set = {},
is_frozen = false,
}
setmetatable(val, self)
local args = table.pack(...)
for i = 1, args.n do
val:put(args[i])
end
return val
end,
__eq = function(left, right)
return left.key_type == right.key_type
end,
__tostring = function(self)
return ("terms-gen set key:<%s>"):format(tostring(self.key_type))
end,
}
local function gen_set_methods(self, key_type)
return {
put = function(val, key)
if val.is_frozen then
error("trying to modify a frozen set")
end
if key_type.value_check(key) ~= true then
p("set-put", key_type)
p(key)
error("wrong key type passed to set:put")
end
local freeze_impl_key = traits.freeze:get(key_type)
if freeze_impl_key then
key = freeze_impl_key.freeze(key_type, key)
else
print(
("WARNING: while putting %s, can't freeze key (type %s)"):format(tostring(self), tostring(key_type))
)
print("this may lead to suboptimal hash-consing")
end
val._set[key] = true
end,
remove = function(val, key)
if val.is_frozen then
error("trying to modify a frozen set")
end
if key_type.value_check(key) ~= true then
p("set-remove", key_type)
p(key)
error("wrong key type passed to set:remove")
end
val._set[key] = nil
end,
test = function(val, key)
if key_type.value_check(key) ~= true then
p("set-test", key_type)
p(key)
error("wrong key type passed to set:test")
end
return val._set[key]
end,
pairs = function(val)
return pairs(val._set)
end,
copy = function(val, onto)
if not onto then
onto = self()
end
local rt = getmetatable(onto)
if self ~= rt then
error("set:copy must be passed sets of the same type")
end
for k in val:pairs() do
onto:put(k)
end
return onto
end,
union = function(left, right)
local rt = getmetatable(right)
if self ~= rt then
error("set:union must be passed sets of the same type")
end
local new = left:copy()
right:copy(new)
return new
end,
subtract = function(left, right)
local rt = getmetatable(right)
if self ~= rt then
error("set:subtract must be passed sets of the same type")
end
local new = left:copy()
for k in right:pairs() do
new:remove(k)
end
return new
end,
superset = function(left, right)
local rt = getmetatable(right)
if self ~= rt then
error("set:superset must be passed sets of the same type")
end
for k in right:pairs() do
if not left:test(k) then
return false
end
end
return true
end,
pretty_preprint = pretty_printer.pretty_preprint,
pretty_print = pretty_printer.pretty_print,
default_print = pretty_printer.default_print,
}
end
local function set_pretty_print(self, pp, ...)
return pp:table(self._set, ...)
end
local function set_freeze_helper_2(t, ...)
local frozenval = t(...)
frozenval.is_frozen = true
return frozenval
end
set_freeze_helper_2 = U.memoize(set_freeze_helper_2, false)
local function set_freeze_helper(t, keys, ...)
if #keys > 0 then
local key = table.remove(keys)
return set_freeze_helper(t, keys, key, ...)
else
return set_freeze_helper_2(t, ...)
end
end
local function set_freeze(t, val)
if val.is_frozen then
return val
end
local order_impl = traits.order:get(t.key_type)
if not order_impl then
print(("WARNING: can't freeze %s"):format(tostring(t)))
return val
end
local keys = {}
for k in pairs(val._set) do
keys[#keys + 1] = k
end
table.sort(keys, order_impl.compare)
local frozen = set_freeze_helper(t, keys)
return frozen
end
local function define_set(self, key_type)
if type(key_type) ~= "table" or type(key_type.value_check) ~= "function" then
error("trying to set the key or value type to something that isn't a type (possible typo?)")
end
setmetatable(self, set_type_mt)
self.value_check = metatable_equality(self)
self.key_type = key_type
self.__index = gen_set_methods(self, key_type)
self.__pairs = self.__index.pairs
self.__tostring = self.__index.pretty_print
traits.pretty_print:implement_on(self, {
pretty_print = set_pretty_print,
default_print = set_pretty_print,
})
traits.value_name:implement_on(self, {
value_name = function()
return ("SetValue<%s>"):format(traits.value_name:get(key_type).value_name())
end,
})
traits.freeze:implement_on(self, { freeze = set_freeze })
return self
end
define_set = U.memoize(define_set, false)
local function declare_set(key_type)
return define_set({}, key_type)
end
declare_set = U.memoize(declare_set, false)
local function array_unchecked_new_fn(self, array, n)
return setmetatable({
n = n,
array = array,
is_frozen = false,
}, self)
end
local array_type_mt = {
__call = function(self, ...)
local value_type = self.value_type
local array, n = {}, select("#", ...)
for i = 1, n do
local value = select(i, ...)
if value_type.value_check(value) ~= true then
error(
attempt_traceback(
("wrong value type passed to array creation: expected [%s] of type %s but got %s"):format(
s(i),
s(value_type),
s(value)
)
)
)
end
array[i] = value
end
return array_unchecked_new_fn(self, array, n)
end,
__eq = function(left, right)
return left.value_type == right.value_type
end,
__tostring = function(self)
return ("terms-gen array val:<%s>"):format(tostring(self.value_type))
end,
}
local function array_new_fn(self, array, first, last)
local value_type = self.value_type
local new_array = {}
if first == nil then
first = 1
end
if last == nil then
last = array.n
if last == nil then
last = #array
end
end
local i = 0
for j = first, last do
i = i + 1
local value = array[j]
if value_type.value_check(value) ~= true then
error(
attempt_traceback(
("wrong value type passed to array creation: expected [%s] of type %s but got %s"):format(
s(i),
s(value_type),
s(value)
)
)
)
end
new_array[i] = value
end
return array_unchecked_new_fn(self, new_array, i)
end
local function array_next(state, control)
local i = control + 1
if i > state:len() then
return nil
else
return i, state[i]
end
end
local function gen_array_methods(self, value_type)
return {
ipairs = function(val)
return array_next, val, 0
end,
len = function(val)
return val.n
end,
append = function(val, value)
if val.is_frozen then
error("trying to modify a frozen array")
end
local n = val.n + 1
val.array[n], val.n = value, n
end,
copy = function(val, first, last)
first, last = first or 1, last or val.n
local array, new_array = val.array, {}
local i = 0
for j = first, last do
i = i + 1
new_array[i] = array[j]
end
return self:unchecked_new(new_array, i)
end,
unpack = function(val)
return table.unpack(val.array, 1, val.n)
end,
map = function(val, to, fn)
local value_type = to.value_type
local array, new_array, n = val.array, {}, val.n
for i = 1, n do
local value = fn(array[i])
if value_type.value_check(value) ~= true then
error(
attempt_traceback(
("wrong value type resulting from array mapping: expected [%s] of type %s but got %s"):format(
s(i),
s(value_type),
s(value)
)
)
)
end
new_array[i] = value
end
return to:unchecked_new(new_array, n)
end,
get = function(val, key)
return val.array[key]
end,
pretty_preprint = pretty_printer.pretty_preprint,
pretty_print = pretty_printer.pretty_print,
default_print = pretty_printer.default_print,
}
end
local function array_eq_fn(left, right)
if getmetatable(left) ~= getmetatable(right) then
return false
end
if left:len() ~= right:len() then
return false
end
for i = 1, left:len() do
if left[i] ~= right[i] then
return false
end
end
return true
end
local function gen_array_index_fns(self, value_type)
local function index(val, key)
local method = self.methods[key]
if method then
return method
end
if type(key) ~= "number" then
p("array-index", value_type)
p(key)
error("wrong key type passed to array indexing")
end
if math_floor(key) ~= key then
p(key)
error("key passed to array indexing is not an integer")
end
if key < 1 or key > val.n then
p(key, val.n)
error(
("key passed to array indexing is out of bounds (read code comment above): %s is not within [1,%s]"):format(
tostring(key),
tostring(val.n)
)
)
end
return val.array[key]
end
local function newindex(val, key, value)
if val.is_frozen then
error(("trying to set %s on a frozen array to %s: %s"):format(s(key), s(value), s(val)))
end
if not builtin_integer_value_check(key) then
error(("key passed to array index-assignment is not an integer: %s"):format(s(key)))
end
if key < 1 or key > val.n + 1 then
error(("key %s passed to array index-assignment is out of bounds: %s"):format(s(key), s(val.n)))
end
if value_type.value_check(value) ~= true then
error(
attempt_traceback(
("wrong value type passed to array index-assignment: expected [%s] of type %s but got %s"):format(
s(key),
s(value_type),
s(value)
)
)
)
end
local freeze_impl_value = traits.freeze:get(value_type)
if freeze_impl_value then
value = freeze_impl_value.freeze(value_type, value)
else
print(
("WARNING: while setting %s, can't freeze value (type %s)\nthis may lead to suboptimal hash-consing"):format(
tostring(self),
tostring(value_type)
)
)
end
val.array[key] = value
if key > val.n then
val.n = key
end
end
return index, newindex
end
local function array_pretty_print(self, pp, ...)
return pp:array(self.array, ...)
end
local function gen_array_diff_fn(self, value_type)
local function diff_fn(left, right)
print(("diffing array with value_type: %s"):format(tostring(value_type)))
local rt = getmetatable(right)
if self ~= rt then
print("unequal types!")
print(self)
print(rt)
print("stopping diff")
return
end
if left:len() ~= right:len() then
print("unequal lengths!")
print(left:len())
print(right:len())
print("stopping diff")
return
end
local n = 0
local diff_elems = {}
for i = 1, left:len() do
if left[i] ~= right[i] then
n = n + 1
diff_elems[n] = i
end
end
if n == 0 then
print("no difference")
print("stopping diff")
return
elseif n == 1 then
local d = diff_elems[1]
print(("difference in element: %s"):format(tostring(d)))
local diff_impl = traits.diff:get(value_type)
if diff_impl then
return diff_impl.diff(left[d], right[d])
else
print("stopping diff (missing diff impl)")
print("value_type:", value_type)
return
end
else
print("difference in multiple elements:")
for i = 1, n do
print(diff_elems[i])
end
print("stopping diff")
return
end
end
return diff_fn
end
local function array_freeze_helper(t, n)
local function array_freeze_helper_aux(array)
local frozenval = t:new(array, 1, n)
frozenval.is_frozen = true
return frozenval
end
array_freeze_helper_aux = U.memoize(array_freeze_helper_aux, true)
return array_freeze_helper_aux
end
array_freeze_helper = U.memoize(array_freeze_helper, false)
local function array_freeze(t, val)
if val.is_frozen then
return val
end
local frozen = array_freeze_helper(t, val.n)(val.array)
return frozen
end
local function define_array(self, value_type)
if type(value_type) ~= "table" or type(value_type.value_check) ~= "function" then
error(
("trying to set the value type to something that isn't a type (possible typo?): %s"):format(
attempt_traceback(tostring(value_type))
)
)
end
setmetatable(self, array_type_mt)
self.unchecked_new = array_unchecked_new_fn
self.new = array_new_fn
self.value_check = metatable_equality(self)
self.value_type = value_type
self.methods = gen_array_methods(self, value_type)
self.__eq = array_eq_fn
self.__index, self.__newindex = gen_array_index_fns(self, value_type)
self.__ipairs = self.methods.ipairs
self.__len = self.methods.len
self.__tostring = self.methods.pretty_print
traits.pretty_print:implement_on(self, {
pretty_print = array_pretty_print,
default_print = array_pretty_print,
})
traits.diff:implement_on(self, {
diff = gen_array_diff_fn(self, value_type),
})
traits.value_name:implement_on(self, {
value_name = function()
return ("ArrayValue<%s>"):format(traits.value_name:get(value_type).value_name())
end,
})
traits.freeze:implement_on(self, { freeze = array_freeze })
return self
end
define_array = U.memoize(define_array, false)
local function declare_array(value_type)
return define_array({}, value_type)
end
declare_array = U.memoize(declare_array, false)
local type_mt = {
__index = {
define_record = define_record,
define_enum = define_enum,
define_foreign = define_foreign,
define_map = define_map,
define_set = define_set,
define_array = define_array,
},
}
local function undefined_value_check(_)
error("trying to typecheck a value against a type that has been declared but not defined")
end
local function define_type(self)
setmetatable(self, type_mt)
self.value_check = undefined_value_check
return self
end
local function declare_type()
return define_type({})
end
local function define_builtin(self, typename)
return define_foreign(self, function(val)
return type(val) == typename
end, typename)
end
local function declare_builtin(typename)
return define_builtin({}, typename)
end
local terms_gen = {
declare_record = declare_record,
declare_enum = declare_enum,
declare_foreign = declare_foreign,
declare_map = declare_map,
declare_set = declare_set,
declare_array = declare_array,
declare_type = declare_type,
metatable_equality = metatable_equality,
builtin_number = declare_builtin("number"),
builtin_integer = declare_foreign(builtin_integer_value_check, "integer"),
builtin_string = declare_builtin("string"),
builtin_function = declare_builtin("function"),
builtin_table = declare_builtin("table"),
array_type_mt = array_type_mt,
map_type_mt = map_type_mt,
define_multi_enum = define_multi_enum,
any_lua_type = declare_foreign(function(_val)
return true
end, "any"),
}
local function freeze_trivial(t, val)
return val
end
local function compare_trivial(left, right)
return left < right
end
for _, t in ipairs { terms_gen.builtin_integer, terms_gen.builtin_string } do
traits.freeze:implement_on(t, { freeze = freeze_trivial })
traits.order:implement_on(t, { compare = compare_trivial })
end
for _, t in ipairs { terms_gen.builtin_table, terms_gen.any_lua_type } do
traits.freeze:implement_on(t, { freeze = freeze_trivial })
end
local function any_lua_type_diff_fn(left, right)
if type(left) ~= type(right) then
print("different primitive lua types!")
print(type(left))
print(type(right))
print("stopping diff")
return
end
local dispatch = {
["nil"] = function()
print("diffing lua nils")
print("no difference")
print("stopping diff")
return
end,
["number"] = function()
print("diffing lua numbers")
if left ~= right then
print("different numbers")
print(left)
print(right)
print("stopping diff")
return
end
print("no difference")
print("stopping diff")
return
end,
["string"] = function()
print("diffing lua strings")
if left ~= right then
print("different strings")
print(left)
print(right)
print("stopping diff")
return
end
print("no difference")
print("stopping diff")
return
end,
["boolean"] = function()
print("diffing lua booleans")
if left ~= right then
print("different booleans")
print(left)
print(right)
print("stopping diff")
return
end
print("no difference")
print("stopping diff")
return
end,
["table"] = function()
print("diffing lua tables")
if left == right then
print("physically equal")
print("stopping diff")
return
end
local n = 0
local diff_elems = {}
for k, lval in pairs(left) do
rval = right[k]
if lval ~= rval then
n = n + 1
diff_elems[n] = k
end
end
for k, rval in pairs(right) do
lval = left[k]
if not lval then
n = n + 1
diff_elems[n] = k
end
end
if n == 0 then
print("no elements different")
print("stopping diff")
return
elseif n == 1 then
local d = diff_elems[1]
print(("difference in element: %s"):format(tostring(d)))
local mtl = getmetatable(left[d])
local mtr = getmetatable(right[d])
if mtl ~= mtr then
print("stopping diff (different metatables)")
return
end
local diff_impl = traits.diff:get(mtl)
if diff_impl then
return diff_impl.diff(left[d], right[d])
else
print("stopping diff (missing diff impl)")
print("mt:", mtl)
return
end
else
print("difference in multiple elements:")
for i = 1, n do
print(diff_elems[i])
end
print("stopping diff")
return
end
end,
["function"] = function()
print("diffing lua functions")
if left ~= right then
print("different functions")
print(left)
print(right)
print("stopping diff")
return
end
print("no difference")
print("stopping diff")
return
end,
["thread"] = function()
print("diffing lua threads")
if left ~= right then
print("different threads")
print(left)
print(right)
print("stopping diff")
return
end
print("no difference")
print("stopping diff")
return
end,
["userdata"] = function()
print("diffing lua userdatas")
if left ~= right then
print("different userdata")
print(left)
print(right)
print("stopping diff")
return
end
print("no difference")
print("stopping diff")
return
end,
}
dispatch[type(left)]()
end
traits.diff:implement_on(terms_gen.any_lua_type, { diff = any_lua_type_diff_fn })
local internals_interface = require "internals-interface"
internals_interface.terms_gen = terms_gen
return terms_gen