pilota_build2/plugin/
mod.rs

1use std::{collections::HashSet, ops::DerefMut, sync::Arc};
2
3use faststr::FastStr;
4use fxhash::FxHashMap;
5use itertools::Itertools;
6use quote::quote;
7
8use crate::{
9    Context,
10    db::RirDatabase,
11    rir::{EnumVariant, Field, Item},
12    symbol::{DefId, EnumRepr},
13    tags::EnumMode,
14    ty::{self, Ty, Visitor},
15};
16
17pub use self::serde::SerdePlugin;
18
19mod serde;
20
21pub trait Plugin: Sync + Send {
22    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
23        walk_item(self, cx, def_id, item)
24    }
25
26    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
27        walk_field(self, cx, def_id, f)
28    }
29
30    fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
31        walk_variant(self, cx, def_id, variant)
32    }
33
34    fn on_emit(&mut self, _cx: &Context) {}
35}
36
37pub trait ClonePlugin: Plugin {
38    fn clone_box(&self) -> Box<dyn ClonePlugin>;
39}
40
41pub struct BoxClonePlugin(Box<dyn ClonePlugin>);
42
43impl BoxClonePlugin {
44    pub fn new<P: ClonePlugin + 'static>(p: P) -> Self {
45        Self(Box::new(p))
46    }
47}
48
49impl Clone for BoxClonePlugin {
50    fn clone(&self) -> Self {
51        Self(self.0.clone_box())
52    }
53}
54
55impl Plugin for BoxClonePlugin {
56    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
57        self.0.on_item(cx, def_id, item)
58    }
59
60    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
61        self.0.on_field(cx, def_id, f)
62    }
63
64    fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
65        self.0.on_variant(cx, def_id, variant)
66    }
67
68    fn on_emit(&mut self, cx: &Context) {
69        self.0.on_emit(cx)
70    }
71}
72
73impl<T> ClonePlugin for T
74    where
75        T: Plugin + Clone + 'static,
76{
77    fn clone_box(&self) -> Box<dyn ClonePlugin> {
78        Box::new(self.clone())
79    }
80}
81
82impl<T> Plugin for &mut T
83    where
84        T: Plugin,
85{
86    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
87        (*self).on_item(cx, def_id, item)
88    }
89
90    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
91        (*self).on_field(cx, def_id, f)
92    }
93
94    fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
95        (*self).on_variant(cx, def_id, variant)
96    }
97
98    fn on_emit(&mut self, cx: &Context) {
99        (*self).on_emit(cx)
100    }
101}
102
103#[allow(clippy::single_match)]
104pub fn walk_item<P: Plugin + ?Sized>(p: &mut P, cx: &Context, _def_id: DefId, item: Arc<Item>) {
105    match &*item {
106        Item::Message(s) => s
107            .fields
108            .iter()
109            .for_each(|f| p.on_field(cx, f.did, f.clone())),
110        Item::Enum(e) => e
111            .variants
112            .iter()
113            .for_each(|v| p.on_variant(cx, v.did, v.clone())),
114        _ => {}
115    }
116}
117
118pub fn walk_field<P: Plugin + ?Sized>(
119    _p: &mut P,
120    _cx: &Context,
121    _def_id: DefId,
122    _field: Arc<Field>,
123) {}
124
125pub fn walk_variant<P: Plugin + ?Sized>(
126    _p: &mut P,
127    _cx: &Context,
128    _def_id: DefId,
129    _variant: Arc<EnumVariant>,
130) {}
131
132pub struct BoxedPlugin;
133
134impl Plugin for BoxedPlugin {
135    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
136        if let Item::Message(s) = &*item {
137            s.fields.iter().for_each(|f| {
138                if let ty::Path(p) = &f.ty.kind {
139                    if cx.type_graph().is_nested(p.did, def_id) {
140                        cx.with_adjust_mut(f.did, |adj| adj.set_boxed())
141                    }
142                }
143            })
144        }
145        walk_item(self, cx, def_id, item)
146    }
147}
148
149pub struct AutoDerivePlugin<F> {
150    can_derive: FxHashMap<DefId, CanDerive>,
151    predicate: F,
152    attrs: Arc<[FastStr]>,
153}
154
155impl<F> AutoDerivePlugin<F>
156    where
157        F: Fn(&Ty) -> PredicateResult,
158{
159    pub fn new(attrs: Arc<[FastStr]>, f: F) -> Self {
160        Self {
161            can_derive: FxHashMap::default(),
162            predicate: f,
163            attrs,
164        }
165    }
166}
167
168#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
169pub enum CanDerive {
170    Yes,
171    No,
172    Delay, // delay to next pass
173}
174
175pub enum PredicateResult {
176    No,
177    // can not derive,
178    GoOn, // can derive, but need more pass
179}
180
181#[derive(Default)]
182pub struct PathCollector {
183    paths: Vec<crate::rir::Path>,
184}
185
186impl super::ty::Visitor for PathCollector {
187    fn visit_path(&mut self, path: &crate::rir::Path) {
188        self.paths.push(path.clone())
189    }
190}
191
192impl<F> AutoDerivePlugin<F>
193    where
194        F: Fn(&Ty) -> PredicateResult,
195{
196    fn can_derive(
197        &mut self,
198        cx: &Context,
199        def_id: DefId,
200        visiting: &mut HashSet<DefId>,
201        delayed: &mut HashSet<DefId>,
202    ) -> CanDerive {
203        if let Some(b) = self.can_derive.get(&def_id) {
204            return *b;
205        }
206        if visiting.contains(&def_id) {
207            return CanDerive::Delay;
208        }
209        visiting.insert(def_id);
210        let item = cx.expect_item(def_id);
211        let deps = match &*item {
212            Item::Message(s) => s.fields.iter().map(|f| &f.ty).collect::<Vec<_>>(),
213            Item::Enum(e) => e
214                .variants
215                .iter()
216                .flat_map(|v| &v.fields)
217                .collect::<Vec<_>>(),
218            Item::Service(_) => return CanDerive::No,
219            Item::NewType(t) => vec![&t.ty],
220            Item::Const(_) => return CanDerive::No,
221            Item::Mod(_) => return CanDerive::No,
222        };
223
224        let can_derive = if deps
225            .iter()
226            .any(|t| matches!((self.predicate)(t), PredicateResult::No))
227        {
228            CanDerive::No
229        } else {
230            let paths = deps.iter().flat_map(|t| {
231                let mut visitor = PathCollector::default();
232                visitor.visit(t);
233                visitor.paths
234            });
235            let paths_can_derive = paths
236                .map(|p| (p.did, self.can_derive(cx, p.did, visiting, delayed)))
237                .collect::<Vec<_>>();
238
239            let delayed_count = paths_can_derive
240                .iter()
241                .filter(|(_, p)| *p == CanDerive::Delay)
242                .count();
243
244            if paths_can_derive.iter().any(|(_, p)| *p == CanDerive::No) {
245                delayed.iter().for_each(|def_id| {
246                    self.can_derive.insert(*def_id, CanDerive::No);
247                });
248
249                CanDerive::No
250            } else if delayed_count > 0 {
251                delayed.insert(def_id);
252                CanDerive::Delay
253            } else {
254                CanDerive::Yes
255            }
256        };
257
258        self.can_derive.insert(def_id, can_derive);
259        visiting.remove(&def_id);
260
261        can_derive
262    }
263}
264
265impl<F> Plugin for AutoDerivePlugin<F>
266    where
267        F: Fn(&Ty) -> PredicateResult + Send + Sync,
268{
269    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
270        self.can_derive(cx, def_id, &mut HashSet::default(), &mut HashSet::default());
271        walk_item(self, cx, def_id, item)
272    }
273
274    fn on_emit(&mut self, cx: &Context) {
275        self.can_derive.iter().for_each(|(def_id, can_derive)| {
276            if !matches!(can_derive, CanDerive::No) {
277                cx.with_adjust_mut(*def_id, |adj| adj.add_attrs(&self.attrs));
278            }
279        })
280    }
281}
282
283impl<T> Plugin for Box<T>
284    where
285        T: Plugin + ?Sized,
286{
287    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
288        self.deref_mut().on_item(cx, def_id, item)
289    }
290
291    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
292        self.deref_mut().on_field(cx, def_id, f)
293    }
294
295    fn on_emit(&mut self, cx: &Context) {
296        self.deref_mut().on_emit(cx)
297    }
298}
299
300pub struct WithAttrsPlugin(pub Arc<[FastStr]>);
301
302impl Plugin for WithAttrsPlugin {
303    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
304        match &*item {
305            Item::Message(_) | Item::Enum(_) | Item::NewType(_) => {
306                cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&self.0))
307            }
308            _ => {}
309        }
310        walk_item(self, cx, def_id, item)
311    }
312}
313
314pub struct ImplDefaultPlugin;
315
316impl Plugin for ImplDefaultPlugin {
317    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
318        match &*item {
319            Item::Message(m) => {
320                let name = cx.rust_name(def_id);
321
322                if m.fields.iter().all(|f| cx.default_val(f).is_none()) {
323                    cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]));
324                } else {
325                    let fields = m
326                        .fields
327                        .iter()
328                        .map(|f| {
329                            let name = cx.rust_name(f.did);
330                            let default = cx.default_val(f).map(|v| v.0);
331
332                            if let Some(default) = default {
333                                let mut val = default;
334                                if f.is_optional() {
335                                    val = format!("Some({val})").into()
336                                }
337                                format!("{name}: {val}")
338                            } else {
339                                format!("{name}: Default::default()")
340                            }
341                        })
342                        .join(",\n");
343
344                    cx.with_adjust_mut(def_id, |adj| {
345                        adj.add_nested_item(
346                            format!(
347                                r#"
348                                impl Default for {name} {{
349                                    fn default() -> Self {{
350                                        {name} {{
351                                            {fields}
352                                        }}
353                                    }}
354                                }}
355                            "#
356                            )
357                                .into(),
358                        )
359                    });
360                };
361            }
362            Item::NewType(_) => {
363                cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]))
364            }
365            Item::Enum(e) => {
366                if !e.variants.is_empty() {
367                    cx.with_adjust_mut(def_id, |adj| {
368                        adj.add_attrs(&[
369                            "#[derive(::pilota::derivative::Derivative)]".into(),
370                            "#[derivative(Default)]".into(),
371                        ]);
372                    });
373
374                    if let Some(v) = e.variants.first() {
375                        cx.with_adjust_mut(v.did, |adj| {
376                            adj.add_attrs(&["#[derivative(Default)]".into()]);
377                        })
378                    }
379                }
380            }
381            _ => {}
382        }
383        walk_item(self, cx, def_id, item)
384    }
385}
386
387pub struct EnumNumPlugin;
388
389impl Plugin for EnumNumPlugin {
390    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
391        match &*item {
392            Item::Enum(e)
393            if e.repr.is_some()
394                && cx
395                .node_tags(def_id)
396                .unwrap()
397                .get::<EnumMode>()
398                .copied()
399                .unwrap_or(EnumMode::Enum)
400                == EnumMode::Enum =>
401                {
402                    let name_str = &*cx.rust_name(def_id);
403                    let name = name_str;
404                    let num_ty = match e.repr {
405                        Some(EnumRepr::I32) => quote!(i32),
406                        None => return,
407                    };
408                    let variants = e
409                        .variants
410                        .iter()
411                        .map(|v| {
412                            let variant_name_str = cx.rust_name(v.did);
413                            let variant_name = variant_name_str;
414                            format!(
415                                "{variant_name} => ::std::result::Result::Ok({name}::{variant_name}), \n"
416                            )
417                        })
418                        .join("");
419
420                    let nums = e
421                        .variants
422                        .iter()
423                        .map(|v| {
424                            let variant_name_str = cx.rust_name(v.did);
425                            let variant_name = variant_name_str;
426                            format!(
427                                "const {variant_name}: {num_ty} = {name}::{variant_name} as {num_ty};"
428                            )
429                        })
430                        .join("\n");
431
432                    cx.with_adjust_mut(def_id, |adj| {
433                        adj.add_nested_item(format!(r#"
434                        impl ::std::convert::From<{name}> for {num_ty} {{
435                            fn from(e: {name}) -> Self {{
436                                e as _
437                            }}
438                        }}
439
440                        impl ::std::convert::TryFrom<{num_ty}> for {name} {{
441                            type Error = ::pilota::EnumConvertError<{num_ty}>;
442
443                            #[allow(non_upper_case_globals)]
444                            fn try_from(v: i32) -> ::std::result::Result<Self, ::pilota::EnumConvertError<{num_ty}>> {{
445                                {nums}
446                                match v {{
447                                    {variants}
448                                    _ => ::std::result::Result::Err(::pilota::EnumConvertError::InvalidNum(v, "{name_str}")),
449                                }}
450                            }}
451                        }}"#).into(),
452                        )
453                    });
454                }
455            _ => {}
456        }
457        walk_item(self, cx, def_id, item)
458    }
459}