cairo_lang_plugins/plugins/derive/
mod.rs

1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_filesystem::ids::SmolStrId;
6use cairo_lang_syntax::attribute::structured::{
7    AttributeArg, AttributeArgVariant, AttributeStructurize,
8};
9use cairo_lang_syntax::node::helpers::QueryAttrs;
10use cairo_lang_syntax::node::{TypedSyntaxNode, ast};
11use salsa::Database;
12
13use super::utils::PluginTypeInfo;
14use crate::plugins::DOC_ATTR;
15
16mod clone;
17mod debug;
18mod default;
19mod destruct;
20mod hash;
21mod panic_destruct;
22mod partial_eq;
23mod serde;
24
25#[derive(Debug, Default)]
26#[non_exhaustive]
27pub struct DerivePlugin;
28
29const DERIVE_ATTR: &str = "derive";
30
31impl MacroPlugin for DerivePlugin {
32    fn generate_code<'db>(
33        &self,
34        db: &'db dyn Database,
35        item_ast: ast::ModuleItem<'db>,
36        metadata: &MacroPluginMetadata<'_>,
37    ) -> PluginResult<'db> {
38        generate_derive_code_for_type(
39            db,
40            metadata,
41            match PluginTypeInfo::new(db, &item_ast) {
42                Some(info) => info,
43                None => {
44                    let maybe_error = item_ast.find_attr(db, DERIVE_ATTR).map(|derive_attr| {
45                        vec![PluginDiagnostic::error(
46                            derive_attr.as_syntax_node().stable_ptr(db),
47                            "`derive` may only be applied to `struct`s and `enum`s".to_string(),
48                        )]
49                    });
50
51                    return PluginResult {
52                        diagnostics: maybe_error.unwrap_or_default(),
53                        ..PluginResult::default()
54                    };
55                }
56            },
57        )
58    }
59
60    fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
61        vec![SmolStrId::from(db, DERIVE_ATTR), SmolStrId::from(db, default::DEFAULT_ATTR)]
62    }
63
64    fn declared_derives<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
65        vec![
66            SmolStrId::from(db, "Copy"),
67            SmolStrId::from(db, "Drop"),
68            SmolStrId::from(db, "Clone"),
69            SmolStrId::from(db, "Debug"),
70            SmolStrId::from(db, "Default"),
71            SmolStrId::from(db, "Destruct"),
72            SmolStrId::from(db, "Hash"),
73            SmolStrId::from(db, "PanicDestruct"),
74            SmolStrId::from(db, "PartialEq"),
75            SmolStrId::from(db, "Serde"),
76        ]
77    }
78}
79
80/// Adds an implementation for all requested derives for the type.
81fn generate_derive_code_for_type<'db>(
82    db: &'db dyn Database,
83    metadata: &MacroPluginMetadata<'_>,
84    info: PluginTypeInfo<'db>,
85) -> PluginResult<'db> {
86    let mut diagnostics = vec![];
87    let mut builder = PatchBuilder::new(db, &info.attributes);
88    for attr in info.attributes.query_attr(db, DERIVE_ATTR) {
89        let attr = attr.structurize(db);
90
91        if attr.args.is_empty() {
92            diagnostics
93                .push(PluginDiagnostic::error(attr.args_stable_ptr, "Expected args.".into()));
94            continue;
95        }
96        let member_access_desnaps = metadata.edition.member_access_desnaps();
97
98        for arg in attr.args {
99            let AttributeArg {
100                variant: AttributeArgVariant::Unnamed(ast::Expr::Path(derived_path)),
101                ..
102            } = arg
103            else {
104                diagnostics
105                    .push(PluginDiagnostic::error(arg.arg.stable_ptr(db), "Expected path.".into()));
106                continue;
107            };
108
109            let derived = derived_path.as_syntax_node().get_text_without_trivia(db);
110            let derived_text = derived.long(db).as_str();
111            if let Some(mut code) = match derived_text {
112                "Copy" | "Drop" => Some(get_empty_impl(derived_text, &info)),
113                "Clone" => Some(clone::handle_clone(&info, member_access_desnaps)),
114                "Debug" => Some(debug::handle_debug(&info, member_access_desnaps)),
115                "Default" => default::handle_default(db, &info, &derived_path, &mut diagnostics),
116                "Destruct" => Some(destruct::handle_destruct(&info)),
117                "Hash" => Some(hash::handle_hash(&info)),
118                "PanicDestruct" => Some(panic_destruct::handle_panic_destruct(&info)),
119                "PartialEq" => Some(partial_eq::handle_partial_eq(&info, member_access_desnaps)),
120                "Serde" => Some(serde::handle_serde(&info, member_access_desnaps)),
121                _ => {
122                    if !metadata.declared_derives.contains(&derived) {
123                        diagnostics.push(PluginDiagnostic::error(
124                            derived_path.stable_ptr(db),
125                            format!(
126                                "Unknown derive `{derived_text}` - a plugin might be missing.",
127                            ),
128                        ));
129                    }
130                    None
131                }
132            } {
133                if let Some(doc_attr) = info.attributes.find_attr(db, DOC_ATTR) {
134                    code = format!(
135                        "{}\n{code}",
136                        doc_attr.as_syntax_node().get_text_without_trivia(db).long(db)
137                    )
138                }
139                builder.add_modified(RewriteNode::mapped_text(code, db, &derived_path));
140            }
141        }
142    }
143    let (content, code_mappings) = builder.build();
144    PluginResult {
145        code: (!content.is_empty()).then(|| PluginGeneratedFile {
146            name: "impls".into(),
147            code_mappings,
148            content,
149            aux_data: None,
150            diagnostics_note: Default::default(),
151            is_unhygienic: false,
152        }),
153        diagnostics,
154        remove_original_item: false,
155    }
156}
157
158fn get_empty_impl(derived_trait: &str, info: &PluginTypeInfo<'_>) -> String {
159    let derive_trait = format!("core::traits::{derived_trait}");
160    format!("{};\n", info.impl_header(&derive_trait, &[&derive_trait]))
161}