hdp 0.9.0

All Herodotus Data Processor
Documentation
//! Module registry is a service that provides the ability to fetch modules from the StarkNet network.
//! It fetch contract class from the StarkNet network and compile it to the casm.

use cairo_lang_starknet_classes::casm_contract_class::{
    CasmContractClass, StarknetSierraCompilationError,
};
use starknet_crypto::Felt;

use crate::{
    constant::HERODOTUS_PROGRAM_REGISTRY_URL,
    primitives::task::{module::Module, ExtendedModule},
};
use reqwest::Client;
use std::path::PathBuf;
use thiserror::Error;
use tracing::info;

#[derive(Error, Debug)]
pub enum ModuleRegistryError {
    #[error("Serialize error: {0}")]
    SerializeError(#[from] serde_json::Error),

    #[error("StarkNet error: {0}")]
    StarkNetSierraCompileError(#[from] StarknetSierraCompilationError),

    #[error("StarkNet Provider error: {0}")]
    StarkNetProviderError(#[from] starknet::providers::ProviderError),

    #[error("Cairo1 module should have sierra as class")]
    SierraNotFound,

    #[error("Tokio join error: {0}")]
    TokioJoinError(#[from] tokio::task::JoinError),

    #[error("Module class source error: {0}")]
    ClassSourceError(String),

    #[error("Type conversion error: {0}")]
    TypeConversionError(String),
}

pub struct ModuleRegistry {
    client: Client,
}

impl Default for ModuleRegistry {
    fn default() -> Self {
        Self::new()
    }
}

impl ModuleRegistry {
    pub fn new() -> Self {
        let client = Client::new();
        Self { client }
    }

    pub async fn get_extended_module(
        &self,
        module: Module,
    ) -> Result<ExtendedModule, ModuleRegistryError> {
        let casm = if let Some(ref local_class_path) = module.local_class_path {
            self.get_module_class_from_local_path(local_class_path)
                .await?
        } else {
            self.get_module_class_from_program_hash(module.program_hash)
                .await?
        };

        Ok(ExtendedModule {
            task: module,
            module_class: casm,
        })
    }

    async fn get_module_class_from_local_path(
        &self,
        local_class_path: &PathBuf,
    ) -> Result<CasmContractClass, ModuleRegistryError> {
        let casm: CasmContractClass =
            serde_json::from_str(&std::fs::read_to_string(local_class_path).map_err(|_| {
                ModuleRegistryError::ClassSourceError(
                    "Local class path is not a valid JSON file".to_string(),
                )
            })?)?;

        info!(
            "contract class fetched successfully from local path: {:?}",
            local_class_path
        );
        Ok(casm)
    }

    async fn get_module_class_from_program_hash(
        &self,
        program_hash: Felt,
    ) -> Result<CasmContractClass, ModuleRegistryError> {
        let program_hash_hex = format!("{:#x}", program_hash);

        info!(
            "fetching contract class from module registry... program_hash: {}",
            program_hash_hex
        );

        let api_url = format!("{}={}", HERODOTUS_PROGRAM_REGISTRY_URL, program_hash_hex);

        let response = self
            .client
            .get(&api_url)
            .header("User-Agent", "request")
            .send()
            .await
            .expect("response is failed");

        // Check if the response status is successful
        if response.status().is_success() {
            let response_text = response.text().await.expect("cannot get response");
            let casm: CasmContractClass = serde_json::from_str(&response_text)?;
            info!(
                "contract class fetched successfully from program_hash: {:?}",
                program_hash
            );
            Ok(casm)
        } else {
            Err(ModuleRegistryError::ClassSourceError(
                "failed to fetch contract class".to_string(),
            ))
        }
    }
}

#[cfg(test)]
mod tests {

    use starknet_crypto::Felt;

    use super::*;

    fn init() -> (ModuleRegistry, Felt) {
        let module_registry = ModuleRegistry::new();
        // This is test contract class hash
        let program_hash =
            Felt::from_hex("0x64041a339b1edd10de83cf031cfa938645450f971d2527c90d4c2ce68d7d412")
                .unwrap();

        (module_registry, program_hash)
    }

    #[tokio::test]
    async fn test_get_module() {
        let (module_registry, program_hash) = init();
        let _casm_from_rpc = module_registry
            .get_module_class_from_program_hash(program_hash)
            .await
            .unwrap();
    }

    #[tokio::test]
    async fn test_get_multiple_module_classes() {
        let (module_registry, program_hash) = init();

        let module = Module::new(program_hash, vec![], None);

        let extended_modules = module_registry.get_extended_module(module).await.unwrap();

        assert_eq!(
            extended_modules.task.program_hash,
            Felt::from_hex("0x64041a339b1edd10de83cf031cfa938645450f971d2527c90d4c2ce68d7d412")
                .unwrap()
        );
        assert_eq!(extended_modules.task.inputs, vec![]);
    }
}