cairo_lang_plugins/plugins/
utils.rs1use cairo_lang_syntax::node::db::SyntaxGroup;
2use cairo_lang_syntax::node::helpers::{GenericParamEx, IsDependentType};
3use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
4use itertools::{Itertools, chain};
5use smol_str::SmolStr;
6
7pub struct MemberInfo {
9 pub name: SmolStr,
10 pub ty: String,
11 pub attributes: ast::AttributeList,
12 pub is_generics_dependent: bool,
13}
14impl MemberInfo {
15 pub fn impl_name(&self, trt: &str) -> String {
16 if self.is_generics_dependent {
17 let short_name = trt.split("::").last().unwrap_or(trt);
18 format!("__MEMBER_IMPL_{}_{short_name}", self.name)
19 } else {
20 format!("{}::<{}>", trt, self.ty)
21 }
22 }
23 pub fn drop_with(&self) -> String {
24 if self.is_generics_dependent {
25 format!("core::internal::DropWith::<{}, {}>", self.ty, self.impl_name("Drop"))
26 } else {
27 format!("core::internal::InferDrop::<{}>", self.ty)
28 }
29 }
30 pub fn destruct_with(&self) -> String {
31 if self.is_generics_dependent {
32 format!("core::internal::DestructWith::<{}, {}>", self.ty, self.impl_name("Destruct"))
33 } else {
34 format!("core::internal::InferDestruct::<{}>", self.ty)
35 }
36 }
37}
38
39pub enum TypeVariant {
41 Enum,
42 Struct,
43}
44
45pub struct GenericParamsInfo {
47 pub param_names: Vec<SmolStr>,
49 pub full_params: Vec<String>,
51}
52impl GenericParamsInfo {
53 pub fn new(db: &dyn SyntaxGroup, generic_params: ast::OptionWrappedGenericParamList) -> Self {
55 let ast::OptionWrappedGenericParamList::WrappedGenericParamList(gens) = generic_params
56 else {
57 return Self { param_names: Default::default(), full_params: Default::default() };
58 };
59 let (param_names, full_params) = gens
60 .generic_params(db)
61 .elements(db)
62 .map(|param| {
63 let name = param.name(db).map(|n| n.text(db)).unwrap_or_else(|| "_".into());
64 let full_param = param.as_syntax_node().get_text_without_trivia(db);
65 (name, full_param)
66 })
67 .unzip();
68 Self { param_names, full_params }
69 }
70}
71
72pub struct PluginTypeInfo {
74 pub name: SmolStr,
75 pub attributes: ast::AttributeList,
76 pub generics: GenericParamsInfo,
77 pub members_info: Vec<MemberInfo>,
78 pub type_variant: TypeVariant,
79}
80impl PluginTypeInfo {
81 pub fn new(db: &dyn SyntaxGroup, item_ast: &ast::ModuleItem) -> Option<Self> {
83 match item_ast {
84 ast::ModuleItem::Struct(struct_ast) => {
85 let generics = GenericParamsInfo::new(db, struct_ast.generic_params(db));
86 let members_info = extract_members(
87 db,
88 struct_ast.members(db),
89 &generics.param_names.iter().map(|p| p.as_str()).collect_vec(),
90 );
91 Some(Self {
92 name: struct_ast.name(db).text(db),
93 attributes: struct_ast.attributes(db),
94 generics,
95 members_info,
96 type_variant: TypeVariant::Struct,
97 })
98 }
99 ast::ModuleItem::Enum(enum_ast) => {
100 let generics = GenericParamsInfo::new(db, enum_ast.generic_params(db));
101 let members_info = extract_variants(
102 db,
103 enum_ast.variants(db),
104 &generics.param_names.iter().map(|p| p.as_str()).collect_vec(),
105 );
106 Some(Self {
107 name: enum_ast.name(db).text(db),
108 attributes: enum_ast.attributes(db),
109 generics,
110 members_info,
111 type_variant: TypeVariant::Enum,
112 })
113 }
114 _ => None,
115 }
116 }
117
118 pub fn impl_header(&self, derived_trait: &str, dependent_traits: &[&str]) -> String {
121 let derived_trait_name = derived_trait.split("::").last().unwrap_or(derived_trait);
122 format!(
123 "impl {name}{derived_trait_name}<{generics}> of {derived_trait}::<{full_typename}>",
124 name = self.name,
125 generics =
126 self.impl_generics(dependent_traits, |trt, ty| format!("{trt}<{ty}>")).join(", "),
127 full_typename = self.full_typename(),
128 )
129 }
130
131 pub fn impl_generics(
135 &self,
136 dependent_traits: &[&str],
137 dep_req: fn(&str, &str) -> String,
138 ) -> Vec<String> {
139 chain!(
140 self.generics.full_params.iter().cloned(),
141 self.members_info.iter().filter(|m| m.is_generics_dependent).flat_map(|m| {
142 dependent_traits
143 .iter()
144 .cloned()
145 .map(move |trt| format!("impl {}: {}", m.impl_name(trt), dep_req(trt, &m.ty)))
146 })
147 )
148 .collect()
149 }
150
151 pub fn full_typename(&self) -> String {
153 if self.generics.param_names.is_empty() {
154 self.name.to_string()
155 } else {
156 format!("{}<{}>", self.name, self.generics.param_names.iter().join(", "))
157 }
158 }
159}
160
161fn extract_members(
163 db: &dyn SyntaxGroup,
164 members: ast::MemberList,
165 generics: &[&str],
166) -> Vec<MemberInfo> {
167 members
168 .elements(db)
169 .map(|member| MemberInfo {
170 name: member.name(db).text(db),
171 ty: member.type_clause(db).ty(db).as_syntax_node().get_text_without_trivia(db),
172 attributes: member.attributes(db),
173 is_generics_dependent: member.type_clause(db).ty(db).is_dependent_type(db, generics),
174 })
175 .collect()
176}
177
178fn extract_variants(
180 db: &dyn SyntaxGroup,
181 variants: ast::VariantList,
182 generics: &[&str],
183) -> Vec<MemberInfo> {
184 variants
185 .elements(db)
186 .map(|variant| MemberInfo {
187 name: variant.name(db).text(db),
188 ty: match variant.type_clause(db) {
189 ast::OptionTypeClause::Empty(_) => "()".to_string(),
190 ast::OptionTypeClause::TypeClause(t) => {
191 t.ty(db).as_syntax_node().get_text_without_trivia(db)
192 }
193 },
194 attributes: variant.attributes(db),
195 is_generics_dependent: match variant.type_clause(db) {
196 ast::OptionTypeClause::Empty(_) => false,
197 ast::OptionTypeClause::TypeClause(t) => t.ty(db).is_dependent_type(db, generics),
198 },
199 })
200 .collect()
201}