cairo-native 0.9.0-rc.5

A compiler to convert Cairo's IR Sierra code to MLIR and execute it.
use super::{drop_overrides::DropOverridesMeta, MetadataStorage};
use crate::{
    error::{Error, Result},
    utils::ProgramRegistryExt,
};
use cairo_lang_sierra::{
    extensions::core::{CoreLibfunc, CoreType},
    ids::ConcreteTypeId,
    program_registry::ProgramRegistry,
};
use melior::{
    dialect::llvm,
    helpers::{BuiltinBlockExt, LlvmBlockExt},
    ir::{
        attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute},
        Attribute, Block, BlockLike, Identifier, Location, Module, Region,
    },
    Context,
};
use std::collections::{hash_map::Entry, HashMap};

#[derive(Clone, Debug, Default)]
pub struct Felt252DictOverrides {
    drop_overrides: HashMap<ConcreteTypeId, String>,
}

impl Felt252DictOverrides {
    pub fn get_drop_fn(&self, type_id: &ConcreteTypeId) -> Option<&str> {
        self.drop_overrides.get(type_id).map(String::as_str)
    }

    pub fn build_drop_fn<'ctx>(
        &mut self,
        context: &'ctx Context,
        module: &Module<'ctx>,
        registry: &ProgramRegistry<CoreType, CoreLibfunc>,
        metadata: &mut MetadataStorage,
        type_id: &ConcreteTypeId,
    ) -> Result<Option<FlatSymbolRefAttribute<'ctx>>> {
        let location = Location::unknown(context);

        let inner_ty = registry.build_type(context, module, metadata, type_id)?;
        Ok(if DropOverridesMeta::is_overriden(metadata, type_id) {
            let drop_fn_symbol = format!("drop${}$item", type_id.id);
            let flat_symbol_ref = FlatSymbolRefAttribute::new(context, &drop_fn_symbol);

            if let Entry::Vacant(entry) = self.drop_overrides.entry(type_id.clone()) {
                let drop_fn_symbol = entry.insert(drop_fn_symbol);

                let region = Region::new();
                let entry = region
                    .append_block(Block::new(&[(llvm::r#type::pointer(context, 0), location)]));

                let value = entry.load(context, location, entry.arg(0)?, inner_ty)?;
                DropOverridesMeta::invoke_override(
                    context, registry, module, &entry, &entry, location, metadata, type_id, value,
                )?;

                entry.append_operation(llvm::r#return(None, location));

                module.body().append_operation(llvm::func(
                    context,
                    StringAttribute::new(context, drop_fn_symbol),
                    TypeAttribute::new(llvm::r#type::function(
                        llvm::r#type::void(context),
                        &[llvm::r#type::pointer(context, 0)],
                        false,
                    )),
                    region,
                    &[
                        (
                            Identifier::new(context, "sym_visibility"),
                            StringAttribute::new(context, "public").into(),
                        ),
                        (
                            Identifier::new(context, "llvm.linkage"),
                            Attribute::parse(context, "#llvm.linkage<private>")
                                .ok_or(Error::ParseAttributeError)?,
                        ),
                    ],
                    location,
                ));
            }

            Some(flat_symbol_ref)
        } else {
            None
        })
    }
}