pilota_build/codegen/thrift/
mod.rs

1use std::{ops::Deref, path::PathBuf, sync::Arc};
2
3use faststr::FastStr;
4use itertools::Itertools;
5
6use super::traits::CodegenBackend;
7use crate::{
8    db::RirDatabase,
9    middle::{
10        context::{Context, Mode},
11        rir::{self, Enum, Field, Message, Method, NewType, Service},
12    },
13    rir::EnumVariant,
14    symbol::{DefId, EnumRepr, Symbol},
15    tags::thrift::EntryMessage,
16    ty::TyKind,
17};
18
19mod ty;
20
21pub use self::decode_helper::DecodeHelper;
22
23mod decode_helper;
24
25#[derive(Clone)]
26pub struct ThriftBackend {
27    cx: Context,
28}
29
30impl ThriftBackend {
31    pub fn new(cx: Context) -> Self {
32        ThriftBackend { cx }
33    }
34}
35
36impl Deref for ThriftBackend {
37    type Target = Context;
38
39    fn deref(&self) -> &Self::Target {
40        &self.cx
41    }
42}
43
44impl ThriftBackend {
45    fn codegen_encode_fields_size<'a>(
46        &'a self,
47        fields: &'a [Arc<rir::Field>],
48    ) -> impl Iterator<Item = FastStr> + 'a {
49        fields.iter().map(|f| {
50            let field_name = self.rust_name(f.did);
51            let is_optional = f.is_optional();
52            let field_id = f.id as i16;
53            let write_field = if is_optional {
54                self.codegen_field_size(&f.ty, field_id, "value".into())
55            } else {
56                self.codegen_field_size(&f.ty, field_id, format!("&self.{field_name}").into())
57            };
58
59            if is_optional {
60                format!("self.{field_name}.as_ref().map_or(0, |value| {write_field})").into()
61            } else {
62                write_field
63            }
64        })
65    }
66
67    fn codegen_encode_fields_size_with_field_mask<'a>(
68        &'a self,
69        fields: &'a [Arc<rir::Field>],
70    ) -> impl Iterator<Item = FastStr> + 'a {
71        fields.iter().map(|f| {
72            let field_name = self.rust_name(f.did);
73            let is_optional = f.is_optional();
74            let field_id = f.id as i16;
75            let write_field = if is_optional {
76                let write_field_size_with_field_mask =
77                    self.codegen_field_size_with_field_mask(&f.ty, field_id, "value".into());
78                format! {
79                    r#"{{
80                        let (field_fm, exist) = struct_fm.field({field_id});
81                        if exist {{
82                            {write_field_size_with_field_mask}
83                        }} else {{
84                            0
85                        }}
86                    }}"#
87                }
88                .into()
89            } else {
90                let write_field_size_with_field_mask = self.codegen_field_size_with_field_mask(
91                    &f.ty,
92                    field_id,
93                    format!("&self.{field_name}").into(),
94                );
95                format! {
96                    r#"{{
97                        let (field_fm, exist) = struct_fm.field({field_id});
98                        if exist {{
99                            {write_field_size_with_field_mask}
100                        }} else {{
101                            0
102                        }}
103                    }}"#
104                }
105                .into()
106            };
107
108            if is_optional {
109                format!("self.{field_name}.as_ref().map_or(0, |value| {write_field})").into()
110            } else {
111                write_field
112            }
113        })
114    }
115
116    fn codegen_encode_fields<'a>(
117        &'a self,
118        fields: &'a [Arc<rir::Field>],
119    ) -> impl Iterator<Item = FastStr> + 'a {
120        fields.iter().map(|f| {
121            let field_name = self.rust_name(f.did);
122            let field_id = f.id as i16;
123            let is_optional = f.is_optional();
124            let write_field = if is_optional {
125                self.codegen_encode_field(field_id, &f.ty, "value".into())
126            } else {
127                self.codegen_encode_field(field_id, &f.ty, format!("&self.{field_name}").into())
128            };
129
130            if is_optional {
131                format! {
132                    r#"if let Some(value) = self.{field_name}.as_ref() {{
133                        {write_field}
134                    }}"#
135                }
136                .into()
137            } else {
138                write_field
139            }
140        })
141    }
142
143    fn codegen_encode_fields_with_field_mask<'a>(
144        &'a self,
145        fields: &'a [Arc<rir::Field>],
146    ) -> impl Iterator<Item = FastStr> + 'a {
147        fields.iter().map(|f| {
148            let field_name = self.rust_name(f.did);
149            let field_id = f.id as i16;
150            let is_optional = f.is_optional();
151            let write_field = if is_optional {
152                let write_field_with_field_mask =
153                    self.codegen_encode_field_with_field_mask(field_id, &f.ty, "value".into());
154                format! {
155                    r#"let (field_fm, exist) = struct_fm.field({field_id});
156                    if exist {{
157                        {write_field_with_field_mask}
158                    }}"#
159                }
160                .into()
161            } else {
162                let write_field_with_field_mask = self.codegen_encode_field_with_field_mask(
163                    field_id,
164                    &f.ty,
165                    format!("&self.{field_name}").into(),
166                );
167                format! {
168                r#"let (field_fm, exist) = struct_fm.field({field_id});
169                if exist {{
170                    {write_field_with_field_mask}
171                }}"#
172                }
173                .into()
174            };
175
176            if is_optional {
177                format! {
178                    r#"if let Some(value) = self.{field_name}.as_ref() {{
179                        {write_field}
180                    }}"#
181                }
182                .into()
183            } else {
184                write_field
185            }
186        })
187    }
188
189    fn codegen_impl_message(
190        &self,
191        _def_id: DefId,
192        name: Symbol,
193        encode: String,
194        size: String,
195        decode: String,
196        decode_async: String,
197    ) -> String {
198        // FIXME: here we will encounter problems when the type is indirect recursive
199        // such as `struct A { a: Vec<A> }`.
200        // Just use the boxed future for now.
201        // let decode_async_fn = if self.cx().db.type_graph().is_cycled(def_id) {
202        //     format!(
203        //         r#"fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>(
204        //         protocol: &'a mut T,
205        //     ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output =
206        //       ::std::result::Result<Self, ::pilota::thrift::ThriftException>> + Send
207        // + 'a>> {{ ::std::boxed::Box::pin(async move {{ {decode_async}
208        // }})     }}"#
209        //     )
210        // } else {
211        //     format!(
212        //         r#"async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
213        //         protocol: &mut T,
214        //     ) -> ::std::result::Result<Self,::pilota::thrift::ThriftException> {{
215        //       {decode_async}
216        //     }}"#
217        //     )
218        // };
219        let decode_async_fn = format!(
220            r#"fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>(
221            __protocol: &'a mut T,
222        ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::std::result::Result<Self, ::pilota::thrift::ThriftException>> + Send + 'a>> {{
223            ::std::boxed::Box::pin(async move {{
224                {decode_async}
225            }})
226        }}"#
227        );
228        format! {r#"
229            impl ::pilota::thrift::Message for {name} {{
230                fn encode<T: ::pilota::thrift::TOutputProtocol>(
231                    &self,
232                    __protocol: &mut T,
233                ) -> ::std::result::Result<(),::pilota::thrift::ThriftException> {{
234                    #[allow(unused_imports)]
235                    use ::pilota::thrift::TOutputProtocolExt;
236                    {encode}
237                }}
238
239                fn decode<T: ::pilota::thrift::TInputProtocol>(
240                    __protocol: &mut T,
241                ) -> ::std::result::Result<Self,::pilota::thrift::ThriftException>  {{
242                    #[allow(unused_imports)]
243                    use ::pilota::{{thrift::TLengthProtocolExt, Buf}};
244                    {decode}
245                }}
246
247                {decode_async_fn}
248
249                fn size<T: ::pilota::thrift::TLengthProtocol>(&self, __protocol: &mut T) -> usize {{
250                    #[allow(unused_imports)]
251                    use ::pilota::thrift::TLengthProtocolExt;
252                    {size}
253                }}
254            }}"#}
255    }
256
257    fn codegen_impl_message_with_helper<F: Fn(&DecodeHelper) -> String>(
258        &self,
259        def_id: DefId,
260        name: Symbol,
261        encode: String,
262        size: String,
263        decode: F,
264    ) -> String {
265        let decode_stream = decode(&DecodeHelper::new(false));
266        let decode_async_stream = decode(&DecodeHelper::new(true));
267        self.codegen_impl_message(
268            def_id,
269            name,
270            encode,
271            size,
272            decode_stream,
273            decode_async_stream,
274        )
275    }
276
277    fn codegen_decode(
278        &self,
279        helper: &DecodeHelper,
280        s: &rir::Message,
281        name: Symbol,
282        keep: bool,
283        is_arg: bool,
284    ) -> String {
285        let def_fields_num = if keep && is_arg && !helper.is_async {
286            "let mut __pilota_fields_num = 0;"
287        } else {
288            ""
289        };
290
291        let mut def_fields = s
292            .fields
293            .iter()
294            .map(|f| {
295                let field_name = f.local_var_name();
296                let mut v = "None".into();
297
298                if let Some((default, is_const)) = self.cx.default_val(f) {
299                    if is_const {
300                        v = default;
301
302                        if f.is_optional() {
303                            v = format!("Some({v})").into()
304                        }
305                    }
306                };
307
308                let mut s = format!("let mut {field_name} = {v};");
309                if keep && is_arg && !helper.is_async {
310                    s.push_str("__pilota_fields_num += 1;");
311                }
312                s
313            })
314            .join("");
315
316        if keep && !helper.is_async {
317            def_fields.push_str("let mut _unknown_fields = ::pilota::BytesVec::new();");
318        }
319
320        let set_default_fields = s
321            .fields
322            .iter()
323            .filter_map(|f| {
324                let field_name = f.local_var_name();
325                match self.cx.default_val(f) { Some((default, is_const)) => {
326                    if !is_const {
327                        if f.is_optional() {
328                            Some(format! {
329                                r#"if {field_name}.is_none() {{
330                                {field_name} = Some({default});
331                            }}"#
332                            })
333                        } else {
334                            Some(format!(
335                                r#"let {field_name} = {field_name}.unwrap_or_else(|| {default});"#
336                            ))
337                        }
338                    } else {
339                        None
340                    }
341                } _ => {
342                    None
343                }}
344            })
345            .join("\n");
346
347        let read_struct_begin = helper.codegen_read_struct_begin();
348        let read_struct_end = helper.codegen_read_struct_end();
349        let read_fields = self.codegen_decode_fields(helper, &s.fields, keep, is_arg);
350
351        let required_without_default_fields = s
352            .fields
353            .iter()
354            .filter(|f| !f.is_optional() && self.default_val(f).is_none())
355            .map(|f| (self.rust_name(f.did), f.local_var_name()))
356            .collect_vec();
357
358        let verify_required_fields = required_without_default_fields
359            .iter()
360            .map(|(s, v)| {
361                format!(
362                    r#"let Some({v}) = {v} else {{
363                return ::std::result::Result::Err(
364                    ::pilota::thrift::new_protocol_exception(
365                        ::pilota::thrift::ProtocolExceptionKind::InvalidData,
366                            "field {s} is required".to_string()
367                    )
368                )
369            }}; "#
370                )
371            })
372            .join("\n");
373
374        let read_fields = if helper.is_async {
375            format! {
376                r#"async {{
377                    {read_fields}
378                    ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(())
379                }}.await"#
380            }
381        } else {
382            format! {
383                r#"(|| {{
384                    {read_fields}
385                    ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(())
386                }})()"#
387            }
388        };
389
390        let format_msg = format!("decode struct `{name}` field(#{{}}) failed");
391
392        let mut fields = s
393            .fields
394            .iter()
395            .map(|f| format!("{}: {}", self.rust_name(f.did), f.local_var_name()))
396            .join(",");
397
398        if keep {
399            if !fields.is_empty() {
400                fields.push_str(", ");
401            }
402            if !helper.is_async {
403                fields.push_str("_unknown_fields");
404            } else {
405                fields.push_str("_unknown_fields: ::pilota::BytesVec::new()");
406            }
407        }
408
409        if !s.is_wrapper && self.with_field_mask {
410            if !fields.is_empty() {
411                fields.push_str(", ");
412            }
413            fields.push_str("_field_mask: ::std::option::Option::None");
414        }
415
416        format! {
417            r#"
418            {def_fields_num}
419            {def_fields}
420
421            let mut __pilota_decoding_field_id = None;
422
423            {read_struct_begin};
424            if let ::std::result::Result::Err(mut err) = {read_fields} {{
425                if let Some(field_id) = __pilota_decoding_field_id {{
426                    err.prepend_msg(&format!("{format_msg}, caused by: ", field_id));
427                }}
428                return ::std::result::Result::Err(err);
429            }};
430            {read_struct_end};
431
432            {verify_required_fields}
433
434            {set_default_fields}
435
436            let data = Self {{
437                {fields}
438            }};
439            ::std::result::Result::Ok(data)
440            "#
441        }
442    }
443
444    #[inline]
445    fn field_is_box(&self, f: &Field) -> bool {
446        self.with_adjust(f.did, |adj| match adj {
447            Some(a) => a.boxed(),
448            None => false,
449        })
450    }
451
452    fn codegen_entry_enum(&self, _def_id: DefId, _stream: &mut str, _e: &rir::Enum) {
453        // TODO
454    }
455
456    fn codegen_decode_fields<'a>(
457        &'a self,
458        helper: &DecodeHelper,
459        fields: &'a [Arc<Field>],
460        keep: bool,
461        is_arg: bool,
462    ) -> String {
463        let record_ptr = if keep && !helper.is_async {
464            r#"let mut __pilota_offset = 0;
465            let __pilota_begin_ptr = __protocol.buf().chunk().as_ptr();"#
466        } else {
467            ""
468        };
469        let read_field_begin = helper.codegen_read_field_begin();
470        let field_begin_len = helper.codegen_field_begin_len(keep);
471        let match_fields = fields
472            .iter()
473            .map(|f| {
474                let field_ident = f.local_var_name();
475                let ttype = self.ttype(&f.ty);
476                let mut read_field = self.codegen_decode_ty(helper, &f.ty);
477                let field_id = f.id as i16;
478                if self.field_is_box(f) {
479                    read_field = format!("::std::boxed::Box::new({read_field})").into();
480                };
481
482                if f.is_optional() || {
483                    match self.cx.default_val(f) {
484                        Some((_, is_const)) => !is_const,
485                        _ => true,
486                    }
487                } {
488                    read_field = format!("Some({read_field})").into();
489                }
490
491                let fields_num = if keep && !helper.is_async && is_arg {
492                    "__pilota_fields_num -= 1;"
493                } else {
494                    ""
495                };
496
497                format!(
498                    r#"Some({field_id}) if field_ident.field_type == {ttype}  => {{
499                    {field_ident} = {read_field};
500                    {fields_num}
501                }},"#
502                )
503            })
504            .join("");
505        let mut skip_ttype = helper.codegen_skip_ttype("field_ident.field_type".into());
506        if keep && !helper.is_async {
507            skip_ttype = format!("__pilota_offset += {skip_ttype}")
508        }
509
510        let write_unknown_field = if keep && !helper.is_async {
511            "_unknown_fields.push_back(__protocol.get_bytes(Some(__pilota_begin_ptr), __pilota_offset)?);"
512        } else {
513            ""
514        };
515
516        let read_field_end = helper.codegen_read_field_end();
517        let field_end_len = helper.codegen_field_end_len(keep);
518        let field_stop_len = helper.codegen_field_stop_len(keep);
519
520        let skip_all = if keep && !helper.is_async && is_arg {
521            "if __pilota_fields_num == 0 {
522                let __pilota_remaining = __protocol.buf().remaining();
523                _unknown_fields.push_back(__protocol.get_bytes(None, __pilota_remaining - 2)?);
524                break;
525            }"
526        } else {
527            ""
528        };
529
530        format! {
531            r#"loop {{
532                {skip_all}
533                {record_ptr}
534                let field_ident = {read_field_begin};
535                if field_ident.field_type == ::pilota::thrift::TType::Stop {{
536                    {field_stop_len}
537                    break;
538                }} else {{
539                    {field_begin_len}
540                }}
541                __pilota_decoding_field_id = field_ident.id;
542                match field_ident.id {{
543                    {match_fields}
544                    _ => {{
545                        {skip_ttype};
546                        {write_unknown_field}
547                    }},
548                }}
549
550                {read_field_end};
551                {field_end_len}
552
553            }};"#
554        }
555    }
556}
557
558impl CodegenBackend for ThriftBackend {
559    const PROTOCOL: &'static str = "thrift";
560
561    fn codegen_struct_impl(&self, def_id: DefId, stream: &mut String, s: &Message) {
562        let filename = self
563            .cx
564            .file_paths()
565            .get(&self.cx.node(def_id).unwrap().file_id)
566            .unwrap()
567            .file_stem()
568            .unwrap()
569            .to_string_lossy()
570            .replace(".", "_");
571        let filename_lower = filename.to_lowercase();
572        let keep = self.keep_unknown_fields.contains(&def_id);
573        let name = self.cx.rust_name(def_id);
574        let mut encode_fields = self.codegen_encode_fields(&s.fields).join("");
575        let mut encode_fields_with_field_mask = self
576            .codegen_encode_fields_with_field_mask(&s.fields)
577            .join("");
578        if keep {
579            encode_fields.push_str(
580                r#"for bytes in self._unknown_fields.list.iter() {
581                                __protocol.write_bytes_without_len(bytes.clone());
582                            }"#,
583            );
584            encode_fields_with_field_mask.push_str(
585                r#"for bytes in self._unknown_fields.list.iter() {
586                                __protocol.write_bytes_without_len(bytes.clone());
587                            }"#,
588            );
589        }
590        let mut encode_fields_size = self
591            .codegen_encode_fields_size(&s.fields)
592            .map(|s| format!("{s} +"))
593            .join("");
594        let mut encode_fields_size_with_field_mask = self
595            .codegen_encode_fields_size_with_field_mask(&s.fields)
596            .map(|s| format!("{s} +"))
597            .join("");
598
599        if keep {
600            encode_fields_size.push_str("self._unknown_fields.size() +");
601            encode_fields_size_with_field_mask.push_str("self._unknown_fields.size() +");
602        }
603
604        if s.is_wrapper || !self.with_field_mask {
605            stream.push_str(&self.codegen_impl_message_with_helper(
606                def_id,
607                name.clone(),
608                format! {
609                    r#"let struct_ident =::pilota::thrift::TStructIdentifier {{
610                        name: "{name}",
611                    }};
612    
613                    __protocol.write_struct_begin(&struct_ident)?;
614                    {encode_fields}
615                    __protocol.write_field_stop()?;
616                    __protocol.write_struct_end()?;
617                    ::std::result::Result::Ok(())
618                    "#
619                },
620                format! {
621                    r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
622                        name: "{name}",
623                    }}) + {encode_fields_size} __protocol.field_stop_len() + __protocol.struct_end_len()"#
624                },
625                |helper| self.codegen_decode(helper, s, name.clone(), keep, self.is_arg(def_id)),
626            ));
627
628            if !s.is_wrapper && self.with_descriptor {
629                stream.push_str(&format! {
630                    r#"impl {name} {{
631                        pub fn get_descriptor() -> &'static ::pilota_thrift_reflect::thrift_reflection::StructDescriptor {{
632                            let file_descriptor = get_file_descriptor_{filename_lower}();
633                            file_descriptor.find_struct_by_name("{name}").unwrap()
634                        }}
635                    }}"#
636                });
637            }
638            return;
639        }
640
641        if self.with_field_mask {
642            stream.push_str(&self.codegen_impl_message_with_helper(
643                def_id,
644                name.clone(),
645                format! {
646                    r#"if let Some(struct_fm) = self._field_mask.as_ref() {{
647                        if !struct_fm.exist() {{
648                            ::std::result::Result::Ok(())
649                        }} else {{
650                            let struct_ident =::pilota::thrift::TStructIdentifier {{
651                                    name: "{name}",
652                                }};
653                                __protocol.write_struct_begin(&struct_ident)?;
654                                {encode_fields_with_field_mask}
655                                __protocol.write_field_stop()?;
656                                __protocol.write_struct_end()?;
657                                ::std::result::Result::Ok(())
658                        }}
659                    }} else {{
660                        let struct_ident =::pilota::thrift::TStructIdentifier {{
661                            name: "{name}",
662                        }};
663        
664                        __protocol.write_struct_begin(&struct_ident)?;
665                        {encode_fields}
666                        __protocol.write_field_stop()?;
667                        __protocol.write_struct_end()?;
668                        ::std::result::Result::Ok(())
669                    }}"#
670                },
671                format! {
672                    r#"if let Some(struct_fm) = self._field_mask.as_ref() {{
673                        if !struct_fm.exist() {{
674                            0
675                        }} else {{
676                            __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
677                                name: "{name}",
678                            }}) + {encode_fields_size_with_field_mask} __protocol.field_stop_len() + __protocol.struct_end_len()
679                        }}
680                    }} else {{
681                        __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
682                            name: "{name}",
683                        }}) + {encode_fields_size} __protocol.field_stop_len() + __protocol.struct_end_len()
684                    }}"#
685                },
686                |helper| self.codegen_decode(helper, s, name.clone(), keep, self.is_arg(def_id)),
687            ));
688        }
689
690        let set_inner_field_mask = s
691            .fields
692            .iter()
693            .filter(|f| self.need_field_mask(&f.ty))
694            .map(|f| {
695                let field_name = self.rust_name(f.did);
696                let field_id = f.id as i16;
697                let is_optional = f.is_optional();
698                let field_mask = if is_optional {
699                    self.codegen_struct_field_mask(
700                        field_id,
701                        &f.ty,
702                        "value".into(),
703                        "field_mask".into(),
704                    )
705                } else {
706                    self.codegen_struct_field_mask(
707                        field_id,
708                        &f.ty,
709                        format!("self.{field_name}").into(),
710                        "field_mask".into(),
711                    )
712                };
713
714                if is_optional {
715                    format! {
716                        r#"if let Some(value) = &mut self.{field_name} {{
717                        {field_mask}
718                    }}"#
719                    }
720                } else {
721                    field_mask.to_string()
722                }
723            })
724            .join("");
725
726        stream.push_str(&format! {
727            r#"impl {name} {{
728                pub fn get_descriptor() -> &'static ::pilota_thrift_reflect::thrift_reflection::StructDescriptor {{
729                    let file_descriptor = get_file_descriptor_{filename_lower}();
730                    file_descriptor.find_struct_by_name("{name}").unwrap()
731                }}
732
733                pub fn set_field_mask(&mut self, field_mask: ::pilota_thrift_fieldmask::FieldMask) {{
734                    self._field_mask = Some(field_mask.clone());
735                    {set_inner_field_mask}
736                }}
737            }}"#
738        });
739    }
740
741    fn codegen_service_impl(&self, _def_id: DefId, _stream: &mut String, _s: &Service) {}
742
743    fn codegen_service_method(&self, _service_def_id: DefId, _m: &Method) -> String {
744        Default::default()
745    }
746
747    fn codegen_enum_impl(&self, def_id: DefId, stream: &mut String, e: &Enum) {
748        let keep = self.keep_unknown_fields.contains(&def_id);
749        let name = self.rust_name(def_id);
750        let is_entry_message = self.node_contains_tag::<EntryMessage>(def_id);
751        let v = "self.inner()";
752        match e.repr {
753            Some(EnumRepr::I32) => stream.push_str(&self.codegen_impl_message_with_helper(
754                def_id,
755                name.clone(),
756                format! {
757                    r#"__protocol.write_i32({v})?;
758                    ::std::result::Result::Ok(())
759                    "#
760                },
761                format!("__protocol.i32_len({v})"),
762                |helper| {
763                    let read_i32 = helper.codegen_read_i32();
764                    let err_msg_tmpl = format!("invalid enum value for {name}, value: {{}}");
765                    format! {
766                        r#"let value = {read_i32};
767                        ::std::result::Result::Ok(::std::convert::TryFrom::try_from(value).map_err(|err|
768                            ::pilota::thrift::new_protocol_exception(
769                                ::pilota::thrift::ProtocolExceptionKind::InvalidData,
770                                format!("{err_msg_tmpl}", value)
771                            ))?)"#
772                    }
773                },
774            )),
775            None if is_entry_message => self.codegen_entry_enum(def_id, stream, e),
776            None => {
777                let name = self.rust_name(def_id);
778                let mut encode_variants = e
779                    .variants
780                    .iter()
781                    .map(|v| {
782                        let variant_name = self.rust_name(v.did);
783                        assert_eq!(v.fields.len(), 1);
784                        let variant_id = v.id.unwrap() as i16;
785                        let encode =
786                            self.codegen_encode_field(variant_id, &v.fields[0], "value".into());
787                        format! {
788                            r#"{name}::{variant_name}(value) => {{
789                            {encode}
790                        }},"#
791                        }
792                    })
793                    .join("");
794                if keep {
795                    encode_variants.push_str(&format! {
796                        "{name}::_UnknownFields(value) => {{
797                                        for bytes in value.list.iter() {{
798                                            __protocol.write_bytes_without_len(bytes.clone());
799                                        }}
800                                    }}",
801                    });
802                }
803
804                if e.variants.is_empty() {
805                    encode_variants.push_str("_ => {},");
806                }
807
808                let mut variants_size = e
809                    .variants
810                    .iter()
811                    .map(|v| {
812                        let variant_name = self.rust_name(v.did);
813                        let variant_id = v.id.unwrap() as i16;
814                        let size =
815                            self.codegen_field_size(&v.fields[0], variant_id, "value".into());
816
817                        format! {
818                            r#"{name}::{variant_name}(value) => {{
819                            {size}
820                        }},"#
821                        }
822                    })
823                    .join("");
824                if keep {
825                    variants_size.push_str(&format! {
826                        "{name}::_UnknownFields(value) => {{
827                                value.size()
828                            }}",
829                    })
830                }
831
832                if e.variants.is_empty() {
833                    variants_size.push_str("_ => 0,");
834                }
835
836                let variant_is_void = |v: &EnumVariant| {
837                    &*v.name.sym == "Ok" && v.fields.len() == 1 && v.fields[0].kind == TyKind::Void
838                };
839
840                stream.push_str(&self.codegen_impl_message_with_helper(def_id,
841                    name.clone(),
842                    format! {
843                        r#"__protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {{
844                            name: "{name}",
845                        }})?;
846                        match self {{
847                            {encode_variants}
848                        }}
849                        __protocol.write_field_stop()?;
850                        __protocol.write_struct_end()?;
851                        ::std::result::Result::Ok(())"#
852                    },
853                        format! {
854                            r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
855                                name: "{name}",
856                            }}) + match self {{
857                                {variants_size}
858                            }} +  __protocol.field_stop_len() + __protocol.struct_end_len()"#
859                        },
860                    |helper| {
861                        let record_ptr = if keep && !helper.is_async {
862                            r#"let mut __pilota_offset = 0;
863                            let __pilota_begin_ptr = __protocol.buf().chunk().as_ptr();"#
864                        } else {
865                            ""
866                        };
867                        let read_struct_begin = helper.codegen_read_struct_begin();
868                        let read_field_begin = helper.codegen_read_field_begin();
869                        let field_begin_len = helper.codegen_field_begin_len(keep);
870                        let read_field_end = helper.codegen_read_field_end();
871                        let field_stop_len = helper.codegen_field_stop_len(keep);
872                        let read_struct_end = helper.codegen_read_struct_end();
873                        let mut skip = helper.codegen_skip_ttype("field_ident.field_type".into());
874                        if keep && !helper.is_async {
875                            skip = format!("__pilota_offset += {skip}")
876                        }
877                        let fields = e
878                            .variants
879                            .iter()
880                            .flat_map(|v| {
881                                if variant_is_void(v) {
882                                    None
883                                } else {
884                                    let variant_name = self.cx.rust_name(v.did);
885                                    assert_eq!(v.fields.len(), 1);
886                                    let variant_id = v.id.unwrap() as i16;
887                                    let decode = self.codegen_decode_ty(helper, &v.fields[0]);
888                                    let decode_len =  if helper.is_async {
889                                        Default::default()
890                                    } else {
891                                        let size = self.codegen_ty_size(
892                                            &v.fields[0],
893                                            "&field_ident".into()
894                                        );
895                                        if keep {
896                                            format!(
897                                                "__pilota_offset += {size};",
898                                            )
899                                        } else {
900                                            format!(
901                                                "{size};",
902                                            )
903                                        }
904                                    };
905                                    Some(format! {
906                                        r#"Some({variant_id}) => {{
907                                    if ret.is_none() {{
908                                        let field_ident = {decode};
909                                        {decode_len}
910                                        ret = Some({name}::{variant_name}(field_ident));
911                                    }} else {{
912                                        return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
913                                            ::pilota::thrift::ProtocolExceptionKind::InvalidData,
914                                            "received multiple fields for union from remote Message"
915                                        ));
916                                    }}
917                                }},"#
918                                    })
919                                }
920                            })
921                            .join("");
922                        let write_unknown_field = if keep && !helper.is_async {
923                            format!(
924                                r#"if ret.is_none() {{
925                                unsafe {{
926                                    let mut __pilota_linked_bytes = ::pilota::BytesVec::new();
927                                    __pilota_linked_bytes.push_back(__protocol.get_bytes(Some(__pilota_begin_ptr), __pilota_offset)?);
928                                    ret = Some({name}::_UnknownFields(__pilota_linked_bytes));
929                                }}
930                            }} else {{
931                                return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
932                                    ::pilota::thrift::ProtocolExceptionKind::InvalidData,
933                                    "received multiple fields for union from remote Message"
934                                ));
935                            }}"#
936                            )
937                        } else {
938                            Default::default()
939                        };
940
941                        let handle_none_ret: FastStr =
942                            if e.variants.first().filter(|v| variant_is_void(v)).is_some() {
943                                format!("::std::result::Result::Ok({name}::Ok(()))").into()
944                            } else {
945                                r#"::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
946                                    ::pilota::thrift::ProtocolExceptionKind::InvalidData,
947                                    "received empty union from remote Message")
948                                )"#.into()
949                            };
950
951                        format! {
952                            r#"let mut ret = None;
953                            {read_struct_begin};
954                            loop {{
955                                {record_ptr}
956                                let field_ident = {read_field_begin};
957                                if field_ident.field_type == ::pilota::thrift::TType::Stop {{
958                                    {field_stop_len}
959                                    break;
960                                }} else {{
961                                    {field_begin_len}
962                                }}
963                                match field_ident.id {{
964                                    {fields}
965                                    _ => {{
966                                        {skip};
967                                        {write_unknown_field}
968                                    }},
969                                }}
970                            }}
971                            {read_field_end};
972                            {read_struct_end};
973                            if let Some(ret) = ret {{
974                                ::std::result::Result::Ok(ret)
975                            }} else {{
976                                {handle_none_ret}
977                            }}"#
978                        }
979                    },
980                ))
981            }
982            #[allow(unreachable_patterns)]
983            _ => {}
984        }
985    }
986
987    fn codegen_newtype_impl(&self, def_id: DefId, stream: &mut String, t: &NewType) {
988        let name = self.rust_name(def_id);
989        let encode = self.codegen_encode_ty(&t.ty, "(&**self)".into());
990        let encode_size = self.codegen_ty_size(&t.ty, "&**self".into());
991
992        if !self.with_field_mask {
993            stream.push_str(&self.codegen_impl_message_with_helper(
994                def_id,
995                name.clone(),
996                format! {
997                    r#"{encode}
998                    ::std::result::Result::Ok(())"#
999                },
1000                format!("{encode_size}"),
1001                |helper| {
1002                    let decode = self.codegen_decode_ty(helper, &t.ty);
1003                    format!("::std::result::Result::Ok({name}({decode}))")
1004                },
1005            ));
1006            return;
1007        }
1008
1009        match &t.ty.kind {
1010            TyKind::Path(p) if !self.is_enum(p.did) => {
1011                let inner_name = self.rust_name(p.did);
1012                let encode_with_field_mask =
1013                    self.codegen_encode_ty_with_field_mask(&t.ty, "(&**self)".into());
1014                let encode_size_with_field_mask =
1015                    self.codegen_ty_size_with_field_mask(&t.ty, "&**self".into());
1016
1017                stream.push_str(&self.codegen_impl_message_with_helper(
1018                    def_id,
1019                    name.clone(),
1020                    format! {
1021                        r#"{encode_with_field_mask}
1022                        ::std::result::Result::Ok(())"#
1023                    },
1024                    encode_size_with_field_mask.into(),
1025                    |helper| {
1026                        let decode = self.codegen_decode_ty(helper, &t.ty);
1027                        format!("::std::result::Result::Ok({name}({decode}))")
1028                    },
1029                ));
1030
1031                stream.push_str(&format!(
1032                        r#"impl {name} {{
1033                            pub fn get_descriptor() -> &'static ::pilota_thrift_reflect::thrift_reflection::StructDescriptor {{
1034                                {inner_name}::get_descriptor()
1035                            }}
1036                            pub fn set_field_mask(&mut self, field_mask: ::pilota_thrift_fieldmask::FieldMask) {{
1037                                self.0.set_field_mask(field_mask);
1038                            }}
1039                    }}"#
1040                    ));
1041            }
1042            _ => {
1043                stream.push_str(&self.codegen_impl_message_with_helper(
1044                    def_id,
1045                    name.clone(),
1046                    format! {
1047                        r#"{encode}
1048                        ::std::result::Result::Ok(())"#
1049                    },
1050                    format!("{encode_size}"),
1051                    |helper| {
1052                        let decode = self.codegen_decode_ty(helper, &t.ty);
1053                        format!("::std::result::Result::Ok({name}({decode}))")
1054                    },
1055                ));
1056
1057                stream.push_str(&format!(
1058                        r#"impl {name} {{
1059                            pub fn set_field_mask(&mut self, _: ::pilota_thrift_fieldmask::FieldMask) {{
1060                            }}
1061                        }}"#
1062                    ));
1063            }
1064        }
1065    }
1066
1067    fn cx(&self) -> &Context {
1068        &self.cx
1069    }
1070
1071    fn codegen_file_descriptor(&self, stream: &mut String, f: &rir::File, has_direct: bool) {
1072        let filename = self
1073            .file_paths()
1074            .get(&f.file_id)
1075            .unwrap()
1076            .file_stem()
1077            .unwrap()
1078            .to_string_lossy()
1079            .replace(".", "_");
1080        let filename_upper = filename.to_uppercase();
1081        let filename_lower = filename.to_lowercase();
1082        if has_direct {
1083            let descriptor = &f.descriptor;
1084            let super_mod = match &*self.mode {
1085                Mode::Workspace(_) => "crate::".to_string(),
1086                Mode::SingleFile { .. } => "super::".repeat(f.package.len()),
1087            };
1088            stream.push_str(&format!(
1089            r#"
1090static FILE_DESCRIPTOR_BYTES_{filename_upper}: ::pilota::Bytes = ::pilota::Bytes::from_static({descriptor:?});
1091static FILE_DESCRIPTOR_{filename_upper}: ::std::sync::LazyLock<::pilota_thrift_reflect::thrift_reflection::FileDescriptor> = ::std::sync::LazyLock::new(|| {{
1092    let descriptor = ::pilota_thrift_reflect::thrift_reflection::FileDescriptor::deserialize(FILE_DESCRIPTOR_BYTES_{filename_upper}.clone())
1093        .expect("Failed to decode file descriptor");
1094    ::pilota_thrift_reflect::service::Register::register(
1095        descriptor.filepath.clone(),
1096        descriptor.clone(),
1097    );
1098    
1099    for (key, include) in descriptor.includes.iter() {{
1100        let path = include.as_str();
1101        if ::pilota_thrift_reflect::service::Register::contains(path) {{
1102            continue;
1103        }}
1104        let include_file_descriptor =
1105            {super_mod}find_mod_file_descriptor(path).expect("include file descriptor must exist");
1106        ::pilota_thrift_reflect::service::Register::register(
1107            include_file_descriptor.filepath.clone(),
1108            include_file_descriptor.clone(),
1109        );
1110    }}
1111    descriptor
1112}});
1113pub fn get_file_descriptor_{filename_lower}() -> &'static ::pilota_thrift_reflect::thrift_reflection::FileDescriptor {{
1114    &*FILE_DESCRIPTOR_{filename_upper}
1115}}"#));
1116        } else {
1117            match &*self.mode {
1118                Mode::Workspace(_) => {
1119                    // 使用 pub use 的形式,从 common crate 中引入
1120                    let mod_prefix = f.package.iter().join("::");
1121                    let common_crate_name = &self.common_crate_name;
1122                    stream.push_str(&format!(
1123                        r#"
1124                        pub use ::{common_crate_name}::{mod_prefix}::get_file_descriptor_{filename_lower};
1125                        "#
1126                    ));
1127                }
1128                Mode::SingleFile { .. } => {}
1129            };
1130        }
1131    }
1132
1133    fn codegen_register_mod_file_descriptor(
1134        &self,
1135        stream: &mut String,
1136        mods: &[(Arc<[FastStr]>, Arc<PathBuf>)],
1137    ) {
1138        stream.push_str(r#"
1139                pub fn find_mod_file_descriptor(path: &str) -> Option<&'static ::pilota_thrift_reflect::thrift_reflection::FileDescriptor> {
1140                    match path {
1141            "#);
1142
1143        for (p, path) in mods {
1144            let filename = path
1145                .file_stem()
1146                .expect("must_exist")
1147                .to_string_lossy()
1148                .to_lowercase()
1149                .replace(".", "_");
1150            let path = path.display();
1151            stream.push_str(&format!(
1152                r#"
1153                r"{path}" => Some(
1154            "#
1155            ));
1156
1157            let prefix = p.iter().map(|s| Symbol::from(s.clone())).join("::");
1158            if !prefix.is_empty() {
1159                stream.push_str(&format!(r#"{prefix}::"#));
1160            }
1161            stream.push_str(&format!("get_file_descriptor_{filename}()),"));
1162        }
1163
1164        stream.push_str(
1165            r#"
1166                _ => None,
1167            }
1168        }
1169        "#,
1170        );
1171    }
1172}