1use std::{collections::HashSet, ops::DerefMut, sync::Arc};
2
3use faststr::FastStr;
4use itertools::Itertools;
5use rustc_hash::FxHashMap;
6
7use crate::{
8 Context,
9 db::RirDatabase,
10 middle::context::tls::CUR_ITEM,
11 rir::{EnumVariant, Field, Item, NodeKind},
12 symbol::DefId,
13 ty::{self, Ty, Visitor},
14};
15
16mod serde;
17mod workspace;
18
19pub use self::serde::SerdePlugin;
20
21pub trait Plugin: Sync + Send {
22 fn on_codegen_uint(&mut self, cx: &Context, items: &[DefId]) {
23 walk_codegen_uint(self, cx, items)
24 }
25
26 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
27 walk_item(self, cx, def_id, item)
28 }
29
30 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
31 walk_field(self, cx, def_id, f)
32 }
33
34 fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
35 walk_variant(self, cx, def_id, variant)
36 }
37
38 fn on_emit(&mut self, _cx: &Context) {}
39}
40
41pub trait ClonePlugin: Plugin {
42 fn clone_box(&self) -> Box<dyn ClonePlugin>;
43}
44
45pub struct BoxClonePlugin(Box<dyn ClonePlugin>);
46
47impl BoxClonePlugin {
48 pub fn new<P: ClonePlugin + 'static>(p: P) -> Self {
49 Self(Box::new(p))
50 }
51}
52
53impl Clone for BoxClonePlugin {
54 fn clone(&self) -> Self {
55 Self(self.0.clone_box())
56 }
57}
58
59impl Plugin for BoxClonePlugin {
60 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
61 self.0.on_item(cx, def_id, item)
62 }
63
64 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
65 self.0.on_field(cx, def_id, f)
66 }
67
68 fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
69 self.0.on_variant(cx, def_id, variant)
70 }
71
72 fn on_emit(&mut self, cx: &Context) {
73 self.0.on_emit(cx)
74 }
75}
76
77impl<T> ClonePlugin for T
78where
79 T: Plugin + Clone + 'static,
80{
81 fn clone_box(&self) -> Box<dyn ClonePlugin> {
82 Box::new(self.clone())
83 }
84}
85
86impl<T> Plugin for &mut T
87where
88 T: Plugin,
89{
90 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
91 (*self).on_item(cx, def_id, item)
92 }
93
94 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
95 (*self).on_field(cx, def_id, f)
96 }
97
98 fn on_variant(&mut self, cx: &Context, def_id: DefId, variant: Arc<EnumVariant>) {
99 (*self).on_variant(cx, def_id, variant)
100 }
101
102 fn on_emit(&mut self, cx: &Context) {
103 (*self).on_emit(cx)
104 }
105}
106
107#[allow(clippy::single_match)]
108pub fn walk_item<P: Plugin + ?Sized>(p: &mut P, cx: &Context, _def_id: DefId, item: Arc<Item>) {
109 match &*item {
110 Item::Message(s) => s
111 .fields
112 .iter()
113 .for_each(|f| p.on_field(cx, f.did, f.clone())),
114 Item::Enum(e) => e
115 .variants
116 .iter()
117 .for_each(|v| p.on_variant(cx, v.did, v.clone())),
118 _ => {}
119 }
120}
121
122pub fn walk_codegen_uint<P: Plugin + ?Sized>(p: &mut P, cx: &Context, items: &[DefId]) {
123 items.iter().for_each(|def_id| {
124 CUR_ITEM.set(def_id, || {
125 let node = cx.node(*def_id).unwrap();
126 if let NodeKind::Item(item) = &node.kind {
127 p.on_item(cx, *def_id, item.clone())
128 }
129 });
130 });
131}
132
133pub fn walk_field<P: Plugin + ?Sized>(
134 _p: &mut P,
135 _cx: &Context,
136 _def_id: DefId,
137 _field: Arc<Field>,
138) {
139}
140
141pub fn walk_variant<P: Plugin + ?Sized>(
142 _p: &mut P,
143 _cx: &Context,
144 _def_id: DefId,
145 _variant: Arc<EnumVariant>,
146) {
147}
148
149pub struct BoxedPlugin;
150
151impl Plugin for BoxedPlugin {
152 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
153 if let Item::Message(s) = &*item {
154 s.fields.iter().for_each(|f| {
155 if let ty::Path(p) = &f.ty.kind {
156 if cx.type_graph().is_nested(p.did, def_id) {
157 cx.with_adjust_mut(f.did, |adj| adj.set_boxed())
158 }
159 }
160 })
161 }
162 walk_item(self, cx, def_id, item)
163 }
164}
165
166pub struct AutoDerivePlugin<F> {
167 can_derive: FxHashMap<DefId, CanDerive>,
168 predicate: F,
169 attrs: Arc<[FastStr]>,
170}
171
172impl<F> AutoDerivePlugin<F>
173where
174 F: Fn(&Ty) -> PredicateResult,
175{
176 pub fn new(attrs: Arc<[FastStr]>, f: F) -> Self {
177 Self {
178 can_derive: FxHashMap::default(),
179 predicate: f,
180 attrs,
181 }
182 }
183}
184
185#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
186pub enum CanDerive {
187 Yes,
188 No,
189 Delay, }
191
192pub enum PredicateResult {
193 No, GoOn, }
196
197#[derive(Default)]
198pub struct PathCollector {
199 paths: Vec<crate::rir::Path>,
200}
201
202impl super::ty::Visitor for PathCollector {
203 fn visit_path(&mut self, path: &crate::rir::Path) {
204 self.paths.push(path.clone())
205 }
206}
207
208impl<F> AutoDerivePlugin<F>
209where
210 F: Fn(&Ty) -> PredicateResult,
211{
212 fn can_derive(
213 &mut self,
214 cx: &Context,
215 def_id: DefId,
216 visiting: &mut HashSet<DefId>,
217 delayed: &mut HashSet<DefId>,
218 ) -> CanDerive {
219 if let Some(b) = self.can_derive.get(&def_id) {
220 return *b;
221 }
222 if visiting.contains(&def_id) {
223 return CanDerive::Delay;
224 }
225 visiting.insert(def_id);
226 let item = cx.expect_item(def_id);
227 let deps = match &*item {
228 Item::Message(s) => s.fields.iter().map(|f| &f.ty).collect::<Vec<_>>(),
229 Item::Enum(e) => e
230 .variants
231 .iter()
232 .flat_map(|v| &v.fields)
233 .collect::<Vec<_>>(),
234 Item::Service(_) => return CanDerive::No,
235 Item::NewType(t) => vec![&t.ty],
236 Item::Const(_) => return CanDerive::No,
237 Item::Mod(_) => return CanDerive::No,
238 };
239
240 let can_derive = if deps
241 .iter()
242 .any(|t| matches!((self.predicate)(t), PredicateResult::No))
243 {
244 CanDerive::No
245 } else {
246 let paths = deps.iter().flat_map(|t| {
247 let mut visitor = PathCollector::default();
248 visitor.visit(t);
249 visitor.paths
250 });
251 let paths_can_derive = paths
252 .map(|p| (p.did, self.can_derive(cx, p.did, visiting, delayed)))
253 .collect::<Vec<_>>();
254
255 let delayed_count = paths_can_derive
256 .iter()
257 .filter(|(_, p)| *p == CanDerive::Delay)
258 .count();
259
260 if paths_can_derive.iter().any(|(_, p)| *p == CanDerive::No) {
261 delayed.iter().for_each(|delayed_def_id| {
262 if cx.workspace_graph().is_nested(*delayed_def_id, def_id) {
263 self.can_derive.insert(*delayed_def_id, CanDerive::No);
264 }
265 });
266 CanDerive::No
267 } else if delayed_count > 0 {
268 delayed.insert(def_id);
269 CanDerive::Delay
270 } else {
271 CanDerive::Yes
272 }
273 };
274
275 self.can_derive.insert(def_id, can_derive);
276 visiting.remove(&def_id);
277
278 can_derive
279 }
280}
281
282impl<F> Plugin for AutoDerivePlugin<F>
283where
284 F: Fn(&Ty) -> PredicateResult + Send + Sync,
285{
286 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
287 self.can_derive(cx, def_id, &mut HashSet::default(), &mut HashSet::default());
288 walk_item(self, cx, def_id, item)
289 }
290
291 fn on_emit(&mut self, cx: &Context) {
292 self.can_derive.iter().for_each(|(def_id, can_derive)| {
293 if !matches!(can_derive, CanDerive::No) {
294 cx.with_adjust_mut(*def_id, |adj| adj.add_attrs(&self.attrs));
295 }
296 })
297 }
298}
299
300impl<T> Plugin for Box<T>
301where
302 T: Plugin + ?Sized,
303{
304 fn on_codegen_uint(&mut self, cx: &Context, items: &[DefId]) {
305 self.deref_mut().on_codegen_uint(cx, items)
306 }
307
308 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
309 self.deref_mut().on_item(cx, def_id, item)
310 }
311
312 fn on_field(&mut self, cx: &Context, def_id: DefId, f: Arc<Field>) {
313 self.deref_mut().on_field(cx, def_id, f)
314 }
315
316 fn on_emit(&mut self, cx: &Context) {
317 self.deref_mut().on_emit(cx)
318 }
319}
320
321pub struct WithAttrsPlugin(pub Arc<[FastStr]>);
322
323impl Plugin for WithAttrsPlugin {
324 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
325 match &*item {
326 Item::Message(_) | Item::Enum(_) | Item::NewType(_) => {
327 cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&self.0))
328 }
329 _ => {}
330 }
331 walk_item(self, cx, def_id, item)
332 }
333}
334
335pub struct ImplDefaultPlugin;
336
337impl Plugin for ImplDefaultPlugin {
338 fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
339 match &*item {
340 Item::Message(m) => {
341 let name = cx.rust_name(def_id);
342
343 if m.fields.iter().all(|f| cx.default_val(f).is_none()) {
344 cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]));
345 } else {
346 #[allow(unused_mut)]
347 let mut fields = m
348 .fields
349 .iter()
350 .map(|f| {
351 let name = cx.rust_name(f.did);
352 let default = cx.default_val(f).map(|v| v.0);
353 match default {
354 Some(default) => {
355 let mut val = default;
356 if f.is_optional() {
357 val = format!("Some({val})").into()
358 }
359 format!("{name}: {val}")
360 }
361 _ => {
362 format!("{name}: ::std::default::Default::default()")
363 }
364 }
365 })
366 .join(",\n");
367
368 if cx.cache.keep_unknown_fields.contains(&def_id) {
369 if !fields.is_empty() {
370 fields.push_str(",\n");
371 }
372 fields.push_str("_unknown_fields: ::pilota::BytesVec::new()");
373 }
374
375 if !m.is_wrapper && cx.config.with_field_mask {
376 if !fields.is_empty() {
377 fields.push_str(",\n");
378 }
379 fields.push_str("_field_mask: ::std::option::Option::None");
380 }
381
382 cx.with_adjust_mut(def_id, |adj| {
383 adj.add_nested_item(
384 format!(
385 r#"
386 impl ::std::default::Default for {name} {{
387 fn default() -> Self {{
388 {name} {{
389 {fields}
390 }}
391 }}
392 }}
393 "#
394 )
395 .into(),
396 )
397 });
398 };
399 }
400 Item::NewType(_) => {
401 cx.with_adjust_mut(def_id, |adj| adj.add_attrs(&["#[derive(Default)]".into()]))
402 }
403 Item::Enum(e) => {
404 if let Some(first_variant) = e.variants.first() {
405 let is_unit_variant = first_variant.fields.is_empty();
406 if is_unit_variant {
407 cx.with_adjust_mut(def_id, |adj| {
408 adj.add_attrs(&["#[derive(Default)]".into()]);
409 });
410
411 if let Some(v) = e.variants.first() {
412 cx.with_adjust_mut(v.did, |adj| {
413 adj.add_attrs(&["#[default]".into()]);
414 })
415 }
416 } else {
417 let enum_name = cx.rust_name(def_id);
419 let variant_name = cx.rust_name(first_variant.did);
420 let fields = first_variant
421 .fields
422 .iter()
423 .map(|_| "::std::default::Default::default()".to_string())
424 .join(",\n");
425
426 cx.with_adjust_mut(def_id, |adj| {
427 adj.add_nested_item(
428 format!(
429 r#"
430 impl ::std::default::Default for {enum_name} {{
431 fn default() -> Self {{
432 {enum_name}::{variant_name} ({fields})
433 }}
434 }}
435 "#
436 )
437 .into(),
438 )
439 });
440 }
441 }
442 }
443 _ => {}
444 }
445 walk_item(self, cx, def_id, item)
446 }
447}