mod error;
pub use error::TypeCheckError;
use std::collections::HashMap;
use crate::common::Span;
use crate::parser::{
Aggregation, AggregationOperator, Arithmetic, ArithmeticOperator, Atom, AtomArg,
ComparisonExpr, ComparisonOperator, ConstType, DataType, Factor, FlowLogRule, FnCall, HeadArg,
Predicate, Program,
};
pub fn check_program(program: &mut Program) -> Result<(), TypeCheckError> {
let decls: DeclTypes = program
.relations()
.iter()
.map(|r| (r.name().to_string(), r.data_type()))
.collect();
let udfs: UdfSigs = program
.udfs()
.iter()
.map(|u| {
(
u.name().to_string(),
(
u.params()
.iter()
.map(|p| (p.name().to_string(), *p.data_type()))
.collect(),
u.ret_type(),
),
)
})
.collect();
for segment in program.segments_mut() {
for rule in segment.as_rules_mut() {
check_rule(rule, &decls, &udfs)?;
}
if let Some(block) = segment.as_loop_mut() {
for rule in block.rules_mut() {
check_rule(rule, &decls, &udfs)?;
}
}
}
check_and_pin_facts(program.facts_mut(), &decls)
}
type DeclTypes = HashMap<String, Vec<DataType>>;
type UdfSigs = HashMap<String, (Vec<(String, DataType)>, DataType)>;
type Bindings = HashMap<String, (DataType, Span)>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LitKind {
IntLit,
FloatLit,
Concrete(DataType),
}
impl LitKind {
fn fits(self, expected: DataType) -> bool {
match self {
LitKind::IntLit => expected.is_integer(),
LitKind::FloatLit => expected.is_float(),
LitKind::Concrete(t) => t == expected,
}
}
fn is_numeric(self) -> bool {
match self {
LitKind::IntLit | LitKind::FloatLit => true,
LitKind::Concrete(t) => t.is_numeric(),
}
}
fn report_ty(self) -> DataType {
match self {
LitKind::IntLit => DataType::Int32,
LitKind::FloatLit => DataType::Float32,
LitKind::Concrete(t) => t,
}
}
}
fn merge_lit(a: LitKind, b: LitKind) -> Option<LitKind> {
match (a, b) {
(x, y) if x == y => Some(x),
(LitKind::Concrete(t), LitKind::IntLit) | (LitKind::IntLit, LitKind::Concrete(t))
if t.is_integer() =>
{
Some(LitKind::Concrete(t))
}
(LitKind::Concrete(t), LitKind::FloatLit) | (LitKind::FloatLit, LitKind::Concrete(t))
if t.is_float() =>
{
Some(LitKind::Concrete(t))
}
_ => None,
}
}
fn check_rule(
rule: &mut FlowLogRule,
decls: &DeclTypes,
udfs: &UdfSigs,
) -> Result<(), TypeCheckError> {
let mut bindings: Bindings = HashMap::new();
for predicate in rule.rhs() {
if let Predicate::PositiveAtom(atom) = predicate {
bind_atom_vars(atom, decls, &mut bindings)?;
}
}
for predicate in rule.rhs_mut() {
match predicate {
Predicate::PositiveAtom(atom) => pin_atom_consts(atom, decls)?,
Predicate::NegativeAtom(atom) => {
check_atom_uses(atom, decls, &bindings)?;
pin_atom_consts(atom, decls)?;
}
Predicate::Compare(cmp) => check_comparison(cmp, &bindings, udfs)?,
Predicate::FnCall(fc) => {
infer_fn_call_type(fc, &bindings, udfs)?;
pin_fn_call_args(fc, udfs)?;
}
}
}
check_head(rule, decls, udfs, &bindings)
}
fn bind_atom_vars(
atom: &Atom,
decls: &DeclTypes,
bindings: &mut Bindings,
) -> Result<(), TypeCheckError> {
for (i, arg) in atom.arguments().iter().enumerate() {
let col_ty = resolve_atom_column(atom, i, decls)?;
match arg {
AtomArg::Var(v) => match bindings.get(v) {
None => {
bindings.insert(v.clone(), (col_ty, atom.span()));
}
Some(&(first_ty, first_span)) if first_ty != col_ty => {
return Err(TypeCheckError::TypeMismatch {
var: v.clone(),
first_ty,
first_span,
later_ty: col_ty,
later_span: atom.span(),
});
}
Some(_) => {}
},
AtomArg::Const(c) => {
if !lit_kind(c)?.fits(col_ty) {
return Err(TypeCheckError::LiteralColumnMismatch {
span: atom.span(),
literal: c.to_string(),
expected: col_ty,
});
}
}
AtomArg::Placeholder => {}
}
}
Ok(())
}
fn check_atom_uses(
atom: &Atom,
decls: &DeclTypes,
bindings: &Bindings,
) -> Result<(), TypeCheckError> {
for (i, arg) in atom.arguments().iter().enumerate() {
let col_ty = resolve_atom_column(atom, i, decls)?;
match arg {
AtomArg::Var(v) => {
if let Some(&(bound_ty, bound_span)) = bindings.get(v)
&& bound_ty != col_ty
{
return Err(TypeCheckError::TypeMismatch {
var: v.clone(),
first_ty: bound_ty,
first_span: bound_span,
later_ty: col_ty,
later_span: atom.span(),
});
}
}
AtomArg::Const(c) => {
if !lit_kind(c)?.fits(col_ty) {
return Err(TypeCheckError::LiteralColumnMismatch {
span: atom.span(),
literal: c.to_string(),
expected: col_ty,
});
}
}
AtomArg::Placeholder => {}
}
}
Ok(())
}
fn pin_atom_consts(atom: &mut Atom, decls: &DeclTypes) -> Result<(), TypeCheckError> {
let col_types: Vec<DataType> = {
let Some(decl) = decls.get(atom.name()) else {
return Err(TypeCheckError::internal(format!(
"atom `{}` not declared",
atom.name()
)));
};
decl.clone()
};
for (arg, col_ty) in atom.arguments_mut().iter_mut().zip(col_types.iter()) {
if let AtomArg::Const(c) = arg
&& c.is_polymorphic()
{
c.pin(*col_ty);
}
}
Ok(())
}
fn resolve_atom_column(
atom: &Atom,
i: usize,
decls: &DeclTypes,
) -> Result<DataType, TypeCheckError> {
let decl = decls
.get(atom.name())
.ok_or_else(|| TypeCheckError::internal(format!("atom `{}` not declared", atom.name())))?;
decl.get(i).copied().ok_or_else(|| {
TypeCheckError::internal(format!(
"atom `{}` has {} arguments but `.decl` has {}",
atom.name(),
atom.arguments().len(),
decl.len(),
))
})
}
fn check_comparison(
cmp: &mut ComparisonExpr,
bindings: &Bindings,
udfs: &UdfSigs,
) -> Result<(), TypeCheckError> {
let left = infer_expr_type(cmp.left(), bindings, udfs)?;
let right = infer_expr_type(cmp.right(), bindings, udfs)?;
let op = cmp.operator().clone();
let span = cmp.span();
if let (Some(l), Some(r)) = (left, right)
&& merge_lit(l, r).is_none()
{
return Err(TypeCheckError::ComparisonTypeMismatch {
span,
op,
left: l.report_ty(),
right: r.report_ty(),
});
}
if !matches!(op, ComparisonOperator::Equal | ComparisonOperator::NotEqual)
&& let Some(kind) = left.or(right)
{
let is_ordered = kind.is_numeric() || matches!(kind, LitKind::Concrete(DataType::String));
if !is_ordered {
return Err(TypeCheckError::ComparisonOpNotAllowed {
span,
op,
ty: kind.report_ty(),
});
}
}
let target = match (left, right) {
(Some(l), Some(r)) => merge_lit(l, r).map(LitKind::report_ty),
(Some(k), None) | (None, Some(k)) => Some(k.report_ty()),
(None, None) => None,
};
if let Some(t) = target {
pin_arith_literals(cmp.left_mut(), t, udfs)?;
pin_arith_literals(cmp.right_mut(), t, udfs)?;
}
Ok(())
}
fn check_head(
rule: &mut FlowLogRule,
decls: &DeclTypes,
udfs: &UdfSigs,
bindings: &Bindings,
) -> Result<(), TypeCheckError> {
let head = rule.head_mut();
let (rel_name, arity, head_span) = (head.name().to_string(), head.arity(), head.span());
let col_types: Vec<DataType> = {
let Some(decl) = decls.get(&rel_name) else {
return Err(TypeCheckError::internal(format!(
"head relation `{rel_name}` not declared"
)));
};
decl.clone()
};
if arity != col_types.len() {
return Err(TypeCheckError::HeadArity {
span: head_span,
rel: rel_name,
expected: col_types.len(),
found: arity,
});
}
for (col, (arg, expected)) in head
.head_arguments_mut()
.iter_mut()
.zip(col_types.iter().copied())
.enumerate()
{
match arg {
HeadArg::Aggregation(agg) => check_aggregation(agg, expected, udfs, bindings)?,
HeadArg::Var(v) => {
if let Some(&(found, _)) = bindings.get(v)
&& found != expected
{
return Err(TypeCheckError::HeadColumnType {
span: head_span,
rel: rel_name.clone(),
col,
expected,
found,
});
}
}
HeadArg::Arith(a) => {
if let Some(kind) = infer_expr_type(a, bindings, udfs)?
&& !kind.fits(expected)
{
return Err(head_or_literal_mismatch(a, &rel_name, col, expected, kind));
}
pin_arith_literals(a, expected, udfs)?;
}
}
}
Ok(())
}
fn head_or_literal_mismatch(
a: &Arithmetic,
rel: &str,
col: usize,
expected: DataType,
kind: LitKind,
) -> TypeCheckError {
if let Some(c) = bare_const(a) {
return TypeCheckError::LiteralColumnMismatch {
span: a.span(),
literal: c.to_string(),
expected,
};
}
TypeCheckError::HeadColumnType {
span: a.span(),
rel: rel.to_string(),
col,
expected,
found: kind.report_ty(),
}
}
fn check_aggregation(
agg: &mut Aggregation,
declared: DataType,
udfs: &UdfSigs,
bindings: &Bindings,
) -> Result<(), TypeCheckError> {
let op = *agg.operator();
let span = agg.span();
let arg_kind = infer_expr_type(agg.arithmetic(), bindings, udfs)?;
if matches!(op, AggregationOperator::Count) {
if !declared.is_numeric() {
return Err(TypeCheckError::AggregationOutputType { span, op, declared });
}
if let Some(k) = arg_kind {
pin_arith_literals(agg.arithmetic_mut(), k.report_ty(), udfs)?;
}
return Ok(());
}
if let Some(kind) = arg_kind {
if !kind.is_numeric() {
return Err(TypeCheckError::AggregationInputNotNumeric {
span,
op,
ty: kind.report_ty(),
});
}
if !kind.fits(declared) {
return Err(TypeCheckError::AggregationOutputType { span, op, declared });
}
}
pin_arith_literals(agg.arithmetic_mut(), declared, udfs)?;
Ok(())
}
fn infer_expr_type(
expr: &Arithmetic,
bindings: &Bindings,
udfs: &UdfSigs,
) -> Result<Option<LitKind>, TypeCheckError> {
let span = expr.span();
let mut inferred = infer_factor_type(expr.init(), bindings, udfs)?;
for (op, factor) in expr.rest() {
if let Some(k) = infer_factor_type(factor, bindings, udfs)? {
inferred = match inferred {
None => Some(k),
Some(existing) => Some(merge_lit(existing, k).ok_or(
TypeCheckError::ArithmeticTypeMismatch {
span,
left: existing.report_ty(),
right: k.report_ty(),
},
)?),
};
}
if let Some(k) = inferred {
check_arith_op(k, op, span)?;
}
}
Ok(inferred)
}
fn infer_factor_type(
factor: &Factor,
bindings: &Bindings,
udfs: &UdfSigs,
) -> Result<Option<LitKind>, TypeCheckError> {
Ok(match factor {
Factor::Var(v) => bindings.get(v).map(|&(ty, _)| LitKind::Concrete(ty)),
Factor::Const(c) => Some(lit_kind(c)?),
Factor::FnCall(fc) => Some(LitKind::Concrete(infer_fn_call_type(fc, bindings, udfs)?)),
})
}
fn infer_fn_call_type(
fc: &FnCall,
bindings: &Bindings,
udfs: &UdfSigs,
) -> Result<DataType, TypeCheckError> {
let (param_types, ret_ty) =
udfs.get(fc.name())
.ok_or_else(|| TypeCheckError::UndeclaredUdf {
span: fc.span(),
name: fc.name().to_string(),
})?;
if fc.args().len() != param_types.len() {
return Err(TypeCheckError::UdfArity {
span: fc.span(),
name: fc.name().to_string(),
expected: param_types.len(),
found: fc.args().len(),
});
}
for (arg, (param_name, expected)) in fc.args().iter().zip(param_types.iter()) {
let Some(kind) = infer_expr_type(arg, bindings, udfs)? else {
continue;
};
if !kind.fits(*expected) {
return Err(TypeCheckError::UdfArgType {
span: arg.span(),
name: fc.name().to_string(),
param: param_name.clone(),
expected: *expected,
found: kind.report_ty(),
});
}
}
Ok(*ret_ty)
}
fn lit_kind(c: &ConstType) -> Result<LitKind, TypeCheckError> {
Ok(match c {
ConstType::Int(_) => LitKind::IntLit,
ConstType::Float(_) => LitKind::FloatLit,
_ => LitKind::Concrete(c.data_type().ok_or_else(|| {
TypeCheckError::internal(format!(
"lit_kind: polymorphic literal {c:?} escaped Int/Float arms"
))
})?),
})
}
fn bare_const(a: &Arithmetic) -> Option<&ConstType> {
match (a.is_const(), a.init()) {
(true, Factor::Const(c)) => Some(c),
_ => None,
}
}
fn check_arith_op(
kind: LitKind,
op: &ArithmeticOperator,
span: Span,
) -> Result<(), TypeCheckError> {
let is_cat = matches!(op, ArithmeticOperator::Cat);
let allowed = match kind {
LitKind::Concrete(DataType::Bool) => false,
LitKind::Concrete(DataType::String) => is_cat,
_ => !is_cat,
};
if allowed {
Ok(())
} else {
Err(TypeCheckError::ArithmeticOpNotAllowed {
span,
op: op.clone(),
ty: kind.report_ty(),
})
}
}
fn pin_arith_literals(
a: &mut Arithmetic,
target: DataType,
udfs: &UdfSigs,
) -> Result<(), TypeCheckError> {
pin_factor(a.init_mut(), target, udfs)?;
for (_, f) in a.rest_mut() {
pin_factor(f, target, udfs)?;
}
Ok(())
}
fn pin_factor(factor: &mut Factor, target: DataType, udfs: &UdfSigs) -> Result<(), TypeCheckError> {
match factor {
Factor::Const(c) => {
if c.is_polymorphic() {
c.pin(target);
}
Ok(())
}
Factor::Var(_) => Ok(()),
Factor::FnCall(fc) => pin_fn_call_args(fc, udfs),
}
}
fn pin_fn_call_args(fc: &mut FnCall, udfs: &UdfSigs) -> Result<(), TypeCheckError> {
let param_types: Vec<DataType> = udfs
.get(fc.name())
.map(|(params, _)| params.iter().map(|(_, ty)| *ty).collect())
.ok_or_else(|| {
TypeCheckError::internal(format!(
"pin_fn_call_args: UDF `{}` not declared",
fc.name()
))
})?;
for (arg, pty) in fc.args_mut().iter_mut().zip(param_types.iter()) {
pin_arith_literals(arg, *pty, udfs)?;
}
Ok(())
}
fn check_and_pin_facts(
facts: &mut HashMap<String, Vec<(Span, Vec<ConstType>)>>,
decls: &DeclTypes,
) -> Result<(), TypeCheckError> {
for (rel_name, tuples) in facts.iter_mut() {
let Some(col_types) = decls.get(rel_name) else {
return Err(TypeCheckError::internal(format!(
"fact references undeclared relation `{rel_name}`"
)));
};
for (span, tuple) in tuples.iter_mut() {
if tuple.len() != col_types.len() {
return Err(TypeCheckError::HeadArity {
span: *span,
rel: rel_name.clone(),
expected: col_types.len(),
found: tuple.len(),
});
}
for (c, col_ty) in tuple.iter_mut().zip(col_types.iter()) {
if !lit_kind(c)?.fits(*col_ty) {
return Err(TypeCheckError::LiteralColumnMismatch {
span: *span,
literal: c.to_string(),
expected: *col_ty,
});
}
if c.is_polymorphic() {
c.pin(*col_ty);
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::SourceMap;
use std::io::Write;
fn parse_and_check(src: &str) -> Program {
let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
tmp.write_all(src.as_bytes()).expect("write");
let mut sm = SourceMap::new();
let mut program =
Program::parse(&tmp.path().to_string_lossy(), true, &mut sm).expect("parse failed");
check_program(&mut program).expect("check failed");
program
}
#[test]
fn body_atom_const_pinned_to_declared_column_width() {
let src = "\
.decl Item(x: int8)\n\
.decl Flag(x: int8)\n\
.decl Out(x: int8)\n\
.input Item(IO=\"file\", filename=\"Item.csv\", delimiter=\",\")\n\
.input Flag(IO=\"file\", filename=\"Flag.csv\", delimiter=\",\")\n\
.output Out\n\
Out(x) :- Item(x), Flag(5).\n";
let program = parse_and_check(src);
let rule = program.rules()[0];
let flag_atom = match &rule.rhs()[1] {
Predicate::PositiveAtom(a) => a,
other => panic!("expected Flag atom, got {other:?}"),
};
match &flag_atom.arguments()[0] {
AtomArg::Const(c) => assert_eq!(c, &ConstType::Int8(5)),
other => panic!("expected Const, got {other:?}"),
}
}
#[test]
fn comparison_literal_pinned_to_variable_type() {
let src = "\
.decl Item(x: int16)\n\
.decl Big(x: int16)\n\
.input Item(IO=\"file\", filename=\"Item.csv\", delimiter=\",\")\n\
.output Big\n\
Big(x) :- Item(x), x > 100.\n";
let program = parse_and_check(src);
let rule = program.rules()[0];
let cmp = match &rule.rhs()[1] {
Predicate::Compare(c) => c,
other => panic!("expected comparison, got {other:?}"),
};
match cmp.right().init() {
Factor::Const(c) => assert_eq!(c, &ConstType::Int16(100)),
other => panic!("expected Const, got {other:?}"),
}
}
#[test]
fn nested_udf_arg_pinned_to_param_type_not_outer_target() {
let src = "\
.decl Item(x: int64)\n\
.decl Flag(x: int64)\n\
.input Item(IO=\"file\", filename=\"Item.csv\", delimiter=\",\")\n\
.output Flag\n\
.extern fn f(a: int8) -> int64\n\
Flag(f(1) + x) :- Item(x).\n";
let program = parse_and_check(src);
let rule = program.rules()[0];
let head_arith = match &rule.head().head_arguments()[0] {
HeadArg::Arith(a) => a,
other => panic!("expected Arith head arg, got {other:?}"),
};
let fc = match head_arith.init() {
Factor::FnCall(fc) => fc,
other => panic!("expected FnCall factor, got {other:?}"),
};
match fc.args()[0].init() {
Factor::Const(c) => assert_eq!(
c,
&ConstType::Int8(1),
"UDF arg must pin to param type (Int8), not outer target (Int64)"
),
other => panic!("expected Const, got {other:?}"),
}
}
#[test]
fn fact_tuple_const_pinned_to_declared_column_width() {
let src = "\
.decl P(x: uint64)\n\
.decl Out(x: uint64)\n\
.output Out\n\
P(5).\n\
Out(x) :- P(x).\n";
let program = parse_and_check(src);
let p_facts = program.facts().get("p").expect("p facts");
let (_, tuple) = &p_facts[0];
assert_eq!(tuple[0], ConstType::UInt64(5));
}
#[test]
fn merge_lit_table() {
use DataType::*;
use LitKind::*;
assert_eq!(merge_lit(IntLit, IntLit), Some(IntLit));
assert_eq!(merge_lit(FloatLit, FloatLit), Some(FloatLit));
assert_eq!(merge_lit(IntLit, Concrete(Int8)), Some(Concrete(Int8)));
assert_eq!(merge_lit(Concrete(UInt16), IntLit), Some(Concrete(UInt16)));
assert_eq!(
merge_lit(FloatLit, Concrete(Float64)),
Some(Concrete(Float64))
);
assert_eq!(merge_lit(Concrete(Int8), Concrete(Int16)), None);
assert_eq!(merge_lit(Concrete(Float32), Concrete(Float64)), None);
assert_eq!(merge_lit(IntLit, FloatLit), None);
assert_eq!(merge_lit(Concrete(Int32), Concrete(Float32)), None);
assert_eq!(merge_lit(Concrete(Bool), IntLit), None);
}
#[test]
fn check_arith_op_table() {
use ArithmeticOperator::*;
use DataType::*;
use LitKind::Concrete;
let span = Span::DUMMY;
assert!(check_arith_op(Concrete(Int32), &Plus, span).is_ok());
assert!(check_arith_op(Concrete(Float64), &Multiply, span).is_ok());
assert!(check_arith_op(Concrete(String), &Cat, span).is_ok());
assert!(check_arith_op(Concrete(Int32), &Cat, span).is_err());
assert!(check_arith_op(Concrete(String), &Plus, span).is_err());
assert!(check_arith_op(Concrete(Bool), &Plus, span).is_err());
assert!(check_arith_op(Concrete(Bool), &Cat, span).is_err());
}
#[test]
fn report_ty_polymorphic_defaults() {
assert_eq!(LitKind::IntLit.report_ty(), DataType::Int32);
assert_eq!(LitKind::FloatLit.report_ty(), DataType::Float32);
}
}