Skip to main content

cairo_lang_plugins/plugins/derive/
mod.rs

1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_syntax::attribute::structured::{
6    AttributeArg, AttributeArgVariant, AttributeStructurize,
7};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::QueryAttrs;
10use cairo_lang_syntax::node::{TypedSyntaxNode, ast};
11
12use super::utils::PluginTypeInfo;
13use crate::plugins::DOC_ATTR;
14
15mod clone;
16mod debug;
17mod default;
18mod destruct;
19mod hash;
20mod panic_destruct;
21mod partial_eq;
22mod serde;
23
24#[derive(Debug, Default)]
25#[non_exhaustive]
26pub struct DerivePlugin;
27
28const DERIVE_ATTR: &str = "derive";
29
30impl MacroPlugin for DerivePlugin {
31    fn generate_code(
32        &self,
33        db: &dyn SyntaxGroup,
34        item_ast: ast::ModuleItem,
35        metadata: &MacroPluginMetadata<'_>,
36    ) -> PluginResult {
37        generate_derive_code_for_type(
38            db,
39            metadata,
40            match PluginTypeInfo::new(db, &item_ast) {
41                Some(info) => info,
42                None => {
43                    let maybe_error = item_ast.find_attr(db, DERIVE_ATTR).map(|derive_attr| {
44                        vec![PluginDiagnostic::error(
45                            derive_attr.as_syntax_node().stable_ptr(db),
46                            "`derive` may only be applied to `struct`s and `enum`s".to_string(),
47                        )]
48                    });
49
50                    return PluginResult {
51                        diagnostics: maybe_error.unwrap_or_default(),
52                        ..PluginResult::default()
53                    };
54                }
55            },
56        )
57    }
58
59    fn declared_attributes(&self) -> Vec<String> {
60        vec![DERIVE_ATTR.to_string(), default::DEFAULT_ATTR.to_string()]
61    }
62
63    fn declared_derives(&self) -> Vec<String> {
64        vec![
65            "Copy".to_string(),
66            "Drop".to_string(),
67            "Clone".to_string(),
68            "Debug".to_string(),
69            "Default".to_string(),
70            "Destruct".to_string(),
71            "Hash".to_string(),
72            "PanicDestruct".to_string(),
73            "PartialEq".to_string(),
74            "Serde".to_string(),
75        ]
76    }
77}
78
79/// Adds an implementation for all requested derives for the type.
80fn generate_derive_code_for_type(
81    db: &dyn SyntaxGroup,
82    metadata: &MacroPluginMetadata<'_>,
83    info: PluginTypeInfo,
84) -> PluginResult {
85    let mut diagnostics = vec![];
86    let mut builder = PatchBuilder::new(db, &info.attributes);
87    for attr in info.attributes.query_attr(db, DERIVE_ATTR) {
88        let attr = attr.structurize(db);
89
90        if attr.args.is_empty() {
91            diagnostics
92                .push(PluginDiagnostic::error(attr.args_stable_ptr, "Expected args.".into()));
93            continue;
94        }
95
96        for arg in attr.args {
97            let AttributeArg {
98                variant: AttributeArgVariant::Unnamed(ast::Expr::Path(derived_path)),
99                ..
100            } = arg
101            else {
102                diagnostics
103                    .push(PluginDiagnostic::error(arg.arg.stable_ptr(db), "Expected path.".into()));
104                continue;
105            };
106
107            let derived = derived_path.as_syntax_node().get_text_without_trivia(db);
108            if let Some(mut code) = match derived.as_str() {
109                "Copy" | "Drop" => Some(get_empty_impl(&derived, &info)),
110                "Clone" => Some(clone::handle_clone(&info)),
111                "Debug" => Some(debug::handle_debug(&info)),
112                "Default" => default::handle_default(db, &info, &derived_path, &mut diagnostics),
113                "Destruct" => Some(destruct::handle_destruct(&info)),
114                "Hash" => Some(hash::handle_hash(&info)),
115                "PanicDestruct" => Some(panic_destruct::handle_panic_destruct(&info)),
116                "PartialEq" => Some(partial_eq::handle_partial_eq(&info)),
117                "Serde" => Some(serde::handle_serde(&info)),
118                _ => {
119                    if !metadata.declared_derives.contains(&derived) {
120                        diagnostics.push(PluginDiagnostic::error(
121                            derived_path.stable_ptr(db),
122                            format!("Unknown derive `{derived}` - a plugin might be missing."),
123                        ));
124                    }
125                    None
126                }
127            } {
128                if let Some(doc_attr) = info.attributes.find_attr(db, DOC_ATTR) {
129                    code =
130                        format!("{}\n{code}", doc_attr.as_syntax_node().get_text_without_trivia(db))
131                }
132                builder.add_modified(RewriteNode::mapped_text(code, db, &derived_path));
133            }
134        }
135    }
136    let (content, code_mappings) = builder.build();
137    PluginResult {
138        code: (!content.is_empty()).then(|| PluginGeneratedFile {
139            name: "impls".into(),
140            code_mappings,
141            content,
142            aux_data: None,
143            diagnostics_note: Default::default(),
144            is_unhygienic: false,
145        }),
146        diagnostics,
147        remove_original_item: false,
148    }
149}
150
151fn get_empty_impl(derived_trait: &str, info: &PluginTypeInfo) -> String {
152    let derive_trait = format!("core::traits::{derived_trait}");
153    format!("{};\n", info.impl_header(&derive_trait, &[&derive_trait]))
154}