use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, StarknetSierraCompilationError,
};
use starknet_crypto::Felt;
use crate::{
constant::HERODOTUS_PROGRAM_REGISTRY_URL,
primitives::task::{module::Module, ExtendedModule},
};
use reqwest::Client;
use std::path::PathBuf;
use thiserror::Error;
use tracing::info;
#[derive(Error, Debug)]
pub enum ModuleRegistryError {
#[error("Serialize error: {0}")]
SerializeError(#[from] serde_json::Error),
#[error("StarkNet error: {0}")]
StarkNetSierraCompileError(#[from] StarknetSierraCompilationError),
#[error("StarkNet Provider error: {0}")]
StarkNetProviderError(#[from] starknet::providers::ProviderError),
#[error("Cairo1 module should have sierra as class")]
SierraNotFound,
#[error("Tokio join error: {0}")]
TokioJoinError(#[from] tokio::task::JoinError),
#[error("Module class source error: {0}")]
ClassSourceError(String),
#[error("Type conversion error: {0}")]
TypeConversionError(String),
}
pub struct ModuleRegistry {
client: Client,
}
impl Default for ModuleRegistry {
fn default() -> Self {
Self::new()
}
}
impl ModuleRegistry {
pub fn new() -> Self {
let client = Client::new();
Self { client }
}
pub async fn get_extended_module(
&self,
module: Module,
) -> Result<ExtendedModule, ModuleRegistryError> {
let casm = if let Some(ref local_class_path) = module.local_class_path {
self.get_module_class_from_local_path(local_class_path)
.await?
} else {
self.get_module_class_from_program_hash(module.program_hash)
.await?
};
Ok(ExtendedModule {
task: module,
module_class: casm,
})
}
async fn get_module_class_from_local_path(
&self,
local_class_path: &PathBuf,
) -> Result<CasmContractClass, ModuleRegistryError> {
let casm: CasmContractClass =
serde_json::from_str(&std::fs::read_to_string(local_class_path).map_err(|_| {
ModuleRegistryError::ClassSourceError(
"Local class path is not a valid JSON file".to_string(),
)
})?)?;
info!(
"contract class fetched successfully from local path: {:?}",
local_class_path
);
Ok(casm)
}
async fn get_module_class_from_program_hash(
&self,
program_hash: Felt,
) -> Result<CasmContractClass, ModuleRegistryError> {
let program_hash_hex = format!("{:#x}", program_hash);
info!(
"fetching contract class from module registry... program_hash: {}",
program_hash_hex
);
let api_url = format!("{}={}", HERODOTUS_PROGRAM_REGISTRY_URL, program_hash_hex);
let response = self
.client
.get(&api_url)
.header("User-Agent", "request")
.send()
.await
.expect("response is failed");
if response.status().is_success() {
let response_text = response.text().await.expect("cannot get response");
let casm: CasmContractClass = serde_json::from_str(&response_text)?;
info!(
"contract class fetched successfully from program_hash: {:?}",
program_hash
);
Ok(casm)
} else {
Err(ModuleRegistryError::ClassSourceError(
"failed to fetch contract class".to_string(),
))
}
}
}
#[cfg(test)]
mod tests {
use starknet_crypto::Felt;
use super::*;
fn init() -> (ModuleRegistry, Felt) {
let module_registry = ModuleRegistry::new();
let program_hash =
Felt::from_hex("0x64041a339b1edd10de83cf031cfa938645450f971d2527c90d4c2ce68d7d412")
.unwrap();
(module_registry, program_hash)
}
#[tokio::test]
async fn test_get_module() {
let (module_registry, program_hash) = init();
let _casm_from_rpc = module_registry
.get_module_class_from_program_hash(program_hash)
.await
.unwrap();
}
#[tokio::test]
async fn test_get_multiple_module_classes() {
let (module_registry, program_hash) = init();
let module = Module::new(program_hash, vec![], None);
let extended_modules = module_registry.get_extended_module(module).await.unwrap();
assert_eq!(
extended_modules.task.program_hash,
Felt::from_hex("0x64041a339b1edd10de83cf031cfa938645450f971d2527c90d4c2ce68d7d412")
.unwrap()
);
assert_eq!(extended_modules.task.inputs, vec![]);
}
}