Skip to main content

pilota_build/
resolve.rs

1use std::{ptr::NonNull, sync::Arc};
2
3use ahash::AHashMap;
4use itertools::Itertools;
5use rustc_hash::{FxHashMap, FxHashSet};
6
7use crate::{
8    errors,
9    index::Idx,
10    ir::{self, visit::Visitor},
11    middle::{
12        self,
13        ext::{
14            FileExts, ItemExts, ModExts,
15            pb::{self, Extendee, ExtendeeIndex, ExtendeeType, Extendees, UsedOptions},
16        },
17        rir::{
18            Arg, Const, DefKind, Enum, EnumVariant, Field, FieldKind, File, Item, ItemPath,
19            Literal, Message, Method, MethodSource, NewType, Node, NodeKind, Path, Service,
20        },
21        ty::{self, Ty},
22    },
23    rir::Mod,
24    symbol::{DefId, EnumRepr, FileId, Ident, Symbol},
25    tags::{RustType, RustWrapperArc, TagId, Tags, protobuf::OptionalRepeated},
26    ty::{Folder, TyKind},
27};
28
29struct ModuleData {
30    resolutions: SymbolTable,
31    _kind: DefKind,
32}
33
34#[derive(Clone, Copy)]
35enum ModuleId {
36    File(FileId),
37    Node(DefId),
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41enum Namespace {
42    Value,
43    Ty,
44    Mod,
45}
46
47pub struct CollectDef<'a> {
48    resolver: &'a mut Resolver,
49    parent: Option<ModuleId>,
50}
51
52impl<'a> CollectDef<'a> {
53    pub fn new(resolver: &'a mut Resolver) -> Self {
54        CollectDef {
55            resolver,
56            parent: None,
57        }
58    }
59}
60
61impl CollectDef<'_> {
62    fn def_item(&mut self, item: &ir::Item, ns: Namespace) -> DefId {
63        let parent = self.parent.as_ref().unwrap();
64        let did = self.resolver.did_counter.inc_one();
65        let table = match parent {
66            ModuleId::File(file_id) => self.resolver.file_sym_map.entry(*file_id).or_default(),
67            ModuleId::Node(def_id) => {
68                &mut self
69                    .resolver
70                    .def_modules
71                    .get_mut(def_id)
72                    .unwrap()
73                    .resolutions
74            }
75        };
76
77        let name = item.name();
78
79        tracing::debug!("def {} with DefId({:?})", name, did);
80
81        if match ns {
82            Namespace::Value => table.value.insert(name.clone(), did),
83            Namespace::Ty => table.ty.insert(name.clone(), did),
84            Namespace::Mod => table.mods.insert(name.clone(), did),
85        }
86        .is_some()
87        {
88            self.resolver
89                .errors
90                .emit_error(format!("duplicate definition of `{name}`"));
91        };
92
93        self.resolver.def_modules.insert(
94            did,
95            ModuleData {
96                resolutions: Default::default(),
97                _kind: match &item.kind {
98                    ir::ItemKind::Message(_)
99                    | ir::ItemKind::Enum(_)
100                    | ir::ItemKind::Service(_)
101                    | ir::ItemKind::NewType(_) => DefKind::Type,
102                    ir::ItemKind::Const(_) => DefKind::Value,
103                    ir::ItemKind::Mod(_) => DefKind::Mod,
104                    ir::ItemKind::Use(_) => unreachable!(),
105                },
106            },
107        );
108
109        did
110    }
111
112    fn def_sym(&mut self, ns: Namespace, sym: Symbol) {
113        let parent = match self.parent.unwrap() {
114            ModuleId::File(_) => panic!(),
115            ModuleId::Node(def_id) => def_id,
116        };
117
118        tracing::debug!("def {} for {:?} in {:?}", sym, parent, ns);
119
120        let table = match ns {
121            Namespace::Value => {
122                &mut self
123                    .resolver
124                    .def_modules
125                    .get_mut(&parent)
126                    .unwrap()
127                    .resolutions
128                    .value
129            }
130            Namespace::Ty => {
131                &mut self
132                    .resolver
133                    .def_modules
134                    .get_mut(&parent)
135                    .unwrap()
136                    .resolutions
137                    .ty
138            }
139            Namespace::Mod => {
140                &mut self
141                    .resolver
142                    .def_modules
143                    .get_mut(&parent)
144                    .unwrap()
145                    .resolutions
146                    .mods
147            }
148        };
149        let def_id = self.resolver.did_counter.inc_one();
150        table.insert(sym, def_id);
151    }
152}
153
154impl ir::visit::Visitor for CollectDef<'_> {
155    fn visit_file(&mut self, file: Arc<ir::File>) {
156        self.parent = Some(ModuleId::File(file.id));
157        ir::visit::walk_file(self, file);
158        self.parent = None;
159    }
160
161    fn visit_item(&mut self, item: Arc<ir::Item>) {
162        if let Some(did) = match &item.kind {
163            ir::ItemKind::Message(_)
164            | ir::ItemKind::Enum(_)
165            | ir::ItemKind::Service(_)
166            | ir::ItemKind::NewType(_) => Some(self.def_item(&item, Namespace::Ty)),
167            ir::ItemKind::Const(_) => Some(self.def_item(&item, Namespace::Value)),
168            ir::ItemKind::Mod(_) => Some(self.def_item(&item, Namespace::Mod)),
169            ir::ItemKind::Use(_) => None,
170        } {
171            let prev_parent = self.parent.replace(ModuleId::Node(did));
172            if let ir::ItemKind::Enum(e) = &item.kind {
173                e.variants.iter().for_each(|e| {
174                    self.def_sym(Namespace::Value, (*e.name).clone());
175                })
176            }
177            ir::visit::walk_item(self, item);
178            self.parent = prev_parent;
179        }
180    }
181}
182
183#[derive(Default, Debug)]
184pub struct SymbolTable {
185    pub(crate) value: AHashMap<Symbol, DefId>,
186    pub(crate) ty: AHashMap<Symbol, DefId>,
187    pub(crate) mods: AHashMap<Symbol, DefId>,
188}
189
190pub struct Resolver {
191    pub(crate) did_counter: DefId,
192    pub(crate) file_sym_map: FxHashMap<FileId, SymbolTable>,
193    def_modules: FxHashMap<DefId, ModuleData>,
194    blocks: Vec<NonNull<SymbolTable>>,
195    parent_node: Option<DefId>,
196    nodes: FxHashMap<DefId, Node>,
197    tags_id_counter: TagId,
198    tags: FxHashMap<TagId, Arc<Tags>>,
199    cur_file: Option<FileId>,
200    ir_files: FxHashMap<FileId, Arc<ir::File>>,
201    errors: errors::Handler,
202    args: FxHashSet<DefId>,
203    pb_ext_indexes: FxHashMap<ExtendeeIndex, Arc<Extendee>>, /* for collecting pb options
204                                                              * references */
205    pb_ext_indexes_used: FxHashSet<ExtendeeIndex>,
206}
207
208impl Default for Resolver {
209    fn default() -> Self {
210        Resolver {
211            tags_id_counter: TagId::from_usize(0),
212            tags: Default::default(),
213            blocks: Default::default(),
214            def_modules: Default::default(),
215            did_counter: DefId::from_usize(0),
216            file_sym_map: Default::default(),
217            nodes: Default::default(),
218            ir_files: Default::default(),
219            errors: Default::default(),
220            cur_file: None,
221            parent_node: None,
222            args: Default::default(),
223            pb_ext_indexes: Default::default(),
224            pb_ext_indexes_used: Default::default(),
225        }
226    }
227}
228
229pub struct ResolveResult {
230    pub files: FxHashMap<FileId, Arc<File>>,
231    pub nodes: FxHashMap<DefId, Node>,
232    pub tags: FxHashMap<TagId, Arc<Tags>>,
233    pub args: FxHashSet<DefId>,
234    pub pb_ext_indexes: FxHashMap<ExtendeeIndex, Arc<Extendee>>,
235    pub pb_ext_indexes_used: FxHashSet<ExtendeeIndex>,
236}
237
238pub struct ResolvedSymbols {
239    ty: Vec<DefId>,
240    value: Vec<DefId>,
241    r#mod: Vec<DefId>,
242}
243
244impl Resolver {
245    fn get_def_id(&self, ns: Namespace, sym: &Symbol) -> DefId {
246        if let Some(parent) = self.parent_node {
247            *match ns {
248                Namespace::Value => self.def_modules[&parent].resolutions.value.get(sym),
249                Namespace::Ty => self.def_modules[&parent].resolutions.ty.get(sym),
250                Namespace::Mod => self.def_modules[&parent].resolutions.mods.get(sym),
251            }
252            .unwrap()
253        } else {
254            let cur_file = &self.file_sym_map[&self.cur_file.unwrap()];
255            *match ns {
256                Namespace::Value => cur_file.value.get(sym),
257                Namespace::Ty => cur_file.ty.get(sym),
258                Namespace::Mod => cur_file.mods.get(sym),
259            }
260            .unwrap()
261        }
262    }
263
264    pub fn resolve_files(mut self, files: &[Arc<ir::File>]) -> ResolveResult {
265        files.iter().for_each(|f| {
266            let mut collect = CollectDef::new(&mut self);
267            collect.visit_file(f.clone());
268            self.ir_files.insert(f.id, f.clone());
269        });
270
271        self.errors.abort_if_errors();
272
273        let files = files
274            .iter()
275            .map(|f| (f.id, Arc::from(self.lower_file(f))))
276            .collect::<FxHashMap<_, _>>();
277
278        self.errors.abort_if_errors();
279
280        ResolveResult {
281            tags: self.tags,
282            files,
283            nodes: self.nodes,
284            args: self.args,
285            pb_ext_indexes: self.pb_ext_indexes,
286            pb_ext_indexes_used: self.pb_ext_indexes_used,
287        }
288    }
289
290    fn modify_ty_by_tags(&mut self, mut ty: Ty, tags: &Tags) -> Ty {
291        match ty.kind {
292            ty::FastStr
293                if tags
294                    .get::<RustType>()
295                    .map(|repr| repr == "string")
296                    .unwrap_or(false) =>
297            {
298                ty.kind = ty::String;
299            }
300            ty::Bytes
301                if tags
302                    .get::<RustType>()
303                    .map(|repr| repr == "vec")
304                    .unwrap_or(false) =>
305            {
306                ty.kind = ty::BytesVec;
307            }
308            _ => {}
309        }
310
311        if let Some(repr) = tags.get::<RustType>() {
312            if repr == "btree" {
313                struct BTreeFolder;
314                impl Folder for BTreeFolder {
315                    fn fold_ty(&mut self, ty: &Ty) -> Ty {
316                        let kind = match &ty.kind {
317                            TyKind::Vec(inner) => {
318                                TyKind::Vec(Arc::new(self.fold_ty(inner.as_ref())))
319                            }
320                            TyKind::Set(inner) => {
321                                TyKind::BTreeSet(Arc::new(self.fold_ty(inner.as_ref())))
322                            }
323                            TyKind::Map(k, v) => TyKind::BTreeMap(
324                                Arc::new(self.fold_ty(k.as_ref())),
325                                Arc::new(self.fold_ty(v.as_ref())),
326                            ),
327                            kind => kind.clone(),
328                        };
329                        Ty {
330                            kind,
331                            tags_id: ty.tags_id,
332                        }
333                    }
334                }
335                ty = BTreeFolder.fold_ty(&ty);
336            } else if repr == "ordered_f64" {
337                ty.kind = ty::OrderedF64;
338            }
339        };
340
341        if let Some(RustWrapperArc(true)) = tags.get::<RustWrapperArc>() {
342            struct ArcFolder;
343            impl Folder for ArcFolder {
344                fn fold_ty(&mut self, ty: &Ty) -> Ty {
345                    let kind = match &ty.kind {
346                        TyKind::Vec(inner) => TyKind::Vec(Arc::new(self.fold_ty(inner.as_ref()))),
347                        TyKind::Set(inner) => TyKind::Set(Arc::new(self.fold_ty(inner.as_ref()))),
348                        TyKind::BTreeSet(inner) => {
349                            TyKind::BTreeSet(Arc::new(self.fold_ty(inner.as_ref())))
350                        }
351                        TyKind::Map(k, v) => {
352                            TyKind::Map(k.clone(), Arc::new(self.fold_ty(v.as_ref())))
353                        }
354                        TyKind::BTreeMap(k, v) => {
355                            TyKind::BTreeMap(k.clone(), Arc::new(self.fold_ty(v.as_ref())))
356                        }
357                        TyKind::Path(_) | TyKind::String | TyKind::BytesVec => {
358                            TyKind::Arc(Arc::new(ty.clone()))
359                        }
360                        _ => panic!("ty: {ty:?} is unnecessary to be wrapped by Arc"),
361                    };
362                    Ty {
363                        kind,
364                        tags_id: ty.tags_id,
365                    }
366                }
367            }
368            ArcFolder.fold_ty(&ty)
369        } else {
370            ty
371        }
372    }
373
374    fn lower_pb_extendees(&mut self, extendees: &ir::ext::pb::Extendees) -> Extendees {
375        Extendees(
376            extendees
377                .0
378                .iter()
379                .map(|e| self.lower_pb_extendee(e))
380                .collect::<Vec<_>>(),
381        )
382    }
383
384    fn lower_file_exts(&mut self, exts: &ir::ext::FileExts) -> FileExts {
385        match exts {
386            ir::ext::FileExts::Pb(exts) => FileExts::Pb(pb::FileExts {
387                well_known_file_name: exts.well_known_file_name,
388                extendees: self.lower_pb_extendees(&exts.extendees),
389                used_options: self.lower_used_options(&exts.used_options),
390            }),
391            ir::ext::FileExts::Thrift => FileExts::Thrift,
392        }
393    }
394
395    fn lower_mod_exts(&mut self, exts: &ir::ext::ModExts) -> ModExts {
396        match exts {
397            ir::ext::ModExts::Pb(exts) => ModExts::Pb(pb::ModExts {
398                extendees: self.lower_pb_extendees(&exts.extendees),
399            }),
400            ir::ext::ModExts::Thrift => ModExts::Thrift,
401        }
402    }
403
404    fn lower_item_exts(&mut self, exts: &ir::ext::ItemExts) -> ItemExts {
405        match exts {
406            ir::ext::ItemExts::Pb(exts) => ItemExts::Pb(pb::ItemExts {
407                used_options: self.lower_used_options(&exts.used_options),
408                parent: exts
409                    .parent
410                    .as_ref()
411                    .map(|p| self.lower_path(p, Namespace::Ty, false)),
412            }),
413            ir::ext::ItemExts::Thrift => ItemExts::Thrift,
414        }
415    }
416
417    fn lower_used_options(&mut self, exts: &ir::ext::pb::UsedOptions) -> UsedOptions {
418        UsedOptions(
419            exts.0
420                .iter()
421                .map(|index| {
422                    let idx = (*index).into();
423                    self.pb_ext_indexes_used.insert(idx);
424                    idx
425                })
426                .collect(),
427        )
428    }
429
430    #[tracing::instrument(level = "debug", skip_all, fields(name = &**f.name))]
431    fn lower_field(&mut self, f: &ir::Field) -> Arc<Field> {
432        tracing::info!("lower filed {}, ty: {:?}", f.name, f.ty.kind);
433        let did = self.did_counter.inc_one();
434        let tags_id = self.tags_id_counter.inc_one();
435        self.tags.insert(tags_id, f.tags.clone());
436        let ty = self.lower_type(&f.ty, false);
437        let ty = self.modify_ty_by_tags(ty, &f.tags);
438
439        let mut kind = match f.kind {
440            ir::FieldKind::Required => FieldKind::Required,
441            ir::FieldKind::Optional => FieldKind::Optional,
442        };
443        if let Some(OptionalRepeated(true)) = f.tags.get::<OptionalRepeated>() {
444            kind = FieldKind::Optional;
445        }
446
447        let f = Arc::from(Field {
448            leading_comments: f.leading_comments.clone(),
449            trailing_comments: f.trailing_comments.clone(),
450            did,
451            id: f.id,
452            kind,
453            name: f.name.clone(),
454            ty,
455            tags_id,
456            default: f.default.as_ref().map(|d| self.lower_lit(d)),
457            item_exts: self.lower_item_exts(&f.item_exts),
458        });
459
460        self.nodes
461            .insert(did, self.mk_node(NodeKind::Field(f.clone()), tags_id));
462
463        f
464    }
465
466    fn mk_node(&self, kind: NodeKind, tags: TagId) -> Node {
467        Node {
468            related_nodes: Default::default(),
469            tags,
470            parent: self.parent_node,
471            file_id: self.cur_file.unwrap(),
472            kind,
473        }
474    }
475
476    fn lower_type(&mut self, ty: &ir::Ty, is_args: bool) -> Ty {
477        let kind = match &ty.kind {
478            ir::TyKind::String => ty::FastStr,
479            ir::TyKind::Void => ty::Void,
480            ir::TyKind::U8 => ty::U8,
481            ir::TyKind::Bool => ty::Bool,
482            ir::TyKind::Bytes => ty::Bytes,
483            ir::TyKind::I8 => ty::I8,
484            ir::TyKind::I16 => ty::I16,
485            ir::TyKind::I32 => ty::I32,
486            ir::TyKind::I64 => ty::I64,
487            ir::TyKind::F64 => ty::F64,
488            ir::TyKind::Uuid => ty::Uuid,
489            ir::TyKind::Vec(ty) => ty::Vec(Arc::from(self.lower_type(ty, false))),
490            ir::TyKind::Set(ty) => ty::Set(Arc::from(self.lower_type_for_hash_key(ty, false))),
491            ir::TyKind::Map(k, v) => ty::Map(
492                Arc::from(self.lower_type_for_hash_key(k, false)),
493                Arc::from(self.lower_type(v, false)),
494            ),
495            ir::TyKind::Path(p) => ty::Path(self.lower_path(p, Namespace::Ty, is_args)),
496            ir::TyKind::UInt64 => ty::UInt64,
497            ir::TyKind::UInt32 => ty::UInt32,
498            ir::TyKind::F32 => ty::F32,
499        };
500        let tags_id = self.tags_id_counter.inc_one();
501
502        self.tags.insert(tags_id, ty.tags.clone());
503
504        Ty { kind, tags_id }
505    }
506
507    fn lower_type_for_hash_key(&mut self, ty: &ir::Ty, is_args: bool) -> Ty {
508        let kind = match &ty.kind {
509            ir::TyKind::String => ty::FastStr,
510            ir::TyKind::Void => ty::Void,
511            ir::TyKind::U8 => ty::U8,
512            ir::TyKind::Bool => ty::Bool,
513            ir::TyKind::Bytes => ty::Bytes,
514            ir::TyKind::I8 => ty::I8,
515            ir::TyKind::I16 => ty::I16,
516            ir::TyKind::I32 => ty::I32,
517            ir::TyKind::I64 => ty::I64,
518            ir::TyKind::F64 => ty::OrderedF64,
519            ir::TyKind::Uuid => ty::Uuid,
520            ir::TyKind::Vec(ty) => ty::Vec(Arc::from(self.lower_type_for_hash_key(ty, false))),
521            ir::TyKind::Set(ty) => ty::Set(Arc::from(self.lower_type_for_hash_key(ty, false))),
522            ir::TyKind::Map(k, v) => ty::Map(
523                Arc::from(self.lower_type_for_hash_key(k, false)),
524                Arc::from(self.lower_type(v, false)),
525            ),
526            ir::TyKind::Path(p) => ty::Path(self.lower_path(p, Namespace::Ty, is_args)),
527            ir::TyKind::UInt64 => ty::UInt64,
528            ir::TyKind::UInt32 => ty::UInt32,
529            ir::TyKind::F32 => ty::F32,
530        };
531        let tags_id = self.tags_id_counter.inc_one();
532
533        self.tags.insert(tags_id, ty.tags.clone());
534
535        Ty { kind, tags_id }
536    }
537
538    fn find_path_in_table(
539        &self,
540        path: &[Ident],
541        ns: Namespace,
542        table: &SymbolTable,
543    ) -> Option<DefId> {
544        assert!(!path.is_empty());
545        let mut status: ResolvedSymbols = ResolvedSymbols {
546            ty: table
547                .ty
548                .get(&path[0].sym)
549                .map_or_else(Default::default, |s| vec![*s]),
550            value: table
551                .value
552                .get(&path[0].sym)
553                .map_or_else(Default::default, |s| vec![*s]),
554            r#mod: table
555                .mods
556                .get(&path[0].sym)
557                .map_or_else(Default::default, |s| vec![*s]),
558        };
559
560        path[1..].iter().for_each(|i| {
561            status = ResolvedSymbols {
562                ty: [&status.ty, &status.value, &status.r#mod]
563                    .into_iter()
564                    .flatten()
565                    .flat_map(|def_id| {
566                        self.def_modules
567                            .get(def_id)
568                            .and_then(|module| module.resolutions.ty.get(&i.sym))
569                    })
570                    .copied()
571                    .collect(),
572                value: [&status.ty, &status.value, &status.r#mod]
573                    .into_iter()
574                    .flatten()
575                    .flat_map(|def_id| {
576                        self.def_modules
577                            .get(def_id)
578                            .and_then(|module| module.resolutions.value.get(&i.sym))
579                    })
580                    .copied()
581                    .collect(),
582                r#mod: [&status.ty, &status.value, &status.r#mod]
583                    .into_iter()
584                    .flatten()
585                    .flat_map(|def_id| {
586                        self.def_modules
587                            .get(def_id)
588                            .and_then(|module| module.resolutions.mods.get(&i.sym))
589                    })
590                    .copied()
591                    .collect_vec(),
592            };
593        });
594
595        assert!(status.value.len() <= 1);
596        assert!(status.ty.len() <= 1);
597        assert!(status.r#mod.len() <= 1);
598
599        match ns {
600            Namespace::Value => status.value.first(),
601            Namespace::Ty => status.ty.first(),
602            Namespace::Mod => status.r#mod.first(),
603        }
604        .copied()
605    }
606
607    fn lower_path(&mut self, path: &ir::Path, ns: Namespace, is_args: bool) -> Path {
608        let segs = &path.segments;
609        let cur_file = self.ir_files.get(self.cur_file.as_ref().unwrap()).unwrap();
610        let path_kind = match ns {
611            Namespace::Value => DefKind::Value,
612            Namespace::Ty => DefKind::Type,
613            Namespace::Mod => unreachable!(),
614        };
615        {
616            let segs = match segs.strip_prefix(&*cur_file.package.segments) {
617                Some(segs) => segs,
618                _ => segs,
619            };
620
621            let def_id = self.blocks.iter().rev().find_map(|b| {
622                let b = unsafe { b.as_ref() };
623                self.find_path_in_table(segs, ns, b)
624            });
625
626            if let Some(def_id) = def_id {
627                if is_args {
628                    self.args.insert(def_id);
629                }
630                return Path {
631                    kind: path_kind,
632                    did: def_id,
633                };
634            }
635        }
636        let def_id = cur_file
637            .uses
638            .iter()
639            .find_map(|f| match path.segments.strip_prefix(&*f.0.segments) {
640                Some(rest) => {
641                    let file = &self.file_sym_map[&f.1];
642                    self.find_path_in_table(rest, ns, file)
643                }
644                _ => None,
645            })
646            .unwrap_or_else(|| {
647                panic!(
648                    "can not find path {} in file symbols {:?}, {:?}",
649                    path,
650                    self.file_sym_map.get(&self.cur_file.unwrap()),
651                    cur_file.uses,
652                )
653            });
654
655        if is_args {
656            self.args.insert(def_id);
657        }
658        Path {
659            kind: path_kind,
660            did: def_id,
661        }
662    }
663
664    #[tracing::instrument(level = "debug", skip(self, s), fields(name = &**s.name))]
665    fn lower_message(&mut self, s: &ir::Message) -> Message {
666        Message {
667            leading_comments: s.leading_comments.clone(),
668            trailing_comments: s.trailing_comments.clone(),
669            name: s.name.clone(),
670            fields: s.fields.iter().map(|f| self.lower_field(f)).collect(),
671            is_wrapper: s.is_wrapper,
672            item_exts: self.lower_item_exts(&s.item_exts),
673        }
674    }
675
676    fn lower_enum(&mut self, e: &ir::Enum) -> Enum {
677        let mut next_discr = 0;
678        Enum {
679            leading_comments: e.leading_comments.clone(),
680            trailing_comments: e.trailing_comments.clone(),
681            name: e.name.clone(),
682            variants: e
683                .variants
684                .iter()
685                .map(|v| {
686                    let tags_id = self.tags_id_counter.inc_one();
687                    let did = self.get_def_id(Namespace::Value, &v.name);
688                    if !v.tags.is_empty() {
689                        self.tags.insert(tags_id, v.tags.clone());
690                    }
691                    let discr = v.discr.unwrap_or(next_discr);
692                    let e = Arc::from(EnumVariant {
693                        leading_comments: v.leading_comments.clone(),
694                        trailing_comments: v.trailing_comments.clone(),
695                        id: v.id,
696                        did,
697                        name: v.name.clone(),
698                        discr: if e.repr == Some(EnumRepr::I32) {
699                            Some(discr)
700                        } else {
701                            None
702                        },
703                        fields: v
704                            .fields
705                            .iter()
706                            .map(|p| {
707                                let ty = self.lower_type(p, false);
708                                self.modify_ty_by_tags(ty, &p.tags)
709                            })
710                            .collect(),
711                        item_exts: self.lower_item_exts(&v.item_exts),
712                    });
713                    next_discr = discr + 1;
714                    self.nodes
715                        .insert(did, self.mk_node(NodeKind::Variant(e.clone()), tags_id));
716                    e
717                })
718                .collect(),
719            repr: e.repr,
720            item_exts: self.lower_item_exts(&e.item_exts),
721        }
722    }
723
724    fn lower_service(&mut self, s: &ir::Service) -> Service {
725        Service {
726            leading_comments: s.leading_comments.clone(),
727            trailing_comments: s.trailing_comments.clone(),
728            name: s.name.clone(),
729            methods: s
730                .methods
731                .iter()
732                .map(|m| {
733                    let def_id = self.did_counter.inc_one();
734                    let tags_id = self.tags_id_counter.inc_one();
735                    self.tags.insert(tags_id, m.tags.clone());
736                    let old_parent = self.parent_node.replace(def_id);
737                    let method = Arc::from(Method {
738                        leading_comments: m.leading_comments.clone(),
739                        trailing_comments: m.trailing_comments.clone(),
740                        def_id,
741                        source: MethodSource::Own,
742                        name: m.name.clone(),
743                        args: m
744                            .args
745                            .iter()
746                            .map(|a| {
747                                let tags_id = self.tags_id_counter.inc_one();
748                                self.tags.insert(tags_id, a.tags.clone());
749                                let def_id = self.did_counter.inc_one();
750                                let arg = Arc::new(Arg {
751                                    def_id,
752                                    ty: self.lower_type(&a.ty, true),
753                                    name: a.name.clone(),
754                                    id: a.id,
755                                    tags_id,
756                                    kind: match a.attribute {
757                                        ir::FieldKind::Required => FieldKind::Required,
758                                        ir::FieldKind::Optional => FieldKind::Optional,
759                                    },
760                                });
761                                self.nodes.insert(
762                                    def_id,
763                                    self.mk_node(NodeKind::Arg(arg.clone()), tags_id),
764                                );
765                                arg
766                            })
767                            .collect(),
768                        ret: self.lower_type(&m.ret, true),
769                        oneway: m.oneway,
770                        exceptions: m
771                            .exceptions
772                            .as_ref()
773                            .map(|p| self.lower_path(p, Namespace::Ty, true)),
774                        item_exts: self.lower_item_exts(&m.item_exts),
775                    });
776                    self.parent_node = old_parent;
777                    self.nodes.insert(
778                        def_id,
779                        self.mk_node(NodeKind::Method(method.clone()), tags_id),
780                    );
781
782                    method
783                })
784                .collect(),
785            extend: s
786                .extend
787                .iter()
788                .map(|p| self.lower_path(p, Namespace::Ty, false))
789                .collect(),
790            item_exts: self.lower_item_exts(&s.item_exts),
791        }
792    }
793
794    fn lower_type_alias(&mut self, t: &ir::NewType, tags: &Tags) -> NewType {
795        NewType {
796            leading_comments: t.leading_comments.clone(),
797            trailing_comments: t.trailing_comments.clone(),
798            name: t.name.clone(),
799            ty: {
800                let ty = self.lower_type(&t.ty, false);
801                self.modify_ty_by_tags(ty, tags)
802            },
803        }
804    }
805
806    fn lower_lit(&mut self, l: &ir::Literal) -> Literal {
807        match l {
808            ir::Literal::Bool(b) => Literal::Bool(*b),
809            ir::Literal::Path(p) => Literal::Path(self.lower_path(p, Namespace::Value, false)),
810            ir::Literal::String(s) => Literal::String(s.clone()),
811            ir::Literal::Int(i) => Literal::Int(*i),
812            ir::Literal::Float(f) => Literal::Float(f.clone()),
813            ir::Literal::List(l) => Literal::List(l.iter().map(|l| self.lower_lit(l)).collect()),
814            ir::Literal::Map(l) => Literal::Map(
815                l.iter()
816                    .map(|(k, v)| (self.lower_lit(k), self.lower_lit(v)))
817                    .collect(),
818            ),
819        }
820    }
821
822    fn lower_const(&mut self, c: &ir::Const, tags: &Tags) -> Const {
823        Const {
824            leading_comments: c.leading_comments.clone(),
825            trailing_comments: c.trailing_comments.clone(),
826            name: c.name.clone(),
827            ty: {
828                let ty = self.lower_type(&c.ty, false);
829                self.modify_ty_by_tags(ty, tags)
830            },
831            lit: self.lower_lit(&c.lit),
832        }
833    }
834
835    fn lower_mod(&mut self, m: &ir::Mod, def_id: DefId) -> Mod {
836        self.blocks.push(NonNull::from(
837            &self.def_modules.get(&def_id).unwrap().resolutions,
838        ));
839
840        let items = m
841            .items
842            .iter()
843            .filter_map(|i| self.lower_item(i))
844            .collect::<Vec<_>>();
845
846        self.blocks.pop();
847
848        Mod {
849            name: m.name.clone(),
850            items,
851            extensions: self.lower_mod_exts(&m.extensions),
852        }
853    }
854
855    fn lower_item(&mut self, item: &ir::Item) -> Option<DefId> {
856        if let ir::ItemKind::Use(_) = &item.kind {
857            return None;
858        }
859
860        let name = item.name();
861        let tags = &item.tags;
862
863        let def_id = self.get_def_id(
864            match &item.kind {
865                ir::ItemKind::Const(_) => Namespace::Value,
866                ir::ItemKind::Mod(_) => Namespace::Mod,
867                _ => Namespace::Ty,
868            },
869            &name,
870        );
871
872        let old_parent = self.parent_node.replace(def_id);
873        let related_items = &item.related_items;
874
875        let item = Arc::new(match &item.kind {
876            ir::ItemKind::Message(s) => Item::Message(self.lower_message(s)),
877            ir::ItemKind::Enum(e) => Item::Enum(self.lower_enum(e)),
878            ir::ItemKind::Service(s) => Item::Service(self.lower_service(s)),
879            ir::ItemKind::NewType(t) => Item::NewType(self.lower_type_alias(t, tags)),
880            ir::ItemKind::Const(c) => Item::Const(self.lower_const(c, tags)),
881            ir::ItemKind::Mod(m) => Item::Mod(self.lower_mod(m, def_id)),
882            ir::ItemKind::Use(_) => unreachable!(),
883        });
884
885        self.parent_node = old_parent;
886
887        let tags_id = self.tags_id_counter.inc_one();
888        self.tags.insert(tags_id, tags.clone());
889
890        let mut node = self.mk_node(NodeKind::Item(item), tags_id);
891        node.related_nodes = related_items
892            .iter()
893            .map(|i| {
894                self.lower_path(
895                    &ir::Path {
896                        segments: Arc::from([i.clone()]),
897                    },
898                    Namespace::Ty,
899                    false,
900                )
901                .did
902            })
903            .collect();
904
905        self.nodes.insert(def_id, node);
906
907        Some(def_id)
908    }
909
910    fn lower_file(&mut self, file: &ir::File) -> File {
911        let old_file = self.cur_file.replace(file.id);
912        let should_pop = self
913            .file_sym_map
914            .get(&file.id)
915            .map(|block| self.blocks.push(NonNull::from(block)))
916            .is_some();
917
918        let f = File {
919            items: file
920                .items
921                .iter()
922                .filter_map(|item| self.lower_item(item))
923                .collect(),
924
925            file_id: file.id,
926            package: ItemPath::from(
927                file.package
928                    .segments
929                    .iter()
930                    .map(|i| i.sym.clone())
931                    .collect::<Vec<_>>(),
932            ),
933            uses: file.uses.iter().map(|(_, f)| *f).collect(),
934            descriptor: file.descriptor.clone(),
935            extensions: self.lower_file_exts(&file.extensions),
936            comments: file.comments.clone(),
937        };
938
939        if should_pop {
940            self.blocks.pop();
941        }
942
943        self.cur_file = old_file;
944        f
945    }
946
947    fn lower_pb_extendee(&mut self, e: &ir::ext::pb::Extendee) -> Arc<middle::ext::pb::Extendee> {
948        let extendee_index = e.index.into();
949        let extendee = Arc::new(middle::ext::pb::Extendee {
950            name: e.name.clone(),
951            index: extendee_index,
952            extendee_ty: ExtendeeType {
953                field_ty: e.extendee_ty.field_ty.into(),
954                item_ty: self.lower_type(&e.extendee_ty.item_ty, false),
955            },
956        });
957
958        self.pb_ext_indexes.insert(extendee_index, extendee.clone());
959        extendee
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use std::{str::FromStr as _, sync::Arc};
966
967    use super::*;
968
969    fn mk_ty(kind: TyKind) -> Ty {
970        Ty {
971            kind,
972            tags_id: TagId::from_usize(0),
973        }
974    }
975
976    fn vec_ty(inner: Ty) -> Ty {
977        mk_ty(TyKind::Vec(Arc::new(inner)))
978    }
979
980    fn set_ty(inner: Ty) -> Ty {
981        mk_ty(TyKind::Set(Arc::new(inner)))
982    }
983
984    fn map_ty(key: Ty, value: Ty) -> Ty {
985        mk_ty(TyKind::Map(Arc::new(key), Arc::new(value)))
986    }
987
988    #[test]
989    fn converts_faststr_to_string_with_string_tag() {
990        let mut resolver = Resolver::default();
991        let ty = mk_ty(TyKind::FastStr);
992        let tags = {
993            let mut tags = Tags::default();
994            tags.insert(RustType::from_str("string").unwrap());
995            tags
996        };
997
998        let result = resolver.modify_ty_by_tags(ty, &tags);
999
1000        assert!(matches!(result.kind, TyKind::String));
1001    }
1002
1003    #[test]
1004    fn converts_bytes_to_bytes_vec_with_vec_tag() {
1005        let mut resolver = Resolver::default();
1006        let ty = mk_ty(TyKind::Bytes);
1007        let tags = {
1008            let mut tags = Tags::default();
1009            tags.insert(RustType::from_str("vec").unwrap());
1010            tags
1011        };
1012
1013        let result = resolver.modify_ty_by_tags(ty, &tags);
1014
1015        assert!(matches!(result.kind, TyKind::BytesVec));
1016    }
1017
1018    #[test]
1019    fn converts_collections_to_btree_variants() {
1020        let mut resolver = Resolver::default();
1021        let ty = map_ty(
1022            set_ty(mk_ty(TyKind::FastStr)),
1023            vec_ty(mk_ty(TyKind::FastStr)),
1024        );
1025        let tags = {
1026            let mut tags = Tags::default();
1027            tags.insert(RustType::from_str("btree").unwrap());
1028            tags
1029        };
1030
1031        let result = resolver.modify_ty_by_tags(ty, &tags);
1032
1033        match result.kind {
1034            TyKind::BTreeMap(key, value) => {
1035                match &key.as_ref().kind {
1036                    TyKind::BTreeSet(inner) => {
1037                        assert!(matches!(inner.as_ref().kind, TyKind::FastStr));
1038                    }
1039                    other => panic!("expected BTreeSet key, got {:?}", other),
1040                }
1041
1042                assert!(matches!(value.as_ref().kind, TyKind::Vec(_)));
1043            }
1044            other => panic!("expected BTreeMap, got {:?}", other),
1045        }
1046    }
1047
1048    #[test]
1049    fn converts_f64_to_ordered_f64_with_ordered_tag() {
1050        let mut resolver = Resolver::default();
1051        let ty = mk_ty(TyKind::F64);
1052        let tags = {
1053            let mut tags = Tags::default();
1054            tags.insert(RustType::from_str("ordered_f64").unwrap());
1055            tags
1056        };
1057
1058        let result = resolver.modify_ty_by_tags(ty, &tags);
1059
1060        assert!(matches!(result.kind, TyKind::OrderedF64));
1061    }
1062
1063    #[test]
1064    fn wraps_collection_elements_with_arc_when_tagged() {
1065        let mut resolver = Resolver::default();
1066        let ty = vec_ty(mk_ty(TyKind::String));
1067        let mut tags = Tags::default();
1068        tags.insert(RustWrapperArc(true));
1069
1070        let result = resolver.modify_ty_by_tags(ty, &tags);
1071
1072        match result.kind {
1073            TyKind::Vec(inner) => match &inner.as_ref().kind {
1074                TyKind::Arc(arc_inner) => {
1075                    assert!(matches!(arc_inner.as_ref().kind, TyKind::String));
1076                }
1077                other => panic!("expected Arc, got {:?}", other),
1078            },
1079            other => panic!("expected Vec, got {:?}", other),
1080        }
1081    }
1082}