use crate::error::Error::{InvalidOperation, UnknownId};
use crate::ir::function::FunctionModifier;
use crate::ir::id::{FunctionID, ImportsID, LocalID, TypeID};
use crate::ir::module::side_effects::{InjectType, Injection};
use crate::ir::module::{AsVec, GetID, LocalOrImport};
use crate::ir::types;
use crate::ir::types::{
Body, FuncInstrFlag, HasInjectTag, InjectTag, InstrumentationMode, Tag, TagUtils,
};
use crate::DataType;
use log::warn;
use std::collections::HashMap;
use wasmparser::Operator;
#[derive(Clone, Debug)]
pub struct Function<'a> {
pub(crate) kind: FuncKind<'a>,
name: Option<String>,
pub(crate) deleted: bool,
}
impl GetID for Function<'_> {
fn get_id(&self) -> u32 {
match &self.kind {
FuncKind::Import(i) => *i.import_fn_id,
FuncKind::Local(l) => *l.func_id,
}
}
}
impl LocalOrImport for Function<'_> {
fn is_local(&self) -> bool {
matches!(&self.kind, FuncKind::Local(_))
}
fn is_import(&self) -> bool {
matches!(&self.kind, FuncKind::Import(_))
}
fn is_deleted(&self) -> bool {
self.deleted
}
}
impl<'a> Function<'a> {
pub fn new(kind: FuncKind<'a>, name: Option<String>) -> Self {
Function {
kind,
name,
deleted: false,
}
}
pub fn get_type_id(&self) -> TypeID {
self.kind.get_type()
}
pub(crate) fn set_kind(&mut self, kind: FuncKind<'a>) {
self.kind = kind;
self.deleted = false;
}
pub fn kind(&self) -> &FuncKind<'a> {
&self.kind
}
pub fn unwrap_local(&self) -> types::Result<&LocalFunction<'a>> {
self.kind.unwrap_local()
}
pub fn unwrap_local_mut(&mut self) -> types::Result<&mut LocalFunction<'a>> {
self.kind.unwrap_local_mut()
}
pub(crate) fn delete(&mut self) {
self.deleted = true;
}
}
#[derive(Clone, Debug)]
pub enum FuncKind<'a> {
Local(Box<LocalFunction<'a>>),
Import(ImportedFunction),
}
impl<'a> FuncKind<'a> {
pub fn unwrap_local(&self) -> types::Result<&LocalFunction<'a>> {
match &self {
FuncKind::Local(l) => Ok(l),
FuncKind::Import(_) => Err(InvalidOperation(
"Attempting to unwrap an imported function as a local!!".to_string(),
)),
}
}
pub fn unwrap_local_mut(&mut self) -> types::Result<&mut LocalFunction<'a>> {
match self {
FuncKind::Local(l) => Ok(l),
FuncKind::Import(_) => Err(InvalidOperation(
"Attempting to unwrap an imported function as a local!!".to_string(),
)),
}
}
pub fn get_type(&self) -> TypeID {
match &self {
FuncKind::Local(l) => l.ty_id,
FuncKind::Import(i) => i.ty_id,
}
}
}
impl PartialEq for FuncKind<'_> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(FuncKind::Import(i1), FuncKind::Import(i2)) => i1.ty_id == i2.ty_id,
(FuncKind::Local(l1), FuncKind::Local(l2)) => l1.ty_id == l2.ty_id,
_ => false,
}
}
}
impl Eq for FuncKind<'_> {}
#[derive(Clone, Debug)]
pub struct LocalFunction<'a> {
pub ty_id: TypeID,
pub func_id: FunctionID,
pub instr_flag: FuncInstrFlag<'a>,
pub body: Body<'a>,
pub args: Vec<LocalID>,
tag: InjectTag,
}
impl TagUtils for LocalFunction<'_> {
fn get_or_create_tag(&mut self) -> &mut Tag {
self.tag.get_or_insert_default()
}
fn get_tag(&self) -> &Option<Tag> {
&self.tag
}
}
impl HasInjectTag for LocalFunction<'_> {}
impl<'a> LocalFunction<'a> {
pub fn new(
type_id: TypeID,
function_id: FunctionID,
body: Body<'a>,
num_args: usize,
tag: InjectTag,
) -> Self {
let mut args = vec![];
for arg in 0..num_args {
args.push(LocalID(arg as u32));
}
LocalFunction {
ty_id: type_id,
func_id: function_id,
instr_flag: FuncInstrFlag::default(),
body,
args,
tag,
}
}
pub fn add_local(&mut self, ty: DataType) -> LocalID {
add_local(
ty,
self.args.len(),
&mut self.body.num_locals,
&mut self.body.locals,
)
}
pub fn add_instr(&mut self, instr: Operator<'a>, instr_idx: usize) {
if self.instr_flag.current_mode.is_some() {
self.instr_flag.add_instr(instr);
} else {
let is_special = self.body.instructions.add_instr(instr_idx, instr);
self.instr_flag.has_special_instr |= is_special;
}
}
pub fn instr_len_at(&self, instr_idx: usize) -> usize {
if self.instr_flag.current_mode.is_some() {
self.instr_flag.instr_len()
} else {
self.body.instructions.instr_len(instr_idx)
}
}
pub fn append_instr_tag_at(&mut self, data: Vec<u8>, instr_idx: usize) {
if self.instr_flag.current_mode.is_some() {
self.instr_flag.append_to_tag(data);
} else {
self.body.instructions.append_to_tag(instr_idx, data);
}
}
pub fn clear_instr_at(&mut self, instr_idx: usize, mode: InstrumentationMode) {
self.body.instructions.clear_instr(instr_idx, mode);
}
pub(crate) fn add_corrected_special_injections(
&mut self,
rel_fid: u32,
func_mapping: &HashMap<u32, u32>,
global_mapping: &HashMap<u32, u32>,
memory_mapping: &HashMap<u32, u32>,
side_effects: &mut HashMap<InjectType, Vec<Injection<'a>>>,
) -> types::Result<()> {
self.instr_flag.add_injections(
rel_fid,
func_mapping,
global_mapping,
memory_mapping,
side_effects,
)
}
pub(crate) fn add_opcode_injections(
&self,
rel_fid: u32,
side_effects: &mut HashMap<InjectType, Vec<Injection<'a>>>,
) {
if let Some(flags) = self.body.instructions.get_flags() {
for (idx, instr_flag) in flags.iter().enumerate() {
instr_flag.add_injections(rel_fid, idx as u32, side_effects);
}
}
}
pub fn lookup_pc_offset_for(&self, instr_idx: usize) -> Option<usize> {
self.body.instructions.lookup_pc_offset_for(instr_idx)
}
}
pub(crate) fn add_local(
ty: DataType,
num_params: usize,
num_locals: &mut u32,
locals: &mut Vec<(u32, DataType)>,
) -> LocalID {
let index = num_params + *num_locals as usize;
let len = locals.len();
*num_locals += 1;
if len > 0 {
let last = len - 1;
if locals[last].1 == ty {
locals[last].0 += 1;
} else {
locals.push((1, ty));
}
} else {
locals.push((1, ty));
}
LocalID(index as u32)
}
pub(crate) fn add_locals(
types: &[DataType],
num_params: usize,
num_locals: &mut u32,
locals: &mut Vec<(u32, DataType)>,
) {
for ty in types.iter() {
add_local(*ty, num_params, num_locals, locals);
}
}
#[derive(Clone, Debug)]
pub struct ImportedFunction {
pub import_id: ImportsID, pub(crate) import_fn_id: FunctionID, pub ty_id: TypeID,
}
impl ImportedFunction {
pub fn new(id: ImportsID, type_id: TypeID, function_id: FunctionID) -> Self {
ImportedFunction {
import_id: id,
ty_id: type_id,
import_fn_id: function_id,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Functions<'a> {
functions: Vec<Function<'a>>,
pub(crate) recalculate_ids: bool,
}
impl<'a> Functions<'a> {
pub fn iter(&self) -> impl Iterator<Item = &Function<'a>> {
self.functions.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Function<'a>> {
self.functions.iter_mut()
}
}
impl<'a> AsVec<Function<'a>> for Functions<'a> {
fn as_vec(&self) -> &Vec<Function<'a>> {
&self.functions
}
fn as_vec_mut(&mut self) -> &mut Vec<Function<'a>> {
&mut self.functions
}
}
impl<'a> Functions<'a> {
pub fn new(functions: Vec<Function<'a>>) -> Self {
Functions {
functions,
recalculate_ids: false,
}
}
pub fn get_fn_by_id(&self, function_id: FunctionID) -> Option<&Function<'a>> {
if *function_id < self.functions.len() as u32 {
return Some(&self.functions[*function_id as usize]);
}
None
}
pub fn is_empty(&self) -> bool {
self.functions.is_empty()
}
pub fn get_kind(&self, function_id: FunctionID) -> &FuncKind<'a> {
&self.functions[*function_id as usize].kind
}
pub fn get_kind_mut(&mut self, function_id: FunctionID) -> &mut FuncKind<'a> {
&mut self.functions[*function_id as usize].kind
}
pub fn get_name(&self, function_id: FunctionID) -> &Option<String> {
&self.functions[*function_id as usize].name
}
pub fn is_local(&self, function_id: FunctionID) -> bool {
self.functions[*function_id as usize].is_local()
}
pub fn is_import(&self, function_id: FunctionID) -> bool {
self.functions[*function_id as usize].is_import()
}
pub fn get_type_id(&self, id: FunctionID) -> TypeID {
self.functions[*id as usize].get_type_id()
}
pub fn is_deleted(&self, function_id: FunctionID) -> bool {
self.functions[*function_id as usize].is_deleted()
}
pub fn get(&self, function_id: FunctionID) -> &Function<'a> {
&self.functions[*function_id as usize]
}
pub fn get_mut(&mut self, function_id: FunctionID) -> &mut Function<'a> {
&mut self.functions[*function_id as usize]
}
pub fn unwrap_local(&self, function_id: FunctionID) -> types::Result<&LocalFunction<'a>> {
self.functions[*function_id as usize].unwrap_local()
}
pub fn unwrap_local_mut(
&mut self,
function_id: FunctionID,
) -> types::Result<&mut LocalFunction<'a>> {
self.functions[*function_id as usize].unwrap_local_mut()
}
pub fn get_local_fid_by_name(&self, name: &str) -> Option<FunctionID> {
for (idx, func) in self.functions.iter().enumerate() {
if let FuncKind::Local(l) = &func.kind {
if let Some(n) = &l.body.name {
if n == name {
return Some(FunctionID(idx as u32));
}
}
}
}
None
}
pub fn get_fn_modifier<'b>(
&'b mut self,
func_id: FunctionID,
) -> types::Result<FunctionModifier<'b, 'a>> {
let func = self.functions.get_mut(*func_id as usize);
if func.is_none() {
return Err(UnknownId(format!(
"Could not find function with ID: {func_id:?}"
)));
}
match func.unwrap().kind {
FuncKind::Local(ref mut l) => {
l.instr_flag.finish_instr();
Ok(FunctionModifier::init(
&mut l.instr_flag,
&mut l.body,
&mut l.args,
))
}
_ => Err(InvalidOperation(
"Cannot modify a non-local function".to_string(),
)),
}
}
pub(crate) fn delete(&mut self, id: FunctionID) {
self.recalculate_ids = true;
if *id < self.functions.len() as u32 {
self.functions[*id as usize].delete();
}
}
fn next_id(&self) -> FunctionID {
FunctionID(self.functions.len() as u32)
}
pub(crate) fn add_local_func(
&mut self,
mut local_function: LocalFunction<'a>,
name: Option<String>,
) -> FunctionID {
self.recalculate_ids = true;
let id = self.next_id();
local_function.func_id = id;
self.functions.push(Function::new(
FuncKind::Local(Box::new(local_function)),
name.clone(),
));
if let Some(name) = name {
self.set_local_fn_name(id, name);
}
id
}
pub(crate) fn add_import_func(
&mut self,
imp_id: ImportsID,
ty_id: TypeID,
name: Option<String>,
imp_fn_id: u32,
) {
self.recalculate_ids = true;
debug_assert_eq!(*self.next_id(), imp_fn_id);
self.functions.push(Function::new(
FuncKind::Import(ImportedFunction::new(imp_id, ty_id, FunctionID(imp_fn_id))),
name,
));
}
pub(crate) fn add_local(
&mut self,
func_idx: FunctionID,
ty: DataType,
) -> types::Result<LocalID> {
let local_func = self.functions[*func_idx as usize].unwrap_local_mut()?;
Ok(local_func.add_local(ty))
}
pub fn set_local_fn_name(&mut self, func_idx: FunctionID, name: String) -> bool {
match &mut self.functions[*func_idx as usize].kind {
FuncKind::Import(_) => {
warn!("is an imported function!");
return false;
}
FuncKind::Local(ref mut l) => l.body.name = Some(name.clone()),
}
self.functions[*func_idx as usize].name = Some(name);
true
}
pub(crate) fn set_imported_fn_name(&mut self, func_idx: FunctionID, name: String) -> bool {
if self.functions[*func_idx as usize].is_local() {
warn!("is a local function!");
return false;
}
self.functions[*func_idx as usize].name = Some(name);
true
}
}