use std::collections::HashMap;
use std::iter::FromIterator;
use crate::{
parse_tree::{AsmOp, AsmRegister, LazyOp, Literal, Visibility},
semantic_analysis::{ast_node::TypedCodeBlock, ast_node::*, *},
type_engine::*,
};
use sway_types::{ident::Ident, span::Span};
use sway_ir::*;
pub(crate) fn compile_ast(ast: TypedParseTree) -> Result<Context, String> {
let mut ctx = Context::default();
match ast {
TypedParseTree::Script {
namespace,
main_function,
declarations,
all_nodes: _,
} => compile_script(&mut ctx, main_function, namespace, declarations),
TypedParseTree::Predicate {
namespace: _,
main_function: _,
declarations: _,
all_nodes: _,
} => unimplemented!("compile predicate to ir"),
TypedParseTree::Contract {
abi_entries,
namespace: _,
declarations,
all_nodes: _,
} => compile_contract(&mut ctx, abi_entries, declarations),
TypedParseTree::Library {
namespace: _,
all_nodes: _,
} => unimplemented!("compile library to ir"),
}?;
ctx.verify()?;
Ok(ctx)
}
fn compile_script(
context: &mut Context,
main_function: TypedFunctionDeclaration,
namespace: Namespace,
declarations: Vec<TypedDeclaration>,
) -> Result<Module, String> {
let module = Module::new(context, Kind::Script, "script");
compile_constants(context, module, &namespace, false)?;
compile_declarations(context, module, declarations)?;
compile_function(context, module, main_function)?;
Ok(module)
}
fn compile_contract(
context: &mut Context,
abi_entries: Vec<TypedFunctionDeclaration>,
declarations: Vec<TypedDeclaration>,
) -> Result<Module, String> {
let module = Module::new(context, Kind::Contract, "contract");
compile_declarations(context, module, declarations)?;
for decl in abi_entries {
compile_abi_method(context, module, decl)?;
}
Ok(module)
}
fn compile_constants(
context: &mut Context,
module: Module,
namespace: &Namespace,
public_only: bool,
) -> Result<(), String> {
for decl in namespace.get_all_declared_symbols() {
if let TypedDeclaration::ConstantDeclaration(TypedConstantDeclaration {
name,
value,
visibility,
}) = decl
{
if !public_only || matches!(visibility, Visibility::Public) {
let const_val = compile_constant_expression(context, value)?;
module.add_global_constant(context, name.as_str().to_owned(), const_val);
}
}
}
for ns in namespace.get_all_imported_modules() {
compile_constants(context, module, ns, true)?;
}
Ok(())
}
fn compile_constant_expression(
context: &mut Context,
const_expr: &TypedExpression,
) -> Result<Value, String> {
if let TypedExpressionVariant::Literal(literal) = &const_expr.expression {
Ok(convert_literal_to_value(context, literal))
} else {
Err("Unsupported constant declaration type.".into())
}
}
fn compile_declarations(
context: &mut Context,
module: Module,
declarations: Vec<TypedDeclaration>,
) -> Result<(), String> {
for declaration in declarations {
match declaration {
TypedDeclaration::ConstantDeclaration(decl) => {
let const_val = compile_constant_expression(context, &decl.value)?;
module.add_global_constant(context, decl.name.as_str().to_owned(), const_val);
}
TypedDeclaration::FunctionDeclaration(decl) => compile_function(context, module, decl)?,
TypedDeclaration::ImplTrait {
methods,
type_implementing_for,
..
} => compile_impl(context, module, type_implementing_for, methods)?,
TypedDeclaration::StructDeclaration(_)
| TypedDeclaration::TraitDeclaration(_)
| TypedDeclaration::EnumDeclaration(_)
| TypedDeclaration::VariableDeclaration(_)
| TypedDeclaration::Reassignment(_)
| TypedDeclaration::AbiDeclaration(_)
| TypedDeclaration::GenericTypeForFunctionScope { .. }
| TypedDeclaration::ErrorRecovery => (),
}
}
Ok(())
}
fn create_struct_aggregate(
context: &mut Context,
name: String,
fields: Vec<OwnedTypedStructField>,
) -> Result<Aggregate, String> {
let (field_types, syms): (Vec<_>, Vec<_>) = fields
.into_iter()
.map(|tsf| {
(
convert_resolved_typeid_no_span(context, &tsf.r#type),
tsf.name,
)
})
.unzip();
let field_types = field_types
.into_iter()
.collect::<Result<Vec<_>, String>>()?;
let aggregate = Aggregate::new_struct(context, Some(name), field_types);
context.add_aggregate_symbols(
aggregate,
HashMap::from_iter(syms.into_iter().enumerate().map(|(n, sym)| (sym, n as u64))),
)?;
Ok(aggregate)
}
fn compile_enum_decl(
context: &mut Context,
enum_decl: TypedEnumDeclaration,
) -> Result<Aggregate, String> {
let TypedEnumDeclaration {
name,
type_parameters,
variants,
.. } = enum_decl;
if !type_parameters.is_empty() {
return Err("Unable to compile generic enums.".into());
}
create_enum_aggregate(
context,
name.as_str().to_owned(),
variants
.into_iter()
.map(|tev| tev.as_owned_typed_enum_variant())
.collect(),
)
}
fn create_enum_aggregate(
context: &mut Context,
name: String,
variants: Vec<OwnedTypedEnumVariant>,
) -> Result<Aggregate, String> {
let (field_types, syms): (Vec<_>, Vec<_>) = variants
.into_iter()
.map(|tev| {
(
convert_resolved_typeid_no_span(context, &tev.r#type),
tev.name,
)
})
.unzip();
let field_types = field_types
.into_iter()
.collect::<Result<Vec<_>, String>>()?;
let enum_aggregate = Aggregate::new_struct(context, Some(name.clone() + "_union"), field_types);
context.add_aggregate_symbols(
enum_aggregate,
HashMap::from_iter(syms.into_iter().enumerate().map(|(n, sym)| (sym, n as u64))),
)?;
Ok(Aggregate::new_struct(
context,
Some(name),
vec![Type::Uint(64), Type::Union(enum_aggregate)],
))
}
fn create_tuple_aggregate(context: &mut Context, fields: Vec<TypeId>) -> Result<Aggregate, String> {
let field_types = fields
.into_iter()
.map(|ty_id| convert_resolved_typeid_no_span(context, &ty_id))
.collect::<Result<Vec<_>, String>>()?;
Ok(Aggregate::new_struct(context, None, field_types))
}
fn compile_function(
context: &mut Context,
module: Module,
ast_fn_decl: TypedFunctionDeclaration,
) -> Result<(), String> {
if !ast_fn_decl.type_parameters.is_empty() {
Ok(())
} else {
let args = ast_fn_decl
.parameters
.iter()
.map(|param| {
convert_resolved_typeid(context, ¶m.r#type, ¶m.type_span)
.map(|ty| (param.name.as_str().into(), ty))
})
.collect::<Result<Vec<(String, Type)>, String>>()?;
compile_fn_with_args(context, module, ast_fn_decl, args, None)
}
}
fn compile_fn_with_args(
context: &mut Context,
module: Module,
ast_fn_decl: TypedFunctionDeclaration,
args: Vec<(String, Type)>,
selector: Option<[u8; 4]>,
) -> Result<(), String> {
let TypedFunctionDeclaration {
name,
body,
return_type,
return_type_span,
visibility,
..
} = ast_fn_decl;
let ret_type = convert_resolved_typeid(context, &return_type, &return_type_span)?;
let func = Function::new(
context,
module,
name.as_str().to_owned(),
args,
ret_type,
selector,
visibility == Visibility::Public,
);
let mut compiler = FnCompiler::new(context, module, func);
let ret_val = compiler.compile_code_block(context, body)?;
compiler.current_block.ins(context).ret(ret_val, ret_type);
Ok(())
}
fn compile_impl(
context: &mut Context,
module: Module,
self_type: TypeInfo,
ast_methods: Vec<TypedFunctionDeclaration>,
) -> Result<(), String> {
for method in ast_methods {
let args = method
.parameters
.iter()
.map(|param| {
if param.name.as_str() == "self" {
convert_resolved_type(context, &self_type)
} else {
convert_resolved_typeid(context, ¶m.r#type, ¶m.type_span)
}
.map(|ty| (param.name.as_str().into(), ty))
})
.collect::<Result<Vec<(String, Type)>, String>>()?;
compile_fn_with_args(context, module, method, args, None)?;
}
Ok(())
}
fn compile_abi_method(
context: &mut Context,
module: Module,
ast_fn_decl: TypedFunctionDeclaration,
) -> Result<(), String> {
let selector = ast_fn_decl.to_fn_selector_value().value.ok_or(format!(
"Cannot generate selector for ABI method: {}",
ast_fn_decl.name.as_str()
))?;
let args = ast_fn_decl
.parameters
.iter()
.map(|param| {
convert_resolved_typeid(context, ¶m.r#type, ¶m.type_span)
.map(|ty| (param.name.as_str().into(), ty))
})
.collect::<Result<Vec<(String, Type)>, String>>()?;
compile_fn_with_args(context, module, ast_fn_decl, args, Some(selector))
}
struct FnCompiler {
module: Module,
function: Function,
current_block: Block,
symbol_map: HashMap<String, String>,
}
impl FnCompiler {
fn new(context: &mut Context, module: Module, function: Function) -> Self {
let symbol_map = HashMap::from_iter(
function
.args_iter(context)
.map(|(name, _value)| (name.clone(), name.clone())),
);
FnCompiler {
module,
function,
current_block: function.get_entry_block(context),
symbol_map,
}
}
fn compile_code_block(
&mut self,
context: &mut Context,
ast_block: TypedCodeBlock,
) -> Result<Value, String> {
ast_block
.contents
.into_iter()
.map(|ast_node| {
match ast_node.content {
TypedAstNodeContent::ReturnStatement(trs) => {
self.compile_return_statement(context, trs.expr)
}
TypedAstNodeContent::Declaration(td) => match td {
TypedDeclaration::VariableDeclaration(tvd) => {
self.compile_var_decl(context, tvd)
}
TypedDeclaration::ConstantDeclaration(tcd) => {
self.compile_const_decl(context, tcd)
}
TypedDeclaration::FunctionDeclaration(_) => Err("func decl".into()),
TypedDeclaration::TraitDeclaration(_) => Err("trait decl".into()),
TypedDeclaration::StructDeclaration(_) => Err("struct decl".into()),
TypedDeclaration::EnumDeclaration(ted) => {
compile_enum_decl(context, ted).map(|_| ())?;
Ok(Constant::get_unit(context))
}
TypedDeclaration::Reassignment(tr) => {
self.compile_reassignment(context, tr)
}
TypedDeclaration::ImplTrait { .. } => {
Ok(Constant::get_unit(context))
}
TypedDeclaration::AbiDeclaration(_) => Err("abi decl".into()),
TypedDeclaration::GenericTypeForFunctionScope { .. } => {
Err("gen ty for fn scope".into())
}
TypedDeclaration::ErrorRecovery { .. } => Err("error recovery".into()),
},
TypedAstNodeContent::Expression(te) => {
self.compile_expression(context, te)
}
TypedAstNodeContent::ImplicitReturnExpression(te) => {
self.compile_expression(context, te)
}
TypedAstNodeContent::WhileLoop(twl) => self.compile_while_loop(context, twl),
TypedAstNodeContent::SideEffect => Err("code block side effect".into()),
}
})
.collect::<Result<Vec<_>, String>>()
.map(|vals| vals.last().cloned())
.transpose()
.unwrap_or_else(|| Ok(Constant::get_unit(context)))
}
fn compile_expression(
&mut self,
context: &mut Context,
ast_expr: TypedExpression,
) -> Result<Value, String> {
match ast_expr.expression {
TypedExpressionVariant::Literal(l) => Ok(convert_literal_to_value(context, &l)),
TypedExpressionVariant::FunctionApplication {
name,
arguments,
function_body,
..
} => self.compile_fn_call(
context,
name.suffix.as_str(),
arguments,
Some(function_body),
),
TypedExpressionVariant::LazyOperator { op, lhs, rhs, .. } => {
self.compile_lazy_op(context, op, *lhs, *rhs)
}
TypedExpressionVariant::VariableExpression { name } => {
self.compile_var_expr(context, name.as_str())
}
TypedExpressionVariant::Array { contents } => {
self.compile_array_expr(context, contents)
}
TypedExpressionVariant::ArrayIndex { prefix, index } => {
self.compile_array_index(context, *prefix, *index)
}
TypedExpressionVariant::StructExpression {
struct_name,
fields,
} => self.compile_struct_expr(context, struct_name.as_str(), fields),
TypedExpressionVariant::CodeBlock(cb) => self.compile_code_block(context, cb),
TypedExpressionVariant::FunctionParameter => Err("expr func param".into()),
TypedExpressionVariant::IfExp {
condition,
then,
r#else,
} => self.compile_if(context, *condition, *then, r#else),
TypedExpressionVariant::AsmExpression {
registers,
body,
returns,
..
} => self.compile_asm_expr(context, registers, body, returns),
TypedExpressionVariant::StructFieldAccess {
prefix,
field_to_access,
resolved_type_of_parent,
..
} => self.compile_struct_field_expr(
context,
*prefix,
field_to_access,
resolved_type_of_parent,
),
TypedExpressionVariant::EnumInstantiation {
enum_decl,
tag,
contents,
..
} => self.compile_enum_expr(context, enum_decl, tag, contents),
TypedExpressionVariant::EnumArgAccess {
..
} => Err("enum arg access".into()),
TypedExpressionVariant::Tuple {
fields
} => self.compile_tuple_expr(context, fields),
TypedExpressionVariant::TupleElemAccess {
prefix,
elem_to_access_num: idx,
elem_to_access_span: span,
resolved_type_of_parent: tuple_type,
} => self.compile_tuple_elem_expr( context, *prefix, tuple_type, idx, span),
TypedExpressionVariant::AbiCast { .. } => Ok(Constant::get_unit(context)),
}
}
fn compile_return_statement(
&mut self,
context: &mut Context,
ast_expr: TypedExpression,
) -> Result<Value, String> {
let ret_value = self.compile_expression(context, ast_expr)?;
match ret_value.get_type(context) {
None => Err("Unable to determine type for return statement expression.".into()),
Some(ret_ty) => {
self.current_block.ins(context).ret(ret_value, ret_ty);
self.current_block = self.function.create_block(context, None);
Ok(Constant::get_unit(context))
}
}
}
fn compile_lazy_op(
&mut self,
context: &mut Context,
ast_op: LazyOp,
ast_lhs: TypedExpression,
ast_rhs: TypedExpression,
) -> Result<Value, String> {
let lhs_val = self.compile_expression(context, ast_lhs)?;
let rhs_block = self.function.create_block(context, None);
let final_block = self.function.create_block(context, None);
let cond_builder = self.current_block.ins(context);
match ast_op {
LazyOp::And => {
cond_builder.conditional_branch(lhs_val, rhs_block, final_block, Some(lhs_val))
}
LazyOp::Or => {
cond_builder.conditional_branch(lhs_val, final_block, rhs_block, Some(lhs_val))
}
};
self.current_block = rhs_block;
let rhs_val = self.compile_expression(context, ast_rhs)?;
self.current_block
.ins(context)
.branch(final_block, Some(rhs_val));
self.current_block = final_block;
Ok(final_block.get_phi(context))
}
fn compile_fn_call(
&mut self,
context: &mut Context,
ast_name: &str,
ast_args: Vec<(Ident, TypedExpression)>,
callee_body: Option<TypedCodeBlock>,
) -> Result<Value, String> {
match context
.module_iter()
.flat_map(|module| module.function_iter(context))
.find(|function| function.get_name(context) == ast_name)
{
Some(callee) => {
let args = ast_args
.into_iter()
.map(|(_, expr)| self.compile_expression(context, expr))
.collect::<Result<Vec<Value>, String>>()?;
Ok(self.current_block.ins(context).call(callee, &args))
}
None if callee_body.is_none() => Err(format!("function not found: {}", ast_name)),
None => {
let callee_name = context.get_unique_name();
let callee_name_len = callee_name.len();
let callee_ident = Ident::new(crate::span::Span {
span: pest::Span::new(
std::sync::Arc::from(callee_name.clone()),
0,
callee_name_len,
)
.unwrap(),
path: None,
});
let parameters = ast_args
.iter()
.map(|(name, expr)| TypedFunctionParameter {
name: name.clone(),
r#type: expr.return_type,
type_span: crate::span::Span {
span: pest::Span::new(" ".into(), 0, 0).unwrap(),
path: None,
},
})
.collect();
let callee_body = callee_body.unwrap();
let return_type =
Self::get_codeblock_return_type(&callee_body).unwrap_or_else(||
insert_type(TypeInfo::Tuple(Vec::new())));
let callee_fn_decl = TypedFunctionDeclaration {
name: callee_ident,
body: callee_body,
parameters,
span: crate::span::Span {
span: pest::Span::new(" ".into(), 0, 0).unwrap(),
path: None,
},
return_type,
type_parameters: Vec::new(),
return_type_span: crate::span::Span {
span: pest::Span::new(" ".into(), 0, 0).unwrap(),
path: None,
},
visibility: Visibility::Private,
is_contract_call: false,
purity: Default::default(),
};
compile_function(context, self.module, callee_fn_decl)?;
self.compile_fn_call(context, &callee_name, ast_args, None)
}
}
}
fn get_codeblock_return_type(codeblock: &TypedCodeBlock) -> Option<TypeId> {
if codeblock.contents.is_empty() {
Some(insert_type(TypeInfo::Tuple(Vec::new())))
} else {
codeblock
.contents
.iter()
.find_map(|node| match &node.content {
TypedAstNodeContent::ReturnStatement(trs) => Some(trs.expr.return_type),
TypedAstNodeContent::ImplicitReturnExpression(te) => Some(te.return_type),
_otherwise => None,
})
}
}
fn compile_if(
&mut self,
context: &mut Context,
ast_condition: TypedExpression,
ast_then: TypedExpression,
ast_else: Option<Box<TypedExpression>>,
) -> Result<Value, String> {
let cond_value = self.compile_expression(context, ast_condition)?;
let entry_block = self.current_block;
let true_block_begin = self.function.create_block(context, None);
self.current_block = true_block_begin;
let true_value = self.compile_expression(context, ast_then)?;
let true_block_end = self.current_block;
let false_block_begin = self.function.create_block(context, None);
self.current_block = false_block_begin;
let false_value = match ast_else {
None => Constant::get_unit(context),
Some(expr) => self.compile_expression(context, *expr)?,
};
let false_block_end = self.current_block;
entry_block.ins(context).conditional_branch(
cond_value,
true_block_begin,
false_block_begin,
None,
);
let merge_block = self.function.create_block(context, None);
true_block_end
.ins(context)
.branch(merge_block, Some(true_value));
false_block_end
.ins(context)
.branch(merge_block, Some(false_value));
self.current_block = merge_block;
Ok(merge_block.get_phi(context))
}
fn compile_while_loop(
&mut self,
context: &mut Context,
ast_while_loop: TypedWhileLoop,
) -> Result<Value, String> {
let cond_block = self.function.create_block(context, Some("while".into()));
self.current_block.ins(context).branch(cond_block, None);
let body_block = self
.function
.create_block(context, Some("while_body".into()));
self.current_block = body_block;
self.compile_code_block(context, ast_while_loop.body)?;
self.current_block.ins(context).branch(cond_block, None);
let final_block = self
.function
.create_block(context, Some("end_while".into()));
self.current_block = cond_block;
let cond_value = self.compile_expression(context, ast_while_loop.condition)?;
self.current_block.ins(context).conditional_branch(
cond_value,
body_block,
final_block,
None,
);
self.current_block = final_block;
Ok(Constant::get_unit(context))
}
fn compile_var_expr(&mut self, context: &mut Context, name: &str) -> Result<Value, String> {
if let Some(ptr) = self
.symbol_map
.get(name)
.map(|local_name| self.function.get_local_ptr(context, local_name))
.flatten()
{
Ok(if ptr.is_struct_ptr(context) {
self.current_block.ins(context).get_ptr(ptr)
} else {
self.current_block.ins(context).load(ptr)
})
} else if let Some(val) = self.function.get_arg(context, name) {
Ok(val)
} else if let Some(const_val) = self.module.get_global_constant(context, name) {
Ok(const_val)
} else {
Err(format!("Unable to resolve variable '{}'.", name))
}
}
fn compile_var_decl(
&mut self,
context: &mut Context,
ast_var_decl: TypedVariableDeclaration,
) -> Result<Value, String> {
let TypedVariableDeclaration {
name,
body,
is_mutable,
..
} = ast_var_decl;
let return_type = convert_resolved_typeid(context, &body.return_type, &body.span)?;
let init_val = self.compile_expression(context, body)?;
let local_name = match self.symbol_map.get(name.as_str()) {
None => {
name.as_str().to_owned()
}
Some(shadowed_name) => {
format!("{}_", shadowed_name)
}
};
self.symbol_map
.insert(name.as_str().to_owned(), local_name.clone());
let ptr =
self.function
.new_local_ptr(context, local_name, return_type, is_mutable, None)?;
self.current_block.ins(context).store(ptr, init_val);
Ok(init_val)
}
fn compile_const_decl(
&mut self,
context: &mut Context,
ast_const_decl: TypedConstantDeclaration,
) -> Result<Value, String> {
let TypedConstantDeclaration { name, value, .. } = ast_const_decl;
if let TypedExpressionVariant::Literal(literal) = &value.expression {
let initialiser = convert_literal_to_constant(literal);
let return_type = convert_resolved_typeid(context, &value.return_type, &value.span)?;
let name = name.as_str().to_owned();
self.function.new_local_ptr(
context,
name.clone(),
return_type,
false,
Some(initialiser),
)?;
self.symbol_map.insert(name.clone(), name);
Ok(Constant::get_unit(context))
} else {
Err("Unsupported constant declaration type.".into())
}
}
fn compile_reassignment(
&mut self,
context: &mut Context,
ast_reassignment: TypedReassignment,
) -> Result<Value, String> {
let name = ast_reassignment.lhs[0].name.as_str();
let ptr_val = self
.function
.get_local_ptr(context, name)
.ok_or(format!("variable not found: {}", name))?;
let reassign_val = self.compile_expression(context, ast_reassignment.rhs)?;
if ast_reassignment.lhs.len() == 1 {
self.current_block.ins(context).store(ptr_val, reassign_val);
} else {
let field_idcs = ast_reassignment.lhs[1..]
.iter()
.fold(
Ok((Vec::new(), *ptr_val.get_type(context))),
|acc, field_name| {
acc.and_then(|(mut fld_idcs, ty)| match ty {
Type::Struct(aggregate) => {
match context
.get_aggregate_index(&aggregate, field_name.name.as_str())
{
None => Err(format!(
"Unknown field name {} for struct ???",
field_name.name.as_str()
)),
Some(field_idx) => {
let field_type = context.aggregates[aggregate.0]
.field_types()
[field_idx as usize];
fld_idcs.push(field_idx);
Ok((fld_idcs, field_type))
}
}
}
_otherwise => {
Err("Reassignment with multiple accessors to non-aggregate.".into())
}
})
},
)?
.0;
let ty = match ptr_val.get_type(context) {
Type::Struct(aggregate) => *aggregate,
_otherwise => {
return Err("Reassignment with multiple accessors to non-aggregate.".into())
}
};
let get_ptr_val = self.current_block.ins(context).get_ptr(ptr_val);
self.current_block
.ins(context)
.insert_value(get_ptr_val, ty, reassign_val, field_idcs);
}
Ok(reassign_val)
}
fn compile_array_expr(
&mut self,
context: &mut Context,
contents: Vec<TypedExpression>,
) -> Result<Value, String> {
if contents.is_empty() {
return Err("Unable to create zero sized static arrays.".into());
}
let elem_type = convert_resolved_typeid_no_span(context, &contents[0].return_type)?;
let aggregate = Aggregate::new_array(context, elem_type, contents.len() as u64);
let array_value = Constant::get_undef(context, Type::Array(aggregate));
contents
.into_iter()
.enumerate()
.fold(Ok(array_value), |array_value, (idx, elem_expr)| {
match array_value {
Err(_) => array_value,
Ok(array_value) => {
let index_val = Constant::get_uint(context, 64, idx as u64);
self.compile_expression(context, elem_expr)
.map(|elem_value| {
self.current_block.ins(context).insert_element(
array_value,
aggregate,
elem_value,
index_val,
)
})
}
}
})
}
fn compile_array_index(
&mut self,
context: &mut Context,
array_expr: TypedExpression,
index_expr: TypedExpression,
) -> Result<Value, String> {
let array_val = self.compile_expression(context, array_expr)?;
let aggregate = match &context.values[array_val.0] {
ValueContent::Instruction(instruction) => {
instruction.get_aggregate(context).ok_or_else(|| {
format!(
"Unsupported instruction as array value for index expression. {:?}",
instruction
)
})
}
ValueContent::Argument(Type::Array(aggregate)) => Ok(*aggregate),
otherwise => Err(format!(
"Unsupported array value for index expression: {:?}",
otherwise
)),
}?;
let (_, count) = context.aggregates[aggregate.0].array_type();
if let TypedExpressionVariant::Literal(Literal::U64(index)) = index_expr.expression {
if index >= *count {
return Err(format!(
"Array index out of bounds; the length is {} but the index is {}.",
*count, index
));
}
}
let index_val = self.compile_expression(context, index_expr)?;
Ok(self
.current_block
.ins(context)
.extract_element(array_val, aggregate, index_val))
}
fn compile_struct_expr(
&mut self,
context: &mut Context,
struct_name: &str,
fields: Vec<TypedStructExpressionField>,
) -> Result<Value, String> {
let aggregate = context
.get_aggregate_by_name(struct_name)
.ok_or_else(|| format!("Unknown aggregate {}", struct_name))?;
let inserted_values_indices = fields
.into_iter()
.map(|field_value| {
let name = field_value.name.as_str();
self.compile_expression(context, field_value.value)
.and_then(|insert_val| {
context
.get_aggregate_index(&aggregate, name)
.ok_or_else(|| {
format!("Unknown field name {} for aggregate {}", name, struct_name)
})
.map(|insert_idx| (insert_val, insert_idx))
})
})
.collect::<Result<Vec<_>, String>>()?;
let agg_value = Constant::get_undef(context, Type::Struct(aggregate));
Ok(inserted_values_indices.into_iter().fold(
agg_value,
|agg_value, (insert_val, insert_idx)| {
self.current_block.ins(context).insert_value(
agg_value,
aggregate,
insert_val,
vec![insert_idx],
)
},
))
}
fn compile_struct_field_expr(
&mut self,
context: &mut Context,
ast_struct_expr: TypedExpression,
ast_field: OwnedTypedStructField,
_ast_parent_type: TypeId,
) -> Result<Value, String> {
let struct_val = self.compile_expression(context, ast_struct_expr)?;
let aggregate = match &context.values[struct_val.0] {
ValueContent::Instruction(instruction) => {
instruction.get_aggregate(context).ok_or_else(|| {
format!(
"Unsupported instruction as struct value for field expression. {:?}",
instruction
)
})
}
ValueContent::Argument(Type::Struct(aggregate)) => Ok(*aggregate),
otherwise => Err(format!(
"Unsupported struct value for field expression: {:?}",
otherwise
)),
}?;
let field_idx = context
.get_aggregate_index(&aggregate, &ast_field.name)
.ok_or_else(|| format!("Unknown field name {} in struct ???", ast_field.name))?;
Ok(self
.current_block
.ins(context)
.extract_value(struct_val, aggregate, vec![field_idx]))
}
fn compile_enum_expr(
&mut self,
context: &mut Context,
enum_decl: TypedEnumDeclaration,
tag: usize,
contents: Option<Box<TypedExpression>>,
) -> Result<Value, String> {
let aggregate = match context.get_aggregate_by_name(enum_decl.name.as_str()) {
Some(agg) => Ok(agg),
None => compile_enum_decl(context, enum_decl),
}?;
let tag_value = Constant::get_uint(context, 64, tag as u64);
let agg_value = Constant::get_undef(context, Type::Struct(aggregate));
let agg_value =
self.current_block
.ins(context)
.insert_value(agg_value, aggregate, tag_value, vec![0]);
Ok(match contents {
None => agg_value,
Some(te) => {
let contents_value = self.compile_expression(context, *te)?;
self.current_block.ins(context).insert_value(
agg_value,
aggregate,
contents_value,
vec![1],
)
}
})
}
fn compile_tuple_expr(
&mut self,
context: &mut Context,
fields: Vec<TypedExpression>,
) -> Result<Value, String> {
if fields.is_empty() {
Ok(Constant::get_unit(context))
} else {
let (init_values, init_types): (Vec<Value>, Vec<Type>) = fields
.into_iter()
.map(|field_expr| {
convert_resolved_typeid_no_span(context, &field_expr.return_type).and_then(
|init_type| {
self.compile_expression(context, field_expr)
.map(|init_value| (init_value, init_type))
},
)
})
.collect::<Result<Vec<_>, String>>()?
.into_iter()
.unzip();
let aggregate = Aggregate::new_struct(context, None, init_types);
let agg_value = Constant::get_undef(context, Type::Struct(aggregate));
Ok(init_values.into_iter().enumerate().fold(
agg_value,
|agg_value, (insert_idx, insert_val)| {
self.current_block.ins(context).insert_value(
agg_value,
aggregate,
insert_val,
vec![insert_idx as u64],
)
},
))
}
}
fn compile_tuple_elem_expr(
&mut self,
context: &mut Context,
tuple: TypedExpression,
tuple_type: TypeId,
idx: usize,
span: Span,
) -> Result<Value, String> {
let tuple_value = self.compile_expression(context, tuple)?;
if let Type::Struct(aggregate) = convert_resolved_typeid(context, &tuple_type, &span)? {
Ok(self.current_block.ins(context).extract_value(
tuple_value,
aggregate,
vec![idx as u64],
))
} else {
Err("Invalid (non-aggregate?) tuple type for TupleElemAccess?".into())
}
}
fn compile_asm_expr(
&mut self,
context: &mut Context,
registers: Vec<TypedAsmRegisterDeclaration>,
body: Vec<AsmOp>,
returns: Option<(AsmRegister, Span)>,
) -> Result<Value, String> {
let registers = registers
.into_iter()
.map(
|TypedAsmRegisterDeclaration {
initializer, name, ..
}| {
initializer
.map(|init_expr| self.compile_expression(context, init_expr))
.transpose()
.map(|init| AsmArg {
name,
initializer: init,
})
},
)
.collect::<Result<Vec<AsmArg>, String>>()?;
let body = body
.into_iter()
.map(
|AsmOp {
op_name,
op_args,
immediate,
..
}| AsmInstruction {
name: op_name,
args: op_args,
immediate,
},
)
.collect();
let returns = returns.as_ref().map(|(asm_reg, _)| {
Ident::new(Span {
span: pest::Span::new(asm_reg.name.as_str().into(), 0, asm_reg.name.len()).unwrap(),
path: None,
})
});
Ok(self
.current_block
.ins(context)
.asm_block(registers, body, returns))
}
}
fn convert_literal_to_value(context: &mut Context, ast_literal: &Literal) -> Value {
match ast_literal {
Literal::U8(n) | Literal::Byte(n) => Constant::get_uint(context, 8, *n as u64),
Literal::U16(n) => Constant::get_uint(context, 16, *n as u64),
Literal::U32(n) => Constant::get_uint(context, 32, *n as u64),
Literal::U64(n) => Constant::get_uint(context, 64, *n),
Literal::String(s) => Constant::get_string(context, s.as_str().to_owned()),
Literal::Boolean(b) => Constant::get_bool(context, *b),
Literal::B256(bs) => Constant::get_b256(context, *bs),
}
}
fn convert_literal_to_constant(ast_literal: &Literal) -> Constant {
match ast_literal {
Literal::U8(n) | Literal::Byte(n) => Constant::new_uint(8, *n as u64),
Literal::U16(n) => Constant::new_uint(16, *n as u64),
Literal::U32(n) => Constant::new_uint(32, *n as u64),
Literal::U64(n) => Constant::new_uint(64, *n),
Literal::String(s) => Constant::new_string(s.as_str().to_owned()),
Literal::Boolean(b) => Constant::new_bool(*b),
Literal::B256(bs) => Constant::new_b256(*bs),
}
}
fn convert_resolved_typeid(
context: &mut Context,
ast_type: &TypeId,
span: &Span,
) -> Result<Type, String> {
convert_resolved_type(
context,
&resolve_type(*ast_type, span).map_err(|ty_err| format!("{:?}", ty_err))?,
)
}
fn convert_resolved_typeid_no_span(
context: &mut Context,
ast_type: &TypeId,
) -> Result<Type, String> {
let span = crate::span::Span {
span: pest::Span::new(" ".into(), 0, 0).unwrap(),
path: None,
};
convert_resolved_typeid(context, ast_type, &span)
}
fn convert_resolved_type(context: &mut Context, ast_type: &TypeInfo) -> Result<Type, String> {
Ok(match ast_type {
TypeInfo::UnsignedInteger(nbits) => {
let nbits = match nbits {
IntegerBits::Eight => 8,
IntegerBits::Sixteen => 16,
IntegerBits::ThirtyTwo => 32,
IntegerBits::SixtyFour => 64,
};
Type::Uint(nbits)
}
TypeInfo::Boolean => Type::Bool,
TypeInfo::Byte => Type::Uint(8), TypeInfo::B256 => Type::B256,
TypeInfo::Str(n) => Type::String(*n),
TypeInfo::Struct { name, fields } => match context.get_aggregate_by_name(name) {
Some(existing_aggregate) => Type::Struct(existing_aggregate),
None => {
create_struct_aggregate(context, name.clone(), fields.clone()).map(&Type::Struct)?
}
},
TypeInfo::Enum {
name,
variant_types,
} => {
match context.get_aggregate_by_name(name) {
Some(existing_aggregate) => Type::Struct(existing_aggregate),
None => {
create_enum_aggregate(context, name.clone(), variant_types.clone())
.map(&Type::Struct)?
}
}
}
TypeInfo::Array(elem_type_id, count) => {
let elem_type = convert_resolved_typeid_no_span(context, elem_type_id)?;
Type::Array(Aggregate::new_array(context, elem_type, *count as u64))
}
TypeInfo::Tuple(fields) => {
if fields.is_empty() {
Type::Unit
} else {
create_tuple_aggregate(context, fields.clone()).map(Type::Struct)?
}
}
TypeInfo::Custom { .. } => return Err("can't do custom types yet".into()),
TypeInfo::SelfType { .. } => return Err("can't do self types yet".into()),
TypeInfo::Contract => Type::Contract,
TypeInfo::ContractCaller { abi_name, address } => Type::ContractCaller(AbiInstance::new(
context,
abi_name.prefixes.clone(),
abi_name.suffix.clone(),
address.clone(),
)),
TypeInfo::Unknown => return Err("unknown type found in AST..?".into()),
TypeInfo::UnknownGeneric { .. } => return Err("unknowngeneric type found in AST..?".into()),
TypeInfo::Numeric => return Err("'numeric' type found in AST..?".into()),
TypeInfo::Ref(_) => return Err("ref type found in AST..?".into()),
TypeInfo::ErrorRecovery => return Err("error recovery type found in AST..?".into()),
})
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use crate::{
control_flow_analysis::{ControlFlowGraph, Graph},
parser::{Rule, SwayParser},
semantic_analysis::{TreeType, TypedParseTree},
};
use pest::Parser;
#[test]
fn sway_to_ir_tests() {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let dir: PathBuf = format!("{}/tests/sway_to_ir", manifest_dir).into();
for entry in std::fs::read_dir(dir).unwrap() {
let path = entry.unwrap().path();
match path.extension().unwrap().to_str() {
Some("sw") => {
println!("---- Sway To IR: {:?} ----", path);
test_sway_to_ir(path);
}
Some("ir") | Some("disabled") => (),
_ => panic!(
"File with invalid extension in tests dir: {:?}",
path.file_name().unwrap_or(path.as_os_str())
),
}
}
}
fn test_sway_to_ir(mut path: PathBuf) {
let input_bytes = std::fs::read(&path).unwrap();
let input = String::from_utf8_lossy(&input_bytes);
path.set_extension("ir");
let expected_bytes = std::fs::read(&path).unwrap();
let expected = String::from_utf8_lossy(&expected_bytes);
let typed_ast = parse_to_typed_ast(&input);
let ir = super::compile_ast(typed_ast).unwrap();
let output = sway_ir::printer::to_string(&ir);
if output != expected {
println!("{}", prettydiff::diff_lines(&expected, &output));
}
assert_eq!(output, expected);
}
#[test]
fn ir_printer_parser_tests() {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let dir: PathBuf = format!("{}/tests/sway_to_ir", manifest_dir).into();
for entry in std::fs::read_dir(dir).unwrap() {
let path = entry.unwrap().path();
match path.extension().unwrap().to_str() {
Some("ir") => {
println!("---- IR Print and Parse Test: {:?} ----", path);
test_printer_parser(path);
}
Some("sw") | Some("disabled") => (),
_ => panic!(
"File with invalid extension in tests dir: {:?}",
path.file_name().unwrap_or(path.as_os_str())
),
}
}
}
fn test_printer_parser(path: PathBuf) {
let input_bytes = std::fs::read(&path).unwrap();
let input = String::from_utf8_lossy(&input_bytes);
let parsed_ctx = match sway_ir::parser::parse(&input) {
Ok(p) => p,
Err(e) => {
println!("{}: {}", path.display(), e);
panic!();
}
};
let printed = sway_ir::printer::to_string(&parsed_ctx);
if printed != input {
println!("{}", prettydiff::diff_lines(&input, &printed));
}
assert_eq!(input, printed);
}
fn parse_to_typed_ast(input: &str) -> TypedParseTree {
let mut parsed =
SwayParser::parse(Rule::program, std::sync::Arc::from(input)).expect("parse_tree");
let mut warnings = vec![];
let mut errors = vec![];
let parse_tree = crate::parse_root_from_pairs(parsed.next().unwrap().into_inner(), None)
.unwrap(&mut warnings, &mut errors);
let mut dead_code_graph = ControlFlowGraph {
graph: Graph::new(),
entry_points: vec![],
namespace: Default::default(),
};
let build_config = crate::build_config::BuildConfig {
file_name: std::sync::Arc::new("test.sw".into()),
dir_of_code: std::sync::Arc::new("tests".into()),
manifest_path: std::sync::Arc::new(".".into()),
use_ir: false,
print_intermediate_asm: false,
print_finalized_asm: false,
print_ir: false,
};
TypedParseTree::type_check(
parse_tree.tree,
Default::default(),
&TreeType::Script,
&build_config,
&mut dead_code_graph,
&mut std::collections::HashMap::new(),
)
.unwrap(&mut warnings, &mut errors)
}
}