use std::cmp::Eq;
use std::collections::hash_map;
use std::fmt::Debug;
use std::hash::Hash;
use fnv::FnvHashMap;
use crate::ast::*;
use crate::Error;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum TypeDef {
F32,
Bool,
Unit,
Struct(Vec<(String, TypeDef)>),
Function(Vec<TypeDef>, Box<TypeDef>),
}
#[derive(Clone, Debug)]
pub struct FunctionDef {
pub decl: FunctionDecl,
pub functiontype: TypeDef,
}
#[derive(Clone, Debug)]
pub struct VContext {
pub functions: FnvHashMap<String, FunctionDef>,
pub types: FnvHashMap<Type, TypeDef>,
}
impl Default for VContext {
fn default() -> Self {
let functions = FnvHashMap::default();
let mut types = FnvHashMap::default();
{
types.insert("F32".into(), TypeDef::F32);
types.insert("Bool".into(), TypeDef::Bool);
types.insert("()".into(), TypeDef::Unit);
types.insert(
"Vec4F".into(),
TypeDef::Struct(vec![
("x".into(), TypeDef::F32),
("y".into(), TypeDef::F32),
("z".into(), TypeDef::F32),
("w".into(), TypeDef::F32),
]),
);
}
Self { functions, types }
}
}
fn hashtbl_insert_with_if_vacant<K, V>(
tbl: &mut FnvHashMap<K, V>,
name: K,
vl: V,
) -> Result<(), Error>
where
K: Debug + Hash + Eq,
{
let symbol_name = format!("{:?}", &name);
match tbl.entry(name) {
hash_map::Entry::Occupied(_) => Err(Error::SymbolExists(symbol_name.into())),
hash_map::Entry::Vacant(e) => {
e.insert(vl);
Ok(())
}
}
}
impl VContext {
fn function_exists(&self, name: &str) -> bool {
self.functions.contains_key(name)
}
fn type_exists(&self, name: &str) -> bool {
self.types.contains_key(&Type(name.to_string()))
}
fn check_type_mismatch(&self, e: &Expr, expected: &Type) -> Result<Type, Error> {
let tt = self.type_of_expr(e)?;
if &tt != expected {
Err(Error::TypeMismatch(expected.clone(), tt))
} else {
Ok(tt)
}
}
pub fn type_of_expr(&self, e: &Expr) -> Result<Type, Error> {
match e {
Expr::Let(_, _, _) => Ok(Type("()".into())),
Expr::Literal(l) => Ok(l.type_of()),
Expr::If(test, ifpart, elsepart) => {
let _ = self.check_type_mismatch(test, &Type("Bool".into()))?;
let it = self.type_of_exprs(ifpart)?;
let et = self.type_of_exprs(elsepart)?;
if it != et {
Err(Error::TypeMismatch(it, et))
} else {
Ok(it)
}
}
Expr::Block(es) => self.type_of_exprs(es),
_ => unimplemented!(),
}
}
pub fn type_of_exprs(&self, e: &[Expr]) -> Result<Type, Error> {
let e = e.last().unwrap_or(&Expr::Literal(Lit::Unit));
self.type_of_expr(e)
}
fn define_function(&mut self, name: &str, def: FunctionDef) -> Result<(), Error> {
hashtbl_insert_with_if_vacant(&mut self.functions, name.into(), def)
}
fn define_type(&mut self, name: &str, def: TypeDef) -> Result<(), Error> {
hashtbl_insert_with_if_vacant(&mut self.types, Type(name.into()), def)
}
pub fn get_defined_type(&self, name: &Type) -> &TypeDef {
self.types.get(name).unwrap()
}
}
pub fn verify_expr(_ctx: &VContext, e: &Expr) -> Result<(), Error> {
match e {
Expr::BinOp(_op, _e1, _e2) => panic!("verify expr"),
_ => unreachable!(),
}
}
pub fn verify(program: Vec<Decl>) -> Result<VContext, Error> {
let mut ctx = VContext::default();
for decl in program.iter() {
match decl {
Decl::Function(f) => {
let return_type = ctx.get_defined_type(&f.returns);
let param_types: Vec<TypeDef> = f
.params
.iter()
.map(|p| ctx.get_defined_type(&p.typ).clone())
.collect();
let functiontype = TypeDef::Function(param_types, Box::new(return_type.clone()));
let function_type_name = format!("{:?}", functiontype);
if !ctx.type_exists(&function_type_name) {
ctx.define_type(&function_type_name, functiontype.clone())?;
}
let def = FunctionDef {
decl: f.clone(),
functiontype,
};
ctx.define_function(f.name.as_str(), def)?;
}
Decl::Structure(_s) => {
}
}
}
verify_program(&mut ctx, &program)?;
Ok(ctx)
}
pub fn verify_program(ctx: &mut VContext, _program: &[Decl]) -> Result<(), Error> {
if !ctx.function_exists("vertex") {
Err(Error::Validation(
"Required function `vertex` doesn't exist.".into(),
))
} else if !ctx.function_exists("fragment") {
Err(Error::Validation(
"Required function `fragment` doesn't exist.".into(),
))
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use assert_matches::*;
use crate::ast::*;
use crate::verify;
use crate::Error;
#[test]
fn required_functions_exist() {
let vert = Decl::Function(FunctionDecl {
name: String::from("vertex"),
params: vec![],
returns: Type("Vec4F".into()),
body: vec![Expr::Structure(
String::from("Vec4F"),
vec![
(String::from("x"), Expr::Literal(Lit::F32(1.0))),
(String::from("y"), Expr::Literal(Lit::F32(1.0))),
(String::from("z"), Expr::Literal(Lit::F32(1.0))),
(String::from("w"), Expr::Literal(Lit::F32(1.0))),
],
)],
});
let frag = Decl::Function(FunctionDecl {
name: String::from("fragment"),
params: vec![Param {
name: "input".into(),
typ: Type("Vec4F".into()),
}],
returns: Type("Vec4F".into()),
body: vec![Expr::Var("input".into())],
});
assert_matches!(verify::verify(vec![]), Err(Error::Validation(_)));
assert_matches!(
verify::verify(vec![frag.clone()]),
Err(Error::Validation(_))
);
assert_matches!(
verify::verify(vec![vert.clone()]),
Err(Error::Validation(_))
);
assert_matches!(verify::verify(vec![frag, vert]), Ok(_));
}
#[test]
fn duplicate_functions_invalid() {
let f = Decl::Function(FunctionDecl {
name: String::from("foo"),
params: vec![],
returns: Type("Vec4F".into()),
body: vec![Expr::Structure(
String::from("Vec4F"),
vec![
(String::from("x"), Expr::Literal(Lit::F32(1.0))),
(String::from("y"), Expr::Literal(Lit::F32(1.0))),
(String::from("z"), Expr::Literal(Lit::F32(1.0))),
(String::from("w"), Expr::Literal(Lit::F32(1.0))),
],
)],
});
assert_matches!(
verify::verify(vec![f.clone(), f.clone()]),
Err(Error::SymbolExists(_))
);
}
fn test_type_of_exprs(es: &[(Expr, Result<Type, Error>)]) {
let c = verify::VContext::default();
for (e, t) in es {
let computed_type = c.type_of_expr(&e);
assert_eq!(t, &computed_type);
}
}
#[test]
fn test_type_of_exprs_lit() {
let es = vec![
(Expr::Literal(Lit::F32(1.0)), Ok(Type("F32".into()))),
(Expr::Literal(Lit::Bool(false)), Ok(Type("Bool".into()))),
(Expr::Literal(Lit::Unit), Ok(Type("()".into()))),
];
test_type_of_exprs(&es);
}
#[test]
fn test_type_of_exprs_if() {
let es = vec![
(
Expr::If(
Box::new(Expr::Literal(Lit::F32(1.0))),
vec![Expr::Literal(Lit::F32(2.0))],
vec![Expr::Literal(Lit::F32(3.0))],
),
Err(Error::TypeMismatch(Type("Bool".into()), Type("F32".into()))),
),
(
Expr::If(
Box::new(Expr::Literal(Lit::Bool(false))),
vec![Expr::Literal(Lit::F32(2.0))],
vec![Expr::Literal(Lit::F32(3.0))],
),
Ok(Type("F32".into())),
),
(
Expr::If(
Box::new(Expr::Literal(Lit::Bool(false))),
vec![Expr::Literal(Lit::F32(2.0))],
vec![Expr::Literal(Lit::Bool(true))],
),
Err(Error::TypeMismatch(Type("F32".into()), Type("Bool".into()))),
),
];
test_type_of_exprs(&es);
}
#[test]
fn test_struct_structural_equality() {
let def1 = verify::TypeDef::Struct(vec![
("thing1".into(), verify::TypeDef::F32),
("thing2".into(), verify::TypeDef::F32),
("thing3".into(), verify::TypeDef::Bool),
]);
let def2 = verify::TypeDef::Struct(vec![
("thing1".into(), def1.clone()),
("thing2".into(), verify::TypeDef::F32),
("thing3".into(), def1.clone()),
]);
let def3 = verify::TypeDef::Struct(vec![
("thing1".into(), def1.clone()),
("thing2".into(), verify::TypeDef::F32),
("thing3".into(), def1.clone()),
]);
assert_eq!(&def2, &def3.clone());
use std::hash::Hash;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
assert_eq!(&def2.hash(&mut hasher), &def3.hash(&mut hasher));
}
}