1use std::{collections::HashSet, ops::DerefMut, sync::Arc};
2
3use faststr::FastStr;
4use fxhash::FxHashMap;
5use itertools::Itertools;
6use quote::quote;
7
8use crate::{
9 Context,
10 db::RirDatabase,
11 rir::{EnumVariant, Field, Item},
12 symbol::{DefId, EnumRepr},
13 tags::EnumMode,
14 ty::{self, Ty, Visitor},
15};
16
17pub use self::serde::SerdePlugin;
18
19mod serde;
20
21pub trait Plugin: Sync + Send {
22 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
23 walk_item(self, cx, def_id, item)
24 }
25
26 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
27 walk_field(self, cx, def_id, f)
28 }
29
30 fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
31 walk_variant(self, cx, def_id, variant)
32 }
33
34 fn on_emit(&mut self, _cx: &Context) {}
35}
36
37pub trait ClonePlugin: Plugin {
38 fn clone_box(&self) -> Box<dyn ClonePlugin>;
39}
40
41pub struct BoxClonePlugin(Box<dyn ClonePlugin>);
42
43impl BoxClonePlugin {
44 pub fn new<P: ClonePlugin + 'static>(p: P) -> Self {
45 Self(Box::new(p))
46 }
47}
48
49impl Clone for BoxClonePlugin {
50 fn clone(&self) -> Self {
51 Self(self.0.clone_box())
52 }
53}
54
55impl Plugin for BoxClonePlugin {
56 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
57 self.0.on_item(cx, def_id, item)
58 }
59
60 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
61 self.0.on_field(cx, def_id, f)
62 }
63
64 fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
65 self.0.on_variant(cx, def_id, variant)
66 }
67
68 fn on_emit(&mut self, cx: &Context) {
69 self.0.on_emit(cx)
70 }
71}
72
73impl<T> ClonePlugin for T
74 where
75 T: Plugin + Clone + 'static,
76{
77 fn clone_box(&self) -> Box<dyn ClonePlugin> {
78 Box::new(self.clone())
79 }
80}
81
82impl<T> Plugin for &mut T
83 where
84 T: Plugin,
85{
86 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
87 (*self).on_item(cx, def_id, item)
88 }
89
90 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
91 (*self).on_field(cx, def_id, f)
92 }
93
94 fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
95 (*self).on_variant(cx, def_id, variant)
96 }
97
98 fn on_emit(&mut self, cx: &Context) {
99 (*self).on_emit(cx)
100 }
101}
102
103#[allow(clippy::single_match)]
104pub fn walk_item<P: Plugin + ?Sized>(p: &mut P, cx: &Context, _def_id: DefId, item: Arc<Item>) {
105 match &*item {
106 Item::Message(s) => s
107 .fields
108 .iter()
109 .for_each(|f| p.on_field(cx, f.did, f.clone())),
110 Item::Enum(e) => e
111 .variants
112 .iter()
113 .for_each(|v| p.on_variant(cx, v.did, v.clone())),
114 _ => {}
115 }
116}
117
118pub fn walk_field<P: Plugin + ?Sized>(
119 _p: &mut P,
120 _cx: &Context,
121 _def_id: DefId,
122 _field: Arc<Field>,
123) {}
124
125pub fn walk_variant<P: Plugin + ?Sized>(
126 _p: &mut P,
127 _cx: &Context,
128 _def_id: DefId,
129 _variant: Arc<EnumVariant>,
130) {}
131
132pub struct BoxedPlugin;
133
134impl Plugin for BoxedPlugin {
135 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
136 if let Item::Message(s) = &*item {
137 s.fields.iter().for_each(|f| {
138 if let ty::Path(p) = &f.ty.kind {
139 if cx.type_graph().is_nested(p.did, def_id) {
140 cx.with_adjust_mut(f.did, |adj| adj.set_boxed())
141 }
142 }
143 })
144 }
145 walk_item(self, cx, def_id, item)
146 }
147}
148
149pub struct AutoDerivePlugin<F> {
150 can_derive: FxHashMap<DefId, CanDerive>,
151 predicate: F,
152 attrs: Arc<[FastStr]>,
153}
154
155impl<F> AutoDerivePlugin<F>
156 where
157 F: Fn(&Ty) -> PredicateResult,
158{
159 pub fn new(attrs: Arc<[FastStr]>, f: F) -> Self {
160 Self {
161 can_derive: FxHashMap::default(),
162 predicate: f,
163 attrs,
164 }
165 }
166}
167
168#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
169pub enum CanDerive {
170 Yes,
171 No,
172 Delay, }
174
175pub enum PredicateResult {
176 No,
177 GoOn, }
180
181#[derive(Default)]
182pub struct PathCollector {
183 paths: Vec<crate::rir::Path>,
184}
185
186impl super::ty::Visitor for PathCollector {
187 fn visit_path(&mut self, path: &crate::rir::Path) {
188 self.paths.push(path.clone())
189 }
190}
191
192impl<F> AutoDerivePlugin<F>
193 where
194 F: Fn(&Ty) -> PredicateResult,
195{
196 fn can_derive(
197 &mut self,
198 cx: &Context,
199 def_id: DefId,
200 visiting: &mut HashSet<DefId>,
201 delayed: &mut HashSet<DefId>,
202 ) -> CanDerive {
203 if let Some(b) = self.can_derive.get(&def_id) {
204 return *b;
205 }
206 if visiting.contains(&def_id) {
207 return CanDerive::Delay;
208 }
209 visiting.insert(def_id);
210 let item = cx.expect_item(def_id);
211 let deps = match &*item {
212 Item::Message(s) => s.fields.iter().map(|f| &f.ty).collect::<Vec<_>>(),
213 Item::Enum(e) => e
214 .variants
215 .iter()
216 .flat_map(|v| &v.fields)
217 .collect::<Vec<_>>(),
218 Item::Service(_) => return CanDerive::No,
219 Item::NewType(t) => vec![&t.ty],
220 Item::Const(_) => return CanDerive::No,
221 Item::Mod(_) => return CanDerive::No,
222 };
223
224 let can_derive = if deps
225 .iter()
226 .any(|t| matches!((self.predicate)(t), PredicateResult::No))
227 {
228 CanDerive::No
229 } else {
230 let paths = deps.iter().flat_map(|t| {
231 let mut visitor = PathCollector::default();
232 visitor.visit(t);
233 visitor.paths
234 });
235 let paths_can_derive = paths
236 .map(|p| (p.did, self.can_derive(cx, p.did, visiting, delayed)))
237 .collect::<Vec<_>>();
238
239 let delayed_count = paths_can_derive
240 .iter()
241 .filter(|(_, p)| *p == CanDerive::Delay)
242 .count();
243
244 if paths_can_derive.iter().any(|(_, p)| *p == CanDerive::No) {
245 delayed.iter().for_each(|def_id| {
246 self.can_derive.insert(*def_id, CanDerive::No);
247 });
248
249 CanDerive::No
250 } else if delayed_count > 0 {
251 delayed.insert(def_id);
252 CanDerive::Delay
253 } else {
254 CanDerive::Yes
255 }
256 };
257
258 self.can_derive.insert(def_id, can_derive);
259 visiting.remove(&def_id);
260
261 can_derive
262 }
263}
264
265impl<F> Plugin for AutoDerivePlugin<F>
266 where
267 F: Fn(&Ty) -> PredicateResult + Send + Sync,
268{
269 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
270 self.can_derive(cx, def_id, &mut HashSet::default(), &mut HashSet::default());
271 walk_item(self, cx, def_id, item)
272 }
273
274 fn on_emit(&mut self, cx: &Context) {
275 self.can_derive.iter().for_each(|(def_id, can_derive)| {
276 if !matches!(can_derive, CanDerive::No) {
277 cx.with_adjust_mut(*def_id, |adj| adj.add_attrs(&self.attrs));
278 }
279 })
280 }
281}
282
283impl<T> Plugin for Box<T>
284 where
285 T: Plugin + ?Sized,
286{
287 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
288 self.deref_mut().on_item(cx, def_id, item)
289 }
290
291 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
292 self.deref_mut().on_field(cx, def_id, f)
293 }
294
295 fn on_emit(&mut self, cx: &Context) {
296 self.deref_mut().on_emit(cx)
297 }
298}
299
300pub struct WithAttrsPlugin(pub Arc<[FastStr]>);
301
302impl Plugin for WithAttrsPlugin {
303 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
304 match &*item {
305 Item::Message(_) | Item::Enum(_) | Item::NewType(_) => {
306 cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&self.0))
307 }
308 _ => {}
309 }
310 walk_item(self, cx, def_id, item)
311 }
312}
313
314pub struct ImplDefaultPlugin;
315
316impl Plugin for ImplDefaultPlugin {
317 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
318 match &*item {
319 Item::Message(m) => {
320 let name = cx.rust_name(def_id);
321
322 if m.fields.iter().all(|f| cx.default_val(f).is_none()) {
323 cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]));
324 } else {
325 let fields = m
326 .fields
327 .iter()
328 .map(|f| {
329 let name = cx.rust_name(f.did);
330 let default = cx.default_val(f).map(|v| v.0);
331
332 if let Some(default) = default {
333 let mut val = default;
334 if f.is_optional() {
335 val = format!("Some({val})").into()
336 }
337 format!("{name}: {val}")
338 } else {
339 format!("{name}: Default::default()")
340 }
341 })
342 .join(",\n");
343
344 cx.with_adjust_mut(def_id, |adj| {
345 adj.add_nested_item(
346 format!(
347 r#"
348 impl Default for {name} {{
349 fn default() -> Self {{
350 {name} {{
351 {fields}
352 }}
353 }}
354 }}
355 "#
356 )
357 .into(),
358 )
359 });
360 };
361 }
362 Item::NewType(_) => {
363 cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]))
364 }
365 Item::Enum(e) => {
366 if !e.variants.is_empty() {
367 cx.with_adjust_mut(def_id, |adj| {
368 adj.add_attrs(&[
369 "#[derive(::pilota::derivative::Derivative)]".into(),
370 "#[derivative(Default)]".into(),
371 ]);
372 });
373
374 if let Some(v) = e.variants.first() {
375 cx.with_adjust_mut(v.did, |adj| {
376 adj.add_attrs(&["#[derivative(Default)]".into()]);
377 })
378 }
379 }
380 }
381 _ => {}
382 }
383 walk_item(self, cx, def_id, item)
384 }
385}
386
387pub struct EnumNumPlugin;
388
389impl Plugin for EnumNumPlugin {
390 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
391 match &*item {
392 Item::Enum(e)
393 if e.repr.is_some()
394 && cx
395 .node_tags(def_id)
396 .unwrap()
397 .get::<EnumMode>()
398 .copied()
399 .unwrap_or(EnumMode::Enum)
400 == EnumMode::Enum =>
401 {
402 let name_str = &*cx.rust_name(def_id);
403 let name = name_str;
404 let num_ty = match e.repr {
405 Some(EnumRepr::I32) => quote!(i32),
406 None => return,
407 };
408 let variants = e
409 .variants
410 .iter()
411 .map(|v| {
412 let variant_name_str = cx.rust_name(v.did);
413 let variant_name = variant_name_str;
414 format!(
415 "{variant_name} => ::std::result::Result::Ok({name}::{variant_name}), \n"
416 )
417 })
418 .join("");
419
420 let nums = e
421 .variants
422 .iter()
423 .map(|v| {
424 let variant_name_str = cx.rust_name(v.did);
425 let variant_name = variant_name_str;
426 format!(
427 "const {variant_name}: {num_ty} = {name}::{variant_name} as {num_ty};"
428 )
429 })
430 .join("\n");
431
432 cx.with_adjust_mut(def_id, |adj| {
433 adj.add_nested_item(format!(r#"
434 impl ::std::convert::From<{name}> for {num_ty} {{
435 fn from(e: {name}) -> Self {{
436 e as _
437 }}
438 }}
439
440 impl ::std::convert::TryFrom<{num_ty}> for {name} {{
441 type Error = ::pilota::EnumConvertError<{num_ty}>;
442
443 #[allow(non_upper_case_globals)]
444 fn try_from(v: i32) -> ::std::result::Result<Self, ::pilota::EnumConvertError<{num_ty}>> {{
445 {nums}
446 match v {{
447 {variants}
448 _ => ::std::result::Result::Err(::pilota::EnumConvertError::InvalidNum(v, "{name_str}")),
449 }}
450 }}
451 }}"#).into(),
452 )
453 });
454 }
455 _ => {}
456 }
457 walk_item(self, cx, def_id, item)
458 }
459}