Skip to main content

pilota_build/parser/thrift/
mod.rs

1use core::panic;
2use std::{path::PathBuf, str::FromStr, sync::Arc};
3
4use faststr::FastStr;
5use heck::ToUpperCamelCase;
6use itertools::Itertools;
7use normpath::PathExt;
8use pilota_thrift_parser::{self as thrift_parser};
9use pilota_thrift_reflect::thrift_reflection;
10use rustc_hash::{FxHashMap, FxHashSet};
11use thrift_parser::Annotations;
12
13use crate::{
14    IdentName,
15    index::Idx,
16    ir::{
17        self, Arg, Enum, EnumVariant, FieldKind, File, Item, ItemKind, Path,
18        ext::{self, FileExts},
19    },
20    symbol::{EnumRepr, FileId, Ident},
21    tags::{Annotation, PilotaName, RustWrapperArc, Tags},
22    util::error_abort,
23};
24
25fn generate_short_uuid() -> FastStr {
26    let uuid: [u8; 4] = rand::random();
27    FastStr::new(hex::encode(uuid))
28}
29
30#[salsa::db]
31#[derive(Default, Clone)]
32struct ThriftSourceDatabase {
33    storage: salsa::Storage<Self>,
34}
35
36#[salsa::db]
37impl salsa::Database for ThriftSourceDatabase {}
38
39impl ThriftSourceDatabase {
40    fn file_text(&self, path: PathBuf) -> Arc<str> {
41        Arc::from(unsafe { String::from_utf8_unchecked(std::fs::read(path).unwrap()) })
42    }
43
44    fn parse(&self, path: PathBuf) -> Arc<thrift_parser::File> {
45        let text = self.file_text(path.clone());
46        let res = thrift_parser::FileParser::new(
47            thrift_parser::FileSource::new_with_path(path.clone(), text.as_ref()).unwrap(),
48        )
49        .parse();
50
51        if res.is_err() {
52            eprintln!("{}", res.err().unwrap());
53            std::process::exit(1);
54        }
55
56        let mut ast = res.unwrap();
57        ast.path = Arc::from(path);
58        ast.uuid = generate_short_uuid();
59        let descriptor = thrift_reflection::FileDescriptor::from(&ast);
60        ast.descriptor = descriptor.serialize();
61        Arc::from(ast)
62    }
63}
64
65#[derive(Debug)]
66pub struct LowerResult {
67    pub files: Vec<Arc<File>>,
68    pub file_ids_map: FxHashMap<Arc<PathBuf>, FileId>,
69    pub file_paths: FxHashMap<FileId, Arc<PathBuf>>,
70    pub file_names: FxHashMap<FileId, FastStr>,
71}
72
73pub trait Lower<Ast> {
74    fn lower(&mut self, file: Ast) -> FileId;
75
76    fn finish(self) -> LowerResult;
77}
78
79pub struct ThriftLower {
80    cur_file: Option<Arc<thrift_parser::File>>,
81    next_file_id: FileId,
82    db: ThriftSourceDatabase,
83    files: FxHashMap<FileId, Arc<File>>,
84    file_ids_map: FxHashMap<Arc<PathBuf>, FileId>,
85    file_paths: FxHashMap<FileId, Arc<PathBuf>>,
86    file_names: FxHashMap<FileId, FastStr>,
87    include_dirs: Vec<PathBuf>,
88    packages: FxHashMap<Path, Vec<Arc<PathBuf>>>,
89    service_name_duplicates: FxHashSet<String>,
90}
91
92impl ThriftLower {
93    fn new(db: ThriftSourceDatabase, include_dirs: Vec<PathBuf>) -> Self {
94        ThriftLower {
95            cur_file: None,
96            next_file_id: FileId::from_u32(0),
97            db,
98            files: FxHashMap::default(),
99            file_ids_map: FxHashMap::default(),
100            file_paths: FxHashMap::default(),
101            file_names: FxHashMap::default(),
102            include_dirs,
103            packages: Default::default(),
104            service_name_duplicates: Default::default(),
105        }
106    }
107
108    pub fn with_cur_file<F>(&mut self, file: Arc<thrift_parser::File>, f: F) -> Arc<File>
109    where
110        F: FnOnce(&mut Self) -> ir::File,
111    {
112        let old_file = self.cur_file.clone();
113        self.cur_file = Some(file);
114
115        let f = Arc::from(f(self));
116        self.cur_file = old_file;
117        self.files.insert(f.id, f.clone());
118        f
119    }
120
121    fn lower_path(&self, path: &thrift_parser::Path) -> ir::Path {
122        Path {
123            segments: Arc::from_iter(path.segments.iter().map(|i| self.lower_ident(i))),
124        }
125    }
126
127    fn mk_item(&self, kind: ItemKind, tags: Arc<Tags>) -> ir::Item {
128        ir::Item {
129            kind,
130            tags,
131            related_items: Default::default(),
132        }
133    }
134
135    fn lower_service(&self, service: &thrift_parser::Service) -> Vec<ir::Item> {
136        let service_name = if self
137            .service_name_duplicates
138            .contains(&service.name.to_upper_camel_case())
139        {
140            service.name.to_string()
141        } else {
142            service.name.to_upper_camel_case()
143        };
144
145        let service_tags = self.extract_tags(&service.annotations);
146        let arc_wrapper = service_tags.get::<RustWrapperArc>().is_some_and(|v| v.0);
147
148        let mut function_names: FxHashMap<FastStr, Vec<String>> = FxHashMap::default();
149        service.functions.iter().for_each(|func| {
150            let name = self
151                .extract_tags(&func.annotations)
152                .get::<PilotaName>()
153                .map(|name| name.0.to_string())
154                .unwrap_or_else(|| func.name.to_string());
155            function_names
156                .entry(name.as_str().upper_camel_ident())
157                .or_default()
158                .push(name);
159        });
160        let function_name_duplicates = function_names
161            .iter()
162            .filter(|(_, v)| v.len() > 1)
163            .map(|(k, _)| k.as_str())
164            .collect::<FxHashSet<_>>();
165
166        let kind = ir::ItemKind::Service(ir::Service {
167            leading_comments: service.leading_comments.clone(),
168            trailing_comments: service.trailing_comments.clone(),
169            name: self.lower_ident(&service.name),
170            extend: service
171                .extends
172                .as_ref()
173                .into_iter()
174                .map(|e| self.lower_path(e))
175                .collect(),
176            methods: service
177                .functions
178                .iter()
179                .map(|f| {
180                    self.lower_method(&service_name, f, &function_name_duplicates, arc_wrapper)
181                })
182                .collect(),
183            item_exts: ext::ItemExts::Thrift,
184        });
185        let mut service_item = self.mk_item(kind, Default::default());
186        let mut result = vec![];
187
188        let mut related_items = Vec::default();
189
190        service.functions.iter().for_each(|f| {
191            let exception = f
192                .throws
193                .iter()
194                .map(|f| ir::EnumVariant {
195                    leading_comments: f.leading_comments.clone(),
196                    trailing_comments: f.trailing_comments.clone(),
197                    id: Some(f.id),
198                    name: if f.name.is_empty() {
199                        match &f.ty.0 {
200                            thrift_parser::Ty::Path(p) => {
201                                self.lower_ident(p.segments.last().unwrap())
202                            }
203                            _ => panic!(""),
204                        }
205                    } else {
206                        self.lower_ident(&f.name)
207                    },
208                    tags: Default::default(),
209                    discr: None,
210                    fields: vec![self.lower_ty(&f.ty)],
211                    item_exts: ext::ItemExts::Thrift,
212                })
213                .collect::<Vec<_>>();
214
215            let tags = self.extract_tags(&f.annotations);
216            let name = tags
217                .get::<PilotaName>()
218                .map(|name| name.0.clone())
219                .unwrap_or_else(|| FastStr::new(f.name.0.clone()));
220
221            let upper_camel_ident = name.as_str().upper_camel_ident();
222            let method_name = if function_name_duplicates.contains(upper_camel_ident.as_str()) {
223                name
224            } else {
225                upper_camel_ident
226            };
227
228            let name: Ident = format!("{}{}ResultRecv", service_name, method_name).into();
229            let kind = ir::ItemKind::Enum(ir::Enum {
230                leading_comments: f.leading_comments.clone(),
231                trailing_comments: f.trailing_comments.clone(),
232                name: name.clone(),
233                variants: std::iter::once(ir::EnumVariant {
234                    leading_comments: f.leading_comments.clone(),
235                    trailing_comments: f.trailing_comments.clone(),
236                    id: Some(0),
237                    name: "Ok".into(),
238                    tags: Default::default(),
239                    discr: None,
240                    fields: vec![self.lower_method_relative_ty(&f.result_type, arc_wrapper)],
241                    item_exts: ext::ItemExts::Thrift,
242                })
243                .chain(exception.clone())
244                .collect(),
245                repr: None,
246                item_exts: ext::ItemExts::Thrift,
247            });
248            related_items.push(name.clone());
249            let mut tags = Tags::default();
250            tags.insert(crate::tags::KeepUnknownFields(false));
251            tags.insert(crate::tags::PilotaName(name.raw_str()));
252            result.push(self.mk_item(kind, tags.into()));
253
254            let name: Ident = format!("{service_name}{method_name}ResultSend").into();
255            let kind = ir::ItemKind::Enum(ir::Enum {
256                leading_comments: f.leading_comments.clone(),
257                trailing_comments: f.trailing_comments.clone(),
258                name: name.clone(),
259                variants: std::iter::once(ir::EnumVariant {
260                    leading_comments: f.leading_comments.clone(),
261                    trailing_comments: f.trailing_comments.clone(),
262                    id: Some(0),
263                    name: "Ok".into(),
264                    tags: Default::default(),
265                    discr: None,
266                    fields: vec![self.lower_method_relative_ty(&f.result_type, arc_wrapper)],
267                    item_exts: ext::ItemExts::Thrift,
268                })
269                .chain(exception.clone())
270                .collect(),
271                repr: None,
272                item_exts: ext::ItemExts::Thrift,
273            });
274            related_items.push(name.clone());
275            let mut tags = Tags::default();
276            tags.insert(crate::tags::KeepUnknownFields(false));
277            tags.insert(crate::tags::PilotaName(name.raw_str()));
278            result.push(self.mk_item(kind, tags.into()));
279
280            if !exception.is_empty() {
281                let name: Ident = format!("{service_name}{method_name}Exception").into();
282                let kind = ir::ItemKind::Enum(ir::Enum {
283                    leading_comments: f.leading_comments.clone(),
284                    trailing_comments: f.trailing_comments.clone(),
285                    name: name.clone(),
286                    variants: exception,
287                    repr: None,
288                    item_exts: ext::ItemExts::Thrift,
289                });
290                related_items.push(name.clone());
291                let mut tags = Tags::default();
292                tags.insert(crate::tags::KeepUnknownFields(false));
293                tags.insert(crate::tags::PilotaName(name.raw_str()));
294                result.push(self.mk_item(kind, tags.into()));
295            }
296
297            let name: Ident = format!("{service_name}{method_name}ArgsSend").into();
298            let kind = ir::ItemKind::Message(ir::Message {
299                leading_comments: f.leading_comments.clone(),
300                trailing_comments: f.trailing_comments.clone(),
301                name: name.clone(),
302                fields: f
303                    .arguments
304                    .iter()
305                    .map(|a| self.lower_method_arg_field(a, arc_wrapper))
306                    .collect(),
307                is_wrapper: true,
308                item_exts: ext::ItemExts::Thrift,
309            });
310            related_items.push(name.clone());
311            let mut tags = Tags::default();
312            tags.insert(crate::tags::KeepUnknownFields(false));
313            tags.insert(crate::tags::PilotaName(name.raw_str()));
314            result.push(self.mk_item(kind, tags.into()));
315
316            let name: Ident = format!("{service_name}{method_name}ArgsRecv").into();
317            let kind = ir::ItemKind::Message(ir::Message {
318                leading_comments: f.leading_comments.clone(),
319                trailing_comments: f.trailing_comments.clone(),
320                name: name.clone(),
321                fields: f
322                    .arguments
323                    .iter()
324                    .map(|a| self.lower_method_arg_field(a, arc_wrapper))
325                    .collect(),
326                is_wrapper: true,
327                item_exts: ext::ItemExts::Thrift,
328            });
329            related_items.push(name.clone());
330            let mut tags: Tags = Tags::default();
331            tags.insert(crate::tags::KeepUnknownFields(false));
332            tags.insert(crate::tags::PilotaName(name.raw_str()));
333            result.push(self.mk_item(kind, tags.into()));
334        });
335
336        service_item.related_items = related_items;
337        result.push(service_item);
338        result
339    }
340
341    fn lower_method(
342        &self,
343        service_name: &String,
344        method: &thrift_parser::Function,
345        function_name_duplicates: &FxHashSet<&str>,
346        arc_wrapper: bool,
347    ) -> ir::Method {
348        let tags = self.extract_tags(&method.annotations);
349        let name = tags
350            .get::<PilotaName>()
351            .map(|name| name.0.clone())
352            .unwrap_or_else(|| FastStr::new(method.name.0.clone()));
353
354        let upper_camel_ident = name.as_str().upper_camel_ident();
355        let method_name = if function_name_duplicates.contains(upper_camel_ident.as_str()) {
356            name
357        } else {
358            upper_camel_ident
359        };
360
361        ir::Method {
362            leading_comments: method.leading_comments.clone(),
363            trailing_comments: method.trailing_comments.clone(),
364            name: self.lower_ident(&method.name),
365            args: method
366                .arguments
367                .iter()
368                .map(|a| {
369                    let mut tags = self.extract_tags(&a.annotations);
370                    if arc_wrapper && !tags.contains::<RustWrapperArc>() {
371                        tags.insert(RustWrapperArc(true));
372                    }
373                    Arg {
374                        ty: self.lower_method_relative_ty(&a.ty, arc_wrapper),
375                        id: a.id,
376                        name: self.lower_ident(&a.name),
377                        tags: Arc::new(tags),
378                        attribute: match a.attribute {
379                            pilota_thrift_parser::Attribute::Required => FieldKind::Required,
380                            pilota_thrift_parser::Attribute::Optional
381                            | pilota_thrift_parser::Attribute::Default => FieldKind::Optional,
382                        },
383                    }
384                })
385                .collect(),
386            ret: self.lower_method_relative_ty(&method.result_type, arc_wrapper),
387            oneway: method.oneway,
388            tags: tags.into(),
389            exceptions: if method.throws.is_empty() {
390                None
391            } else {
392                Some(Path {
393                    segments: Arc::from([Ident::from(format!(
394                        "{service_name}{method_name}Exception",
395                    ))]),
396                })
397            },
398            item_exts: ext::ItemExts::Thrift,
399        }
400    }
401
402    fn lower_enum(&self, e: &thrift_parser::Enum) -> ir::Enum {
403        ir::Enum {
404            leading_comments: e.leading_comments.clone(),
405            trailing_comments: e.trailing_comments.clone(),
406            name: self.lower_ident(&e.name),
407            variants: e
408                .values
409                .iter()
410                .map(|v| ir::EnumVariant {
411                    leading_comments: v.leading_comments.clone(),
412                    trailing_comments: v.trailing_comments.clone(),
413                    id: None,
414                    name: self.lower_ident(&v.name),
415                    discr: v.value.map(|v| v.0),
416                    fields: vec![],
417                    tags: self.extract_tags(&v.annotations).into(),
418                    item_exts: ext::ItemExts::Thrift,
419                })
420                .collect(),
421            repr: Some(EnumRepr::I32),
422            item_exts: ext::ItemExts::Thrift,
423        }
424    }
425
426    fn lower_lit(&self, l: &thrift_parser::ConstValue) -> ir::Literal {
427        match &l {
428            thrift_parser::ConstValue::Bool(b) => ir::Literal::Bool(*b),
429            thrift_parser::ConstValue::Path(p) => ir::Literal::Path(self.lower_path(p)),
430            thrift_parser::ConstValue::String(s) => ir::Literal::String(Arc::from(s.0.as_str())),
431            thrift_parser::ConstValue::Int(i) => ir::Literal::Int(i.0),
432            thrift_parser::ConstValue::Double(d) => ir::Literal::Float(d.0.clone()),
433            thrift_parser::ConstValue::List(inner) => {
434                ir::Literal::List(inner.iter().map(|i| self.lower_lit(i)).collect())
435            }
436            thrift_parser::ConstValue::Map(kvs) => ir::Literal::Map(
437                kvs.iter()
438                    .map(|(k, v)| (self.lower_lit(k), self.lower_lit(v)))
439                    .collect(),
440            ),
441        }
442    }
443
444    fn lower_const(&self, c: &thrift_parser::Constant) -> ir::Const {
445        ir::Const {
446            leading_comments: c.leading_comments.clone(),
447            trailing_comments: c.trailing_comments.clone(),
448            name: self.lower_ident(&c.name),
449            ty: self.lower_ty(&c.r#type),
450            lit: self.lower_lit(&c.value),
451        }
452    }
453
454    fn lower_typedef(&self, t: &thrift_parser::Typedef) -> ir::NewType {
455        ir::NewType {
456            leading_comments: t.leading_comments.clone(),
457            trailing_comments: t.trailing_comments.clone(),
458            name: self.lower_ident(&t.alias),
459            ty: self.lower_ty(&t.r#type),
460        }
461    }
462
463    fn lower_item(&self, item: &thrift_parser::Item) -> Vec<ir::Item> {
464        let single = match item {
465            thrift_parser::Item::Typedef(t) => ir::ItemKind::NewType(self.lower_typedef(t)),
466            thrift_parser::Item::Constant(c) => ir::ItemKind::Const(self.lower_const(c)),
467            thrift_parser::Item::Enum(e) => ir::ItemKind::Enum(self.lower_enum(e)),
468            thrift_parser::Item::Struct(s) => ir::ItemKind::Message(self.lower_struct(s)),
469            thrift_parser::Item::Union(u) => ir::ItemKind::Enum(self.lower_union(u)),
470            thrift_parser::Item::Exception(e) => ir::ItemKind::Message(self.lower_exception(e)),
471            thrift_parser::Item::Service(s) => return self.lower_service(s),
472            _ => return vec![],
473        };
474
475        let empty_annotations = Annotations::default();
476
477        let annotations = match item {
478            thrift_parser::Item::Typedef(t) => &t.annotations,
479            thrift_parser::Item::Constant(c) => &c.annotations,
480            thrift_parser::Item::Enum(e) => &e.annotations,
481            thrift_parser::Item::Struct(s) => &s.annotations,
482            thrift_parser::Item::Union(u) => &u.annotations,
483            thrift_parser::Item::Exception(e) => &e.annotations,
484            thrift_parser::Item::Service(s) => &s.annotations,
485            _ => &empty_annotations,
486        };
487
488        let tags = self.extract_tags(annotations);
489
490        vec![self.mk_item(single, tags.into())]
491    }
492
493    fn lower_union(&self, union: &thrift_parser::Union) -> Enum {
494        Enum {
495            leading_comments: union.leading_comments.clone(),
496            trailing_comments: union.trailing_comments.clone(),
497            name: self.lower_ident(&union.name),
498            variants: union
499                .fields
500                .iter()
501                .map(|f| EnumVariant {
502                    leading_comments: f.leading_comments.clone(),
503                    trailing_comments: f.trailing_comments.clone(),
504                    id: Some(f.id),
505                    name: self.lower_ident(&f.name),
506                    discr: None,
507                    fields: vec![self.lower_ty(&f.ty)],
508                    tags: Default::default(),
509                    item_exts: ext::ItemExts::Thrift,
510                })
511                .collect(),
512            repr: None,
513            item_exts: ext::ItemExts::Thrift,
514        }
515    }
516
517    fn lower_ident(&self, s: &thrift_parser::Ident) -> Ident {
518        Ident::from(s.0.clone())
519    }
520
521    fn extract_tags_with_arc_wrapper(&self, annotations: &Annotations, arc_wrapper: bool) -> Tags {
522        let mut tags = self.extract_tags(annotations);
523        if arc_wrapper && !tags.contains::<RustWrapperArc>() {
524            tags.insert(RustWrapperArc(true));
525        }
526        tags
527    }
528
529    fn lower_method_relative_ty(&self, ty: &thrift_parser::Type, arc_wrapper: bool) -> ir::Ty {
530        self.lower_ty_with_tags(ty, self.extract_tags_with_arc_wrapper(&ty.1, arc_wrapper))
531    }
532
533    fn lower_ty(&self, ty: &thrift_parser::Type) -> ir::Ty {
534        let tags = self.extract_tags(&ty.1);
535        self.lower_ty_with_tags(ty, tags)
536    }
537
538    fn lower_ty_with_tags(&self, ty: &thrift_parser::Type, tags: Tags) -> ir::Ty {
539        let kind = match &ty.0 {
540            thrift_parser::Ty::String => ir::TyKind::String,
541            thrift_parser::Ty::Void => ir::TyKind::Void,
542            thrift_parser::Ty::Byte => ir::TyKind::I8,
543            thrift_parser::Ty::Bool => ir::TyKind::Bool,
544            thrift_parser::Ty::Binary => ir::TyKind::Bytes,
545            thrift_parser::Ty::I8 => ir::TyKind::I8,
546            thrift_parser::Ty::I16 => ir::TyKind::I16,
547            thrift_parser::Ty::I32 => ir::TyKind::I32,
548            thrift_parser::Ty::I64 => ir::TyKind::I64,
549            thrift_parser::Ty::Double => ir::TyKind::F64,
550            thrift_parser::Ty::Uuid => ir::TyKind::Uuid,
551            thrift_parser::Ty::List { value, .. } => ir::TyKind::Vec(self.lower_ty(value).into()),
552            thrift_parser::Ty::Set { value, .. } => ir::TyKind::Set(self.lower_ty(value).into()),
553            thrift_parser::Ty::Map { key, value, .. } => {
554                ir::TyKind::Map(self.lower_ty(key).into(), self.lower_ty(value).into())
555            }
556            thrift_parser::Ty::Path(path) => ir::TyKind::Path(self.lower_path(path)),
557        };
558
559        ir::Ty {
560            kind,
561            tags: tags.into(),
562        }
563    }
564
565    fn lower_method_arg_field(&self, f: &thrift_parser::Field, arc_wrapper: bool) -> ir::Field {
566        let mut tags = self.extract_tags(&f.annotations);
567        if arc_wrapper && !tags.contains::<RustWrapperArc>() {
568            tags.insert(RustWrapperArc(true));
569        }
570
571        ir::Field {
572            leading_comments: f.leading_comments.clone(),
573            trailing_comments: f.trailing_comments.clone(),
574            name: self.lower_ident(&f.name),
575            id: f.id,
576            ty: self.lower_method_relative_ty(&f.ty, arc_wrapper),
577            kind: match f.attribute {
578                pilota_thrift_parser::Attribute::Required => FieldKind::Required,
579                _ => FieldKind::Optional,
580            },
581            tags: tags.into(),
582            default: f.default.as_ref().map(|c| self.lower_lit(c)),
583            item_exts: ext::ItemExts::Thrift,
584        }
585    }
586
587    fn lower_field(&self, f: &thrift_parser::Field) -> ir::Field {
588        let tags = self.extract_tags(&f.annotations);
589        self.lower_field_with_tags(f, tags)
590    }
591
592    fn lower_field_with_tags(&self, f: &thrift_parser::Field, tags: Tags) -> ir::Field {
593        ir::Field {
594            leading_comments: f.leading_comments.clone(),
595            trailing_comments: f.trailing_comments.clone(),
596            name: self.lower_ident(&f.name),
597            id: f.id,
598            ty: self.lower_ty(&f.ty),
599            kind: match f.attribute {
600                pilota_thrift_parser::Attribute::Required => FieldKind::Required,
601                _ => FieldKind::Optional,
602            },
603            default: f.default.as_ref().map(|c| self.lower_lit(c)),
604            tags: tags.into(),
605            item_exts: ext::ItemExts::Thrift,
606        }
607    }
608
609    fn extract_tags(&self, annotations: &Annotations) -> Tags {
610        let mut tags = Tags::default();
611        macro_rules! with_tags {
612            ($annotation: tt -> $($key: ty)|+) => {
613                match $annotation.key.as_str()  {
614                    $(<$key>::KEY => {
615                        tags.insert(<$key>::from_str(&$annotation.value).unwrap());
616                    }),+
617                    _ => {},
618                }
619            };
620        }
621
622        annotations.iter().for_each(
623            |annotation| with_tags!(annotation -> crate::tags::PilotaName | crate::tags::RustType | crate::tags::RustWrapperArc | crate::tags::SerdeAttribute),
624        );
625
626        tags
627    }
628
629    fn lower_struct(&self, s: &thrift_parser::Struct) -> ir::Message {
630        let mut seen_ids = FxHashSet::default();
631        for field in &s.fields {
632            if !seen_ids.insert(field.id) {
633                panic!(
634                    "duplicate ID `{}` in struct `{}`",
635                    field.id,
636                    self.lower_ident(&s.name),
637                );
638            }
639        }
640        ir::Message {
641            leading_comments: s.leading_comments.clone(),
642            trailing_comments: s.trailing_comments.clone(),
643            name: self.lower_ident(&s.name),
644            fields: s.fields.iter().map(|f| self.lower_field(f)).collect(),
645            is_wrapper: false,
646            item_exts: ext::ItemExts::Thrift,
647        }
648    }
649
650    fn lower_exception(&self, e: &thrift_parser::Exception) -> ir::Message {
651        let mut seen_ids = FxHashSet::default();
652        for field in &e.fields {
653            if !seen_ids.insert(field.id) {
654                panic!(
655                    "duplicate ID `{}` in struct `{}`",
656                    field.id,
657                    self.lower_ident(&e.name),
658                );
659            }
660        }
661        ir::Message {
662            leading_comments: e.leading_comments.clone(),
663            trailing_comments: e.trailing_comments.clone(),
664            name: self.lower_ident(&e.name),
665            fields: e.fields.iter().map(|f| self.lower_field(f)).collect(),
666            is_wrapper: false,
667            item_exts: ext::ItemExts::Thrift,
668        }
669    }
670
671    fn lower_include(&mut self, s: &thrift_parser::Include) -> ir::Use {
672        // add current file's dir to include dirs
673        let current_dir = self.cur_file.as_ref().unwrap().path.parent().unwrap();
674        let mut include_dirs = vec![current_dir.to_path_buf()];
675        include_dirs.extend_from_slice(&self.include_dirs);
676
677        // search for the first existing include path
678        let target_dir = include_dirs.into_iter().find(|p| {
679            let path = p.join(&s.path.0);
680            path.exists()
681        });
682        let target_path = match target_dir {
683            Some(dir) => dir.join(&s.path.0),
684            None => {
685                error_abort(format!("{}: include file not found", s.path.0));
686            }
687        };
688
689        let ast = self
690            .db
691            .parse(target_path.normalize().unwrap().into_path_buf());
692
693        let file_id = self.lower(ast);
694
695        ir::Use { file: file_id }
696    }
697}
698
699impl Lower<Arc<thrift_parser::File>> for ThriftLower {
700    fn lower(&mut self, f: Arc<thrift_parser::File>) -> FileId {
701        if let Some(file_id) = self.file_ids_map.get(&f.path) {
702            return *file_id;
703        }
704
705        println!("cargo:rerun-if-changed={}", f.path.display());
706
707        let file_id = self.next_file_id.inc_one();
708        self.file_ids_map.insert(f.path.clone(), file_id);
709        self.file_paths.insert(file_id, f.path.clone());
710        self.file_names.insert(
711            file_id,
712            FastStr::new(f.path.file_stem().unwrap().to_string_lossy()),
713        );
714
715        let file = self.with_cur_file(f.clone(), |this| {
716            let include_files = f
717                .items
718                .iter()
719                .filter_map(|item| {
720                    if let thrift_parser::Item::Include(i) = item {
721                        Some(i)
722                    } else {
723                        None
724                    }
725                })
726                .map(|i| {
727                    (
728                        i.path
729                            .0
730                            .split('/')
731                            .next_back()
732                            .unwrap()
733                            .trim_end_matches(".thrift")
734                            .split('.')
735                            .map(FastStr::new)
736                            .map(Ident::from)
737                            .collect_vec(),
738                        this.lower_include(i),
739                    )
740                })
741                .collect::<Vec<_>>();
742
743            let includes = include_files
744                .iter()
745                .map(|(_, file)| Item {
746                    related_items: Default::default(),
747                    kind: ir::ItemKind::Use(ir::Use { file: file.file }),
748                    tags: Default::default(),
749                })
750                .collect::<Vec<_>>();
751
752            let uses = include_files
753                .into_iter()
754                .map(|(name, u)| {
755                    (
756                        Path {
757                            segments: name.into(),
758                        },
759                        u.file,
760                    )
761                })
762                .collect::<Vec<(_, FileId)>>();
763
764            let file_package = f
765                .package
766                .as_ref()
767                .map(|p| this.lower_path(p))
768                .unwrap_or_else(|| Path {
769                    segments: Arc::from([f
770                        .path
771                        .file_stem()
772                        .unwrap()
773                        .to_str()
774                        .unwrap()
775                        .replace('.', "_")
776                        .into()]),
777                });
778
779            this.packages
780                .entry(file_package.clone())
781                .or_default()
782                .push(f.path.clone());
783
784            let mut service_names: FxHashMap<String, Vec<String>> = FxHashMap::default();
785            f.items.iter().for_each(|item| {
786                if let thrift_parser::Item::Service(service) = item {
787                    service_names
788                        .entry(service.name.to_upper_camel_case())
789                        .or_default()
790                        .push(service.name.to_string());
791                }
792            });
793            this.service_name_duplicates.extend(
794                service_names
795                    .into_iter()
796                    .filter(|(_, v)| v.len() > 1)
797                    .map(|(k, _)| k),
798            );
799
800            let ret = ir::File {
801                package: file_package,
802                items: f
803                    .items
804                    .iter()
805                    .flat_map(|i| this.lower_item(i))
806                    .chain(includes)
807                    .map(Arc::from)
808                    .collect(),
809                id: file_id,
810                uses,
811                descriptor: f.descriptor.clone(),
812                extensions: FileExts::Thrift,
813                comments: f.comments.clone(),
814            };
815
816            this.service_name_duplicates.clear();
817            ret
818        });
819
820        file.id
821    }
822
823    fn finish(self) -> LowerResult {
824        self.packages.iter().for_each(|(k, v)| {
825            if v.len() > 1 {
826                println!(
827                    "cargo:warning={:?} has the same namespace `{}`, you may need to set namespace for these file \n",
828                    v,
829                    k.segments.iter().join(".")
830                )
831            }
832        });
833        LowerResult {
834            files: self.files.into_values().collect::<Vec<_>>(),
835            file_ids_map: self.file_ids_map,
836            file_paths: self.file_paths,
837            file_names: self.file_names,
838        }
839    }
840}
841
842#[derive(Default)]
843pub struct ThriftParser {
844    files: Vec<PathBuf>,
845    db: ThriftSourceDatabase,
846    include_dirs: Vec<PathBuf>,
847}
848
849impl super::Parser for ThriftParser {
850    fn input<P: AsRef<std::path::Path>>(&mut self, path: P) {
851        self.files.push(path.as_ref().into())
852    }
853
854    fn include_dirs(&mut self, dirs: Vec<PathBuf>) {
855        self.include_dirs.extend(dirs);
856    }
857
858    fn parse(self) -> super::ParseResult {
859        let db = self.db.clone();
860        let mut lower = ThriftLower::new(self.db, self.include_dirs.clone());
861        let mut input_files = Vec::default();
862
863        self.files.iter().for_each(|f| {
864            input_files.push(
865                lower.lower(
866                    db.parse(
867                        f.to_path_buf()
868                            .normalize()
869                            .unwrap_or_else(|_| panic!("normalize path failed: {}", f.display()))
870                            .into_path_buf(),
871                    ),
872                ),
873            );
874        });
875
876        let result = lower.finish();
877
878        super::ParseResult {
879            files: result.files,
880            input_files,
881            file_ids_map: result.file_ids_map,
882            file_paths: result.file_paths,
883            file_names: result.file_names,
884        }
885    }
886}