use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt;
use std::vec;
use super::ast::Type::*;
use super::ast::*;
use super::error::*;
use super::util::{join, SymbolGenerator};
use fnv;
pub mod optimizations;
pub type BasicBlockId = usize;
pub type FunctionId = usize;
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum StatementKind {
Assign(Symbol),
Assert(Symbol),
AssignLiteral(LiteralKind),
BinOp {
op: BinOpKind,
left: Symbol,
right: Symbol,
},
Broadcast(Symbol),
Cast(Symbol, Type),
CUDF {
symbol_name: String,
args: Vec<Symbol>,
},
GetField {
value: Symbol,
index: u32,
},
KeyExists {
child: Symbol,
key: Symbol,
},
Length(Symbol),
Lookup {
child: Symbol,
index: Symbol,
},
OptLookup {
child: Symbol,
index: Symbol,
},
MakeStruct(Vec<Symbol>),
MakeVector(Vec<Symbol>),
Merge {
builder: Symbol,
value: Symbol,
},
Negate(Symbol),
Not(Symbol),
NewBuilder {
arg: Option<Symbol>,
ty: Type,
},
ParallelFor(ParallelForData),
Res(Symbol),
Select {
cond: Symbol,
on_true: Symbol,
on_false: Symbol,
},
Slice {
child: Symbol,
index: Symbol,
size: Symbol,
},
Sort {
child: Symbol,
cmpfunc: FunctionId,
},
Serialize(Symbol),
Deserialize(Symbol),
ToVec(Symbol),
UnaryOp {
op: UnaryOpKind,
child: Symbol,
},
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct ParallelForIter {
pub data: Symbol,
pub start: Option<Symbol>,
pub end: Option<Symbol>,
pub stride: Option<Symbol>,
pub kind: IterKind,
pub strides: Option<Symbol>,
pub shape: Option<Symbol>,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct ParallelForData {
pub data: Vec<ParallelForIter>,
pub builder: Symbol,
pub data_arg: Symbol,
pub builder_arg: Symbol,
pub idx_arg: Symbol,
pub body: FunctionId,
pub innermost: bool,
}
impl StatementKind {
pub fn children(&self) -> vec::IntoIter<&Symbol> {
use self::StatementKind::*;
let mut vars = vec![];
match *self {
BinOp {
ref left,
ref right,
..
} => {
vars.push(left);
vars.push(right);
}
ParallelFor(ref data) => {
vars.push(&data.builder);
for iter in data.data.iter() {
vars.push(&iter.data);
if iter.shape.is_some() {
vars.push(iter.start.as_ref().unwrap());
vars.push(iter.end.as_ref().unwrap());
vars.push(iter.stride.as_ref().unwrap());
vars.push(iter.shape.as_ref().unwrap());
vars.push(iter.strides.as_ref().unwrap());
} else if iter.start.is_some() {
vars.push(iter.start.as_ref().unwrap());
vars.push(iter.end.as_ref().unwrap());
vars.push(iter.stride.as_ref().unwrap());
}
}
}
UnaryOp { ref child, .. } => {
vars.push(child);
}
Cast(ref child, _) => {
vars.push(child);
}
Negate(ref child) => {
vars.push(child);
}
Not(ref child) => {
vars.push(child);
}
Assert(ref child) => {
vars.push(child);
}
Broadcast(ref child) => {
vars.push(child);
}
Serialize(ref child) => {
vars.push(child);
}
Deserialize(ref child) => {
vars.push(child);
}
Lookup {
ref child,
ref index,
} => {
vars.push(child);
vars.push(index);
}
OptLookup {
ref child,
ref index,
} => {
vars.push(child);
vars.push(index);
}
KeyExists { ref child, ref key } => {
vars.push(child);
vars.push(key);
}
Slice {
ref child,
ref index,
ref size,
} => {
vars.push(child);
vars.push(index);
vars.push(size);
}
Sort { ref child, .. } => {
vars.push(child);
}
Select {
ref cond,
ref on_true,
ref on_false,
} => {
vars.push(cond);
vars.push(on_true);
vars.push(on_false);
}
ToVec(ref child) => {
vars.push(child);
}
Length(ref child) => {
vars.push(child);
}
Assign(ref value) => {
vars.push(value);
}
Merge {
ref builder,
ref value,
} => {
vars.push(builder);
vars.push(value);
}
Res(ref builder) => vars.push(builder),
GetField { ref value, .. } => vars.push(value),
AssignLiteral { .. } => {}
NewBuilder { ref arg, .. } => {
if let Some(ref a) = *arg {
vars.push(a);
}
}
MakeStruct(ref elems) => {
for elem in elems {
vars.push(elem);
}
}
MakeVector(ref elems) => {
for elem in elems {
vars.push(elem);
}
}
CUDF { ref args, .. } => {
for arg in args {
vars.push(arg);
}
}
}
vars.into_iter()
}
pub fn children_mut(&mut self) -> vec::IntoIter<&mut Symbol> {
use self::StatementKind::*;
let mut vars = vec![];
match *self {
BinOp {
ref mut left,
ref mut right,
..
} => {
vars.push(left);
vars.push(right);
}
ParallelFor(ref mut data) => {
vars.push(&mut data.builder);
for iter in data.data.iter_mut() {
vars.push(&mut iter.data);
if iter.shape.is_some() {
vars.push(iter.start.as_mut().unwrap());
vars.push(iter.end.as_mut().unwrap());
vars.push(iter.stride.as_mut().unwrap());
vars.push(iter.shape.as_mut().unwrap());
vars.push(iter.strides.as_mut().unwrap());
} else if iter.start.is_some() {
vars.push(iter.start.as_mut().unwrap());
vars.push(iter.end.as_mut().unwrap());
vars.push(iter.stride.as_mut().unwrap());
}
}
}
UnaryOp { ref mut child, .. } => {
vars.push(child);
}
Cast(ref mut child, _) => {
vars.push(child);
}
Negate(ref mut child) => {
vars.push(child);
}
Not(ref mut child) => {
vars.push(child);
}
Assert(ref mut child) => {
vars.push(child);
}
Broadcast(ref mut child) => {
vars.push(child);
}
Serialize(ref mut child) => {
vars.push(child);
}
Deserialize(ref mut child) => {
vars.push(child);
}
Lookup {
ref mut child,
ref mut index,
} => {
vars.push(child);
vars.push(index);
}
OptLookup {
ref mut child,
ref mut index,
} => {
vars.push(child);
vars.push(index);
}
KeyExists {
ref mut child,
ref mut key,
} => {
vars.push(child);
vars.push(key);
}
Slice {
ref mut child,
ref mut index,
ref mut size,
} => {
vars.push(child);
vars.push(index);
vars.push(size);
}
Sort { ref mut child, .. } => {
vars.push(child);
}
Select {
ref mut cond,
ref mut on_true,
ref mut on_false,
} => {
vars.push(cond);
vars.push(on_true);
vars.push(on_false);
}
ToVec(ref mut child) => {
vars.push(child);
}
Length(ref mut child) => {
vars.push(child);
}
Assign(ref mut value) => {
vars.push(value);
}
Merge {
ref mut builder,
ref mut value,
} => {
vars.push(builder);
vars.push(value);
}
Res(ref mut builder) => vars.push(builder),
GetField { ref mut value, .. } => vars.push(value),
AssignLiteral { .. } => {}
NewBuilder { ref mut arg, .. } => {
if let Some(ref mut a) = *arg {
vars.push(a);
}
}
MakeStruct(ref mut elems) => {
for elem in elems {
vars.push(elem);
}
}
MakeVector(ref mut elems) => {
for elem in elems {
vars.push(elem);
}
}
CUDF { ref mut args, .. } => {
for arg in args {
vars.push(arg);
}
}
}
vars.into_iter()
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Statement {
pub output: Option<Symbol>,
pub kind: StatementKind,
}
impl Statement {
pub fn new(output: Option<Symbol>, kind: StatementKind) -> Statement {
Statement { output, kind }
}
pub fn substitute_symbol(&mut self, target: &Symbol, with: &Symbol) {
for child in self.kind.children_mut() {
if child == target {
*child = with.clone();
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct ProgramSite(FunctionId, BasicBlockId);
type SiteSymbolMap = fnv::FnvHashMap<StatementKind, Symbol>;
struct StatementTracker {
generated: fnv::FnvHashMap<ProgramSite, SiteSymbolMap>,
}
impl StatementTracker {
pub fn new() -> StatementTracker {
StatementTracker {
generated: fnv::FnvHashMap::default(),
}
}
fn symbol_for_statement(
&mut self,
prog: &mut SirProgram,
func: FunctionId,
block: BasicBlockId,
sym_ty: &Type,
kind: StatementKind,
) -> Symbol {
use crate::sir::StatementKind::CUDF;
let site = ProgramSite(func, block);
let map = self
.generated
.entry(site)
.or_insert_with(fnv::FnvHashMap::default);
if let CUDF { .. } = kind {
let res_sym = prog.add_local(sym_ty, func);
prog.funcs[func].blocks[block]
.add_statement(Statement::new(Some(res_sym.clone()), kind));
return res_sym;
}
match map.entry(kind.clone()) {
Entry::Occupied(ent) => ent.get().clone(),
Entry::Vacant(ent) => {
let res_sym = prog.add_local(sym_ty, func);
prog.funcs[func].blocks[block]
.add_statement(Statement::new(Some(res_sym.clone()), kind));
ent.insert(res_sym.clone());
res_sym
}
}
}
fn named_symbol_for_statement(
&mut self,
prog: &mut SirProgram,
func: FunctionId,
block: BasicBlockId,
sym_ty: &Type,
kind: StatementKind,
named_sym: Symbol,
) {
let site = ProgramSite(func, block);
let map = self
.generated
.entry(site)
.or_insert_with(fnv::FnvHashMap::default);
prog.add_local_named(sym_ty, &named_sym, func);
prog.funcs[func].blocks[block]
.add_statement(Statement::new(Some(named_sym.clone()), kind.clone()));
map.insert(kind, named_sym);
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum Terminator {
Branch {
cond: Symbol,
on_true: BasicBlockId,
on_false: BasicBlockId,
},
JumpBlock(BasicBlockId),
ProgramReturn(Symbol),
EndFunction(Symbol),
Crash,
}
impl Terminator {
pub fn children(&self) -> vec::IntoIter<&Symbol> {
use self::Terminator::*;
let mut vars = vec![];
match *self {
Branch { ref cond, .. } => {
vars.push(cond);
}
ProgramReturn(ref sym) => {
vars.push(sym);
}
EndFunction(ref sym) => vars.push(&sym),
Crash => (),
JumpBlock(_) => (),
};
vars.into_iter()
}
pub fn children_mut(&mut self) -> vec::IntoIter<&mut Symbol> {
use self::Terminator::*;
let mut vars = vec![];
match *self {
Branch { ref mut cond, .. } => {
vars.push(cond);
}
ProgramReturn(ref mut sym) => {
vars.push(sym);
}
EndFunction(ref mut sym) => vars.push(sym),
Crash => (),
JumpBlock(_) => (),
};
vars.into_iter()
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct BasicBlock {
pub id: BasicBlockId,
pub statements: Vec<Statement>,
pub terminator: Terminator,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct SirFunction {
pub id: FunctionId,
pub params: BTreeMap<Symbol, Type>,
pub locals: BTreeMap<Symbol, Type>,
pub loop_variables: Vec<Symbol>,
pub blocks: Vec<BasicBlock>,
pub return_type: Type,
pub loop_body: bool,
pub innermost_loop: bool,
}
impl SirFunction {
pub fn symbol_type(&self, sym: &Symbol) -> WeldResult<&Type> {
self.locals.get(sym).map(Ok).unwrap_or_else(|| {
self.params
.get(sym)
.map(Ok)
.unwrap_or_else(|| compile_err!("Can't find symbol {}", sym.to_string()))
})
}
}
pub struct SirProgram {
pub funcs: Vec<SirFunction>,
pub ret_ty: Type,
pub top_params: Vec<Parameter>,
sym_gen: SymbolGenerator,
}
impl SirProgram {
pub fn new(ret_ty: &Type, top_params: &[Parameter]) -> SirProgram {
let mut prog = SirProgram {
funcs: vec![],
ret_ty: ret_ty.clone(),
top_params: top_params.to_vec(),
sym_gen: SymbolGenerator::new(),
};
prog.add_func();
prog
}
pub fn add_func(&mut self) -> FunctionId {
let func = SirFunction {
id: self.funcs.len(),
params: BTreeMap::new(),
blocks: vec![],
locals: BTreeMap::new(),
loop_variables: vec![],
return_type: Unknown,
loop_body: false,
innermost_loop: false,
};
self.funcs.push(func);
self.funcs.len() - 1
}
pub fn add_local(&mut self, ty: &Type, func: FunctionId) -> Symbol {
let sym = self.sym_gen.new_symbol(format!("fn{}_tmp", func).as_str());
self.funcs[func].locals.insert(sym.clone(), ty.clone());
sym
}
pub fn add_local_named(&mut self, ty: &Type, sym: &Symbol, func: FunctionId) {
self.funcs[func].locals.insert(sym.clone(), ty.clone());
}
pub fn add_loop_variable(&mut self, sym: &Symbol, func: FunctionId) {
self.funcs[func].loop_variables.push(sym.clone());
}
}
impl SirFunction {
pub fn add_block(&mut self) -> BasicBlockId {
let block = BasicBlock {
id: self.blocks.len(),
statements: vec![],
terminator: Terminator::Crash,
};
self.blocks.push(block);
self.blocks.len() - 1
}
}
impl BasicBlock {
pub fn add_statement(&mut self, statement: Statement) {
self.statements.push(statement);
}
pub fn substitute_symbol(&mut self, target: &Symbol, with: &Symbol) {
for statement in self.statements.iter_mut() {
statement.substitute_symbol(target, with)
}
for child in self.terminator.children_mut() {
if child == target {
*child = with.clone();
}
}
}
}
impl fmt::Display for StatementKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::StatementKind::*;
match *self {
Assign(ref value) => write!(f, "{}", value),
AssignLiteral(ref value) => write!(f, "{}", value),
BinOp {
ref op,
ref left,
ref right,
} => write!(f, "{} {} {}", op, left, right),
Broadcast(ref child) => write!(f, "broadcast({})", child),
Serialize(ref child) => write!(f, "serialize({})", child),
Deserialize(ref child) => write!(f, "deserialize({})", child),
Cast(ref child, ref ty) => write!(f, "cast({}, {})", child, ty),
CUDF {
ref symbol_name,
ref args,
} => write!(
f,
"cudf[{}]{}",
symbol_name,
join("(", ", ", ")", args.iter().map(|e| format!("{}", e)))
),
GetField { ref value, index } => write!(f, "{}.${}", value, index),
KeyExists { ref child, ref key } => write!(f, "keyexists({}, {})", child, key),
Length(ref child) => write!(f, "len({})", child),
MakeStruct(ref elems) => write!(
f,
"{}",
join("{", ",", "}", elems.iter().map(|e| format!("{}", e)))
),
MakeVector(ref elems) => write!(
f,
"{}",
join("[", ", ", "]", elems.iter().map(|e| format!("{}", e)))
),
Merge {
ref builder,
ref value,
} => write!(f, "merge({}, {})", builder, value),
Negate(ref child) => write!(f, "-{}", child),
Not(ref child) => write!(f, "!{}", child),
Assert(ref child) => write!(f, "assert({})", child),
NewBuilder { ref arg, ref ty } => {
let arg_str = if let Some(ref a) = *arg {
a.to_string()
} else {
"".to_string()
};
write!(f, "new {}({})", ty, arg_str)
}
Lookup {
ref child,
ref index,
} => write!(f, "lookup({}, {})", child, index),
OptLookup {
ref child,
ref index,
} => write!(f, "optlookup({}, {})", child, index),
ParallelFor(ref pf) => {
write!(f, "for [")?;
for iter in &pf.data {
write!(f, "{}, ", iter)?;
}
write!(f, "] ")?;
write!(
f,
"{} {} {} {} F{} {}",
pf.builder, pf.builder_arg, pf.idx_arg, pf.data_arg, pf.body, pf.innermost
)?;
Ok(())
}
Res(ref builder) => write!(f, "result({})", builder),
Select {
ref cond,
ref on_true,
ref on_false,
} => write!(f, "select({}, {}, {})", cond, on_true, on_false),
Slice {
ref child,
ref index,
ref size,
} => write!(f, "slice({}, {}, {})", child, index, size),
Sort { ref child, .. } => write!(f, "sort({})", child),
ToVec(ref child) => write!(f, "toVec({})", child),
UnaryOp { ref op, ref child } => write!(f, "{}({})", op, child),
}
}
}
impl fmt::Display for Statement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref sym) = self.output {
write!(f, "{} = {}", sym, self.kind)
} else {
write!(f, "{}", self.kind)
}
}
}
impl fmt::Display for Terminator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::Terminator::*;
match *self {
Branch {
ref cond,
ref on_true,
ref on_false,
} => write!(f, "branch {} B{} B{}", cond, on_true, on_false),
JumpBlock(block) => write!(f, "jump B{}", block),
ProgramReturn(ref sym) => write!(f, "return {}", sym),
EndFunction(ref sym) => write!(f, "end {}", sym),
Crash => write!(f, "crash"),
}
}
}
impl fmt::Display for ParallelForIter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let iterkind = match self.kind {
IterKind::ScalarIter => "iter",
IterKind::SimdIter => "simditer",
IterKind::FringeIter => "fringeiter",
IterKind::NdIter => "nditer",
IterKind::RangeIter => "rangeiter",
};
if self.shape.is_some() {
write!(
f,
"{}({}, {}, {}, {})",
iterkind,
self.data,
self.start.clone().unwrap(),
self.shape.clone().unwrap(),
self.strides.clone().unwrap()
)
} else if self.start.is_some() {
write!(
f,
"{}({}, {}, {}, {})",
iterkind,
self.data,
self.start.clone().unwrap(),
self.end.clone().unwrap(),
self.stride.clone().unwrap()
)
} else if self.kind != IterKind::ScalarIter {
write!(f, "{}({})", iterkind, self.data)
} else {
write!(f, "{}", self.data)
}
}
}
impl fmt::Display for BasicBlock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "B{}:", self.id)?;
for stmt in &self.statements {
writeln!(f, " {}", stmt)?;
}
writeln!(f, " {}", self.terminator)?;
Ok(())
}
}
impl fmt::Display for SirFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let loopbody = if self.loop_body { " (loopbody)" } else { "" };
writeln!(f, "F{} -> {}{}:", self.id, &self.return_type, loopbody)?;
writeln!(f, "Params:")?;
let params_sorted: BTreeMap<&Symbol, &Type> = self.params.iter().collect();
for (name, ty) in params_sorted {
writeln!(f, " {}: {}", name, ty)?;
}
writeln!(f, "Locals:")?;
let locals_sorted: BTreeMap<&Symbol, &Type> = self.locals.iter().collect();
for (name, ty) in locals_sorted {
writeln!(f, " {}: {}", name, ty)?;
}
for block in &self.blocks {
write!(f, "{}", block)?;
}
Ok(())
}
}
impl fmt::Display for SirProgram {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for func in &self.funcs {
writeln!(f, "{}", func)?;
}
Ok(())
}
}
fn sir_param_correction_helper(
prog: &mut SirProgram,
func_id: FunctionId,
env: &mut HashMap<Symbol, Type>,
closure: &mut HashSet<Symbol>,
visited: &mut HashSet<FunctionId>,
) {
for name in prog.funcs[func_id].params.keys() {
closure.insert(name.clone());
}
if !visited.insert(func_id) {
return;
}
for (name, ty) in &prog.funcs[func_id].params {
env.insert(name.clone(), ty.clone());
}
for (name, ty) in &prog.funcs[func_id].locals {
env.insert(name.clone(), ty.clone());
}
for block in prog.funcs[func_id].blocks.clone() {
let mut vars = vec![];
for statement in &block.statements {
vars.extend(statement.kind.children().cloned());
}
vars.extend(block.terminator.children().cloned());
for var in &vars {
if prog.funcs[func_id].locals.get(&var) == None {
prog.funcs[func_id]
.params
.insert(var.clone(), env.get(&var).unwrap().clone());
closure.insert(var.clone());
}
}
let mut inner_closure = HashSet::new();
for statement in &block.statements {
use self::StatementKind::ParallelFor;
if let ParallelFor(ref pf) = statement.kind {
sir_param_correction_helper(prog, pf.body, env, &mut inner_closure, visited);
}
}
for var in inner_closure {
if prog.funcs[func_id].locals.get(&var) == None {
prog.funcs[func_id]
.params
.insert(var.clone(), env.get(&var).unwrap().clone());
closure.insert(var.clone());
}
}
}
}
fn assign_return_types_helper(prog: &mut SirProgram, func: FunctionId) -> WeldResult<Type> {
use crate::sir::Terminator::*;
if prog.funcs[func].return_type != Unknown {
return Ok(prog.funcs[func].return_type.clone());
}
let mut return_symbol = None;
{
let function = &prog.funcs[func];
for block in function.blocks.iter() {
match block.terminator {
Branch { .. } => (),
JumpBlock(_) => (),
ProgramReturn(ref sym) | EndFunction(ref sym) => {
return_symbol = Some(sym.clone());
}
Crash => (),
}
}
}
if let Some(symbol) = return_symbol {
let return_type = prog.funcs[func].symbol_type(&symbol)?.clone();
prog.funcs[func].return_type = return_type.clone();
Ok(return_type)
} else {
unreachable!()
}
}
fn assign_return_types(prog: &mut SirProgram) -> WeldResult<()> {
for funcs in 0..prog.funcs.len() {
assign_return_types_helper(prog, funcs)?;
}
Ok(())
}
fn sir_param_correction(prog: &mut SirProgram) -> WeldResult<()> {
let mut env = HashMap::new();
let mut closure = HashSet::new();
let mut visited = HashSet::new();
sir_param_correction_helper(prog, 0, &mut env, &mut closure, &mut visited);
let func = &prog.funcs[0];
for name in closure {
if func.params.get(&name) == None {
compile_err!("Unbound symbol {}", name.to_string())?;
}
}
Ok(())
}
pub fn ast_to_sir(expr: &Expr) -> WeldResult<SirProgram> {
if let ExprKind::Lambda {
ref params,
ref body,
} = expr.kind
{
let mut prog = SirProgram::new(&body.ty, params);
prog.sym_gen = SymbolGenerator::from_expression(expr);
for tp in params {
prog.funcs[0].params.insert(tp.name.clone(), tp.ty.clone());
}
let first_block = prog.funcs[0].add_block();
let (res_func, res_block, res_sym) = gen_expr(
body,
&mut prog,
0,
first_block,
&mut StatementTracker::new(),
)?;
prog.funcs[res_func].blocks[res_block].terminator = Terminator::ProgramReturn(res_sym);
sir_param_correction(&mut prog)?;
sir_param_correction(&mut prog)?;
assign_return_types(&mut prog)?;
Ok(prog)
} else {
compile_err!("Expression passed to ast_to_sir was not a Lambda")
}
}
fn get_iter_sym(
opt: &Option<Box<Expr>>,
prog: &mut SirProgram,
cur_func: &mut FunctionId,
cur_block: &mut BasicBlockId,
tracker: &mut StatementTracker,
body_func: FunctionId,
) -> WeldResult<Option<Symbol>> {
if let Some(ref opt_expr) = *opt {
let opt_res = gen_expr(&opt_expr, prog, *cur_func, *cur_block, tracker)?;
*cur_func = opt_res.0;
*cur_block = opt_res.1;
prog.funcs[body_func]
.params
.insert(opt_res.2.clone(), opt_expr.ty.clone());
Ok(Some(opt_res.2))
} else {
Ok(None)
}
}
fn gen_expr(
expr: &Expr,
prog: &mut SirProgram,
cur_func: FunctionId,
cur_block: BasicBlockId,
tracker: &mut StatementTracker,
) -> WeldResult<(FunctionId, BasicBlockId, Symbol)> {
use self::StatementKind::*;
use self::Terminator::*;
match expr.kind {
ExprKind::Ident(ref sym) => Ok((cur_func, cur_block, sym.clone())),
ExprKind::Literal(ref lit) => {
let kind = AssignLiteral(lit.clone());
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Let {
ref name,
ref value,
ref body,
} => {
let (cur_func, cur_block, val_sym) =
gen_expr(value, prog, cur_func, cur_block, tracker)?;
let kind = Assign(val_sym);
tracker.named_symbol_for_statement(
prog,
cur_func,
cur_block,
&value.ty,
kind,
name.clone(),
);
let (cur_func, cur_block, res_sym) =
gen_expr(body, prog, cur_func, cur_block, tracker)?;
Ok((cur_func, cur_block, res_sym))
}
ExprKind::BinOp {
kind,
ref left,
ref right,
} => {
let (cur_func, cur_block, left_sym) =
gen_expr(left, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, right_sym) =
gen_expr(right, prog, cur_func, cur_block, tracker)?;
let kind = BinOp {
op: kind,
left: left_sym,
right: right_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::UnaryOp { kind, ref value } => {
let (cur_func, cur_block, value_sym) =
gen_expr(value, prog, cur_func, cur_block, tracker)?;
let kind = UnaryOp {
op: kind,
child: value_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Negate(ref child_expr) => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = Negate(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Not(ref child_expr) => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = Not(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Assert(ref child_expr) => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = Assert(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Broadcast(ref child_expr) => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = Broadcast(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Serialize(ref child_expr) => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = Serialize(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Deserialize { ref value, .. } => {
let (cur_func, cur_block, child_sym) =
gen_expr(value, prog, cur_func, cur_block, tracker)?;
let kind = Deserialize(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Cast { ref child_expr, .. } => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = Cast(child_sym, expr.ty.clone());
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Lookup {
ref data,
ref index,
} => {
let (cur_func, cur_block, data_sym) =
gen_expr(data, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, index_sym) =
gen_expr(index, prog, cur_func, cur_block, tracker)?;
let kind = Lookup {
child: data_sym,
index: index_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::OptLookup {
ref data,
ref index,
} => {
let (cur_func, cur_block, data_sym) =
gen_expr(data, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, index_sym) =
gen_expr(index, prog, cur_func, cur_block, tracker)?;
let kind = OptLookup {
child: data_sym,
index: index_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::KeyExists { ref data, ref key } => {
let (cur_func, cur_block, data_sym) =
gen_expr(data, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, key_sym) = gen_expr(key, prog, cur_func, cur_block, tracker)?;
let kind = KeyExists {
child: data_sym,
key: key_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Slice {
ref data,
ref index,
ref size,
} => {
let (cur_func, cur_block, data_sym) =
gen_expr(data, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, index_sym) =
gen_expr(index, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, size_sym) =
gen_expr(size, prog, cur_func, cur_block, tracker)?;
let kind = Slice {
child: data_sym,
index: index_sym,
size: size_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Sort {
ref data,
ref cmpfunc,
} => {
if let ExprKind::Lambda {
ref params,
ref body,
} = cmpfunc.kind
{
let cmpfunc_id = prog.add_func();
let cmpblock = prog.funcs[cmpfunc_id].add_block();
let (cmpfunc_id, cmpblock, cmp_sym) =
gen_expr(body, prog, cmpfunc_id, cmpblock, tracker)?;
prog.funcs[cmpfunc_id]
.params
.insert(params[0].name.clone(), params[0].ty.clone());
prog.funcs[cmpfunc_id]
.params
.insert(params[1].name.clone(), params[1].ty.clone());
prog.funcs[cmpfunc_id].blocks[cmpblock].terminator =
Terminator::EndFunction(cmp_sym);
let (cur_func, cur_block, data_sym) =
gen_expr(data, prog, cur_func, cur_block, tracker)?;
let kind = Sort {
child: data_sym,
cmpfunc: cmpfunc_id,
};
let res_sym =
tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
} else {
compile_err!(
"Sort comparison function expected lambda type, instead {:?} provided",
cmpfunc.ty
)
}
}
ExprKind::Select {
ref cond,
ref on_true,
ref on_false,
} => {
let (cur_func, cur_block, cond_sym) =
gen_expr(cond, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, true_sym) =
gen_expr(on_true, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, false_sym) =
gen_expr(on_false, prog, cur_func, cur_block, tracker)?;
let kind = Select {
cond: cond_sym,
on_true: true_sym,
on_false: false_sym,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::ToVec { ref child_expr } => {
let (cur_func, cur_block, child_sym) =
gen_expr(child_expr, prog, cur_func, cur_block, tracker)?;
let kind = ToVec(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::Length { ref data } => {
let (cur_func, cur_block, child_sym) =
gen_expr(data, prog, cur_func, cur_block, tracker)?;
let kind = Length(child_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::If {
ref cond,
ref on_true,
ref on_false,
} => {
let (cur_func, cur_block, cond_sym) =
gen_expr(cond, prog, cur_func, cur_block, tracker)?;
let true_block = prog.funcs[cur_func].add_block();
let false_block = prog.funcs[cur_func].add_block();
prog.funcs[cur_func].blocks[cur_block].terminator = Branch {
cond: cond_sym,
on_true: true_block,
on_false: false_block,
};
let (true_func, true_block, true_sym) =
gen_expr(on_true, prog, cur_func, true_block, tracker)?;
let (false_func, false_block, false_sym) =
gen_expr(on_false, prog, cur_func, false_block, tracker)?;
let res_sym = prog.add_local(&expr.ty, true_func);
prog.funcs[true_func].blocks[true_block]
.add_statement(Statement::new(Some(res_sym.clone()), Assign(true_sym)));
prog.funcs[false_func].blocks[false_block]
.add_statement(Statement::new(Some(res_sym.clone()), Assign(false_sym)));
let cont_block = prog.funcs[cur_func].add_block();
prog.funcs[true_func].blocks[true_block].terminator = JumpBlock(cont_block);
prog.funcs[false_func].blocks[false_block].terminator = JumpBlock(cont_block);
Ok((cur_func, cont_block, res_sym))
}
ExprKind::Iterate {
ref initial,
ref update_func,
} => {
let (cur_func, cur_block, initial_sym) =
gen_expr(initial, prog, cur_func, cur_block, tracker)?;
let argument_sym;
let func_body;
match update_func.kind {
ExprKind::Lambda {
ref params,
ref body,
} if params.len() == 1 => {
argument_sym = ¶ms[0].name;
func_body = body;
if params[0].ty != initial.ty {
return compile_err!("Wrong argument type for body of Iterate");
}
if func_body.ty != Struct(vec![initial.ty.clone(), Scalar(ScalarKind::Bool)]) {
return compile_err!("Wrong return type for body of Iterate");
}
prog.add_local_named(¶ms[0].ty, argument_sym, cur_func);
}
_ => return compile_err!("Argument of Iterate was not a Lambda"),
}
prog.funcs[cur_func].blocks[cur_block].add_statement(Statement::new(
Some(argument_sym.clone()),
Assign(initial_sym),
));
let body_start_block = prog.funcs[cur_func].add_block();
prog.funcs[cur_func].blocks[cur_block].terminator = JumpBlock(body_start_block);
let (body_end_func, body_end_block, result_sym) =
gen_expr(func_body, prog, cur_func, body_start_block, tracker)?;
let continue_sym = prog.add_local(&Scalar(ScalarKind::Bool), body_end_func);
prog.funcs[body_end_func].blocks[body_end_block].add_statement(Statement::new(
Some(argument_sym.clone()),
GetField {
value: result_sym.clone(),
index: 0,
},
));
prog.funcs[body_end_func].blocks[body_end_block].add_statement(Statement::new(
Some(continue_sym.clone()),
GetField {
value: result_sym,
index: 1,
},
));
let repeat_block = prog.funcs[body_end_func].add_block();
let finish_block = prog.funcs[body_end_func].add_block();
prog.funcs[body_end_func].blocks[body_end_block].terminator = Branch {
cond: continue_sym,
on_true: repeat_block,
on_false: finish_block,
};
assert!(body_end_func == cur_func);
prog.funcs[body_end_func].blocks[repeat_block].terminator = JumpBlock(body_start_block);
Ok((body_end_func, finish_block, argument_sym.clone()))
}
ExprKind::Merge {
ref builder,
ref value,
} => {
let (cur_func, cur_block, builder_sym) =
gen_expr(builder, prog, cur_func, cur_block, tracker)?;
let (cur_func, cur_block, elem_sym) =
gen_expr(value, prog, cur_func, cur_block, tracker)?;
prog.funcs[cur_func].blocks[cur_block].add_statement(Statement::new(
None,
Merge {
builder: builder_sym.clone(),
value: elem_sym,
},
));
Ok((cur_func, cur_block, builder_sym))
}
ExprKind::Res { ref builder } => {
let (cur_func, cur_block, builder_sym) =
gen_expr(builder, prog, cur_func, cur_block, tracker)?;
let kind = Res(builder_sym);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::NewBuilder(ref arg) => {
let (cur_func, cur_block, arg_sym) = if let Some(ref a) = *arg {
let (cur_func, cur_block, arg_sym) =
gen_expr(a, prog, cur_func, cur_block, tracker)?;
(cur_func, cur_block, Some(arg_sym))
} else {
(cur_func, cur_block, None)
};
let res_sym = prog.add_local(&expr.ty, cur_func);
prog.funcs[cur_func].blocks[cur_block].add_statement(Statement::new(
Some(res_sym.clone()),
NewBuilder {
arg: arg_sym,
ty: expr.ty.clone(),
},
));
Ok((cur_func, cur_block, res_sym))
}
ExprKind::MakeStruct { ref elems } => {
let mut syms = vec![];
let (mut cur_func, mut cur_block, mut sym) =
gen_expr(&elems[0], prog, cur_func, cur_block, tracker)?;
syms.push(sym);
for elem in elems.iter().skip(1) {
let r = gen_expr(elem, prog, cur_func, cur_block, tracker)?;
cur_func = r.0;
cur_block = r.1;
sym = r.2;
syms.push(sym);
}
let kind = MakeStruct(syms);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::MakeVector { ref elems } => {
let mut syms = vec![];
let mut cur_func = cur_func;
let mut cur_block = cur_block;
for elem in elems.iter() {
let r = gen_expr(elem, prog, cur_func, cur_block, tracker)?;
cur_func = r.0;
cur_block = r.1;
let sym = r.2;
syms.push(sym);
}
let kind = MakeVector(syms);
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::CUDF {
ref sym_name,
ref args,
..
} => {
let mut syms = vec![];
let mut cur_func = cur_func;
let mut cur_block = cur_block;
for arg in args.iter() {
let r = gen_expr(arg, prog, cur_func, cur_block, tracker)?;
cur_func = r.0;
cur_block = r.1;
let sym = r.2;
syms.push(sym);
}
let kind = CUDF {
args: syms,
symbol_name: sym_name.clone(),
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &expr.ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::GetField { ref expr, index } => {
let (cur_func, cur_block, struct_sym) =
gen_expr(expr, prog, cur_func, cur_block, tracker)?;
let field_ty = match expr.ty {
super::ast::Type::Struct(ref v) => &v[index as usize],
_ => compile_err!("Internal error: tried to get field of type {}", &expr.ty)?,
};
let kind = GetField {
value: struct_sym,
index,
};
let res_sym = tracker.symbol_for_statement(prog, cur_func, cur_block, &field_ty, kind);
Ok((cur_func, cur_block, res_sym))
}
ExprKind::For {
ref iters,
ref builder,
ref func,
} => {
if let ExprKind::Lambda {
ref params,
ref body,
} = func.kind
{
let (cur_func, cur_block, builder_sym) =
gen_expr(builder, prog, cur_func, cur_block, tracker)?;
let body_func = prog.add_func();
prog.funcs[body_func].loop_body = true;
let body_block = prog.funcs[body_func].add_block();
prog.add_local_named(¶ms[0].ty, ¶ms[0].name, body_func);
prog.add_local_named(¶ms[1].ty, ¶ms[1].name, body_func);
prog.add_local_named(¶ms[2].ty, ¶ms[2].name, body_func);
prog.add_loop_variable(¶ms[0].name, body_func);
prog.add_loop_variable(¶ms[1].name, body_func);
prog.add_loop_variable(¶ms[2].name, body_func);
prog.funcs[body_func]
.params
.insert(builder_sym.clone(), builder.ty.clone());
let mut cur_func = cur_func;
let mut cur_block = cur_block;
let mut pf_iters: Vec<ParallelForIter> = Vec::new();
for iter in iters.iter() {
let data_res = gen_expr(&iter.data, prog, cur_func, cur_block, tracker)?;
cur_func = data_res.0;
cur_block = data_res.1;
prog.funcs[body_func]
.params
.insert(data_res.2.clone(), iter.data.ty.clone());
let start_sym = get_iter_sym(
&iter.start,
prog,
&mut cur_func,
&mut cur_block,
tracker,
body_func,
)?;
let end_sym = get_iter_sym(
&iter.end,
prog,
&mut cur_func,
&mut cur_block,
tracker,
body_func,
)?;
let stride_sym = get_iter_sym(
&iter.stride,
prog,
&mut cur_func,
&mut cur_block,
tracker,
body_func,
)?;
let shape_sym = get_iter_sym(
&iter.shape,
prog,
&mut cur_func,
&mut cur_block,
tracker,
body_func,
)?;
let strides_sym = get_iter_sym(
&iter.strides,
prog,
&mut cur_func,
&mut cur_block,
tracker,
body_func,
)?;
pf_iters.push(ParallelForIter {
data: data_res.2,
start: start_sym,
end: end_sym,
stride: stride_sym,
kind: iter.kind.clone(),
shape: shape_sym,
strides: strides_sym,
});
}
let (body_end_func, body_end_block, result_sym) =
gen_expr(body, prog, body_func, body_block, tracker)?;
prog.funcs[body_end_func].blocks[body_end_block].terminator =
EndFunction(result_sym);
let mut is_innermost = true;
body.traverse(&mut |ref e| {
if let ExprKind::For { .. } = e.kind {
is_innermost = false;
}
});
prog.funcs[body_end_func].innermost_loop = is_innermost;
let kind = ParallelFor(ParallelForData {
data: pf_iters,
builder: builder_sym,
builder_arg: params[0].name.clone(),
idx_arg: params[1].name.clone(),
data_arg: params[2].name.clone(),
body: body_func,
innermost: is_innermost,
});
let res_sym =
tracker.symbol_for_statement(prog, cur_func, cur_block, &builder.ty, kind);
Ok((cur_func, cur_block, res_sym))
} else {
compile_err!("Argument to For was not a Lambda: {}", func.pretty_print())
}
}
_ => compile_err!("Unsupported expression: {}", expr.pretty_print()),
}
}