use fnv::FnvHashMap;
use rspirv::mr::Builder;
use spirv_headers as spirv;
use crate::ast;
use crate::verify;
#[cfg(test)]
mod tests;
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct VarBinding {
pub address: spirv::Word,
pub typedef: spirv::Word,
}
pub struct CContext {
pub b: Builder,
pub typetable: FnvHashMap<verify::TypeDef, spirv::Word>,
pub consts: FnvHashMap<ast::Lit, spirv::Word>,
pub symtable: Vec<FnvHashMap<String, VarBinding>>,
}
impl CContext {
pub fn new() -> Self {
let mut b = Builder::new();
b.set_version(1, 0);
b.capability(spirv::Capability::Shader);
b.memory_model(spirv::AddressingModel::Logical, spirv::MemoryModel::Simple);
let typetable = FnvHashMap::default();
let consts = FnvHashMap::default();
let symtable = vec![FnvHashMap::default()];
Self {
b,
typetable,
consts,
symtable,
}
}
pub fn push_scope(&mut self) {
self.symtable.push(Default::default());
}
pub fn bind_var(&mut self, name: &str, address: spirv::Word, typedef: spirv::Word) {
self.symtable
.last_mut()
.expect("No scope in variable binding!")
.insert(name.into(), VarBinding { address, typedef });
}
pub fn lookup_var(&self, name: &str) -> &VarBinding {
assert!(self.symtable.len() > 1, "No scope for variable lookup!");
for scope in self.symtable.iter().rev() {
if let Some(w) = scope.get(name) {
return w;
}
}
dbg!(&self.symtable);
panic!("Variable not found!")
}
pub fn pop_scope(&mut self) {
self.symtable
.pop()
.expect("Tried to pop empty scope stack!");
}
pub fn define_const(&mut self, vl: ast::Lit) -> spirv::Word {
let type_float = self.get_type(&verify::TypeDef::F32);
let type_bool = self.get_type(&verify::TypeDef::Bool);
let consts = &mut self.consts;
let b = &mut self.b;
*consts.entry(vl.clone()).or_insert_with(|| match vl {
ast::Lit::F32(f) => b.constant_f32(type_float, f),
ast::Lit::Bool(bl) => {
if bl {
b.constant_true(type_bool)
} else {
b.constant_false(type_bool)
}
}
ast::Lit::Unit =>
0
})
}
pub fn get_type(&self, t: &verify::TypeDef) -> spirv::Word {
*self.typetable.get(t).expect("Could not get type!")
}
pub fn add_type(&mut self, typedef: &verify::TypeDef) -> spirv::Word {
use verify::TypeDef;
match typedef {
TypeDef::F32 => {
let b = &mut self.b;
*self
.typetable
.entry(typedef.clone())
.or_insert_with(|| b.type_float(32))
}
TypeDef::Bool => {
let b = &mut self.b;
*self
.typetable
.entry(typedef.clone())
.or_insert_with(|| b.type_bool())
}
TypeDef::Unit => {
let b = &mut self.b;
*self
.typetable
.entry(typedef.clone())
.or_insert_with(|| b.type_void())
}
TypeDef::Struct(fields) => {
if let Some(val) = self.typetable.get(typedef) {
*val
} else {
let mut fields_words = vec![];
for (_name, typedef) in fields.iter() {
let type_word = self.add_type(typedef);
fields_words.push(type_word);
}
let n = self.b.type_struct(fields_words);
self.typetable.insert(typedef.clone(), n);
n
}
}
TypeDef::Function(params, returns) => {
if let Some(val) = self.typetable.get(typedef) {
*val
} else {
let rettype_word = self.add_type(returns);
let mut param_words = vec![];
for typedef in params.iter() {
let type_word = self.add_type(typedef);
param_words.push(type_word);
}
let f = self.b.type_function(rettype_word, param_words);
self.typetable.insert(typedef.clone(), f);
f
}
}
}
}
pub fn compile_expr(
&mut self,
e: &ast::Expr,
ctx: &verify::VContext,
) -> Result<spirv::Word, crate::Error> {
let result_word = match e {
ast::Expr::Var(name) => {
let binding = self.lookup_var(name);
let addr = binding.address;
addr
}
ast::Expr::Let(pattern, typ, vl) => {
let value_word = self.compile_expr(vl, ctx)?;
let var_typedef = ctx.get_defined_type(&typ);
let var_typeword = self.get_type(var_typedef);
match pattern {
ast::Pattern::FreeVar(varname) => {
self.bind_var(varname, value_word, var_typeword)
}
}
value_word
}
ast::Expr::Literal(lit) => {
self.define_const(lit.clone())
}
ast::Expr::FunCall(fname, params) => {
let param_words: Result<Vec<spirv::Word>, _> = params
.clone()
.iter()
.map(|param| self.compile_expr(param, ctx))
.collect();
let param_words = param_words?;
let binding = *self.lookup_var(&fname);
let functiondef = ctx
.functions
.get(fname)
.expect("Function does not exist for funcall");
if let verify::TypeDef::Function(ref _params, ref rettype) =
functiondef.functiontype
{
let rettype_word = self.get_type(rettype);
self.b
.function_call(rettype_word, None, binding.address, param_words)?
} else {
unreachable!("Function type is not TypeDef::Function")
}
}
ast::Expr::Block(exprs) => {
assert!(exprs.len() > 0, "Blocks with no expressions are verboten!");
self.push_scope();
let last_word: spirv::Word = exprs
.iter()
.map(|e| self.compile_expr(e, ctx))
.last()
.expect("Can't happen")?;
self.pop_scope();
last_word
}
ast::Expr::BinOp(op, e1, e2) => {
let t1 = ctx.get_defined_type(&ctx.type_of_expr(e1)?);
let t2 = ctx.get_defined_type(&ctx.type_of_expr(e2)?);
let w1 = self.compile_expr(e1, ctx)?;
let w2 = self.compile_expr(e2, ctx)?;
self.compile_inferred_binop(*op, t1, t2, w1, w2)
}
ast::Expr::UniOp(op, e) => {
let t = ctx.get_defined_type(&ctx.type_of_expr(e)?);
let w = self.compile_expr(e, ctx)?;
self.compile_inferred_uniop(*op, t, w)
}
ast::Expr::Structure(_name, _vals) => {
unimplemented!()
}
_ => unimplemented!(),
};
Ok(result_word)
}
pub fn compile_inferred_binop(
&mut self,
op: ast::Op,
t1: &verify::TypeDef,
t2: &verify::TypeDef,
e1: spirv::Word,
e2: spirv::Word,
) -> spirv::Word {
let f_word = self.get_type(&verify::TypeDef::F32);
let b_word = self.get_type(&verify::TypeDef::Bool);
match (op, t1, t2) {
(ast::Op::Add, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.fadd(f_word, None, e1, e2).expect("???")
}
(ast::Op::Sub, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.fsub(f_word, None, e1, e2).expect("???")
}
(ast::Op::Mul, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.fmul(f_word, None, e1, e2).expect("???")
}
(ast::Op::Div, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.fdiv(f_word, None, e1, e2).expect("???")
}
(ast::Op::Gt, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.ford_greater_than(b_word, None, e1, e2).expect("???")
}
(ast::Op::Lt, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.ford_less_than(b_word, None, e1, e2).expect("???")
}
(ast::Op::Gte, verify::TypeDef::F32, verify::TypeDef::F32) => self
.b
.ford_greater_than_equal(b_word, None, e1, e2)
.expect("???"),
(ast::Op::Lte, verify::TypeDef::F32, verify::TypeDef::F32) => self
.b
.ford_less_than_equal(b_word, None, e1, e2)
.expect("???"),
(ast::Op::Eq, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.ford_equal(b_word, None, e1, e2).expect("???")
}
(ast::Op::Neq, verify::TypeDef::F32, verify::TypeDef::F32) => {
self.b.ford_not_equal(b_word, None, e1, e2).expect("???")
}
(ast::Op::And, verify::TypeDef::Bool, verify::TypeDef::Bool) => {
self.b.logical_and(b_word, None, e1, e2).expect("???")
}
(ast::Op::Or, verify::TypeDef::Bool, verify::TypeDef::Bool) => {
self.b.logical_or(b_word, None, e1, e2).expect("???")
}
_ => {
let msg = format!("Invalid binary op type: {:?} {:?} {:?}", op, t1, t2);
panic!(msg)
}
}
}
pub fn compile_inferred_uniop(
&mut self,
op: ast::UOp,
t: &verify::TypeDef,
e: spirv::Word,
) -> spirv::Word {
let f_word = self.get_type(&verify::TypeDef::F32);
let b_word = self.get_type(&verify::TypeDef::Bool);
match (op, t) {
(ast::UOp::Negate, verify::TypeDef::F32) => {
self.b.fnegate(f_word, None, e).expect("???")
}
(ast::UOp::Not, verify::TypeDef::Bool) => {
self.b.logical_not(b_word, None, e).expect("???")
}
_ => {
let msg = format!("Invalid unary op type: {:?} {:?}", op, t);
panic!(msg)
}
}
}
pub fn compile_function(
&mut self,
ctx: &verify::VContext,
def: &verify::FunctionDef,
) -> Result<(), crate::Error> {
assert!(def.decl.body.len() > 0, "Empty function body!");
let function_returns: spirv::Word =
self.get_type(ctx.types.get(&def.decl.returns).unwrap());
let ftype = self.add_type(&def.functiontype);
let f_word = self.b.begin_function(
function_returns,
None,
spirv::FunctionControl::DONT_INLINE
| spirv::FunctionControl::CONST
| spirv::FunctionControl::PURE,
ftype,
)?;
self.bind_var(&def.decl.name, f_word, ftype);
self.push_scope();
for p in def.decl.params.iter() {
let param_typedef = ctx.get_defined_type(&p.typ);
let param_typeword = self.get_type(param_typedef);
let word = self.b.function_parameter(param_typeword)?;
self.bind_var(&p.name, word, param_typeword);
}
self.b.begin_basic_block(None).unwrap();
let last_expr_val: spirv::Word = def
.decl
.body
.iter()
.map(|e| self.compile_expr(e, ctx).expect("Could not compile expr!"))
.last()
.expect("Empty function body! Try returning ()");
self.b.ret_value(last_expr_val)?;
self.b.end_function()?;
self.pop_scope();
self.b.name(f_word, &def.decl.name);
Ok(())
}
fn mongle_entry_points(&mut self, _ctx: &verify::VContext) -> Result<(), crate::Error> {
let ftype = self.add_type(&verify::TypeDef::Function(
vec![],
Box::new(verify::TypeDef::Unit),
));
let void = self.add_type(&verify::TypeDef::Unit);
let v_word = self.b.begin_function(
void,
None,
spirv::FunctionControl::DONT_INLINE | spirv::FunctionControl::CONST,
ftype,
)?;
self.b.begin_basic_block(None)?;
self.b.ret()?;
self.b.end_function()?;
let f_word = self.b.begin_function(
void,
None,
spirv::FunctionControl::DONT_INLINE | spirv::FunctionControl::CONST,
ftype,
)?;
self.b.begin_basic_block(None)?;
self.b.ret()?;
self.b.end_function()?;
let vert_name = "_vertex_entry";
let frag_name = "_fragment_entry";
self.b.name(v_word, vert_name);
self.b.name(f_word, frag_name);
self.b
.entry_point(spirv::ExecutionModel::Vertex, v_word, vert_name, []);
self.b
.entry_point(spirv::ExecutionModel::Fragment, f_word, frag_name, []);
self.b
.execution_mode(f_word, spirv::ExecutionMode::OriginUpperLeft, []);
Ok(())
}
}
pub fn compile(ctx: &verify::VContext) -> Result<CContext, crate::Error> {
let mut cc = CContext::new();
for (_name, def) in ctx.types.iter() {
let _ = cc.add_type(def);
}
for (_name, def) in ctx.functions.iter() {
cc.compile_function(ctx, def)?;
}
cc.mongle_entry_points(ctx)?;
Ok(cc)
}