cairo_lang_plugins/plugins/
utils.rs

1use cairo_lang_filesystem::ids::SmolStrId;
2use cairo_lang_syntax::node::helpers::{GenericParamEx, IsDependentType};
3use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
4use itertools::{Itertools, chain};
5use salsa::Database;
6
7/// Information on struct members or enum variants.
8pub struct MemberInfo<'a> {
9    pub name: &'a str,
10    pub ty: &'a str,
11    pub attributes: ast::AttributeList<'a>,
12    pub is_generics_dependent: bool,
13}
14impl<'a> MemberInfo<'a> {
15    pub fn impl_name(&self, trt: &str) -> String {
16        if self.is_generics_dependent {
17            let short_name = trt.split("::").last().unwrap_or(trt);
18            format!("__MEMBER_IMPL_{}_{short_name}", self.name)
19        } else {
20            format!("{}::<{}>", trt, self.ty)
21        }
22    }
23    pub fn drop_with(&self) -> String {
24        if self.is_generics_dependent {
25            format!("core::internal::DropWith::<{}, {}>", self.ty, self.impl_name("Drop"))
26        } else {
27            format!("core::internal::InferDrop::<{}>", self.ty)
28        }
29    }
30    pub fn destruct_with(&self) -> String {
31        if self.is_generics_dependent {
32            format!("core::internal::DestructWith::<{}, {}>", self.ty, self.impl_name("Destruct"))
33        } else {
34            format!("core::internal::InferDestruct::<{}>", self.ty)
35        }
36    }
37}
38
39/// Information on the type being derived.
40pub enum TypeVariant {
41    Enum,
42    Struct,
43}
44
45/// Information on generic params.
46pub struct GenericParamsInfo<'a> {
47    /// All the generic param names, at the original order.
48    pub param_names: Vec<&'a str>,
49    /// The full generic params, including keywords and definitions.
50    pub full_params: Vec<&'a str>,
51}
52impl<'a> GenericParamsInfo<'a> {
53    /// Extracts the information on generic params.
54    pub fn new(
55        db: &'a dyn Database,
56        generic_params: ast::OptionWrappedGenericParamList<'a>,
57    ) -> Self {
58        let ast::OptionWrappedGenericParamList::WrappedGenericParamList(gens) = generic_params
59        else {
60            return Self { param_names: Default::default(), full_params: Default::default() };
61        };
62        let (param_names, full_params) = gens
63            .generic_params(db)
64            .elements(db)
65            .map(|param| {
66                let name =
67                    param.name(db).map(|n| n.text(db).long(db).as_str()).unwrap_or_else(|| "_");
68                let full_param =
69                    param.as_syntax_node().get_text_without_trivia(db).long(db).as_str();
70                (name, full_param)
71            })
72            .unzip();
73        Self { param_names, full_params }
74    }
75}
76
77/// Information for the type being processed by a plugin.
78pub struct PluginTypeInfo<'a> {
79    pub name: &'a str,
80    pub attributes: ast::AttributeList<'a>,
81    pub generics: GenericParamsInfo<'a>,
82    pub members_info: Vec<MemberInfo<'a>>,
83    pub type_variant: TypeVariant,
84}
85impl<'a> PluginTypeInfo<'a> {
86    /// Extracts the information on the type being derived.
87    pub fn new(db: &'a dyn Database, item_ast: &ast::ModuleItem<'a>) -> Option<Self> {
88        match item_ast {
89            ast::ModuleItem::Struct(struct_ast) => {
90                let generics = GenericParamsInfo::new(db, struct_ast.generic_params(db));
91                let interned =
92                    generics.param_names.iter().map(|s| SmolStrId::from(db, *s)).collect_vec();
93                let members_info = extract_members(db, struct_ast.members(db), &interned);
94                Some(Self {
95                    name: struct_ast.name(db).text(db).long(db).as_str(),
96                    attributes: struct_ast.attributes(db),
97                    generics,
98                    members_info,
99                    type_variant: TypeVariant::Struct,
100                })
101            }
102            ast::ModuleItem::Enum(enum_ast) => {
103                let generics = GenericParamsInfo::new(db, enum_ast.generic_params(db));
104                let members_info =
105                    extract_variants(db, enum_ast.variants(db), &generics.param_names);
106                Some(Self {
107                    name: enum_ast.name(db).text(db).long(db).as_str(),
108                    attributes: enum_ast.attributes(db),
109                    generics,
110                    members_info,
111                    type_variant: TypeVariant::Enum,
112                })
113            }
114            _ => None,
115        }
116    }
117
118    /// Returns a full derived impl header - given `derived_trait` - and the `dependent_traits`
119    /// required for all its members.
120    pub fn impl_header(&self, derived_trait: &str, dependent_traits: &[&str]) -> String {
121        let derived_trait_name = derived_trait.split("::").last().unwrap_or(derived_trait);
122        format!(
123            "impl {name}{derived_trait_name}<{generics}> of {derived_trait}::<{full_typename}>",
124            name = self.name,
125            generics =
126                self.impl_generics(dependent_traits, |trt, ty| format!("{trt}<{ty}>")).join(", "),
127            full_typename = self.full_typename(),
128        )
129    }
130
131    /// Returns the expected generics parameters for a derived impl definition.
132    ///
133    /// `dep_req` - is the formatting of a trait and the type as a concrete trait.
134    pub fn impl_generics(
135        &self,
136        dependent_traits: &[&str],
137        dep_req: fn(&str, &str) -> String,
138    ) -> Vec<String> {
139        chain!(
140            self.generics.full_params.iter().map(ToString::to_string),
141            self.members_info.iter().filter(|m| m.is_generics_dependent).flat_map(|m| {
142                dependent_traits
143                    .iter()
144                    .cloned()
145                    .map(move |trt| format!("impl {}: {}", m.impl_name(trt), dep_req(trt, m.ty)))
146            })
147        )
148        .collect()
149    }
150
151    /// Formats the full typename of the type, including generic args.
152    pub fn full_typename(&self) -> String {
153        if self.generics.param_names.is_empty() {
154            self.name.to_string()
155        } else {
156            format!("{}<{}>", self.name, self.generics.param_names.iter().join(", "))
157        }
158    }
159}
160
161/// Extracts the information on the members of the struct.
162fn extract_members<'a>(
163    db: &'a dyn Database,
164    members: ast::MemberList<'a>,
165    generics: &[SmolStrId<'a>],
166) -> Vec<MemberInfo<'a>> {
167    members
168        .elements(db)
169        .map(|member| MemberInfo {
170            name: member.name(db).text(db).long(db).as_str(),
171            ty: member
172                .type_clause(db)
173                .ty(db)
174                .as_syntax_node()
175                .get_text_without_trivia(db)
176                .long(db)
177                .as_str(),
178            attributes: member.attributes(db),
179            is_generics_dependent: member.type_clause(db).ty(db).is_dependent_type(db, generics),
180        })
181        .collect()
182}
183
184/// Extracts the information on the variants of the enum.
185fn extract_variants<'a>(
186    db: &'a dyn Database,
187    variants: ast::VariantList<'a>,
188    generics: &[&str],
189) -> Vec<MemberInfo<'a>> {
190    variants
191        .elements(db)
192        .map(|variant| MemberInfo {
193            name: variant.name(db).text(db).long(db).as_str(),
194            ty: match variant.type_clause(db) {
195                ast::OptionTypeClause::Empty(_) => "()",
196                ast::OptionTypeClause::TypeClause(t) => {
197                    t.ty(db).as_syntax_node().get_text_without_trivia(db).long(db).as_str()
198                }
199            },
200            attributes: variant.attributes(db),
201            is_generics_dependent: match variant.type_clause(db) {
202                ast::OptionTypeClause::Empty(_) => false,
203                ast::OptionTypeClause::TypeClause(t) => {
204                    let interned = generics.iter().map(|s| SmolStrId::from(db, *s)).collect_vec();
205                    t.ty(db).is_dependent_type(db, &interned)
206                }
207            },
208        })
209        .collect()
210}