use casper_wasm::elements::{
    self, External, Instruction, Internal, MemorySection, Module, Section, TableType, Type,
};
use casper_wasm_utils::{self, stack_height};
use thiserror::Error;
use super::wasm_config::WasmConfig;
use crate::core::execution;
const DEFAULT_GAS_MODULE_NAME: &str = "env";
const INTERNAL_GAS_FUNCTION_NAME: &str = "gas";
pub const DEFAULT_MAX_TABLE_SIZE: u32 = 4096;
pub const DEFAULT_BR_TABLE_MAX_SIZE: u32 = 256;
pub const DEFAULT_MAX_GLOBALS: u32 = 256;
pub const DEFAULT_MAX_PARAMETER_COUNT: u32 = 256;
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum WasmValidationError {
    #[error("initial table size of {actual} exceeds allowed limit of {max}")]
    InitialTableSizeExceeded {
        max: u32,
        actual: u32,
    },
    #[error("maximum table size of {actual} exceeds allowed limit of {max}")]
    MaxTableSizeExceeded {
        max: u32,
        actual: u32,
    },
    #[error("the number of tables must be at most one")]
    MoreThanOneTable,
    #[error("maximum br_table size of {actual} exceeds allowed limit of {max}")]
    BrTableSizeExceeded {
        max: u32,
        actual: usize,
    },
    #[error("declared number of globals ({actual}) exceeds allowed limit of {max}")]
    TooManyGlobals {
        max: u32,
        actual: usize,
    },
    #[error("use of a function type with too many parameters (limit of {max} but function declares {actual})")]
    TooManyParameters {
        max: u32,
        actual: usize,
    },
    #[error("module imports a non-existent function")]
    MissingHostFunction,
    #[error("opcode for a global access refers to non-existing global index {index}")]
    IncorrectGlobalOperation {
        index: u32,
    },
    #[error("missing function index {index}")]
    MissingFunctionIndex {
        index: u32,
    },
    #[error("missing type index {index}")]
    MissingFunctionType {
        index: u32,
    },
}
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum PreprocessingError {
    #[error("Deserialization error: {0}")]
    Deserialize(String),
    #[error(
        "Encountered operation forbidden by gas rules. Consult instruction -> metering config map"
    )]
    OperationForbiddenByGasRules,
    #[error("Stack limiter error")]
    StackLimiter,
    #[error("Memory section should exist")]
    MissingMemorySection,
    #[error("Missing module")]
    MissingModule,
    #[error("Wasm validation error: {0}")]
    WasmValidation(#[from] WasmValidationError),
}
impl From<elements::Error> for PreprocessingError {
    fn from(error: elements::Error) -> Self {
        PreprocessingError::Deserialize(error.to_string())
    }
}
fn ensure_valid_access(module: &Module) -> Result<(), WasmValidationError> {
    let function_types_count = module
        .type_section()
        .map(|ts| ts.types().len())
        .unwrap_or_default();
    let mut function_count = 0_u32;
    if let Some(import_section) = module.import_section() {
        for import_entry in import_section.entries() {
            if let External::Function(function_type_index) = import_entry.external() {
                if (*function_type_index as usize) < function_types_count {
                    function_count = function_count.saturating_add(1);
                } else {
                    return Err(WasmValidationError::MissingFunctionType {
                        index: *function_type_index,
                    });
                }
            }
        }
    }
    if let Some(function_section) = module.function_section() {
        for function_entry in function_section.entries() {
            let function_type_index = function_entry.type_ref();
            if (function_type_index as usize) < function_types_count {
                function_count = function_count.saturating_add(1);
            } else {
                return Err(WasmValidationError::MissingFunctionType {
                    index: function_type_index,
                });
            }
        }
    }
    if let Some(function_index) = module.start_section() {
        ensure_valid_function_index(function_index, function_count)?;
    }
    if let Some(export_section) = module.export_section() {
        for export_entry in export_section.entries() {
            if let Internal::Function(function_index) = export_entry.internal() {
                ensure_valid_function_index(*function_index, function_count)?;
            }
        }
    }
    if let Some(code_section) = module.code_section() {
        let global_len = module
            .global_section()
            .map(|global_section| global_section.entries().len())
            .unwrap_or(0);
        for instr in code_section
            .bodies()
            .iter()
            .flat_map(|body| body.code().elements())
        {
            match instr {
                Instruction::Call(idx) => {
                    ensure_valid_function_index(*idx, function_count)?;
                }
                Instruction::GetGlobal(idx) | Instruction::SetGlobal(idx)
                    if *idx as usize >= global_len =>
                {
                    return Err(WasmValidationError::IncorrectGlobalOperation { index: *idx });
                }
                _ => {}
            }
        }
    }
    if let Some(element_section) = module.elements_section() {
        for element_segment in element_section.entries() {
            for idx in element_segment.members() {
                ensure_valid_function_index(*idx, function_count)?;
            }
        }
    }
    Ok(())
}
fn ensure_valid_function_index(index: u32, function_count: u32) -> Result<(), WasmValidationError> {
    if index >= function_count {
        return Err(WasmValidationError::MissingFunctionIndex { index });
    }
    Ok(())
}
fn memory_section(module: &Module) -> Option<&MemorySection> {
    for section in module.sections() {
        if let Section::Memory(section) = section {
            return if section.entries().is_empty() {
                None
            } else {
                Some(section)
            };
        }
    }
    None
}
fn ensure_table_size_limit(mut module: Module, limit: u32) -> Result<Module, WasmValidationError> {
    if let Some(sect) = module.table_section_mut() {
        if sect.entries().len() > 1 {
            return Err(WasmValidationError::MoreThanOneTable);
        }
        if let Some(table_entry) = sect.entries_mut().first_mut() {
            let initial = table_entry.limits().initial();
            if initial > limit {
                return Err(WasmValidationError::InitialTableSizeExceeded {
                    max: limit,
                    actual: initial,
                });
            }
            match table_entry.limits().maximum() {
                Some(max) => {
                    if max > limit {
                        return Err(WasmValidationError::MaxTableSizeExceeded {
                            max: limit,
                            actual: max,
                        });
                    }
                }
                None => {
                    *table_entry = TableType::new(initial, Some(limit))
                }
            }
        }
    }
    Ok(module)
}
fn ensure_br_table_size_limit(module: &Module, limit: u32) -> Result<(), WasmValidationError> {
    let code_section = if let Some(type_section) = module.code_section() {
        type_section
    } else {
        return Ok(());
    };
    for instr in code_section
        .bodies()
        .iter()
        .flat_map(|body| body.code().elements())
    {
        if let Instruction::BrTable(br_table_data) = instr {
            if br_table_data.table.len() > limit as usize {
                return Err(WasmValidationError::BrTableSizeExceeded {
                    max: limit,
                    actual: br_table_data.table.len(),
                });
            }
        }
    }
    Ok(())
}
fn ensure_global_variable_limit(module: &Module, limit: u32) -> Result<(), WasmValidationError> {
    if let Some(global_section) = module.global_section() {
        let actual = global_section.entries().len();
        if actual > limit as usize {
            return Err(WasmValidationError::TooManyGlobals { max: limit, actual });
        }
    }
    Ok(())
}
fn ensure_parameter_limit(module: &Module, limit: u32) -> Result<(), WasmValidationError> {
    let type_section = if let Some(type_section) = module.type_section() {
        type_section
    } else {
        return Ok(());
    };
    for Type::Function(func) in type_section.types() {
        let actual = func.params().len();
        if actual > limit as usize {
            return Err(WasmValidationError::TooManyParameters { max: limit, actual });
        }
    }
    Ok(())
}
fn ensure_valid_imports(module: &Module) -> Result<(), WasmValidationError> {
    let import_entries = module
        .import_section()
        .map(|is| is.entries())
        .unwrap_or(&[]);
    for import in import_entries {
        if import.module() == DEFAULT_GAS_MODULE_NAME
            && import.field() == INTERNAL_GAS_FUNCTION_NAME
        {
            return Err(WasmValidationError::MissingHostFunction);
        }
    }
    Ok(())
}
pub fn preprocess(
    wasm_config: WasmConfig,
    module_bytes: &[u8],
) -> Result<Module, PreprocessingError> {
    let module = deserialize(module_bytes)?;
    ensure_valid_access(&module)?;
    if memory_section(&module).is_none() {
        return Err(PreprocessingError::MissingMemorySection);
    }
    let module = ensure_table_size_limit(module, DEFAULT_MAX_TABLE_SIZE)?;
    ensure_br_table_size_limit(&module, DEFAULT_BR_TABLE_MAX_SIZE)?;
    ensure_global_variable_limit(&module, DEFAULT_MAX_GLOBALS)?;
    ensure_parameter_limit(&module, DEFAULT_MAX_PARAMETER_COUNT)?;
    ensure_valid_imports(&module)?;
    let module = casper_wasm_utils::externalize_mem(module, None, wasm_config.max_memory);
    let module = casper_wasm_utils::inject_gas_counter(
        module,
        &wasm_config.opcode_costs(),
        DEFAULT_GAS_MODULE_NAME,
    )
    .map_err(|_| PreprocessingError::OperationForbiddenByGasRules)?;
    let module = stack_height::inject_limiter(module, wasm_config.max_stack_height)
        .map_err(|_| PreprocessingError::StackLimiter)?;
    Ok(module)
}
pub fn deserialize(module_bytes: &[u8]) -> Result<Module, PreprocessingError> {
    casper_wasm::deserialize_buffer::<Module>(module_bytes).map_err(Into::into)
}
pub fn get_module_from_entry_points(
    entry_point_names: Vec<&str>,
    mut module: Module,
) -> Result<Vec<u8>, execution::Error> {
    let export_section = module.export_section().ok_or_else(|| {
        execution::Error::FunctionNotFound(String::from("Missing Export Section"))
    })?;
    let maybe_missing_name: Option<String> = entry_point_names
        .iter()
        .find(|name| {
            !export_section
                .entries()
                .iter()
                .any(|export_entry| export_entry.field() == **name)
        })
        .map(|s| String::from(*s));
    match maybe_missing_name {
        Some(missing_name) => Err(execution::Error::FunctionNotFound(missing_name)),
        None => {
            casper_wasm_utils::optimize(&mut module, entry_point_names)?;
            casper_wasm::serialize(module).map_err(execution::Error::ParityWasm)
        }
    }
}
#[cfg(test)]
mod tests {
    use casper_types::contracts::DEFAULT_ENTRY_POINT_NAME;
    use casper_wasm::{
        builder,
        elements::{CodeSection, Instructions},
    };
    use walrus::{FunctionBuilder, ModuleConfig, ValType};
    use super::*;
    #[test]
    fn should_not_panic_on_empty_memory() {
        const MODULE_BYTES_WITH_EMPTY_MEMORY: [u8; 61] = [
            0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x09, 0x02, 0x60, 0x01, 0x7f,
            0x01, 0x7f, 0x60, 0x00, 0x00, 0x03, 0x03, 0x02, 0x00, 0x01, 0x05, 0x01, 0x00, 0x08,
            0x01, 0x01, 0x0a, 0x1d, 0x02, 0x18, 0x00, 0x20, 0x00, 0x41, 0x80, 0x80, 0x82, 0x80,
            0x78, 0x70, 0x41, 0x80, 0x82, 0x80, 0x80, 0x7e, 0x4f, 0x22, 0x00, 0x1a, 0x20, 0x00,
            0x0f, 0x0b, 0x02, 0x00, 0x0b,
        ];
        match preprocess(WasmConfig::default(), &MODULE_BYTES_WITH_EMPTY_MEMORY).unwrap_err() {
            PreprocessingError::MissingMemorySection => (),
            error => panic!("expected MissingMemorySection, got {:?}", error),
        }
    }
    #[test]
    fn should_not_overflow_in_export_section() {
        let module = builder::module()
            .function()
            .signature()
            .build()
            .body()
            .with_instructions(Instructions::new(vec![Instruction::Nop, Instruction::End]))
            .build()
            .build()
            .export()
            .field(DEFAULT_ENTRY_POINT_NAME)
            .internal()
            .func(u32::MAX)
            .build()
            .memory()
            .build()
            .build();
        let module_bytes = casper_wasm::serialize(module).expect("should serialize");
        let error = preprocess(WasmConfig::default(), &module_bytes)
            .expect_err("should fail with an error");
        assert!(
            matches!(
                &error,
                PreprocessingError::WasmValidation(WasmValidationError::MissingFunctionIndex { index: missing_index })
                if *missing_index == u32::MAX
            ),
            "{:?}",
            error,
        );
    }
    #[test]
    fn should_not_overflow_in_element_section() {
        const CALL_FN_IDX: u32 = 0;
        let module = builder::module()
            .function()
            .signature()
            .build()
            .body()
            .with_instructions(Instructions::new(vec![Instruction::Nop, Instruction::End]))
            .build()
            .build()
            .export()
            .field(DEFAULT_ENTRY_POINT_NAME)
            .internal()
            .func(CALL_FN_IDX)
            .build()
            .table()
            .with_element(u32::MAX, vec![u32::MAX])
            .build()
            .memory()
            .build()
            .build();
        let module_bytes = casper_wasm::serialize(module).expect("should serialize");
        let error = preprocess(WasmConfig::default(), &module_bytes)
            .expect_err("should fail with an error");
        assert!(
            matches!(
                &error,
                PreprocessingError::WasmValidation(WasmValidationError::MissingFunctionIndex { index: missing_index })
                if *missing_index == u32::MAX
            ),
            "{:?}",
            error,
        );
    }
    #[test]
    fn should_not_overflow_in_call_opcode() {
        let module = builder::module()
            .function()
            .signature()
            .build()
            .body()
            .with_instructions(Instructions::new(vec![
                Instruction::Call(u32::MAX),
                Instruction::End,
            ]))
            .build()
            .build()
            .export()
            .field(DEFAULT_ENTRY_POINT_NAME)
            .build()
            .memory()
            .build()
            .build();
        let module_bytes = casper_wasm::serialize(module).expect("should serialize");
        let error = preprocess(WasmConfig::default(), &module_bytes)
            .expect_err("should fail with an error");
        assert!(
            matches!(
                &error,
                PreprocessingError::WasmValidation(WasmValidationError::MissingFunctionIndex { index: missing_index })
                if *missing_index == u32::MAX
            ),
            "{:?}",
            error,
        );
    }
    #[test]
    fn should_not_overflow_in_start_section_without_code_section() {
        let module = builder::module()
            .with_section(Section::Start(u32::MAX))
            .memory()
            .build()
            .build();
        let module_bytes = casper_wasm::serialize(module).expect("should serialize");
        let error = preprocess(WasmConfig::default(), &module_bytes)
            .expect_err("should fail with an error");
        assert!(
            matches!(
                &error,
                PreprocessingError::WasmValidation(WasmValidationError::MissingFunctionIndex { index: missing_index })
                if *missing_index == u32::MAX
            ),
            "{:?}",
            error,
        );
    }
    #[test]
    fn should_not_overflow_in_start_section_with_code() {
        let module = builder::module()
            .with_section(Section::Start(u32::MAX))
            .with_section(Section::Code(CodeSection::with_bodies(Vec::new())))
            .memory()
            .build()
            .build();
        let module_bytes = casper_wasm::serialize(module).expect("should serialize");
        let error = preprocess(WasmConfig::default(), &module_bytes)
            .expect_err("should fail with an error");
        assert!(
            matches!(
                &error,
                PreprocessingError::WasmValidation(WasmValidationError::MissingFunctionIndex { index: missing_index })
                if *missing_index == u32::MAX
            ),
            "{:?}",
            error,
        );
    }
    #[test]
    fn should_not_accept_multi_value_proposal_wasm() {
        let module_bytes = {
            let mut module = walrus::Module::with_config(ModuleConfig::new());
            let _memory_id = module.memories.add_local(false, 11, None);
            let mut func_with_locals =
                FunctionBuilder::new(&mut module.types, &[], &[ValType::I32, ValType::I64]);
            func_with_locals.func_body().i64_const(0).i32_const(1);
            let func_with_locals = func_with_locals.finish(vec![], &mut module.funcs);
            let mut call_func = FunctionBuilder::new(&mut module.types, &[], &[]);
            call_func.func_body().call(func_with_locals);
            let call = call_func.finish(Vec::new(), &mut module.funcs);
            module.exports.add(DEFAULT_ENTRY_POINT_NAME, call);
            module.emit_wasm()
        };
        let error = preprocess(WasmConfig::default(), &module_bytes)
            .expect_err("should fail with an error");
        assert!(
            matches!(&error, PreprocessingError::Deserialize(msg)
            if msg == "Enable the multi_value feature to deserialize more than one function result"),
            "{:?}",
            error,
        );
    }
}