1use std::{ops::Deref, path::PathBuf, sync::Arc};
2
3use faststr::FastStr;
4use itertools::Itertools;
5
6use super::traits::CodegenBackend;
7use crate::{
8 db::RirDatabase,
9 middle::{
10 context::{Context, Mode},
11 rir::{self, Enum, Field, Message, Method, NewType, Service},
12 },
13 rir::EnumVariant,
14 symbol::{DefId, EnumRepr, Symbol},
15 tags::thrift::EntryMessage,
16 ty::TyKind,
17};
18
19mod ty;
20
21pub use self::decode_helper::DecodeHelper;
22
23mod decode_helper;
24
25#[derive(Clone)]
26pub struct ThriftBackend {
27 cx: Context,
28}
29
30impl ThriftBackend {
31 pub fn new(cx: Context) -> Self {
32 ThriftBackend { cx }
33 }
34}
35
36impl Deref for ThriftBackend {
37 type Target = Context;
38
39 fn deref(&self) -> &Self::Target {
40 &self.cx
41 }
42}
43
44impl ThriftBackend {
45 fn codegen_encode_fields_size<'a>(
46 &'a self,
47 fields: &'a [Arc<rir::Field>],
48 ) -> impl Iterator<Item = FastStr> + 'a {
49 fields.iter().map(|f| {
50 let field_name = self.rust_name(f.did);
51 let is_optional = f.is_optional();
52 let field_id = f.id as i16;
53 let write_field = if is_optional {
54 self.codegen_field_size(&f.ty, field_id, "value".into())
55 } else {
56 self.codegen_field_size(&f.ty, field_id, format!("&self.{field_name}").into())
57 };
58
59 if is_optional {
60 format!("self.{field_name}.as_ref().map_or(0, |value| {write_field})").into()
61 } else {
62 write_field
63 }
64 })
65 }
66
67 fn codegen_encode_fields_size_with_field_mask<'a>(
68 &'a self,
69 fields: &'a [Arc<rir::Field>],
70 ) -> impl Iterator<Item = FastStr> + 'a {
71 fields.iter().map(|f| {
72 let field_name = self.rust_name(f.did);
73 let is_optional = f.is_optional();
74 let field_id = f.id as i16;
75 let write_field = if is_optional {
76 let write_field_size_with_field_mask =
77 self.codegen_field_size_with_field_mask(&f.ty, field_id, "value".into());
78 format! {
79 r#"{{
80 let (field_fm, exist) = struct_fm.field({field_id});
81 if exist {{
82 {write_field_size_with_field_mask}
83 }} else {{
84 0
85 }}
86 }}"#
87 }
88 .into()
89 } else {
90 let write_field_size_with_field_mask = self.codegen_field_size_with_field_mask(
91 &f.ty,
92 field_id,
93 format!("&self.{field_name}").into(),
94 );
95 format! {
96 r#"{{
97 let (field_fm, exist) = struct_fm.field({field_id});
98 if exist {{
99 {write_field_size_with_field_mask}
100 }} else {{
101 0
102 }}
103 }}"#
104 }
105 .into()
106 };
107
108 if is_optional {
109 format!("self.{field_name}.as_ref().map_or(0, |value| {write_field})").into()
110 } else {
111 write_field
112 }
113 })
114 }
115
116 fn codegen_encode_fields<'a>(
117 &'a self,
118 fields: &'a [Arc<rir::Field>],
119 ) -> impl Iterator<Item = FastStr> + 'a {
120 fields.iter().map(|f| {
121 let field_name = self.rust_name(f.did);
122 let field_id = f.id as i16;
123 let is_optional = f.is_optional();
124 let write_field = if is_optional {
125 self.codegen_encode_field(field_id, &f.ty, "value".into())
126 } else {
127 self.codegen_encode_field(field_id, &f.ty, format!("&self.{field_name}").into())
128 };
129
130 if is_optional {
131 format! {
132 r#"if let Some(value) = self.{field_name}.as_ref() {{
133 {write_field}
134 }}"#
135 }
136 .into()
137 } else {
138 write_field
139 }
140 })
141 }
142
143 fn codegen_encode_fields_with_field_mask<'a>(
144 &'a self,
145 fields: &'a [Arc<rir::Field>],
146 ) -> impl Iterator<Item = FastStr> + 'a {
147 fields.iter().map(|f| {
148 let field_name = self.rust_name(f.did);
149 let field_id = f.id as i16;
150 let is_optional = f.is_optional();
151 let write_field = if is_optional {
152 let write_field_with_field_mask =
153 self.codegen_encode_field_with_field_mask(field_id, &f.ty, "value".into());
154 format! {
155 r#"let (field_fm, exist) = struct_fm.field({field_id});
156 if exist {{
157 {write_field_with_field_mask}
158 }}"#
159 }
160 .into()
161 } else {
162 let write_field_with_field_mask = self.codegen_encode_field_with_field_mask(
163 field_id,
164 &f.ty,
165 format!("&self.{field_name}").into(),
166 );
167 format! {
168 r#"let (field_fm, exist) = struct_fm.field({field_id});
169 if exist {{
170 {write_field_with_field_mask}
171 }}"#
172 }
173 .into()
174 };
175
176 if is_optional {
177 format! {
178 r#"if let Some(value) = self.{field_name}.as_ref() {{
179 {write_field}
180 }}"#
181 }
182 .into()
183 } else {
184 write_field
185 }
186 })
187 }
188
189 fn codegen_impl_message(
190 &self,
191 _def_id: DefId,
192 name: Symbol,
193 encode: String,
194 size: String,
195 decode: String,
196 decode_async: String,
197 ) -> String {
198 let decode_async_fn = format!(
220 r#"fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>(
221 __protocol: &'a mut T,
222 ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::std::result::Result<Self, ::pilota::thrift::ThriftException>> + Send + 'a>> {{
223 ::std::boxed::Box::pin(async move {{
224 {decode_async}
225 }})
226 }}"#
227 );
228 format! {r#"
229 impl ::pilota::thrift::Message for {name} {{
230 fn encode<T: ::pilota::thrift::TOutputProtocol>(
231 &self,
232 __protocol: &mut T,
233 ) -> ::std::result::Result<(),::pilota::thrift::ThriftException> {{
234 #[allow(unused_imports)]
235 use ::pilota::thrift::TOutputProtocolExt;
236 {encode}
237 }}
238
239 fn decode<T: ::pilota::thrift::TInputProtocol>(
240 __protocol: &mut T,
241 ) -> ::std::result::Result<Self,::pilota::thrift::ThriftException> {{
242 #[allow(unused_imports)]
243 use ::pilota::{{thrift::TLengthProtocolExt, Buf}};
244 {decode}
245 }}
246
247 {decode_async_fn}
248
249 fn size<T: ::pilota::thrift::TLengthProtocol>(&self, __protocol: &mut T) -> usize {{
250 #[allow(unused_imports)]
251 use ::pilota::thrift::TLengthProtocolExt;
252 {size}
253 }}
254 }}"#}
255 }
256
257 fn codegen_impl_message_with_helper<F: Fn(&DecodeHelper) -> String>(
258 &self,
259 def_id: DefId,
260 name: Symbol,
261 encode: String,
262 size: String,
263 decode: F,
264 ) -> String {
265 let decode_stream = decode(&DecodeHelper::new(false));
266 let decode_async_stream = decode(&DecodeHelper::new(true));
267 self.codegen_impl_message(
268 def_id,
269 name,
270 encode,
271 size,
272 decode_stream,
273 decode_async_stream,
274 )
275 }
276
277 fn codegen_decode(
278 &self,
279 helper: &DecodeHelper,
280 s: &rir::Message,
281 name: Symbol,
282 keep: bool,
283 is_arg: bool,
284 ) -> String {
285 let def_fields_num = if keep && is_arg && !helper.is_async {
286 "let mut __pilota_fields_num = 0;"
287 } else {
288 ""
289 };
290
291 let mut def_fields = s
292 .fields
293 .iter()
294 .map(|f| {
295 let field_name = f.local_var_name();
296 let mut v = "None".into();
297
298 if let Some((default, is_const)) = self.cx.default_val(f) {
299 if is_const {
300 v = default;
301
302 if f.is_optional() {
303 v = format!("Some({v})").into()
304 }
305 }
306 };
307
308 let mut s = format!("let mut {field_name} = {v};");
309 if keep && is_arg && !helper.is_async {
310 s.push_str("__pilota_fields_num += 1;");
311 }
312 s
313 })
314 .join("");
315
316 if keep && !helper.is_async {
317 def_fields.push_str("let mut _unknown_fields = ::pilota::BytesVec::new();");
318 }
319
320 let set_default_fields = s
321 .fields
322 .iter()
323 .filter_map(|f| {
324 let field_name = f.local_var_name();
325 match self.cx.default_val(f) { Some((default, is_const)) => {
326 if !is_const {
327 if f.is_optional() {
328 Some(format! {
329 r#"if {field_name}.is_none() {{
330 {field_name} = Some({default});
331 }}"#
332 })
333 } else {
334 Some(format!(
335 r#"let {field_name} = {field_name}.unwrap_or_else(|| {default});"#
336 ))
337 }
338 } else {
339 None
340 }
341 } _ => {
342 None
343 }}
344 })
345 .join("\n");
346
347 let read_struct_begin = helper.codegen_read_struct_begin();
348 let read_struct_end = helper.codegen_read_struct_end();
349 let read_fields = self.codegen_decode_fields(helper, &s.fields, keep, is_arg);
350
351 let required_without_default_fields = s
352 .fields
353 .iter()
354 .filter(|f| !f.is_optional() && self.default_val(f).is_none())
355 .map(|f| (self.rust_name(f.did), f.local_var_name()))
356 .collect_vec();
357
358 let verify_required_fields = required_without_default_fields
359 .iter()
360 .map(|(s, v)| {
361 format!(
362 r#"let Some({v}) = {v} else {{
363 return ::std::result::Result::Err(
364 ::pilota::thrift::new_protocol_exception(
365 ::pilota::thrift::ProtocolExceptionKind::InvalidData,
366 "field {s} is required".to_string()
367 )
368 )
369 }}; "#
370 )
371 })
372 .join("\n");
373
374 let read_fields = if helper.is_async {
375 format! {
376 r#"async {{
377 {read_fields}
378 ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(())
379 }}.await"#
380 }
381 } else {
382 format! {
383 r#"(|| {{
384 {read_fields}
385 ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(())
386 }})()"#
387 }
388 };
389
390 let format_msg = format!("decode struct `{name}` field(#{{}}) failed");
391
392 let mut fields = s
393 .fields
394 .iter()
395 .map(|f| format!("{}: {}", self.rust_name(f.did), f.local_var_name()))
396 .join(",");
397
398 if keep {
399 if !fields.is_empty() {
400 fields.push_str(", ");
401 }
402 if !helper.is_async {
403 fields.push_str("_unknown_fields");
404 } else {
405 fields.push_str("_unknown_fields: ::pilota::BytesVec::new()");
406 }
407 }
408
409 if !s.is_wrapper && self.with_field_mask {
410 if !fields.is_empty() {
411 fields.push_str(", ");
412 }
413 fields.push_str("_field_mask: ::std::option::Option::None");
414 }
415
416 format! {
417 r#"
418 {def_fields_num}
419 {def_fields}
420
421 let mut __pilota_decoding_field_id = None;
422
423 {read_struct_begin};
424 if let ::std::result::Result::Err(mut err) = {read_fields} {{
425 if let Some(field_id) = __pilota_decoding_field_id {{
426 err.prepend_msg(&format!("{format_msg}, caused by: ", field_id));
427 }}
428 return ::std::result::Result::Err(err);
429 }};
430 {read_struct_end};
431
432 {verify_required_fields}
433
434 {set_default_fields}
435
436 let data = Self {{
437 {fields}
438 }};
439 ::std::result::Result::Ok(data)
440 "#
441 }
442 }
443
444 #[inline]
445 fn field_is_box(&self, f: &Field) -> bool {
446 self.with_adjust(f.did, |adj| match adj {
447 Some(a) => a.boxed(),
448 None => false,
449 })
450 }
451
452 fn codegen_entry_enum(&self, _def_id: DefId, _stream: &mut str, _e: &rir::Enum) {
453 }
455
456 fn codegen_decode_fields<'a>(
457 &'a self,
458 helper: &DecodeHelper,
459 fields: &'a [Arc<Field>],
460 keep: bool,
461 is_arg: bool,
462 ) -> String {
463 let record_ptr = if keep && !helper.is_async {
464 r#"let mut __pilota_offset = 0;
465 let __pilota_begin_ptr = __protocol.buf().chunk().as_ptr();"#
466 } else {
467 ""
468 };
469 let read_field_begin = helper.codegen_read_field_begin();
470 let field_begin_len = helper.codegen_field_begin_len(keep);
471 let match_fields = fields
472 .iter()
473 .map(|f| {
474 let field_ident = f.local_var_name();
475 let ttype = self.ttype(&f.ty);
476 let mut read_field = self.codegen_decode_ty(helper, &f.ty);
477 let field_id = f.id as i16;
478 if self.field_is_box(f) {
479 read_field = format!("::std::boxed::Box::new({read_field})").into();
480 };
481
482 if f.is_optional() || {
483 match self.cx.default_val(f) {
484 Some((_, is_const)) => !is_const,
485 _ => true,
486 }
487 } {
488 read_field = format!("Some({read_field})").into();
489 }
490
491 let fields_num = if keep && !helper.is_async && is_arg {
492 "__pilota_fields_num -= 1;"
493 } else {
494 ""
495 };
496
497 format!(
498 r#"Some({field_id}) if field_ident.field_type == {ttype} => {{
499 {field_ident} = {read_field};
500 {fields_num}
501 }},"#
502 )
503 })
504 .join("");
505 let mut skip_ttype = helper.codegen_skip_ttype("field_ident.field_type".into());
506 if keep && !helper.is_async {
507 skip_ttype = format!("__pilota_offset += {skip_ttype}")
508 }
509
510 let write_unknown_field = if keep && !helper.is_async {
511 "_unknown_fields.push_back(__protocol.get_bytes(Some(__pilota_begin_ptr), __pilota_offset)?);"
512 } else {
513 ""
514 };
515
516 let read_field_end = helper.codegen_read_field_end();
517 let field_end_len = helper.codegen_field_end_len(keep);
518 let field_stop_len = helper.codegen_field_stop_len(keep);
519
520 let skip_all = if keep && !helper.is_async && is_arg {
521 "if __pilota_fields_num == 0 {
522 let __pilota_remaining = __protocol.buf().remaining();
523 _unknown_fields.push_back(__protocol.get_bytes(None, __pilota_remaining - 2)?);
524 break;
525 }"
526 } else {
527 ""
528 };
529
530 format! {
531 r#"loop {{
532 {skip_all}
533 {record_ptr}
534 let field_ident = {read_field_begin};
535 if field_ident.field_type == ::pilota::thrift::TType::Stop {{
536 {field_stop_len}
537 break;
538 }} else {{
539 {field_begin_len}
540 }}
541 __pilota_decoding_field_id = field_ident.id;
542 match field_ident.id {{
543 {match_fields}
544 _ => {{
545 {skip_ttype};
546 {write_unknown_field}
547 }},
548 }}
549
550 {read_field_end};
551 {field_end_len}
552
553 }};"#
554 }
555 }
556}
557
558impl CodegenBackend for ThriftBackend {
559 const PROTOCOL: &'static str = "thrift";
560
561 fn codegen_struct_impl(&self, def_id: DefId, stream: &mut String, s: &Message) {
562 let filename = self
563 .cx
564 .file_paths()
565 .get(&self.cx.node(def_id).unwrap().file_id)
566 .unwrap()
567 .file_stem()
568 .unwrap()
569 .to_string_lossy()
570 .replace(".", "_");
571 let filename_lower = filename.to_lowercase();
572 let keep = self.keep_unknown_fields.contains(&def_id);
573 let name = self.cx.rust_name(def_id);
574 let mut encode_fields = self.codegen_encode_fields(&s.fields).join("");
575 let mut encode_fields_with_field_mask = self
576 .codegen_encode_fields_with_field_mask(&s.fields)
577 .join("");
578 if keep {
579 encode_fields.push_str(
580 r#"for bytes in self._unknown_fields.list.iter() {
581 __protocol.write_bytes_without_len(bytes.clone());
582 }"#,
583 );
584 encode_fields_with_field_mask.push_str(
585 r#"for bytes in self._unknown_fields.list.iter() {
586 __protocol.write_bytes_without_len(bytes.clone());
587 }"#,
588 );
589 }
590 let mut encode_fields_size = self
591 .codegen_encode_fields_size(&s.fields)
592 .map(|s| format!("{s} +"))
593 .join("");
594 let mut encode_fields_size_with_field_mask = self
595 .codegen_encode_fields_size_with_field_mask(&s.fields)
596 .map(|s| format!("{s} +"))
597 .join("");
598
599 if keep {
600 encode_fields_size.push_str("self._unknown_fields.size() +");
601 encode_fields_size_with_field_mask.push_str("self._unknown_fields.size() +");
602 }
603
604 if s.is_wrapper || !self.with_field_mask {
605 stream.push_str(&self.codegen_impl_message_with_helper(
606 def_id,
607 name.clone(),
608 format! {
609 r#"let struct_ident =::pilota::thrift::TStructIdentifier {{
610 name: "{name}",
611 }};
612
613 __protocol.write_struct_begin(&struct_ident)?;
614 {encode_fields}
615 __protocol.write_field_stop()?;
616 __protocol.write_struct_end()?;
617 ::std::result::Result::Ok(())
618 "#
619 },
620 format! {
621 r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
622 name: "{name}",
623 }}) + {encode_fields_size} __protocol.field_stop_len() + __protocol.struct_end_len()"#
624 },
625 |helper| self.codegen_decode(helper, s, name.clone(), keep, self.is_arg(def_id)),
626 ));
627
628 if !s.is_wrapper && self.with_descriptor {
629 stream.push_str(&format! {
630 r#"impl {name} {{
631 pub fn get_descriptor() -> &'static ::pilota_thrift_reflect::thrift_reflection::StructDescriptor {{
632 let file_descriptor = get_file_descriptor_{filename_lower}();
633 file_descriptor.find_struct_by_name("{name}").unwrap()
634 }}
635 }}"#
636 });
637 }
638 return;
639 }
640
641 if self.with_field_mask {
642 stream.push_str(&self.codegen_impl_message_with_helper(
643 def_id,
644 name.clone(),
645 format! {
646 r#"if let Some(struct_fm) = self._field_mask.as_ref() {{
647 if !struct_fm.exist() {{
648 ::std::result::Result::Ok(())
649 }} else {{
650 let struct_ident =::pilota::thrift::TStructIdentifier {{
651 name: "{name}",
652 }};
653 __protocol.write_struct_begin(&struct_ident)?;
654 {encode_fields_with_field_mask}
655 __protocol.write_field_stop()?;
656 __protocol.write_struct_end()?;
657 ::std::result::Result::Ok(())
658 }}
659 }} else {{
660 let struct_ident =::pilota::thrift::TStructIdentifier {{
661 name: "{name}",
662 }};
663
664 __protocol.write_struct_begin(&struct_ident)?;
665 {encode_fields}
666 __protocol.write_field_stop()?;
667 __protocol.write_struct_end()?;
668 ::std::result::Result::Ok(())
669 }}"#
670 },
671 format! {
672 r#"if let Some(struct_fm) = self._field_mask.as_ref() {{
673 if !struct_fm.exist() {{
674 0
675 }} else {{
676 __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
677 name: "{name}",
678 }}) + {encode_fields_size_with_field_mask} __protocol.field_stop_len() + __protocol.struct_end_len()
679 }}
680 }} else {{
681 __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
682 name: "{name}",
683 }}) + {encode_fields_size} __protocol.field_stop_len() + __protocol.struct_end_len()
684 }}"#
685 },
686 |helper| self.codegen_decode(helper, s, name.clone(), keep, self.is_arg(def_id)),
687 ));
688 }
689
690 let set_inner_field_mask = s
691 .fields
692 .iter()
693 .filter(|f| self.need_field_mask(&f.ty))
694 .map(|f| {
695 let field_name = self.rust_name(f.did);
696 let field_id = f.id as i16;
697 let is_optional = f.is_optional();
698 let field_mask = if is_optional {
699 self.codegen_struct_field_mask(
700 field_id,
701 &f.ty,
702 "value".into(),
703 "field_mask".into(),
704 )
705 } else {
706 self.codegen_struct_field_mask(
707 field_id,
708 &f.ty,
709 format!("self.{field_name}").into(),
710 "field_mask".into(),
711 )
712 };
713
714 if is_optional {
715 format! {
716 r#"if let Some(value) = &mut self.{field_name} {{
717 {field_mask}
718 }}"#
719 }
720 } else {
721 field_mask.to_string()
722 }
723 })
724 .join("");
725
726 stream.push_str(&format! {
727 r#"impl {name} {{
728 pub fn get_descriptor() -> &'static ::pilota_thrift_reflect::thrift_reflection::StructDescriptor {{
729 let file_descriptor = get_file_descriptor_{filename_lower}();
730 file_descriptor.find_struct_by_name("{name}").unwrap()
731 }}
732
733 pub fn set_field_mask(&mut self, field_mask: ::pilota_thrift_fieldmask::FieldMask) {{
734 self._field_mask = Some(field_mask.clone());
735 {set_inner_field_mask}
736 }}
737 }}"#
738 });
739 }
740
741 fn codegen_service_impl(&self, _def_id: DefId, _stream: &mut String, _s: &Service) {}
742
743 fn codegen_service_method(&self, _service_def_id: DefId, _m: &Method) -> String {
744 Default::default()
745 }
746
747 fn codegen_enum_impl(&self, def_id: DefId, stream: &mut String, e: &Enum) {
748 let keep = self.keep_unknown_fields.contains(&def_id);
749 let name = self.rust_name(def_id);
750 let is_entry_message = self.node_contains_tag::<EntryMessage>(def_id);
751 let v = "self.inner()";
752 match e.repr {
753 Some(EnumRepr::I32) => stream.push_str(&self.codegen_impl_message_with_helper(
754 def_id,
755 name.clone(),
756 format! {
757 r#"__protocol.write_i32({v})?;
758 ::std::result::Result::Ok(())
759 "#
760 },
761 format!("__protocol.i32_len({v})"),
762 |helper| {
763 let read_i32 = helper.codegen_read_i32();
764 let err_msg_tmpl = format!("invalid enum value for {name}, value: {{}}");
765 format! {
766 r#"let value = {read_i32};
767 ::std::result::Result::Ok(::std::convert::TryFrom::try_from(value).map_err(|err|
768 ::pilota::thrift::new_protocol_exception(
769 ::pilota::thrift::ProtocolExceptionKind::InvalidData,
770 format!("{err_msg_tmpl}", value)
771 ))?)"#
772 }
773 },
774 )),
775 None if is_entry_message => self.codegen_entry_enum(def_id, stream, e),
776 None => {
777 let name = self.rust_name(def_id);
778 let mut encode_variants = e
779 .variants
780 .iter()
781 .map(|v| {
782 let variant_name = self.rust_name(v.did);
783 assert_eq!(v.fields.len(), 1);
784 let variant_id = v.id.unwrap() as i16;
785 let encode =
786 self.codegen_encode_field(variant_id, &v.fields[0], "value".into());
787 format! {
788 r#"{name}::{variant_name}(value) => {{
789 {encode}
790 }},"#
791 }
792 })
793 .join("");
794 if keep {
795 encode_variants.push_str(&format! {
796 "{name}::_UnknownFields(value) => {{
797 for bytes in value.list.iter() {{
798 __protocol.write_bytes_without_len(bytes.clone());
799 }}
800 }}",
801 });
802 }
803
804 if e.variants.is_empty() {
805 encode_variants.push_str("_ => {},");
806 }
807
808 let mut variants_size = e
809 .variants
810 .iter()
811 .map(|v| {
812 let variant_name = self.rust_name(v.did);
813 let variant_id = v.id.unwrap() as i16;
814 let size =
815 self.codegen_field_size(&v.fields[0], variant_id, "value".into());
816
817 format! {
818 r#"{name}::{variant_name}(value) => {{
819 {size}
820 }},"#
821 }
822 })
823 .join("");
824 if keep {
825 variants_size.push_str(&format! {
826 "{name}::_UnknownFields(value) => {{
827 value.size()
828 }}",
829 })
830 }
831
832 if e.variants.is_empty() {
833 variants_size.push_str("_ => 0,");
834 }
835
836 let variant_is_void = |v: &EnumVariant| {
837 &*v.name.sym == "Ok" && v.fields.len() == 1 && v.fields[0].kind == TyKind::Void
838 };
839
840 stream.push_str(&self.codegen_impl_message_with_helper(def_id,
841 name.clone(),
842 format! {
843 r#"__protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {{
844 name: "{name}",
845 }})?;
846 match self {{
847 {encode_variants}
848 }}
849 __protocol.write_field_stop()?;
850 __protocol.write_struct_end()?;
851 ::std::result::Result::Ok(())"#
852 },
853 format! {
854 r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
855 name: "{name}",
856 }}) + match self {{
857 {variants_size}
858 }} + __protocol.field_stop_len() + __protocol.struct_end_len()"#
859 },
860 |helper| {
861 let record_ptr = if keep && !helper.is_async {
862 r#"let mut __pilota_offset = 0;
863 let __pilota_begin_ptr = __protocol.buf().chunk().as_ptr();"#
864 } else {
865 ""
866 };
867 let read_struct_begin = helper.codegen_read_struct_begin();
868 let read_field_begin = helper.codegen_read_field_begin();
869 let field_begin_len = helper.codegen_field_begin_len(keep);
870 let read_field_end = helper.codegen_read_field_end();
871 let field_stop_len = helper.codegen_field_stop_len(keep);
872 let read_struct_end = helper.codegen_read_struct_end();
873 let mut skip = helper.codegen_skip_ttype("field_ident.field_type".into());
874 if keep && !helper.is_async {
875 skip = format!("__pilota_offset += {skip}")
876 }
877 let fields = e
878 .variants
879 .iter()
880 .flat_map(|v| {
881 if variant_is_void(v) {
882 None
883 } else {
884 let variant_name = self.cx.rust_name(v.did);
885 assert_eq!(v.fields.len(), 1);
886 let variant_id = v.id.unwrap() as i16;
887 let decode = self.codegen_decode_ty(helper, &v.fields[0]);
888 let decode_len = if helper.is_async {
889 Default::default()
890 } else {
891 let size = self.codegen_ty_size(
892 &v.fields[0],
893 "&field_ident".into()
894 );
895 if keep {
896 format!(
897 "__pilota_offset += {size};",
898 )
899 } else {
900 format!(
901 "{size};",
902 )
903 }
904 };
905 Some(format! {
906 r#"Some({variant_id}) => {{
907 if ret.is_none() {{
908 let field_ident = {decode};
909 {decode_len}
910 ret = Some({name}::{variant_name}(field_ident));
911 }} else {{
912 return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
913 ::pilota::thrift::ProtocolExceptionKind::InvalidData,
914 "received multiple fields for union from remote Message"
915 ));
916 }}
917 }},"#
918 })
919 }
920 })
921 .join("");
922 let write_unknown_field = if keep && !helper.is_async {
923 format!(
924 r#"if ret.is_none() {{
925 unsafe {{
926 let mut __pilota_linked_bytes = ::pilota::BytesVec::new();
927 __pilota_linked_bytes.push_back(__protocol.get_bytes(Some(__pilota_begin_ptr), __pilota_offset)?);
928 ret = Some({name}::_UnknownFields(__pilota_linked_bytes));
929 }}
930 }} else {{
931 return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
932 ::pilota::thrift::ProtocolExceptionKind::InvalidData,
933 "received multiple fields for union from remote Message"
934 ));
935 }}"#
936 )
937 } else {
938 Default::default()
939 };
940
941 let handle_none_ret: FastStr =
942 if e.variants.first().filter(|v| variant_is_void(v)).is_some() {
943 format!("::std::result::Result::Ok({name}::Ok(()))").into()
944 } else {
945 r#"::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
946 ::pilota::thrift::ProtocolExceptionKind::InvalidData,
947 "received empty union from remote Message")
948 )"#.into()
949 };
950
951 format! {
952 r#"let mut ret = None;
953 {read_struct_begin};
954 loop {{
955 {record_ptr}
956 let field_ident = {read_field_begin};
957 if field_ident.field_type == ::pilota::thrift::TType::Stop {{
958 {field_stop_len}
959 break;
960 }} else {{
961 {field_begin_len}
962 }}
963 match field_ident.id {{
964 {fields}
965 _ => {{
966 {skip};
967 {write_unknown_field}
968 }},
969 }}
970 }}
971 {read_field_end};
972 {read_struct_end};
973 if let Some(ret) = ret {{
974 ::std::result::Result::Ok(ret)
975 }} else {{
976 {handle_none_ret}
977 }}"#
978 }
979 },
980 ))
981 }
982 #[allow(unreachable_patterns)]
983 _ => {}
984 }
985 }
986
987 fn codegen_newtype_impl(&self, def_id: DefId, stream: &mut String, t: &NewType) {
988 let name = self.rust_name(def_id);
989 let encode = self.codegen_encode_ty(&t.ty, "(&**self)".into());
990 let encode_size = self.codegen_ty_size(&t.ty, "&**self".into());
991
992 if !self.with_field_mask {
993 stream.push_str(&self.codegen_impl_message_with_helper(
994 def_id,
995 name.clone(),
996 format! {
997 r#"{encode}
998 ::std::result::Result::Ok(())"#
999 },
1000 format!("{encode_size}"),
1001 |helper| {
1002 let decode = self.codegen_decode_ty(helper, &t.ty);
1003 format!("::std::result::Result::Ok({name}({decode}))")
1004 },
1005 ));
1006 return;
1007 }
1008
1009 match &t.ty.kind {
1010 TyKind::Path(p) if !self.is_enum(p.did) => {
1011 let inner_name = self.rust_name(p.did);
1012 let encode_with_field_mask =
1013 self.codegen_encode_ty_with_field_mask(&t.ty, "(&**self)".into());
1014 let encode_size_with_field_mask =
1015 self.codegen_ty_size_with_field_mask(&t.ty, "&**self".into());
1016
1017 stream.push_str(&self.codegen_impl_message_with_helper(
1018 def_id,
1019 name.clone(),
1020 format! {
1021 r#"{encode_with_field_mask}
1022 ::std::result::Result::Ok(())"#
1023 },
1024 encode_size_with_field_mask.into(),
1025 |helper| {
1026 let decode = self.codegen_decode_ty(helper, &t.ty);
1027 format!("::std::result::Result::Ok({name}({decode}))")
1028 },
1029 ));
1030
1031 stream.push_str(&format!(
1032 r#"impl {name} {{
1033 pub fn get_descriptor() -> &'static ::pilota_thrift_reflect::thrift_reflection::StructDescriptor {{
1034 {inner_name}::get_descriptor()
1035 }}
1036 pub fn set_field_mask(&mut self, field_mask: ::pilota_thrift_fieldmask::FieldMask) {{
1037 self.0.set_field_mask(field_mask);
1038 }}
1039 }}"#
1040 ));
1041 }
1042 _ => {
1043 stream.push_str(&self.codegen_impl_message_with_helper(
1044 def_id,
1045 name.clone(),
1046 format! {
1047 r#"{encode}
1048 ::std::result::Result::Ok(())"#
1049 },
1050 format!("{encode_size}"),
1051 |helper| {
1052 let decode = self.codegen_decode_ty(helper, &t.ty);
1053 format!("::std::result::Result::Ok({name}({decode}))")
1054 },
1055 ));
1056
1057 stream.push_str(&format!(
1058 r#"impl {name} {{
1059 pub fn set_field_mask(&mut self, _: ::pilota_thrift_fieldmask::FieldMask) {{
1060 }}
1061 }}"#
1062 ));
1063 }
1064 }
1065 }
1066
1067 fn cx(&self) -> &Context {
1068 &self.cx
1069 }
1070
1071 fn codegen_file_descriptor(&self, stream: &mut String, f: &rir::File, has_direct: bool) {
1072 let filename = self
1073 .file_paths()
1074 .get(&f.file_id)
1075 .unwrap()
1076 .file_stem()
1077 .unwrap()
1078 .to_string_lossy()
1079 .replace(".", "_");
1080 let filename_upper = filename.to_uppercase();
1081 let filename_lower = filename.to_lowercase();
1082 if has_direct {
1083 let descriptor = &f.descriptor;
1084 let super_mod = match &*self.mode {
1085 Mode::Workspace(_) => "crate::".to_string(),
1086 Mode::SingleFile { .. } => "super::".repeat(f.package.len()),
1087 };
1088 stream.push_str(&format!(
1089 r#"
1090static FILE_DESCRIPTOR_BYTES_{filename_upper}: ::pilota::Bytes = ::pilota::Bytes::from_static({descriptor:?});
1091static FILE_DESCRIPTOR_{filename_upper}: ::std::sync::LazyLock<::pilota_thrift_reflect::thrift_reflection::FileDescriptor> = ::std::sync::LazyLock::new(|| {{
1092 let descriptor = ::pilota_thrift_reflect::thrift_reflection::FileDescriptor::deserialize(FILE_DESCRIPTOR_BYTES_{filename_upper}.clone())
1093 .expect("Failed to decode file descriptor");
1094 ::pilota_thrift_reflect::service::Register::register(
1095 descriptor.filepath.clone(),
1096 descriptor.clone(),
1097 );
1098
1099 for (key, include) in descriptor.includes.iter() {{
1100 let path = include.as_str();
1101 if ::pilota_thrift_reflect::service::Register::contains(path) {{
1102 continue;
1103 }}
1104 let include_file_descriptor =
1105 {super_mod}find_mod_file_descriptor(path).expect("include file descriptor must exist");
1106 ::pilota_thrift_reflect::service::Register::register(
1107 include_file_descriptor.filepath.clone(),
1108 include_file_descriptor.clone(),
1109 );
1110 }}
1111 descriptor
1112}});
1113pub fn get_file_descriptor_{filename_lower}() -> &'static ::pilota_thrift_reflect::thrift_reflection::FileDescriptor {{
1114 &*FILE_DESCRIPTOR_{filename_upper}
1115}}"#));
1116 } else {
1117 match &*self.mode {
1118 Mode::Workspace(_) => {
1119 let mod_prefix = f.package.iter().join("::");
1121 let common_crate_name = &self.common_crate_name;
1122 stream.push_str(&format!(
1123 r#"
1124 pub use ::{common_crate_name}::{mod_prefix}::get_file_descriptor_{filename_lower};
1125 "#
1126 ));
1127 }
1128 Mode::SingleFile { .. } => {}
1129 };
1130 }
1131 }
1132
1133 fn codegen_register_mod_file_descriptor(
1134 &self,
1135 stream: &mut String,
1136 mods: &[(Arc<[FastStr]>, Arc<PathBuf>)],
1137 ) {
1138 stream.push_str(r#"
1139 pub fn find_mod_file_descriptor(path: &str) -> Option<&'static ::pilota_thrift_reflect::thrift_reflection::FileDescriptor> {
1140 match path {
1141 "#);
1142
1143 for (p, path) in mods {
1144 let filename = path
1145 .file_stem()
1146 .expect("must_exist")
1147 .to_string_lossy()
1148 .to_lowercase()
1149 .replace(".", "_");
1150 let path = path.display();
1151 stream.push_str(&format!(
1152 r#"
1153 r"{path}" => Some(
1154 "#
1155 ));
1156
1157 let prefix = p.iter().map(|s| Symbol::from(s.clone())).join("::");
1158 if !prefix.is_empty() {
1159 stream.push_str(&format!(r#"{prefix}::"#));
1160 }
1161 stream.push_str(&format!("get_file_descriptor_{filename}()),"));
1162 }
1163
1164 stream.push_str(
1165 r#"
1166 _ => None,
1167 }
1168 }
1169 "#,
1170 );
1171 }
1172}