use std::collections::HashMap;
use std::sync::Arc;
use std::vec;
use cairo_lang_defs::plugin::{
DynGeneratedFileAuxData, GeneratedFileAuxData, MacroPlugin, PluginDiagnostic,
PluginGeneratedFile, PluginResult,
};
use cairo_lang_diagnostics::DiagnosticEntry;
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::patcher::{ModifiedNode, PatchBuilder, Patches, RewriteNode};
use cairo_lang_semantic::plugin::{
AsDynGeneratedFileAuxData, AsDynMacroPlugin, DiagnosticMapper, DynDiagnosticMapper,
PluginMappedDiagnostic, SemanticPlugin,
};
use cairo_lang_semantic::SemanticDiagnostic;
use cairo_lang_syntax::node::ast::{
FunctionWithBody, MaybeModuleBody, MaybeTraitBody, Modifier, OptionReturnTypeClause,
OptionWrappedGenericParamList, Param,
};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode};
use cairo_lang_utils::try_extract_matches;
use indoc::formatdoc;
use crate::contract::starknet_keccak;
const ABI_ATTR: &str = "abi";
const CONTRACT_ATTR: &str = "contract";
const EXTERNAL_ATTR: &str = "external";
const CONSTRUCTOR_ATTR: &str = "constructor";
pub const VIEW_ATTR: &str = "view";
pub const EVENT_ATTR: &str = "event";
pub const GENERATED_CONTRACT_ATTR: &str = "generated_contract";
pub const ABI_TRAIT: &str = "__abi";
pub const EXTERNAL_MODULE: &str = "__external";
pub const CONSTRUCTOR_MODULE: &str = "__constructor";
#[derive(Debug, PartialEq, Eq)]
pub struct DiagnosticRemapper {
patches: Patches,
}
impl GeneratedFileAuxData for DiagnosticRemapper {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn eq(&self, other: &dyn GeneratedFileAuxData) -> bool {
if let Some(other) = other.as_any().downcast_ref::<Self>() { self == other } else { false }
}
}
impl AsDynGeneratedFileAuxData for DiagnosticRemapper {
fn as_dyn_macro_token(&self) -> &(dyn GeneratedFileAuxData + 'static) {
self
}
}
impl DiagnosticMapper for DiagnosticRemapper {
fn map_diag(
&self,
db: &(dyn SemanticGroup + 'static),
diag: &dyn std::any::Any,
) -> Option<PluginMappedDiagnostic> {
let Some(diag) = diag.downcast_ref::<SemanticDiagnostic>() else {return None;};
let span = self
.patches
.translate(db.upcast(), diag.stable_location.diagnostic_location(db.upcast()).span)?;
Some(PluginMappedDiagnostic { span, message: diag.format(db) })
}
}
#[cfg(test)]
#[path = "plugin_test.rs"]
mod test;
#[derive(Debug)]
pub struct StarkNetPlugin {}
impl MacroPlugin for StarkNetPlugin {
fn generate_code(&self, db: &dyn SyntaxGroup, item_ast: ast::Item) -> PluginResult {
match item_ast {
ast::Item::Module(module_ast) => handle_mod(db, module_ast),
ast::Item::Trait(trait_ast) => handle_trait(db, trait_ast),
_ => PluginResult::default(),
}
}
}
impl AsDynMacroPlugin for StarkNetPlugin {
fn as_dyn_macro_plugin<'a>(self: Arc<Self>) -> Arc<dyn MacroPlugin + 'a>
where
Self: 'a,
{
self
}
}
impl SemanticPlugin for StarkNetPlugin {}
fn handle_trait(db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult {
let attrs = trait_ast.attributes(db).elements(db);
if !attrs.iter().any(|attr| attr.attr(db).text(db) == ABI_ATTR) {
return PluginResult::default();
}
let body = match trait_ast.body(db) {
MaybeTraitBody::Some(body) => body,
MaybeTraitBody::None(empty_body) => {
return PluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
message: "ABIs without body are not supported.".to_string(),
stable_ptr: empty_body.stable_ptr().untyped(),
}],
remove_original_item: false,
};
}
};
let mut diagnostics = vec![];
let mut functions = vec![];
for item_ast in body.items(db).elements(db) {
match item_ast {
ast::TraitItem::Function(func) => {
if func.has_attr(db, EVENT_ATTR) {
continue;
}
let declaration = func.declaration(db);
let mut skip_generation = false;
let mut serialization_code = vec![];
let signature = declaration.signature(db);
for param in signature.parameters(db).elements(db) {
if is_ref_param(db, ¶m) {
skip_generation = true;
diagnostics.push(PluginDiagnostic {
message: "`ref` parameters are not supported in the ABI of a contract."
.to_string(),
stable_ptr: param.modifiers(db).stable_ptr().untyped(),
})
}
let param_type = param.type_clause(db).ty(db);
let type_name = ¶m_type.as_syntax_node().get_text(db);
serialization_code.push(RewriteNode::interpolate_patched(
&formatdoc!(
" serde::Serde::<{type_name}>::serialize(ref calldata, \
$arg_name$);\n"
),
HashMap::from([(
"arg_name".to_string(),
RewriteNode::Trimmed(param.name(db).as_syntax_node()),
)]),
));
}
if skip_generation {
continue;
}
let ret_decode = match signature.ret_ty(db) {
OptionReturnTypeClause::Empty(_) => "".to_string(),
OptionReturnTypeClause::ReturnTypeClause(ty) => {
let ret_type_ast = ty.ty(db);
let type_name = ret_type_ast.as_syntax_node().get_text(db);
format!(
"
serde::Serde::<{type_name}>::deserialize(ref ret_data).expect(
'Returned data too short')"
)
}
};
let mut func_declaration = RewriteNode::from_ast(&declaration);
func_declaration
.modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE)
.modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS)
.modify(db)
.children
.splice(
0..0,
[
RewriteNode::Text("contract_address: ContractAddress".to_string()),
RewriteNode::Text(", ".to_string()),
],
);
functions.push(RewriteNode::interpolate_patched(
"$func_decl$ {
let mut calldata = array_new();
$serialization_code$
let mut ret_data = starknet::call_contract_syscall(
contract_address,
calldata,
).unwrap_syscall();
$deserialization_code$
}
",
HashMap::from([
("func_decl".to_string(), func_declaration),
(
"serialization_code".to_string(),
RewriteNode::Modified(ModifiedNode { children: serialization_code }),
),
("deserialization_code".to_string(), RewriteNode::Text(ret_decode)),
]),
));
}
}
}
let mut builder = PatchBuilder::new(db);
let dispatcher_name = format!("{}Dispatcher", trait_ast.name(db).text(db));
builder.add_modified(RewriteNode::interpolate_patched(
&formatdoc!(
"mod {dispatcher_name} {{
use starknet::SyscallResultTrait;
use starknet::SyscallResultTraitImpl;
use option::OptionTrait;
use option::OptionTraitImpl;
$body$
}}",
),
HashMap::from([(
"body".to_string(),
RewriteNode::Modified(ModifiedNode { children: functions }),
)]),
));
PluginResult {
code: Some(PluginGeneratedFile {
name: dispatcher_name.into(),
content: builder.code,
aux_data: DynGeneratedFileAuxData::new(DynDiagnosticMapper::new(DiagnosticRemapper {
patches: builder.patches,
})),
}),
diagnostics,
remove_original_item: false,
}
}
fn handle_mod(db: &dyn SyntaxGroup, module_ast: ast::ItemModule) -> PluginResult {
if !module_ast.has_attr(db, CONTRACT_ATTR) {
return PluginResult::default();
}
let body = match module_ast.body(db) {
MaybeModuleBody::Some(body) => body,
MaybeModuleBody::None(empty_body) => {
return PluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
message: "Contracts without body are not supported.".to_string(),
stable_ptr: empty_body.stable_ptr().untyped(),
}],
remove_original_item: false,
};
}
};
let mut diagnostics = vec![];
let mut generated_external_functions = Vec::new();
let mut generated_constructor_functions = Vec::new();
let mut storage_code = RewriteNode::Text("".to_string());
let mut original_items = Vec::new();
let mut abi_functions = Vec::new();
let mut event_functions = Vec::new();
let mut abi_events = Vec::new();
for item in body.items(db).elements(db) {
let keep_original = match &item {
ast::Item::FreeFunction(item_function)
if item_function.has_attr(db, EXTERNAL_ATTR)
|| item_function.has_attr(db, VIEW_ATTR)
|| item_function.has_attr(db, CONSTRUCTOR_ATTR) =>
{
let attr = if item_function.has_attr(db, EXTERNAL_ATTR) {
EXTERNAL_ATTR
} else if item_function.has_attr(db, VIEW_ATTR) {
VIEW_ATTR
} else {
CONSTRUCTOR_ATTR
};
let declaration = item_function.declaration(db);
if let OptionWrappedGenericParamList::WrappedGenericParamList(generic_params) =
declaration.generic_params(db)
{
diagnostics.push(PluginDiagnostic {
message: "Contract entry points cannot have generic arguments".to_string(),
stable_ptr: generic_params.stable_ptr().untyped(),
})
}
abi_functions.push(RewriteNode::Modified(ModifiedNode {
children: vec![
RewriteNode::Text(format!("#[{attr}]\n ")),
RewriteNode::Trimmed(declaration.as_syntax_node()),
RewriteNode::Text(";\n ".to_string()),
],
}));
match generate_entry_point_wrapper(db, item_function) {
Ok(generated_function) => {
let generated = if item_function.has_attr(db, CONSTRUCTOR_ATTR) {
&mut generated_constructor_functions
} else {
&mut generated_external_functions
};
generated.push(generated_function);
generated.push(RewriteNode::Text("\n ".to_string()));
}
Err(entry_point_diagnostics) => {
diagnostics.extend(entry_point_diagnostics);
}
}
true
}
ast::Item::FreeFunction(item_function) if item_function.has_attr(db, EVENT_ATTR) => {
let (rewrite_nodes, event_diagnostics) = handle_event(db, item_function.clone());
if let Some((event_function_rewrite, abi_event_rewrite)) = rewrite_nodes {
event_functions.push(event_function_rewrite);
abi_events.push(abi_event_rewrite);
}
diagnostics.extend(event_diagnostics);
false
}
ast::Item::Struct(item_struct) if item_struct.name(db).text(db) == "Storage" => {
let (storage_rewrite_node, storage_diagnostics) =
handle_storage_struct(db, item_struct.clone());
storage_code = storage_rewrite_node;
diagnostics.extend(storage_diagnostics);
false
}
_ => true,
};
if keep_original {
original_items.push(RewriteNode::Copied(item.as_syntax_node()));
}
}
let generated_contract_mod = RewriteNode::interpolate_patched(
formatdoc!(
"
#[{GENERATED_CONTRACT_ATTR}]
mod $contract_name$ {{
use starknet::SyscallResultTrait;
use starknet::SyscallResultTraitImpl;
$original_items$
$storage_code$
$event_functions$
trait {ABI_TRAIT} {{
$abi_functions$
$abi_events$
}}
mod {EXTERNAL_MODULE} {{
$generated_external_functions$
}}
mod {CONSTRUCTOR_MODULE} {{
$generated_constructor_functions$
}}
}}
"
)
.as_str(),
HashMap::from([
(
"contract_name".to_string(),
RewriteNode::Trimmed(module_ast.name(db).as_syntax_node()),
),
(
"original_items".to_string(),
RewriteNode::Modified(ModifiedNode { children: original_items }),
),
("storage_code".to_string(), storage_code),
(
"abi_functions".to_string(),
RewriteNode::Modified(ModifiedNode { children: abi_functions }),
),
(
"abi_events".to_string(),
RewriteNode::Modified(ModifiedNode { children: abi_events }),
),
(
"generated_external_functions".to_string(),
RewriteNode::Modified(ModifiedNode { children: generated_external_functions }),
),
(
"generated_constructor_functions".to_string(),
RewriteNode::Modified(ModifiedNode { children: generated_constructor_functions }),
),
(
"event_functions".to_string(),
RewriteNode::Modified(ModifiedNode { children: event_functions }),
),
]),
);
let mut builder = PatchBuilder::new(db);
builder.add_modified(generated_contract_mod);
PluginResult {
code: Some(PluginGeneratedFile {
name: "contract".into(),
content: builder.code,
aux_data: DynGeneratedFileAuxData::new(DynDiagnosticMapper::new(DiagnosticRemapper {
patches: builder.patches,
})),
}),
diagnostics,
remove_original_item: true,
}
}
fn handle_event(
db: &dyn SyntaxGroup,
function_ast: ast::FunctionWithBody,
) -> (Option<(RewriteNode, RewriteNode)>, Vec<PluginDiagnostic>) {
let mut diagnostics = vec![];
let declaration = function_ast.declaration(db);
if let OptionWrappedGenericParamList::WrappedGenericParamList(generic_params) =
declaration.generic_params(db)
{
diagnostics.push(PluginDiagnostic {
message: "Event functions cannot have generic arguments".to_string(),
stable_ptr: generic_params.stable_ptr().untyped(),
})
}
let signature = declaration.signature(db);
let ret_ty = declaration.signature(db).ret_ty(db);
if matches!(ret_ty, OptionReturnTypeClause::ReturnTypeClause(_)) {
diagnostics.push(PluginDiagnostic {
stable_ptr: ret_ty.stable_ptr().untyped(),
message: "Event functions must not return a value.".to_string(),
});
}
let mut param_serializations = Vec::new();
for param in signature.parameters(db).elements(db) {
let mut skip_param_serialization = false;
if is_ref_param(db, ¶m) {
diagnostics.push(PluginDiagnostic {
stable_ptr: param.modifiers(db).stable_ptr().untyped(),
message: "`ref` parameters are not supported in contract events.".to_string(),
});
skip_param_serialization = true;
}
let param_name = param.name(db);
let param_type_ast = param.type_clause(db).ty(db);
let type_name = param_type_ast.as_syntax_node().get_text(db);
if skip_param_serialization {
continue;
}
let param_serialization = RewriteNode::interpolate_patched(
&format!("serde::Serde::<{type_name}>::serialize(ref data, $param_name$);\n "),
HashMap::from([(
"param_name".to_string(),
RewriteNode::Trimmed(param_name.as_syntax_node()),
)]),
);
param_serializations.push(param_serialization);
}
if !function_ast.body(db).statements(db).elements(db).is_empty() {
diagnostics.push(PluginDiagnostic {
stable_ptr: function_ast.body(db).statements(db).stable_ptr().untyped(),
message: "Event function body must be empty.".to_string(),
});
}
if !diagnostics.is_empty() {
return (None, diagnostics);
}
let name = declaration.name(db).text(db);
let event_key = format!("0x{:x}", starknet_keccak(name.as_bytes()));
(
Some((
RewriteNode::interpolate_patched(
&format!(
"
$attrs$
$declaration$ {{
let mut keys = array_new();
array_append(ref keys, {event_key});
let mut data = array_new();
$param_serializations$
starknet::emit_event_syscall(keys, data).unwrap_syscall()
}}
"
),
HashMap::from([
(
"attrs".to_string(),
RewriteNode::Trimmed(function_ast.attributes(db).as_syntax_node()),
),
("declaration".to_string(), RewriteNode::Trimmed(declaration.as_syntax_node())),
(
"param_serializations".to_string(),
RewriteNode::Modified(ModifiedNode { children: param_serializations }),
),
]),
),
RewriteNode::Modified(ModifiedNode {
children: vec![
RewriteNode::Text("#[event]\n ".to_string()),
RewriteNode::Trimmed(function_ast.declaration(db).as_syntax_node()),
RewriteNode::Text(";\n ".to_string()),
],
}),
)),
diagnostics,
)
}
fn handle_storage_struct(
db: &dyn SyntaxGroup,
struct_ast: ast::ItemStruct,
) -> (RewriteNode, Vec<PluginDiagnostic>) {
let mut members_code = Vec::new();
let diagnostics = vec![];
for member in struct_ast.members(db).elements(db) {
let name = member.name(db).text(db);
let address = format!("0x{:x}", starknet_keccak(name.as_bytes()));
let type_ast = member.type_clause(db).ty(db);
members_code.push(
if let Some((key_type_ast, value_type_ast)) = try_extract_mapping_types(db, &type_ast) {
RewriteNode::interpolate_patched(
handle_mapping_storage_var(&address).as_str(),
HashMap::from([
(
"storage_var_name".to_string(),
RewriteNode::Trimmed(member.name(db).as_syntax_node()),
),
(
"key_type".to_string(),
RewriteNode::Trimmed(key_type_ast.as_syntax_node()),
),
(
"value_type".to_string(),
RewriteNode::Trimmed(value_type_ast.as_syntax_node()),
),
]),
)
} else {
RewriteNode::interpolate_patched(
handle_simple_storage_var(&address).as_str(),
HashMap::from([
(
"storage_var_name".to_string(),
RewriteNode::Trimmed(member.name(db).as_syntax_node()),
),
("type_name".to_string(), RewriteNode::Trimmed(type_ast.as_syntax_node())),
]),
)
},
);
}
(RewriteNode::Modified(ModifiedNode { children: members_code }), diagnostics)
}
fn try_extract_mapping_types(
db: &dyn SyntaxGroup,
type_ast: &ast::Expr,
) -> Option<(ast::Expr, ast::Expr)> {
let as_path = try_extract_matches!(type_ast, ast::Expr::Path)?;
let [ast::PathSegment::WithGenericArgs(segment)] = &as_path.elements(db)[..] else {
return None;
};
let ty = segment.ident(db).text(db);
if ty == "Map" {
let [key_ty, value_ty] =
<[ast::Expr; 2]>::try_from(segment.generic_args(db).generic_args(db).elements(db))
.ok()?;
Some((key_ty, value_ty))
} else {
None
}
}
fn handle_simple_storage_var(address: &str) -> String {
format!(
"
mod $storage_var_name$ {{
use starknet::SyscallResultTrait;
use starknet::SyscallResultTraitImpl;
fn address() -> starknet::StorageBaseAddress {{
starknet::storage_base_address_const::<{address}>()
}}
fn read() -> $type_name$ {{
// Only address_domain 0 is currently supported.
let address_domain = 0;
starknet::StorageAccess::<$type_name$>::read(
address_domain,
address(),
).unwrap_syscall()
}}
fn write(value: $type_name$) {{
// Only address_domain 0 is currently supported.
let address_domain = 0;
starknet::StorageAccess::<$type_name$>::write(
address_domain,
address(),
value,
).unwrap_syscall()
}}
}}"
)
}
fn handle_mapping_storage_var(address: &str) -> String {
format!(
"
mod $storage_var_name$ {{
use starknet::SyscallResultTrait;
use starknet::SyscallResultTraitImpl;
fn address(key: $key_type$) -> starknet::StorageBaseAddress {{
starknet::storage_base_address_from_felt(
hash::LegacyHash::<$key_type$>::hash({address}, key))
}}
fn read(key: $key_type$) -> $value_type$ {{
// Only address_domain 0 is currently supported.
let address_domain = 0;
starknet::StorageAccess::<$value_type$>::read(
address_domain,
address(key),
).unwrap_syscall()
}}
fn write(key: $key_type$, value: $value_type$) {{
// Only address_domain 0 is currently supported.
let address_domain = 0;
starknet::StorageAccess::<$value_type$>::write(
address_domain,
address(key),
value,
).unwrap_syscall()
}}
}}"
)
}
fn generate_entry_point_wrapper(
db: &dyn SyntaxGroup,
function: &FunctionWithBody,
) -> Result<RewriteNode, Vec<PluginDiagnostic>> {
let declaration = function.declaration(db);
let sig = declaration.signature(db);
let params = sig.parameters(db).elements(db);
let diagnostics = vec![];
let mut arg_names = Vec::new();
let mut arg_definitions = Vec::new();
let mut ref_appends = Vec::new();
let input_data_short_err = "'Input too short for arguments'";
for param in params {
let arg_name = format!("__arg_{}", param.name(db).text(db));
let arg_type_ast = param.type_clause(db).ty(db);
let type_name = arg_type_ast.as_syntax_node().get_text_without_trivia(db);
let is_ref = is_ref_param(db, ¶m);
let ref_modifier = if is_ref { "ref " } else { "" };
arg_names.push(format!("{ref_modifier}{arg_name}"));
let mut_modifier = if is_ref { "mut " } else { "" };
let arg_definition = format!(
"
let {mut_modifier}{arg_name} =
match serde::Serde::<{type_name}>::deserialize(ref data) {{
Option::Some(x) => x,
Option::None(()) => {{
let mut err_data = array_new();
array_append(ref err_data, {input_data_short_err});
panic(err_data)
}},
}};"
);
arg_definitions.push(arg_definition);
if is_ref {
ref_appends.push(RewriteNode::Text(format!(
"\n serde::Serde::<{type_name}>::serialize(ref arr, {arg_name});"
)));
}
}
let arg_names_str = arg_names.join(", ");
let function_name = RewriteNode::Trimmed(declaration.name(db).as_syntax_node());
let wrapped_name = RewriteNode::interpolate_patched(
"super::$function_name$",
HashMap::from([("function_name".to_string(), function_name.clone())]),
);
let (let_res, append_res) = match sig.ret_ty(db) {
OptionReturnTypeClause::Empty(_) => ("", "".to_string()),
OptionReturnTypeClause::ReturnTypeClause(ty) => {
let ret_type_ast = ty.ty(db);
let ret_type_name = ret_type_ast.as_syntax_node().get_text_without_trivia(db);
(
"\n let res = ",
format!("\n serde::Serde::<{ret_type_name}>::serialize(ref arr, res)"),
)
}
};
if !diagnostics.is_empty() {
return Err(diagnostics);
}
let oog_err = "'Out of gas'";
let input_data_long_err = "'Input too long for arguments'";
let arg_definitions = arg_definitions.join("\n");
Ok(RewriteNode::interpolate_patched(
format!(
"fn $function_name$(mut data: Array::<felt>) -> Array::<felt> {{
match get_gas() {{
Option::Some(_) => {{
}},
Option::None(_) => {{
let mut err_data = array_new();
array_append(ref err_data, {oog_err});
panic(err_data)
}},
}}
{arg_definitions}
if array_len(ref data) != 0_u128 {{
// Force the inclusion of `System` in the list of implicits.
starknet::use_system_implicit();
let mut err_data = array_new();
array_append(ref err_data, {input_data_long_err});
panic(err_data);
}}
match get_gas_all(get_builtin_costs()) {{
Option::Some(_) => {{
}},
Option::None(_) => {{
let mut err_data = array_new();
array_append(ref err_data, {oog_err});
panic(err_data)
}},
}}
{let_res}$wrapped_name$({arg_names_str});
let mut arr = array_new();
// References.$ref_appends$
// Result.{append_res}
arr
}}"
)
.as_str(),
HashMap::from([
("function_name".to_string(), function_name),
("wrapped_name".to_string(), wrapped_name),
(
"ref_appends".to_string(),
RewriteNode::Modified(ModifiedNode { children: ref_appends }),
),
("nothing".to_string(), RewriteNode::Text("".to_string())),
]),
))
}
fn is_ref_param(db: &dyn SyntaxGroup, param: &Param) -> bool {
let param_modifiers = param.modifiers(db).elements(db);
param_modifiers.len() == 1 && matches!(param_modifiers[0], Modifier::Ref(_))
}