Skip to main content

pilota_build/plugin/
mod.rs

1use std::{collections::HashSet, ops::DerefMut, sync::Arc};
2
3use faststr::FastStr;
4use itertools::Itertools;
5use rustc_hash::FxHashMap;
6
7use crate::{
8    Context,
9    db::RirDatabase,
10    middle::context::tls::CUR_ITEM,
11    rir::{EnumVariant, Field, Item, NodeKind},
12    symbol::DefId,
13    ty::{self, Ty, Visitor},
14};
15
16mod serde;
17mod workspace;
18
19pub use self::serde::SerdePlugin;
20
21pub trait Plugin: Sync + Send {
22    fn on_codegen_uint(&mut self, cx: &Context, items: &[DefId]) {
23        walk_codegen_uint(self, cx, items)
24    }
25
26    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
27        walk_item(self, cx, def_id, item)
28    }
29
30    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
31        walk_field(self, cx, def_id, f)
32    }
33
34    fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
35        walk_variant(self, cx, def_id, variant)
36    }
37
38    fn on_emit(&mut self, _cx: &Context) {}
39}
40
41pub trait ClonePlugin: Plugin {
42    fn clone_box(&self) -> Box<dyn ClonePlugin>;
43}
44
45pub struct BoxClonePlugin(Box<dyn ClonePlugin>);
46
47impl BoxClonePlugin {
48    pub fn new<P: ClonePlugin + 'static>(p: P) -> Self {
49        Self(Box::new(p))
50    }
51}
52
53impl Clone for BoxClonePlugin {
54    fn clone(&self) -> Self {
55        Self(self.0.clone_box())
56    }
57}
58
59impl Plugin for BoxClonePlugin {
60    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
61        self.0.on_item(cx, def_id, item)
62    }
63
64    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
65        self.0.on_field(cx, def_id, f)
66    }
67
68    fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
69        self.0.on_variant(cx, def_id, variant)
70    }
71
72    fn on_emit(&mut self, cx: &Context) {
73        self.0.on_emit(cx)
74    }
75}
76
77impl<T> ClonePlugin for T
78where
79    T: Plugin + Clone + 'static,
80{
81    fn clone_box(&self) -> Box<dyn ClonePlugin> {
82        Box::new(self.clone())
83    }
84}
85
86impl<T> Plugin for &mut T
87where
88    T: Plugin,
89{
90    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
91        (*self).on_item(cx, def_id, item)
92    }
93
94    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
95        (*self).on_field(cx, def_id, f)
96    }
97
98    fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
99        (*self).on_variant(cx, def_id, variant)
100    }
101
102    fn on_emit(&mut self, cx: &Context) {
103        (*self).on_emit(cx)
104    }
105}
106
107#[allow(clippy::single_match)]
108pub fn walk_item<P: Plugin + ?Sized>(p: &mut P, cx: &Context, _def_id: DefId, item: Arc<Item>) {
109    match &*item {
110        Item::Message(s) => s
111            .fields
112            .iter()
113            .for_each(|f| p.on_field(cx, f.did, f.clone())),
114        Item::Enum(e) => e
115            .variants
116            .iter()
117            .for_each(|v| p.on_variant(cx, v.did, v.clone())),
118        _ => {}
119    }
120}
121
122pub fn walk_codegen_uint<P: Plugin + ?Sized>(p: &mut P, cx: &Context, items: &[DefId]) {
123    items.iter().for_each(|def_id| {
124        CUR_ITEM.set(def_id, || {
125            let node = cx.node(*def_id).unwrap();
126            if let NodeKind::Item(item) = &node.kind {
127                p.on_item(cx, *def_id, item.clone())
128            }
129        });
130    });
131}
132
133pub fn walk_field<P: Plugin + ?Sized>(
134    _p: &mut P,
135    _cx: &Context,
136    _def_id: DefId,
137    _field: Arc<Field>,
138) {
139}
140
141pub fn walk_variant<P: Plugin + ?Sized>(
142    _p: &mut P,
143    _cx: &Context,
144    _def_id: DefId,
145    _variant: Arc<EnumVariant>,
146) {
147}
148
149pub struct BoxedPlugin;
150
151impl Plugin for BoxedPlugin {
152    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
153        if let Item::Message(s) = &*item {
154            s.fields.iter().for_each(|f| {
155                if let ty::Path(p) = &f.ty.kind {
156                    if cx.type_graph().is_nested(p.did, def_id) {
157                        cx.with_adjust_mut(f.did, |adj| adj.set_boxed())
158                    }
159                }
160            })
161        }
162        walk_item(self, cx, def_id, item)
163    }
164}
165
166pub struct AutoDerivePlugin<F> {
167    can_derive: FxHashMap<DefId, CanDerive>,
168    predicate: F,
169    attrs: Arc<[FastStr]>,
170}
171
172impl<F> AutoDerivePlugin<F>
173where
174    F: Fn(&Ty) -> PredicateResult,
175{
176    pub fn new(attrs: Arc<[FastStr]>, f: F) -> Self {
177        Self {
178            can_derive: FxHashMap::default(),
179            predicate: f,
180            attrs,
181        }
182    }
183}
184
185#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
186pub enum CanDerive {
187    Yes,
188    No,
189    Delay, // delay to next pass
190}
191
192pub enum PredicateResult {
193    No,   // can not derive,
194    GoOn, // can derive, but need more pass
195}
196
197#[derive(Default)]
198pub struct PathCollector {
199    paths: Vec<crate::rir::Path>,
200}
201
202impl super::ty::Visitor for PathCollector {
203    fn visit_path(&mut self, path: &crate::rir::Path) {
204        self.paths.push(path.clone())
205    }
206}
207
208impl<F> AutoDerivePlugin<F>
209where
210    F: Fn(&Ty) -> PredicateResult,
211{
212    fn can_derive(
213        &mut self,
214        cx: &Context,
215        def_id: DefId,
216        visiting: &mut HashSet<DefId>,
217        delayed: &mut HashSet<DefId>,
218    ) -> CanDerive {
219        if let Some(b) = self.can_derive.get(&def_id) {
220            return *b;
221        }
222        if visiting.contains(&def_id) {
223            return CanDerive::Delay;
224        }
225        visiting.insert(def_id);
226        let item = cx.expect_item(def_id);
227        let deps = match &*item {
228            Item::Message(s) => s.fields.iter().map(|f| &f.ty).collect::<Vec<_>>(),
229            Item::Enum(e) => e
230                .variants
231                .iter()
232                .flat_map(|v| &v.fields)
233                .collect::<Vec<_>>(),
234            Item::Service(_) => return CanDerive::No,
235            Item::NewType(t) => vec![&t.ty],
236            Item::Const(_) => return CanDerive::No,
237            Item::Mod(_) => return CanDerive::No,
238        };
239
240        let can_derive = if deps
241            .iter()
242            .any(|t| matches!((self.predicate)(t), PredicateResult::No))
243        {
244            CanDerive::No
245        } else {
246            let paths = deps.iter().flat_map(|t| {
247                let mut visitor = PathCollector::default();
248                visitor.visit(t);
249                visitor.paths
250            });
251            let paths_can_derive = paths
252                .map(|p| (p.did, self.can_derive(cx, p.did, visiting, delayed)))
253                .collect::<Vec<_>>();
254
255            let delayed_count = paths_can_derive
256                .iter()
257                .filter(|(_, p)| *p == CanDerive::Delay)
258                .count();
259
260            if paths_can_derive.iter().any(|(_, p)| *p == CanDerive::No) {
261                delayed.iter().for_each(|delayed_def_id| {
262                    if cx.workspace_graph().is_nested(*delayed_def_id, def_id) {
263                        self.can_derive.insert(*delayed_def_id, CanDerive::No);
264                    }
265                });
266                CanDerive::No
267            } else if delayed_count > 0 {
268                delayed.insert(def_id);
269                CanDerive::Delay
270            } else {
271                CanDerive::Yes
272            }
273        };
274
275        self.can_derive.insert(def_id, can_derive);
276        visiting.remove(&def_id);
277
278        can_derive
279    }
280}
281
282impl<F> Plugin for AutoDerivePlugin<F>
283where
284    F: Fn(&Ty) -> PredicateResult + Send + Sync,
285{
286    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
287        self.can_derive(cx, def_id, &mut HashSet::default(), &mut HashSet::default());
288        walk_item(self, cx, def_id, item)
289    }
290
291    fn on_emit(&mut self, cx: &Context) {
292        self.can_derive.iter().for_each(|(def_id, can_derive)| {
293            if !matches!(can_derive, CanDerive::No) {
294                cx.with_adjust_mut(*def_id, |adj| adj.add_attrs(&self.attrs));
295            }
296        })
297    }
298}
299
300impl<T> Plugin for Box<T>
301where
302    T: Plugin + ?Sized,
303{
304    fn on_codegen_uint(&mut self, cx: &Context, items: &[DefId]) {
305        self.deref_mut().on_codegen_uint(cx, items)
306    }
307
308    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
309        self.deref_mut().on_item(cx, def_id, item)
310    }
311
312    fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
313        self.deref_mut().on_field(cx, def_id, f)
314    }
315
316    fn on_emit(&mut self, cx: &Context) {
317        self.deref_mut().on_emit(cx)
318    }
319}
320
321pub struct WithAttrsPlugin(pub Arc<[FastStr]>);
322
323impl Plugin for WithAttrsPlugin {
324    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
325        match &*item {
326            Item::Message(_) | Item::Enum(_) | Item::NewType(_) => {
327                cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&self.0))
328            }
329            _ => {}
330        }
331        walk_item(self, cx, def_id, item)
332    }
333}
334
335pub struct ImplDefaultPlugin;
336
337impl Plugin for ImplDefaultPlugin {
338    fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
339        match &*item {
340            Item::Message(m) => {
341                let name = cx.rust_name(def_id);
342
343                if m.fields.iter().all(|f| cx.default_val(f).is_none()) {
344                    cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]));
345                } else {
346                    #[allow(unused_mut)]
347                    let mut fields = m
348                        .fields
349                        .iter()
350                        .map(|f| {
351                            let name = cx.rust_name(f.did);
352                            let default = cx.default_val(f).map(|v| v.0);
353                            match default {
354                                Some(default) => {
355                                    let mut val = default;
356                                    if f.is_optional() {
357                                        val = format!("Some({val})").into()
358                                    }
359                                    format!("{name}: {val}")
360                                }
361                                _ => {
362                                    format!("{name}: ::std::default::Default::default()")
363                                }
364                            }
365                        })
366                        .join(",\n");
367
368                    if cx.cache.keep_unknown_fields.contains(&def_id) {
369                        if !fields.is_empty() {
370                            fields.push_str(",\n");
371                        }
372                        fields.push_str("_unknown_fields: ::pilota::BytesVec::new()");
373                    }
374
375                    if !m.is_wrapper && cx.config.with_field_mask {
376                        if !fields.is_empty() {
377                            fields.push_str(",\n");
378                        }
379                        fields.push_str("_field_mask: ::std::option::Option::None");
380                    }
381
382                    cx.with_adjust_mut(def_id, |adj| {
383                        adj.add_nested_item(
384                            format!(
385                                r#"
386                                impl ::std::default::Default for {name} {{
387                                    fn default() -> Self {{
388                                        {name} {{
389                                            {fields}
390                                        }}
391                                    }}
392                                }}
393                            "#
394                            )
395                            .into(),
396                        )
397                    });
398                };
399            }
400            Item::NewType(_) => {
401                cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]))
402            }
403            Item::Enum(e) => {
404                if let Some(first_variant) = e.variants.first() {
405                    let is_unit_variant = first_variant.fields.is_empty();
406                    if is_unit_variant {
407                        cx.with_adjust_mut(def_id, |adj| {
408                            adj.add_attrs(&["#[derive(Default)]".into()]);
409                        });
410
411                        if let Some(v) = e.variants.first() {
412                            cx.with_adjust_mut(v.did, |adj| {
413                                adj.add_attrs(&["#[default]".into()]);
414                            })
415                        }
416                    } else {
417                        // for non unit variant, we need to impl Default for the enum
418                        let enum_name = cx.rust_name(def_id);
419                        let variant_name = cx.rust_name(first_variant.did);
420                        let fields = first_variant
421                            .fields
422                            .iter()
423                            .map(|_| "::std::default::Default::default()".to_string())
424                            .join(",\n");
425
426                        cx.with_adjust_mut(def_id, |adj| {
427                            adj.add_nested_item(
428                                format!(
429                                    r#"
430                                    impl ::std::default::Default for {enum_name} {{
431                                        fn default() -> Self {{
432                                            {enum_name}::{variant_name} ({fields})
433                                        }}
434                                    }}
435                                "#
436                                )
437                                .into(),
438                            )
439                        });
440                    }
441                }
442            }
443            _ => {}
444        }
445        walk_item(self, cx, def_id, item)
446    }
447}