use std::collections::HashMap;
use std::sync::Arc;
use tatara_lisp::{Atom, Span, Spanned, SpannedForm};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum StaticType {
Any,
Nil,
Bool,
Int,
Float,
Number,
Str,
Symbol,
Keyword,
List(Box<StaticType>),
Map(Box<StaticType>, Box<StaticType>),
Procedure,
Promise,
Error,
Union(Vec<StaticType>),
}
impl StaticType {
pub fn render(&self) -> String {
match self {
Self::Any => ":any".into(),
Self::Nil => ":nil".into(),
Self::Bool => ":bool".into(),
Self::Int => ":int".into(),
Self::Float => ":float".into(),
Self::Number => ":number".into(),
Self::Str => ":string".into(),
Self::Symbol => ":symbol".into(),
Self::Keyword => ":keyword".into(),
Self::List(t) => format!("(:list-of {})", t.render()),
Self::Map(k, v) => format!("(:map-of {} {})", k.render(), v.render()),
Self::Procedure => ":procedure".into(),
Self::Promise => ":promise".into(),
Self::Error => ":error".into(),
Self::Union(branches) => {
let parts: Vec<String> = branches.iter().map(Self::render).collect();
format!("(:union {})", parts.join(" "))
}
}
}
pub fn conforms_to(&self, expected: &StaticType) -> bool {
if matches!(self, Self::Any) || matches!(expected, Self::Any) {
return true;
}
if matches!(expected, Self::Number) && matches!(self, Self::Int | Self::Float) {
return true;
}
if matches!(self, Self::Number) && matches!(expected, Self::Int | Self::Float) {
return true;
}
if let Self::Union(branches) = expected {
return branches.iter().any(|b| self.conforms_to(b));
}
if let Self::Union(branches) = self {
return branches.iter().all(|b| b.conforms_to(expected));
}
match (self, expected) {
(Self::List(a), Self::List(b)) => a.conforms_to(b),
(Self::Map(ak, av), Self::Map(bk, bv)) => ak.conforms_to(bk) && av.conforms_to(bv),
_ => self == expected,
}
}
pub fn from_spanned(form: &Spanned) -> Option<Self> {
match &form.form {
SpannedForm::Atom(Atom::Keyword(k)) => Some(match k.as_str() {
"any" => Self::Any,
"nil" => Self::Nil,
"bool" => Self::Bool,
"int" => Self::Int,
"float" => Self::Float,
"number" => Self::Number,
"string" => Self::Str,
"symbol" => Self::Symbol,
"keyword" => Self::Keyword,
"procedure" | "fn" => Self::Procedure,
"promise" => Self::Promise,
"error" => Self::Error,
"list" => Self::List(Box::new(Self::Any)),
"map" => Self::Map(Box::new(Self::Any), Box::new(Self::Any)),
_ => return None,
}),
SpannedForm::List(items) if !items.is_empty() => {
let head = items[0].as_keyword()?;
match head {
"list-of" if items.len() == 2 => {
Some(Self::List(Box::new(Self::from_spanned(&items[1])?)))
}
"map-of" if items.len() == 3 => Some(Self::Map(
Box::new(Self::from_spanned(&items[1])?),
Box::new(Self::from_spanned(&items[2])?),
)),
"union" => {
let mut branches = Vec::with_capacity(items.len() - 1);
for it in &items[1..] {
branches.push(Self::from_spanned(it)?);
}
Some(Self::Union(branches))
}
"fn" => Some(Self::Procedure),
_ => None,
}
}
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct TypeDiagnostic {
pub span: Span,
pub kind: TypeDiagnosticKind,
}
#[derive(Debug, Clone)]
pub enum TypeDiagnosticKind {
Mismatch {
expected: StaticType,
got: StaticType,
context: String,
},
BadTypeSpec(String),
}
impl TypeDiagnostic {
pub fn render(&self, src: &str) -> String {
let (line, col) = Span::line_col(src, self.span.start);
let head = format!("type:{}", line);
match &self.kind {
TypeDiagnosticKind::Mismatch {
expected,
got,
context,
} => format!(
"{head}:{col}: type mismatch in {context}: expected {}, got {}",
expected.render(),
got.render()
),
TypeDiagnosticKind::BadTypeSpec(msg) => {
format!("{head}:{col}: bad type spec — {msg}")
}
}
}
}
pub fn check_program(forms: &[Spanned]) -> Vec<TypeDiagnostic> {
let mut env = TypeEnv::default();
let mut diags = Vec::new();
for form in forms {
check_form(form, &mut env, &mut diags);
}
diags
}
#[derive(Default)]
struct TypeEnv {
bindings: HashMap<Arc<str>, StaticType>,
}
impl TypeEnv {
fn lookup(&self, name: &str) -> StaticType {
self.bindings
.get(name)
.cloned()
.unwrap_or(StaticType::Any)
}
fn define(&mut self, name: impl Into<Arc<str>>, ty: StaticType) {
self.bindings.insert(name.into(), ty);
}
}
fn check_form(form: &Spanned, env: &mut TypeEnv, diags: &mut Vec<TypeDiagnostic>) {
if let SpannedForm::List(items) = &form.form {
if let Some(head) = items.first().and_then(Spanned::as_symbol) {
match head {
"the" if items.len() == 3 => {
check_the(&items[1], &items[2], env, diags);
return;
}
"declare" if items.len() == 3 => {
check_declare(&items[1], &items[2], env, diags);
return;
}
"define" if items.len() >= 3 => {
check_define(items, env, diags);
return;
}
_ => {}
}
}
for item in items {
check_form(item, env, diags);
}
}
}
fn check_the(
type_form: &Spanned,
expr: &Spanned,
env: &mut TypeEnv,
diags: &mut Vec<TypeDiagnostic>,
) {
let Some(expected) = StaticType::from_spanned(type_form) else {
diags.push(TypeDiagnostic {
span: type_form.span,
kind: TypeDiagnosticKind::BadTypeSpec(format!(
"unrecognized type spec: {}",
render_form_brief(type_form)
)),
});
return;
};
let got = infer(expr, env);
if !got.conforms_to(&expected) {
diags.push(TypeDiagnostic {
span: expr.span,
kind: TypeDiagnosticKind::Mismatch {
expected,
got,
context: "the-form".into(),
},
});
}
check_form(expr, env, diags);
}
fn check_declare(
name_form: &Spanned,
type_form: &Spanned,
env: &mut TypeEnv,
diags: &mut Vec<TypeDiagnostic>,
) {
let Some(name) = name_form.as_symbol() else {
diags.push(TypeDiagnostic {
span: name_form.span,
kind: TypeDiagnosticKind::BadTypeSpec("declare: name must be a symbol".into()),
});
return;
};
let Some(ty) = StaticType::from_spanned(type_form) else {
diags.push(TypeDiagnostic {
span: type_form.span,
kind: TypeDiagnosticKind::BadTypeSpec(format!(
"unrecognized type spec: {}",
render_form_brief(type_form)
)),
});
return;
};
env.define(name, ty);
}
fn check_define(items: &[Spanned], env: &mut TypeEnv, diags: &mut Vec<TypeDiagnostic>) {
match &items[1].form {
SpannedForm::Atom(Atom::Symbol(name)) => {
let expected = env.lookup(name).clone();
let got = infer(&items[2], env);
if !got.conforms_to(&expected) {
diags.push(TypeDiagnostic {
span: items[2].span,
kind: TypeDiagnosticKind::Mismatch {
expected,
got: got.clone(),
context: format!("define {name}"),
},
});
}
env.define(name.as_str(), got);
check_form(&items[2], env, diags);
}
SpannedForm::List(head) if !head.is_empty() => {
if let Some(name) = head[0].as_symbol() {
env.define(name, StaticType::Procedure);
}
for body_form in &items[2..] {
check_form(body_form, env, diags);
}
}
_ => {}
}
}
fn infer(form: &Spanned, env: &TypeEnv) -> StaticType {
match &form.form {
SpannedForm::Nil => StaticType::Nil,
SpannedForm::Atom(a) => match a {
Atom::Bool(_) => StaticType::Bool,
Atom::Int(_) => StaticType::Int,
Atom::Float(_) => StaticType::Float,
Atom::Str(_) => StaticType::Str,
Atom::Keyword(_) => StaticType::Keyword,
Atom::Symbol(s) => env.lookup(s),
},
SpannedForm::List(items) if !items.is_empty() => {
if let Some(head) = items[0].as_symbol() {
if head == "the" && items.len() == 3 {
return StaticType::from_spanned(&items[1]).unwrap_or(StaticType::Any);
}
if head == "quote" {
return infer_quoted(&items[1]);
}
if head == "list" {
return infer_list_ctor(&items[1..], env);
}
if let Some(t) = primitive_return_type(head) {
return t;
}
}
StaticType::Any
}
SpannedForm::Quote(inner) => infer_quoted(inner),
SpannedForm::Quasiquote(_)
| SpannedForm::Unquote(_)
| SpannedForm::UnquoteSplice(_) => StaticType::Any,
_ => StaticType::Any,
}
}
fn infer_quoted(form: &Spanned) -> StaticType {
match &form.form {
SpannedForm::Atom(Atom::Symbol(_)) => StaticType::Symbol,
SpannedForm::Atom(Atom::Keyword(_)) => StaticType::Keyword,
SpannedForm::Atom(Atom::Str(_)) => StaticType::Str,
SpannedForm::Atom(Atom::Int(_)) => StaticType::Int,
SpannedForm::Atom(Atom::Float(_)) => StaticType::Float,
SpannedForm::Atom(Atom::Bool(_)) => StaticType::Bool,
SpannedForm::Nil => StaticType::Nil,
SpannedForm::List(_) => StaticType::List(Box::new(StaticType::Any)),
_ => StaticType::Any,
}
}
fn infer_list_ctor(args: &[Spanned], env: &TypeEnv) -> StaticType {
if args.is_empty() {
return StaticType::List(Box::new(StaticType::Any));
}
let mut element = infer(&args[0], env);
for arg in &args[1..] {
let next = infer(arg, env);
element = least_upper_bound(element, next);
if matches!(element, StaticType::Any) {
break;
}
}
StaticType::List(Box::new(element))
}
fn least_upper_bound(a: StaticType, b: StaticType) -> StaticType {
if a == b {
return a;
}
if matches!(a, StaticType::Any) || matches!(b, StaticType::Any) {
return StaticType::Any;
}
if matches!((&a, &b),
(StaticType::Int, StaticType::Float)
| (StaticType::Float, StaticType::Int))
{
return StaticType::Number;
}
StaticType::Union(vec![a, b])
}
fn primitive_return_type(name: &str) -> Option<StaticType> {
Some(match name {
"+" | "-" | "*" | "/" | "abs" | "min" | "max" | "modulo" | "expt" | "sqrt" | "floor"
| "ceiling" | "round" | "truncate" | "gcd" | "lcm" | "sin" | "cos" | "tan" | "log"
| "exp" | "inc" | "dec" => StaticType::Number,
"=" | "<" | ">" | "<=" | ">=" | "not=" | "null?" | "pair?" | "list?" | "symbol?"
| "string?" | "integer?" | "number?" | "boolean?" | "procedure?" | "foreign?"
| "atom?" | "keyword?" | "even?" | "odd?" | "zero?" | "positive?" | "negative?"
| "empty?" | "not-empty?" | "any?" | "every?" | "member?" | "is?"
| "hash-map?" | "hash-map-empty?" | "hash-map-has?" | "chan?" | "chan-closed?"
| "promise?" | "error?" => StaticType::Bool,
"list" | "cons" | "reverse" | "append" | "take" | "drop" | "range" | "map"
| "filter" | "remove" | "concat" | "distinct" | "flatten" | "zip" | "partition"
| "scan-left" | "iterate" | "repeatedly" | "drain!" | "hash-map-keys"
| "hash-map-values" | "hash-map-entries" | "read-all" => {
StaticType::List(Box::new(StaticType::Any))
}
"hash-map" | "hash-map-set" | "hash-map-remove" | "hash-map-merge"
| "hash-map-update" => StaticType::Map(Box::new(StaticType::Any), Box::new(StaticType::Any)),
"string-append" | "string" | "pr-str" | "symbol->string" | "keyword->string"
| "error-message" => StaticType::Str,
"length" | "count-if" | "find-index" | "position" | "compare" | "string-length"
| "hash-map-count" | "chan-len" => StaticType::Int,
"type-of" | "error-tag" => StaticType::Keyword,
_ => return None,
})
}
fn render_form_brief(form: &Spanned) -> String {
match &form.form {
SpannedForm::Atom(Atom::Symbol(s)) => s.to_string(),
SpannedForm::Atom(Atom::Keyword(k)) => format!(":{k}"),
SpannedForm::Atom(Atom::Str(s)) => format!("{s:?}"),
SpannedForm::Atom(Atom::Int(n)) => n.to_string(),
SpannedForm::Atom(Atom::Float(n)) => n.to_string(),
SpannedForm::Atom(Atom::Bool(b)) => if *b { "#t" } else { "#f" }.into(),
SpannedForm::Nil => "()".into(),
SpannedForm::List(_) => "(...)".into(),
_ => "?".into(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use tatara_lisp::read_spanned;
fn check(src: &str) -> Vec<TypeDiagnostic> {
let forms = read_spanned(src).unwrap();
check_program(&forms)
}
#[test]
fn no_annotations_no_diagnostics() {
assert!(check("(define x 42) (+ 1 2)").is_empty());
}
#[test]
fn the_with_correct_atom_passes() {
assert!(check("(the :int 42)").is_empty());
assert!(check("(the :string \"hi\")").is_empty());
assert!(check("(the :bool #t)").is_empty());
}
#[test]
fn the_with_wrong_atom_flags() {
let diags = check("(the :int \"oops\")");
assert_eq!(diags.len(), 1);
match &diags[0].kind {
TypeDiagnosticKind::Mismatch { expected, got, .. } => {
assert!(matches!(expected, StaticType::Int));
assert!(matches!(got, StaticType::Str));
}
other => panic!("{other:?}"),
}
}
#[test]
fn declare_then_define_match_passes() {
assert!(check("(declare counter :int) (define counter 0)").is_empty());
}
#[test]
fn declare_then_define_mismatch_flags() {
let diags = check("(declare counter :int) (define counter \"oops\")");
assert_eq!(diags.len(), 1);
match &diags[0].kind {
TypeDiagnosticKind::Mismatch { expected, .. } => {
assert!(matches!(expected, StaticType::Int));
}
other => panic!("{other:?}"),
}
}
#[test]
fn list_ctor_infers_homogeneous_element_type() {
assert!(check("(the (:list-of :int) (list 1 2 3))").is_empty());
}
#[test]
fn list_ctor_heterogeneous_widens_to_any_or_union() {
let diags = check("(the (:list-of :int) (list 1 \"x\" 3))");
assert_eq!(diags.len(), 1);
}
#[test]
fn bad_type_spec_diagnoses() {
let diags = check("(the :nonsense 1)");
assert_eq!(diags.len(), 1);
assert!(matches!(diags[0].kind, TypeDiagnosticKind::BadTypeSpec(_)));
}
#[test]
fn primitive_return_type_drives_inference() {
let diags = check("(the :int (string-append \"a\" \"b\"))");
assert_eq!(diags.len(), 1);
}
#[test]
fn arithmetic_returns_number_so_conforms_to_int_or_float() {
assert!(check("(the :int (+ 1 2))").is_empty());
assert!(check("(the :float (+ 1.0 2.0))").is_empty());
}
#[test]
fn union_type_admits_any_branch() {
assert!(check("(the (:union :int :string) 42)").is_empty());
assert!(check("(the (:union :int :string) \"hi\")").is_empty());
let diags = check("(the (:union :int :string) #t)");
assert_eq!(diags.len(), 1);
}
#[test]
fn nested_list_inference() {
assert!(check("(the (:list-of (:list-of :int)) (list (list 1 2) (list 3)))").is_empty());
}
#[test]
fn conforms_to_total_for_any() {
assert!(StaticType::Any.conforms_to(&StaticType::Int));
assert!(StaticType::Int.conforms_to(&StaticType::Any));
assert!(StaticType::Union(vec![StaticType::Int, StaticType::Str])
.conforms_to(&StaticType::Any));
}
#[test]
fn render_round_trips_canonical_forms() {
assert_eq!(StaticType::Int.render(), ":int");
assert_eq!(StaticType::List(Box::new(StaticType::Str)).render(), "(:list-of :string)");
assert_eq!(
StaticType::Map(Box::new(StaticType::Keyword), Box::new(StaticType::Int)).render(),
"(:map-of :keyword :int)"
);
assert_eq!(
StaticType::Union(vec![StaticType::Int, StaticType::Str]).render(),
"(:union :int :string)"
);
}
}