1#![cfg_attr(docsrs, feature(doc_cfg))]
96#![doc(
97 html_logo_url = "https://github.com/specta-rs/specta/raw/main/.github/logo-128.png",
98 html_favicon_url = "https://github.com/specta-rs/specta/raw/main/.github/logo-128.png"
99)]
100
101use std::{
102 borrow::Cow,
103 collections::{HashMap, HashSet, VecDeque},
104};
105
106use specta::{
107 ResolvedTypes, Types,
108 datatype::{
109 DataType, Enum, Field, Fields, NamedDataType, Primitive, Reference, Struct, Tuple,
110 UnnamedFields, Variant,
111 },
112};
113
114mod error;
115mod inflection;
116mod parser;
117mod phased;
118mod repr;
119mod validate;
120
121use inflection::RenameRule;
122use parser::{SerdeContainerAttrs, SerdeFieldAttrs, SerdeVariantAttrs};
123use phased::PhasedTy;
124
125pub use error::Error;
126pub use phased::{Phased, phased};
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum Phase {
131 Serialize,
133 Deserialize,
135}
136
137#[doc(hidden)]
138pub mod internal {
139 pub use crate::error::Result;
140 pub use crate::inflection::RenameRule;
141 pub use crate::parser::{
142 ConversionType, SerdeContainerAttrs, SerdeFieldAttrs, SerdeVariantAttrs,
143 };
144}
145
146use error::Result;
147use repr::EnumRepr;
148
149pub fn validate(dt: &DataType, types: &ResolvedTypes) -> Result<()> {
157 validate::validate_datatype_for_mode(dt, types.as_types(), validate::ApplyMode::Unified)
158}
159
160pub fn apply(types: Types) -> Result<ResolvedTypes> {
176 validate::validate_for_mode(&types, validate::ApplyMode::Unified)?;
177
178 let mut out = types.clone();
179 let generated = HashMap::<TypeIdentity, SplitGeneratedTypes>::new();
180 let split_types = HashSet::<TypeIdentity>::new();
181 let mut rewrite_err = None;
182
183 out.iter_mut(|ndt| {
184 if rewrite_err.is_some() {
185 return;
186 }
187
188 let ndt_name = ndt.name().to_string();
189
190 if let Err(err) = rewrite_datatype_for_phase(
191 ndt.ty_mut(),
192 PhaseRewrite::Unified,
193 &types,
194 &generated,
195 &split_types,
196 Some(ndt_name.as_str()),
197 ) {
198 rewrite_err = Some(err);
199 }
200 });
201
202 if let Some(err) = rewrite_err {
203 return Err(err);
204 }
205
206 Ok(ResolvedTypes::from_resolved_types(out))
207}
208
209pub fn apply_phases(types: Types) -> Result<ResolvedTypes> {
223 validate::validate_for_mode(&types, validate::ApplyMode::Phases)?;
224
225 let originals = types.into_unsorted_iter().collect::<Vec<_>>();
226 let mut dependencies = HashMap::<TypeIdentity, HashSet<TypeIdentity>>::new();
227 let mut reverse_dependencies = HashMap::<TypeIdentity, HashSet<TypeIdentity>>::new();
228
229 for original in &originals {
230 let key = TypeIdentity::from_ndt(original);
231 let mut deps = HashSet::new();
232 collect_dependencies(original.ty(), &types, &mut deps)?;
233 for dep in &deps {
234 reverse_dependencies
235 .entry(dep.clone())
236 .or_default()
237 .insert(key.clone());
238 }
239 dependencies.insert(key, deps);
240 }
241
242 let mut split_types = HashSet::new();
243 for ndt in &originals {
244 if has_local_phase_difference(ndt.ty())? {
245 split_types.insert(TypeIdentity::from_ndt(ndt));
246 }
247 }
248
249 let mut queue = VecDeque::from_iter(split_types.iter().cloned());
250 while let Some(key) = queue.pop_front() {
251 if let Some(dependents) = reverse_dependencies.get(&key) {
252 for dependent in dependents {
253 if split_types.insert(dependent.clone()) {
254 queue.push_back(dependent.clone());
255 }
256 }
257 }
258 }
259
260 let mut out = types.clone();
261 let mut generated = HashMap::<TypeIdentity, SplitGeneratedTypes>::new();
262 let mut generated_types = HashSet::<TypeIdentity>::new();
263
264 for original in &originals {
265 let key = TypeIdentity::from_ndt(original);
266
267 if split_types.contains(&key) {
268 let serialize_ndt = build_from_original(
269 original,
270 format!("{}_Serialize", original.name()),
271 original.generics().to_vec(),
272 original.ty().clone(),
273 &types,
274 );
275
276 let deserialize_ndt = build_from_original(
277 original,
278 format!("{}_Deserialize", original.name()),
279 original.generics().to_vec(),
280 original.ty().clone(),
281 &types,
282 );
283
284 generated.insert(
285 key,
286 SplitGeneratedTypes {
287 serialize: serialize_ndt,
288 deserialize: Box::new(deserialize_ndt),
289 },
290 );
291 }
292 }
293
294 for original in &originals {
295 let key = TypeIdentity::from_ndt(original);
296
297 if !split_types.contains(&key) {
298 continue;
299 }
300
301 let Some(mut generated_types_for_phase) = generated.get(&key).cloned() else {
302 continue;
303 };
304
305 rewrite_datatype_for_phase(
306 generated_types_for_phase.serialize.ty_mut(),
307 PhaseRewrite::Serialize,
308 &types,
309 &generated,
310 &split_types,
311 Some(original.name().as_ref()),
312 )?;
313
314 rewrite_datatype_for_phase(
315 generated_types_for_phase.deserialize.ty_mut(),
316 PhaseRewrite::Deserialize,
317 &types,
318 &generated,
319 &split_types,
320 Some(original.name().as_ref()),
321 )?;
322
323 generated.insert(key, generated_types_for_phase);
324 }
325
326 for generated_types_for_phase in generated.values() {
327 generated_types.insert(TypeIdentity::from_ndt(&generated_types_for_phase.serialize));
328 generated_types.insert(TypeIdentity::from_ndt(
329 &generated_types_for_phase.deserialize,
330 ));
331 generated_types_for_phase.serialize.register(&mut out);
332 generated_types_for_phase.deserialize.register(&mut out);
333 }
334
335 let mut rewrite_err = None;
336 out.iter_mut(|ndt| {
337 if rewrite_err.is_some() {
338 return;
339 }
340
341 let ndt_name = ndt.name().to_string();
342 let key = TypeIdentity::from_ndt(ndt);
343
344 if split_types.contains(&key) || generated_types.contains(&key) {
345 return;
346 }
347
348 if let Err(err) = rewrite_datatype_for_phase(
349 ndt.ty_mut(),
350 PhaseRewrite::Unified,
351 &types,
352 &generated,
353 &split_types,
354 Some(ndt_name.as_str()),
355 ) {
356 rewrite_err = Some(err);
357 }
358 });
359
360 if let Some(err) = rewrite_err {
361 return Err(err);
362 }
363
364 out.iter_mut(|ndt| {
365 let key = TypeIdentity::from_ndt(ndt);
366 if !split_types.contains(&key) {
367 return;
368 }
369
370 let Some(SplitGeneratedTypes {
371 serialize,
372 deserialize,
373 }) = generated.get(&key)
374 else {
375 return;
376 };
377
378 let generic_args = ndt
379 .generics()
380 .iter()
381 .map(|(generic, _)| (generic.clone(), generic.clone().into()))
382 .collect::<Vec<_>>();
383
384 let mut serialize_variant = Variant::unnamed().build();
385 if let Fields::Unnamed(fields) = serialize_variant.fields_mut() {
386 fields
387 .fields_mut()
388 .push(Field::new(serialize.reference(generic_args.clone()).into()));
389 }
390
391 let mut deserialize_variant = Variant::unnamed().build();
392 if let Fields::Unnamed(fields) = deserialize_variant.fields_mut() {
393 fields
394 .fields_mut()
395 .push(Field::new(deserialize.reference(generic_args).into()));
396 }
397
398 let mut wrapper = Enum::new();
399 wrapper
400 .variants_mut()
401 .push((Cow::Borrowed("Serialize"), serialize_variant));
402 wrapper
403 .variants_mut()
404 .push((Cow::Borrowed("Deserialize"), deserialize_variant));
405
406 ndt.set_ty(DataType::Enum(wrapper));
407 });
408 Ok(ResolvedTypes::from_resolved_types(out))
409}
410
411pub fn select_phase_datatype(dt: &DataType, types: &ResolvedTypes, phase: Phase) -> DataType {
462 let mut dt = dt.clone();
463 select_phase_datatype_inner(&mut dt, types.as_types(), phase);
464 dt
465}
466
467#[derive(Debug, Clone, Copy, PartialEq, Eq)]
468enum PhaseRewrite {
469 Unified,
470 Serialize,
471 Deserialize,
472}
473
474fn select_phase_datatype_inner(ty: &mut DataType, types: &Types, phase: Phase) {
475 if let Some(resolved) = select_explicit_phased_type(ty, phase) {
476 *ty = resolved;
477 select_phase_datatype_inner(ty, types, phase);
478 return;
479 }
480
481 match ty {
482 DataType::Struct(s) => select_phase_fields(s.fields_mut(), types, phase),
483 DataType::Enum(e) => {
484 for (_, variant) in e.variants_mut() {
485 select_phase_fields(variant.fields_mut(), types, phase);
486 }
487 }
488 DataType::Tuple(tuple) => {
489 for ty in tuple.elements_mut() {
490 select_phase_datatype_inner(ty, types, phase);
491 }
492 }
493 DataType::List(list) => select_phase_datatype_inner(list.ty_mut(), types, phase),
494 DataType::Map(map) => {
495 select_phase_datatype_inner(map.key_ty_mut(), types, phase);
496 select_phase_datatype_inner(map.value_ty_mut(), types, phase);
497 }
498 DataType::Nullable(inner) => select_phase_datatype_inner(inner, types, phase),
499 DataType::Reference(Reference::Named(reference)) => {
500 let Some(referenced_ndt) = reference.get(types) else {
501 return;
502 };
503
504 let generics = reference
505 .generics()
506 .iter()
507 .map(|(generic, dt)| {
508 let mut dt = dt.clone();
509 select_phase_datatype_inner(&mut dt, types, phase);
510 (generic.clone(), dt)
511 })
512 .collect::<Vec<_>>();
513
514 let target_ndt =
515 select_split_type_variant(referenced_ndt, types, phase).unwrap_or(referenced_ndt);
516
517 let mut new_reference = target_ndt.reference(generics);
518 if reference.inline() {
519 new_reference = new_reference.inline();
520 }
521
522 *ty = DataType::Reference(new_reference);
523 }
524 DataType::Reference(Reference::Generic(_))
525 | DataType::Reference(Reference::Opaque(_))
526 | DataType::Primitive(_) => {}
527 }
528}
529
530fn select_phase_fields(fields: &mut Fields, types: &Types, phase: Phase) {
531 match fields {
532 Fields::Unit => {}
533 Fields::Unnamed(fields) => {
534 for field in fields.fields_mut() {
535 if let Some(ty) = field.ty_mut() {
536 select_phase_datatype_inner(ty, types, phase);
537 }
538 }
539 }
540 Fields::Named(fields) => {
541 for (_, field) in fields.fields_mut() {
542 if let Some(ty) = field.ty_mut() {
543 select_phase_datatype_inner(ty, types, phase);
544 }
545 }
546 }
547 }
548}
549
550fn select_explicit_phased_type(ty: &DataType, phase: Phase) -> Option<DataType> {
551 let DataType::Reference(Reference::Opaque(reference)) = ty else {
552 return None;
553 };
554 let phased = reference.downcast_ref::<PhasedTy>()?;
555
556 Some(match phase {
557 Phase::Serialize => phased.serialize.clone(),
558 Phase::Deserialize => phased.deserialize.clone(),
559 })
560}
561
562fn select_split_type_variant<'a>(
563 ndt: &'a NamedDataType,
564 types: &'a Types,
565 phase: Phase,
566) -> Option<&'a NamedDataType> {
567 let DataType::Enum(wrapper) = ndt.ty() else {
568 return None;
569 };
570
571 let variant_name = match phase {
572 Phase::Serialize => "Serialize",
573 Phase::Deserialize => "Deserialize",
574 };
575
576 let (_, variant) = wrapper
577 .variants()
578 .iter()
579 .find(|(name, _)| name == variant_name)?;
580 let Fields::Unnamed(fields) = variant.fields() else {
581 return None;
582 };
583 let field = fields.fields().first()?;
584 let Some(DataType::Reference(Reference::Named(reference))) = field.ty() else {
585 return None;
586 };
587
588 reference.get(types)
589}
590
591#[derive(Debug, Clone)]
592struct SplitGeneratedTypes {
593 serialize: NamedDataType,
594 deserialize: Box<NamedDataType>,
595}
596
597#[derive(Debug, Clone, PartialEq, Eq, Hash)]
598struct TypeIdentity {
599 name: String,
600 module_path: String,
601 file: &'static str,
602 line: u32,
603 column: u32,
604}
605
606impl TypeIdentity {
607 fn from_ndt(ty: &specta::datatype::NamedDataType) -> Self {
608 let location = ty.location();
609 Self {
610 name: ty.name().to_string(),
611 module_path: ty.module_path().to_string(),
612 file: location.file(),
613 line: location.line(),
614 column: location.column(),
615 }
616 }
617}
618
619fn rewrite_datatype_for_phase(
620 ty: &mut DataType,
621 mode: PhaseRewrite,
622 original_types: &Types,
623 generated: &HashMap<TypeIdentity, SplitGeneratedTypes>,
624 split_types: &HashSet<TypeIdentity>,
625 container_name: Option<&str>,
626) -> Result<()> {
627 if let Some(resolved) = resolve_phased_type(ty, mode, "type")? {
628 *ty = resolved;
629 }
630
631 if let Some(converted) = conversion_datatype_for_mode(ty, mode)?
632 && converted != *ty
633 {
634 *ty = converted;
635 return rewrite_datatype_for_phase(
636 ty,
637 mode,
638 original_types,
639 generated,
640 split_types,
641 container_name,
642 );
643 }
644
645 match ty {
646 DataType::Struct(s) => {
647 let container_default = SerdeContainerAttrs::from_attributes(s.attributes())?
648 .is_some_and(|attrs| attrs.default);
649 let container_rename_all = container_rename_all_rule(
650 s.attributes(),
651 mode,
652 "struct rename_all",
653 container_name.unwrap_or("<anonymous struct>"),
654 )?;
655
656 rewrite_fields_for_phase(
657 s.fields_mut(),
658 mode,
659 original_types,
660 generated,
661 split_types,
662 container_rename_all,
663 container_default,
664 false,
665 )?;
666 rewrite_struct_repr_for_phase(s, mode, container_name)?;
667 }
668 DataType::Enum(e) => {
669 filter_enum_variants_for_phase(e, mode)?;
670 let container_attrs = SerdeContainerAttrs::from_attributes(e.attributes())?;
671
672 for (variant_name, variant) in e.variants_mut() {
673 let rename_rule =
674 enum_variant_field_rename_rule(&container_attrs, variant, mode, variant_name)?;
675
676 rewrite_fields_for_phase(
677 variant.fields_mut(),
678 mode,
679 original_types,
680 generated,
681 split_types,
682 rename_rule,
683 false,
684 true,
685 )?;
686 }
687
688 if rewrite_identifier_enum_for_phase(e, mode, original_types, generated, split_types)? {
689 return Ok(());
690 }
691
692 rewrite_enum_repr_for_phase(e, mode, original_types)?;
693 }
694 DataType::Tuple(tuple) => {
695 for ty in tuple.elements_mut() {
696 rewrite_datatype_for_phase(ty, mode, original_types, generated, split_types, None)?;
697 }
698 }
699 DataType::List(list) => rewrite_datatype_for_phase(
700 list.ty_mut(),
701 mode,
702 original_types,
703 generated,
704 split_types,
705 None,
706 )?,
707 DataType::Map(map) => {
708 rewrite_datatype_for_phase(
709 map.key_ty_mut(),
710 mode,
711 original_types,
712 generated,
713 split_types,
714 None,
715 )?;
716 rewrite_datatype_for_phase(
717 map.value_ty_mut(),
718 mode,
719 original_types,
720 generated,
721 split_types,
722 None,
723 )?;
724 }
725 DataType::Nullable(inner) => {
726 rewrite_datatype_for_phase(inner, mode, original_types, generated, split_types, None)?
727 }
728 DataType::Reference(Reference::Named(reference)) => {
729 let Some(referenced_ndt) = reference.get(original_types) else {
730 return Ok(());
731 };
732 let key = TypeIdentity::from_ndt(referenced_ndt);
733
734 let mut generics = Vec::with_capacity(reference.generics().len());
735 for (generic, dt) in reference.generics() {
736 let mut dt = dt.clone();
737 rewrite_datatype_for_phase(
738 &mut dt,
739 mode,
740 original_types,
741 generated,
742 split_types,
743 None,
744 )?;
745 generics.push((generic.clone(), dt));
746 }
747
748 if !split_types.contains(&key) {
749 let mut new_reference = referenced_ndt.reference(generics);
750 if reference.inline() {
751 new_reference = new_reference.inline();
752 }
753 *ty = DataType::Reference(new_reference);
754 return Ok(());
755 }
756
757 let Some(target) = generated.get(&key) else {
758 return Ok(());
759 };
760
761 let mut new_reference = match mode {
762 PhaseRewrite::Unified => {
763 unreachable!("unified mode should not reference split types")
764 }
765 PhaseRewrite::Serialize => target.serialize.reference(generics),
766 PhaseRewrite::Deserialize => target.deserialize.reference(generics),
767 };
768
769 if reference.inline() {
770 new_reference = new_reference.inline();
771 }
772
773 *ty = DataType::Reference(new_reference);
774 }
775 DataType::Reference(Reference::Generic(_))
776 | DataType::Reference(Reference::Opaque(_))
777 | DataType::Primitive(_) => {}
778 }
779
780 Ok(())
781}
782
783fn rewrite_fields_for_phase(
784 fields: &mut Fields,
785 mode: PhaseRewrite,
786 original_types: &Types,
787 generated: &HashMap<TypeIdentity, SplitGeneratedTypes>,
788 split_types: &HashSet<TypeIdentity>,
789 rename_all_rule: Option<RenameRule>,
790 container_default: bool,
791 preserve_skipped_unnamed_fields: bool,
792) -> Result<()> {
793 match fields {
794 Fields::Unit => {}
795 Fields::Unnamed(unnamed) => {
796 for field in unnamed.fields_mut() {
797 if should_skip_field_for_mode(field, mode)? {
798 if preserve_skipped_unnamed_fields {
799 *field = skipped_field_marker(field);
800 }
801
802 continue;
803 }
804
805 apply_field_attrs(field, mode, container_default)?;
806 rewrite_field_for_phase(field, mode, original_types, generated, split_types)?;
807 }
808
809 if !preserve_skipped_unnamed_fields {
810 unnamed.fields_mut().retain(|field| field.ty().is_some());
811 }
812 }
813 Fields::Named(named) => {
814 let mut skip_err = None;
815 named
816 .fields_mut()
817 .retain(|(_, field)| match should_skip_field_for_mode(field, mode) {
818 Ok(skip) => !skip,
819 Err(err) => {
820 skip_err = Some(err);
821 true
822 }
823 });
824 if let Some(err) = skip_err {
825 return Err(err);
826 }
827
828 for (name, field) in named.fields_mut() {
829 apply_field_attrs(field, mode, container_default)?;
830
831 if let Some(serde_attrs) = SerdeFieldAttrs::from_attributes(field.attributes())? {
832 let rename = match mode {
833 PhaseRewrite::Serialize => serde_attrs.rename_serialize.as_deref(),
834 PhaseRewrite::Deserialize => serde_attrs.rename_deserialize.as_deref(),
835 PhaseRewrite::Unified => serde_attrs
836 .rename_serialize
837 .as_deref()
838 .or(serde_attrs.rename_deserialize.as_deref()),
839 };
840
841 if let Some(rename) = rename {
842 *name = Cow::Owned(rename.to_string());
843 } else if let Some(rule) = rename_all_rule {
844 *name = Cow::Owned(rule.apply_to_field(name));
845 }
846 } else if let Some(rule) = rename_all_rule {
847 *name = Cow::Owned(rule.apply_to_field(name));
848 }
849
850 rewrite_field_for_phase(field, mode, original_types, generated, split_types)?;
851 }
852 }
853 }
854
855 Ok(())
856}
857
858fn rewrite_field_for_phase(
859 field: &mut Field,
860 mode: PhaseRewrite,
861 original_types: &Types,
862 generated: &HashMap<TypeIdentity, SplitGeneratedTypes>,
863 split_types: &HashSet<TypeIdentity>,
864) -> Result<()> {
865 if let Some(attrs) = SerdeFieldAttrs::from_attributes(field.attributes())?
866 && let PhaseRewrite::Serialize = mode
867 && attrs.skip_serializing_if.is_some()
868 {
869 field.set_optional(true);
870 }
871
872 if let Some(ty) = field.ty().cloned()
873 && let Some(resolved) = resolve_phased_type(&ty, mode, "field")?
874 {
875 field.set_ty(resolved);
876 }
877
878 if let Some(ty) = field.ty_mut() {
879 rewrite_datatype_for_phase(ty, mode, original_types, generated, split_types, None)?;
880 }
881
882 Ok(())
883}
884
885fn rewrite_struct_repr_for_phase(
886 strct: &mut Struct,
887 mode: PhaseRewrite,
888 container_name: Option<&str>,
889) -> Result<()> {
890 let Some((tag, rename_serialize, rename_deserialize)) =
891 SerdeContainerAttrs::from_attributes(strct.attributes())?.map(|attrs| {
892 (
893 attrs.tag.clone(),
894 attrs.rename_serialize.clone(),
895 attrs.rename_deserialize.clone(),
896 )
897 })
898 else {
899 return Ok(());
900 };
901
902 let Some(tag) = tag.as_deref() else {
903 return Ok(());
904 };
905
906 let serialized_name = match select_phase_string(
907 mode,
908 rename_serialize.as_deref(),
909 rename_deserialize.as_deref(),
910 "struct rename",
911 container_name.unwrap_or("<anonymous struct>"),
912 )? {
913 Some(rename) => rename.to_string(),
914 None => container_name
915 .map(str::to_owned)
916 .ok_or_else(|| {
917 Error::invalid_phased_type_usage(
918 "<anonymous struct>",
919 "`#[serde(tag = ...)]` on structs requires either a named type or `#[serde(rename = ...)]`",
920 )
921 })?,
922 };
923
924 let Fields::Named(named) = strct.fields_mut() else {
925 return Ok(());
926 };
927
928 named.fields_mut().insert(
929 0,
930 (
931 Cow::Owned(tag.to_string()),
932 Field::new(string_literal_datatype(serialized_name)),
933 ),
934 );
935
936 Ok(())
937}
938
939fn should_skip_field_for_mode(field: &Field, mode: PhaseRewrite) -> Result<bool> {
940 let Some(attrs) = SerdeFieldAttrs::from_attributes(field.attributes())? else {
941 return Ok(false);
942 };
943
944 Ok(match mode {
945 PhaseRewrite::Serialize => attrs.skip_serializing,
946 PhaseRewrite::Deserialize => attrs.skip_deserializing,
947 PhaseRewrite::Unified => attrs.skip_serializing || attrs.skip_deserializing,
948 })
949}
950
951fn skipped_field_marker(field: &Field) -> Field {
952 let mut skipped = Field::default();
953 skipped.set_optional(field.optional());
954 skipped.set_flatten(field.flatten());
955 skipped.set_deprecated(field.deprecated().cloned());
956 skipped.set_docs(field.docs().clone());
957 skipped.set_inline(field.inline());
958 skipped.set_type_overridden(field.type_overridden());
959 skipped.set_attributes(field.attributes().clone());
960 skipped
961}
962
963fn unnamed_live_fields(unnamed: &UnnamedFields) -> impl Iterator<Item = &Field> {
964 unnamed.fields().iter().filter(|field| field.ty().is_some())
965}
966
967fn unnamed_live_field_count(unnamed: &UnnamedFields) -> usize {
968 unnamed_live_fields(unnamed).count()
969}
970
971fn unnamed_has_effective_payload(unnamed: &UnnamedFields) -> bool {
972 unnamed_live_field_count(unnamed) != 0
973}
974
975fn unnamed_fields_all_skipped(unnamed: &UnnamedFields) -> bool {
976 !unnamed.fields().is_empty() && !unnamed_has_effective_payload(unnamed)
977}
978
979fn rewrite_enum_repr_for_phase(
980 e: &mut Enum,
981 mode: PhaseRewrite,
982 original_types: &Types,
983) -> Result<()> {
984 let repr = enum_repr_from_attrs(e.attributes())?;
985 if matches!(repr, EnumRepr::Untagged) {
986 return Ok(());
987 }
988
989 let container_attrs = SerdeContainerAttrs::from_attributes(e.attributes())?;
990 let variants = std::mem::take(e.variants_mut());
991 let mut transformed = Vec::with_capacity(variants.len());
992 for (variant_name, variant) in variants {
993 if variant.skip() {
994 continue;
995 }
996
997 let variant_attrs = SerdeVariantAttrs::from_attributes(variant.attributes())?;
998 if variant_attrs
999 .as_ref()
1000 .is_some_and(|attrs| variant_is_skipped_for_mode(attrs, mode))
1001 {
1002 continue;
1003 }
1004
1005 if variant_attrs.as_ref().is_some_and(|attrs| attrs.untagged) {
1006 transformed.push((
1007 Cow::Owned(variant_name.into_owned()),
1008 transform_untagged_variant(&variant)?,
1009 ));
1010 continue;
1011 }
1012
1013 let serialized_name =
1014 serialized_variant_name(&variant_name, &variant, &container_attrs, mode)?;
1015 let widen_tag =
1016 mode == PhaseRewrite::Deserialize && variant_attrs.is_some_and(|attrs| attrs.other);
1017 let transformed_variant = match &repr {
1018 EnumRepr::External => transform_external_variant(serialized_name.clone(), &variant)?,
1019 EnumRepr::Internal { tag } => transform_internal_variant(
1020 serialized_name.clone(),
1021 tag.as_ref(),
1022 &variant,
1023 original_types,
1024 widen_tag,
1025 )?,
1026 EnumRepr::Adjacent { tag, content } => {
1027 if tag == content {
1028 return Err(Error::invalid_enum_representation(
1029 "serde adjacent tagging requires distinct `tag` and `content` field names",
1030 ));
1031 }
1032
1033 transform_adjacent_variant(
1034 serialized_name.clone(),
1035 tag.as_ref(),
1036 content.as_ref(),
1037 &variant,
1038 widen_tag,
1039 )?
1040 }
1041 EnumRepr::Untagged => unreachable!(),
1042 };
1043
1044 transformed.push((Cow::Owned(serialized_name), transformed_variant));
1045 }
1046
1047 *e.variants_mut() = transformed;
1048
1049 Ok(())
1050}
1051
1052fn rewrite_identifier_enum_for_phase(
1053 e: &mut Enum,
1054 mode: PhaseRewrite,
1055 original_types: &Types,
1056 generated: &HashMap<TypeIdentity, SplitGeneratedTypes>,
1057 split_types: &HashSet<TypeIdentity>,
1058) -> Result<bool> {
1059 let Some(attrs) = SerdeContainerAttrs::from_attributes(e.attributes())? else {
1060 return Ok(false);
1061 };
1062
1063 if !attrs.variant_identifier && !attrs.field_identifier {
1064 return Ok(false);
1065 }
1066
1067 if mode != PhaseRewrite::Deserialize {
1068 return Ok(false);
1069 }
1070
1071 let container_attrs = SerdeContainerAttrs::from_attributes(e.attributes())?;
1072 let mut variants = Vec::new();
1073 let mut seen = HashSet::new();
1074
1075 for (variant_name, variant) in e.variants().iter() {
1076 let serialized_name = serialized_variant_name(
1077 variant_name,
1078 variant,
1079 &container_attrs,
1080 PhaseRewrite::Deserialize,
1081 )?;
1082
1083 if seen.insert(serialized_name.clone()) {
1084 variants.push((
1085 Cow::Owned(serialized_name.clone()),
1086 identifier_union_variant(string_literal_datatype(serialized_name)),
1087 ));
1088 }
1089
1090 if let Some(variant_attrs) = SerdeVariantAttrs::from_attributes(variant.attributes())? {
1091 for alias in &variant_attrs.aliases {
1092 if seen.insert(alias.clone()) {
1093 variants.push((
1094 Cow::Owned(alias.clone()),
1095 identifier_union_variant(string_literal_datatype(alias.clone())),
1096 ));
1097 }
1098 }
1099 }
1100 }
1101
1102 variants.push((
1103 Cow::Borrowed("__specta_identifier_index"),
1104 identifier_union_variant(DataType::Primitive(specta::datatype::Primitive::u32)),
1105 ));
1106
1107 if attrs.field_identifier
1108 && let Some((_, fallback)) = e.variants().last()
1109 && let Fields::Unnamed(unnamed) = fallback.fields()
1110 && let Some(field) = unnamed.fields().first()
1111 && let Some(ty) = field.ty()
1112 {
1113 let mut fallback_ty = ty.clone();
1114 rewrite_datatype_for_phase(
1115 &mut fallback_ty,
1116 mode,
1117 original_types,
1118 generated,
1119 split_types,
1120 None,
1121 )?;
1122 variants.push((
1123 Cow::Borrowed("__specta_identifier_other"),
1124 identifier_union_variant(fallback_ty),
1125 ));
1126 }
1127
1128 *e.variants_mut() = variants;
1129 Ok(true)
1130}
1131
1132fn container_rename_all_rule(
1133 attrs: &specta::datatype::Attributes,
1134 mode: PhaseRewrite,
1135 context: &str,
1136 container_name: &str,
1137) -> Result<Option<RenameRule>> {
1138 let attrs = SerdeContainerAttrs::from_attributes(attrs)?;
1139
1140 select_phase_rule(
1141 mode,
1142 attrs.as_ref().and_then(|attrs| attrs.rename_all_serialize),
1143 attrs
1144 .as_ref()
1145 .and_then(|attrs| attrs.rename_all_deserialize),
1146 context,
1147 container_name,
1148 )
1149}
1150
1151fn enum_variant_field_rename_rule(
1152 container_attrs: &Option<SerdeContainerAttrs>,
1153 variant: &Variant,
1154 mode: PhaseRewrite,
1155 variant_name: &str,
1156) -> Result<Option<RenameRule>> {
1157 let variant_attrs = SerdeVariantAttrs::from_attributes(variant.attributes())?;
1158
1159 let variant_rule = select_phase_rule(
1160 mode,
1161 variant_attrs
1162 .as_ref()
1163 .and_then(|attrs| attrs.rename_all_serialize),
1164 variant_attrs
1165 .as_ref()
1166 .and_then(|attrs| attrs.rename_all_deserialize),
1167 "enum variant rename_all",
1168 variant_name,
1169 )?;
1170
1171 if variant_rule.is_some() {
1172 return Ok(variant_rule);
1173 }
1174
1175 select_phase_rule(
1176 mode,
1177 container_attrs
1178 .as_ref()
1179 .and_then(|attrs| attrs.rename_all_fields_serialize),
1180 container_attrs
1181 .as_ref()
1182 .and_then(|attrs| attrs.rename_all_fields_deserialize),
1183 "enum rename_all_fields",
1184 variant_name,
1185 )
1186}
1187
1188fn identifier_union_variant(ty: DataType) -> Variant {
1189 let mut variant = Variant::unnamed().build();
1190 if let Fields::Unnamed(fields) = variant.fields_mut() {
1191 fields.fields_mut().push(Field::new(ty));
1192 }
1193 variant
1194}
1195
1196fn transform_untagged_variant(variant: &Variant) -> Result<Variant> {
1197 let payload = variant_payload_field(variant)
1198 .ok_or_else(|| Error::invalid_external_tagged_variant("<untagged variant>"))?;
1199 Ok(clone_variant_with_unnamed_fields(variant, vec![payload]))
1200}
1201
1202fn filter_enum_variants_for_phase(e: &mut Enum, mode: PhaseRewrite) -> Result<()> {
1203 let mut filter_err = None;
1204 e.variants_mut().retain(|(_, variant)| {
1205 if variant.skip() {
1206 return false;
1207 }
1208
1209 match SerdeVariantAttrs::from_attributes(variant.attributes()) {
1210 Ok(Some(attrs)) => !variant_is_skipped_for_mode(&attrs, mode),
1211 Ok(None) => true,
1212 Err(err) => {
1213 filter_err = Some(err);
1214 true
1215 }
1216 }
1217 });
1218
1219 if let Some(err) = filter_err {
1220 return Err(err);
1221 }
1222
1223 Ok(())
1224}
1225
1226fn variant_is_skipped_for_mode(attrs: &SerdeVariantAttrs, mode: PhaseRewrite) -> bool {
1227 match mode {
1228 PhaseRewrite::Serialize => attrs.skip_serializing,
1229 PhaseRewrite::Deserialize => attrs.skip_deserializing,
1230 PhaseRewrite::Unified => attrs.skip_serializing || attrs.skip_deserializing,
1231 }
1232}
1233
1234fn enum_repr_from_attrs(attrs: &specta::datatype::Attributes) -> Result<EnumRepr> {
1235 let Some(container_attrs) = SerdeContainerAttrs::from_attributes(attrs)? else {
1236 return Ok(EnumRepr::External);
1237 };
1238
1239 if container_attrs.untagged {
1240 return Ok(EnumRepr::Untagged);
1241 }
1242
1243 Ok(
1244 match (
1245 container_attrs.tag.as_deref(),
1246 container_attrs.content.as_deref(),
1247 ) {
1248 (Some(tag), Some(content)) => EnumRepr::Adjacent {
1249 tag: Cow::Owned(tag.to_string()),
1250 content: Cow::Owned(content.to_string()),
1251 },
1252 (Some(tag), None) => EnumRepr::Internal {
1253 tag: Cow::Owned(tag.to_string()),
1254 },
1255 (None, Some(_)) => {
1256 return Err(Error::invalid_enum_representation(
1257 "`content` is set without `tag`",
1258 ));
1259 }
1260 (None, None) => EnumRepr::External,
1261 },
1262 )
1263}
1264
1265fn serialized_variant_name(
1266 variant_name: &str,
1267 variant: &Variant,
1268 container_attrs: &Option<SerdeContainerAttrs>,
1269 mode: PhaseRewrite,
1270) -> Result<String> {
1271 let variant_attrs = SerdeVariantAttrs::from_attributes(variant.attributes())?;
1272
1273 if let Some(rename) = select_phase_string(
1274 mode,
1275 variant_attrs
1276 .as_ref()
1277 .and_then(|attrs| attrs.rename_serialize.as_deref()),
1278 variant_attrs
1279 .as_ref()
1280 .and_then(|attrs| attrs.rename_deserialize.as_deref()),
1281 "enum variant rename",
1282 variant_name,
1283 )? {
1284 return Ok(rename.to_string());
1285 }
1286
1287 Ok(select_phase_rule(
1288 mode,
1289 container_attrs
1290 .as_ref()
1291 .and_then(|attrs| attrs.rename_all_serialize),
1292 container_attrs
1293 .as_ref()
1294 .and_then(|attrs| attrs.rename_all_deserialize),
1295 "enum rename_all",
1296 variant_name,
1297 )?
1298 .map_or_else(
1299 || variant_name.to_string(),
1300 |rule| rule.apply_to_variant(variant_name),
1301 ))
1302}
1303
1304fn select_phase_string<'a>(
1305 mode: PhaseRewrite,
1306 serialize: Option<&'a str>,
1307 deserialize: Option<&'a str>,
1308 context: &str,
1309 name: &str,
1310) -> Result<Option<&'a str>> {
1311 Ok(match mode {
1312 PhaseRewrite::Serialize => serialize,
1313 PhaseRewrite::Deserialize => deserialize,
1314 PhaseRewrite::Unified => match (serialize, deserialize) {
1315 (Some(serialize), Some(deserialize)) if serialize != deserialize => {
1316 return Err(Error::incompatible_rename(
1317 context.to_string(),
1318 name,
1319 Some(serialize.to_string()),
1320 Some(deserialize.to_string()),
1321 ));
1322 }
1323 (serialize, deserialize) => serialize.or(deserialize),
1324 },
1325 })
1326}
1327
1328fn select_phase_rule(
1329 mode: PhaseRewrite,
1330 serialize: Option<RenameRule>,
1331 deserialize: Option<RenameRule>,
1332 context: &str,
1333 name: &str,
1334) -> Result<Option<RenameRule>> {
1335 Ok(match mode {
1336 PhaseRewrite::Serialize => serialize,
1337 PhaseRewrite::Deserialize => deserialize,
1338 PhaseRewrite::Unified => match (serialize, deserialize) {
1339 (Some(serialize), Some(deserialize)) if serialize != deserialize => {
1340 return Err(Error::incompatible_rename(
1341 context.to_string(),
1342 name,
1343 Some(format!("{serialize:?}")),
1344 Some(format!("{deserialize:?}")),
1345 ));
1346 }
1347 (serialize, deserialize) => serialize.or(deserialize),
1348 },
1349 })
1350}
1351
1352fn resolve_phased_type(ty: &DataType, mode: PhaseRewrite, path: &str) -> Result<Option<DataType>> {
1353 let DataType::Reference(Reference::Opaque(reference)) = ty else {
1354 return Ok(None);
1355 };
1356 let Some(phased) = reference.downcast_ref::<PhasedTy>() else {
1357 return Ok(None);
1358 };
1359
1360 Ok(match mode {
1361 PhaseRewrite::Unified => {
1363 return Err(Error::invalid_phased_type_usage(
1364 path,
1365 "`specta_serde::Phased<Serialize, Deserialize>` requires `apply_phases`",
1366 ));
1367 }
1368 PhaseRewrite::Serialize => Some(phased.serialize.clone()),
1369 PhaseRewrite::Deserialize => Some(phased.deserialize.clone()),
1370 })
1371}
1372
1373fn conversion_datatype_for_mode(ty: &DataType, mode: PhaseRewrite) -> Result<Option<DataType>> {
1374 let attrs = match ty {
1375 DataType::Struct(s) => s.attributes(),
1376 DataType::Enum(e) => e.attributes(),
1377 _ => return Ok(None),
1378 };
1379
1380 select_conversion_target(attrs, mode)
1381}
1382
1383fn select_conversion_target(
1384 attrs: &specta::datatype::Attributes,
1385 mode: PhaseRewrite,
1386) -> Result<Option<DataType>> {
1387 let parsed = SerdeContainerAttrs::from_attributes(attrs)?;
1388 let resolved = parsed.as_ref();
1389
1390 let serialize_target = resolved.and_then(|v| v.resolved_into.as_ref());
1391 let deserialize_target =
1392 resolved.and_then(|v| v.resolved_from.as_ref().or(v.resolved_try_from.as_ref()));
1393
1394 match mode {
1395 PhaseRewrite::Serialize => Ok(serialize_target.cloned()),
1396 PhaseRewrite::Deserialize => Ok(deserialize_target.cloned()),
1397 PhaseRewrite::Unified => match (serialize_target, deserialize_target) {
1398 (None, None) => Ok(None),
1399 (Some(serialize), Some(deserialize)) if serialize == deserialize => {
1400 Ok(Some(serialize.clone()))
1401 }
1402 _ => Err(Error::incompatible_conversion(
1403 "container conversion",
1404 conversion_name(attrs)?,
1405 serialize_conversion_name(parsed.as_ref()),
1406 deserialize_conversion_name(parsed.as_ref()),
1407 )),
1408 },
1409 }
1410}
1411
1412fn conversion_name(attrs: &specta::datatype::Attributes) -> Result<String> {
1413 Ok(SerdeContainerAttrs::from_attributes(attrs)?
1414 .and_then(|attrs| {
1415 attrs
1416 .into
1417 .as_ref()
1418 .map(|v| format!("into({})", v.type_src))
1419 .or_else(|| attrs.from.as_ref().map(|v| format!("from({})", v.type_src)))
1420 .or_else(|| {
1421 attrs
1422 .try_from
1423 .as_ref()
1424 .map(|v| format!("try_from({})", v.type_src))
1425 })
1426 })
1427 .unwrap_or_else(|| "<container>".to_string()))
1428}
1429
1430fn serialize_conversion_name(attrs: Option<&SerdeContainerAttrs>) -> Option<String> {
1431 attrs.and_then(|attrs| attrs.into.as_ref().map(|v| v.type_src.clone()))
1432}
1433
1434fn deserialize_conversion_name(attrs: Option<&SerdeContainerAttrs>) -> Option<String> {
1435 attrs.and_then(|attrs| {
1436 attrs.from.as_ref().map(|v| v.type_src.clone()).or_else(|| {
1437 attrs
1438 .try_from
1439 .as_ref()
1440 .map(|v| format!("try_from({})", v.type_src))
1441 })
1442 })
1443}
1444
1445fn transform_external_variant(serialized_name: String, variant: &Variant) -> Result<Variant> {
1446 let skipped_only_unnamed = match variant.fields() {
1447 Fields::Unnamed(unnamed) => unnamed_fields_all_skipped(unnamed),
1448 Fields::Unit | Fields::Named(_) => false,
1449 };
1450
1451 Ok(match variant.fields() {
1452 Fields::Unit => clone_variant_with_unnamed_fields(
1453 variant,
1454 vec![Field::new(string_literal_datatype(serialized_name))],
1455 ),
1456 _ if skipped_only_unnamed => clone_variant_with_unnamed_fields(
1457 variant,
1458 vec![Field::new(string_literal_datatype(serialized_name))],
1459 ),
1460 _ => {
1461 let payload = variant_payload_field(variant)
1462 .ok_or_else(|| Error::invalid_external_tagged_variant(serialized_name.clone()))?;
1463
1464 clone_variant_with_named_fields(variant, vec![(Cow::Owned(serialized_name), payload)])
1465 }
1466 })
1467}
1468
1469fn transform_adjacent_variant(
1470 serialized_name: String,
1471 tag: &str,
1472 content: &str,
1473 variant: &Variant,
1474 widen_tag: bool,
1475) -> Result<Variant> {
1476 let mut fields = vec![(
1477 Cow::Owned(tag.to_string()),
1478 Field::new(if widen_tag {
1479 DataType::Primitive(Primitive::str)
1480 } else {
1481 string_literal_datatype(serialized_name.clone())
1482 }),
1483 )];
1484
1485 if variant_has_effective_payload(variant) {
1486 let payload = variant_payload_field(variant)
1487 .ok_or_else(|| Error::invalid_adjacent_tagged_variant(serialized_name.clone()))?;
1488 fields.push((Cow::Owned(content.to_string()), payload));
1489 }
1490
1491 Ok(clone_variant_with_named_fields(variant, fields))
1492}
1493
1494fn transform_internal_variant(
1495 serialized_name: String,
1496 tag: &str,
1497 variant: &Variant,
1498 original_types: &Types,
1499 widen_tag: bool,
1500) -> Result<Variant> {
1501 let mut fields = vec![(
1502 Cow::Owned(tag.to_string()),
1503 Field::new(if widen_tag {
1504 DataType::Primitive(Primitive::str)
1505 } else {
1506 string_literal_datatype(serialized_name.clone())
1507 }),
1508 )];
1509
1510 match variant.fields() {
1511 Fields::Unit => {}
1512 Fields::Named(named) => {
1513 fields.extend(named.fields().iter().cloned());
1514 }
1515 Fields::Unnamed(unnamed) => {
1516 let live_field_count = unnamed_live_field_count(unnamed);
1517
1518 if live_field_count == 0 {
1519 return Ok(clone_variant_with_named_fields(variant, fields));
1520 }
1521
1522 let non_skipped = unnamed_live_fields(unnamed).collect::<Vec<_>>();
1523
1524 if live_field_count != 1 {
1525 return Err(Error::invalid_internally_tagged_variant(
1526 serialized_name,
1527 "tuple variant must have exactly one non-skipped field",
1528 ));
1529 }
1530
1531 let payload_field = non_skipped
1532 .into_iter()
1533 .next()
1534 .expect("checked above")
1535 .clone();
1536 let payload_ty = payload_field.ty().cloned().expect("checked above");
1537 let Some(payload_is_effectively_empty) = internal_tag_payload_compatibility(
1538 &payload_ty,
1539 original_types,
1540 &mut HashSet::new(),
1541 )?
1542 else {
1543 return Err(Error::invalid_internally_tagged_variant(
1544 serialized_name,
1545 "payload cannot be merged with a tag",
1546 ));
1547 };
1548
1549 if !payload_is_effectively_empty {
1550 let mut flattened = payload_field;
1551 flattened.set_flatten(true);
1552 fields.push((Cow::Borrowed("__specta_internal_payload"), flattened));
1553 }
1554 }
1555 }
1556
1557 Ok(clone_variant_with_named_fields(variant, fields))
1558}
1559
1560fn string_literal_datatype(value: String) -> DataType {
1561 let mut value_enum = Enum::new();
1562 value_enum
1563 .variants_mut()
1564 .push((Cow::Owned(value), Variant::unit()));
1565 DataType::Enum(value_enum)
1566}
1567
1568fn variant_has_effective_payload(variant: &Variant) -> bool {
1569 match variant.fields() {
1570 Fields::Unit => false,
1571 Fields::Named(named) => !named.fields().is_empty(),
1572 Fields::Unnamed(unnamed) => unnamed_has_effective_payload(unnamed),
1573 }
1574}
1575
1576fn variant_payload_field(variant: &Variant) -> Option<Field> {
1577 match variant.fields() {
1578 Fields::Unit => Some(Field::new(DataType::Tuple(Tuple::new(vec![])))),
1579 Fields::Named(named) => {
1580 let mut out = Struct::named();
1581 for (name, field) in named.fields().iter().cloned() {
1582 out.field_mut(name, field);
1583 }
1584 Some(Field::new(out.build()))
1585 }
1586 Fields::Unnamed(unnamed) => {
1587 let original_unnamed_len = unnamed.fields().len();
1588
1589 let non_skipped = unnamed_live_fields(unnamed).collect::<Vec<_>>();
1590
1591 match non_skipped.as_slice() {
1592 [] => Some(Field::new(DataType::Tuple(Tuple::new(vec![])))),
1593 [single] if original_unnamed_len == 1 => Some((*single).clone()),
1594 _ => Some(Field::new(DataType::Tuple(Tuple::new(
1595 non_skipped
1596 .iter()
1597 .filter_map(|field| field.ty().cloned())
1598 .collect(),
1599 )))),
1600 }
1601 }
1602 }
1603}
1604
1605fn clone_variant_with_named_fields(
1606 original: &Variant,
1607 fields: Vec<(Cow<'static, str>, Field)>,
1608) -> Variant {
1609 let mut builder = Variant::named();
1610 for (name, field) in fields {
1611 builder = builder.field(name, field);
1612 }
1613
1614 let mut transformed = builder.build();
1615 transformed.set_skip(original.skip());
1616 transformed.set_docs(original.docs().clone());
1617 transformed.set_deprecated(original.deprecated().cloned());
1618 transformed.set_type_overridden(original.type_overridden());
1619 *transformed.attributes_mut() = original.attributes().clone();
1620 transformed
1621}
1622
1623fn clone_variant_with_unnamed_fields(original: &Variant, fields: Vec<Field>) -> Variant {
1624 let mut builder = Variant::unnamed();
1625 for field in fields {
1626 builder = builder.field(field);
1627 }
1628
1629 let mut transformed = builder.build();
1630 transformed.set_skip(original.skip());
1631 transformed.set_docs(original.docs().clone());
1632 transformed.set_deprecated(original.deprecated().cloned());
1633 transformed.set_type_overridden(original.type_overridden());
1634 *transformed.attributes_mut() = original.attributes().clone();
1635 transformed
1636}
1637
1638fn internal_tag_payload_compatibility(
1639 ty: &DataType,
1640 original_types: &Types,
1641 seen: &mut HashSet<TypeIdentity>,
1642) -> Result<Option<bool>> {
1643 match ty {
1644 DataType::Map(_) => Ok(Some(false)),
1645 DataType::Struct(strct) => {
1646 if SerdeContainerAttrs::from_attributes(strct.attributes())?
1647 .is_some_and(|attrs| attrs.transparent)
1648 {
1649 let payload_fields = match strct.fields() {
1650 Fields::Unit => return Ok(Some(true)),
1651 Fields::Unnamed(unnamed) => unnamed
1652 .fields()
1653 .iter()
1654 .filter_map(Field::ty)
1655 .collect::<Vec<_>>(),
1656 Fields::Named(named) => named
1657 .fields()
1658 .iter()
1659 .filter_map(|(_, field)| field.ty())
1660 .collect::<Vec<_>>(),
1661 };
1662
1663 let [inner_ty] = payload_fields.as_slice() else {
1664 if payload_fields.is_empty() {
1665 return Ok(Some(true));
1666 }
1667
1668 return Ok(None);
1669 };
1670
1671 return internal_tag_payload_compatibility(inner_ty, original_types, seen);
1672 }
1673
1674 Ok(match strct.fields() {
1675 Fields::Named(named) => {
1676 Some(named.fields().iter().all(|(_, field)| field.ty().is_none()))
1677 }
1678 Fields::Unit | Fields::Unnamed(_) => None,
1679 })
1680 }
1681 DataType::Tuple(tuple) => Ok(tuple.elements().is_empty().then_some(true)),
1682 DataType::Reference(Reference::Named(reference)) => {
1683 let Some(referenced) = reference.get(original_types) else {
1684 return Ok(None);
1685 };
1686
1687 let key = TypeIdentity::from_ndt(referenced);
1688 if !seen.insert(key.clone()) {
1689 return Ok(Some(false));
1690 }
1691
1692 let compatible =
1693 internal_tag_payload_compatibility(referenced.ty(), original_types, seen);
1694 seen.remove(&key);
1695 compatible
1696 }
1697 DataType::Enum(enm) => match enum_repr_from_attrs(enm.attributes()) {
1698 Ok(EnumRepr::Untagged) => {
1699 let mut is_effectively_empty = true;
1700 for (_, variant) in enm.variants() {
1701 let Some(variant_empty) =
1702 internal_tag_variant_payload_compatibility(variant, original_types, seen)?
1703 else {
1704 return Ok(None);
1705 };
1706
1707 is_effectively_empty &= variant_empty;
1708 }
1709
1710 Ok(Some(is_effectively_empty))
1711 }
1712 Ok(EnumRepr::External | EnumRepr::Internal { .. } | EnumRepr::Adjacent { .. }) => {
1713 Ok(Some(false))
1714 }
1715 Err(_) => Ok(None),
1716 },
1717 DataType::Primitive(_)
1718 | DataType::List(_)
1719 | DataType::Nullable(_)
1720 | DataType::Reference(Reference::Generic(_))
1721 | DataType::Reference(Reference::Opaque(_)) => Ok(None),
1722 }
1723}
1724
1725fn internal_tag_variant_payload_compatibility(
1726 variant: &Variant,
1727 original_types: &Types,
1728 seen: &mut HashSet<TypeIdentity>,
1729) -> Result<Option<bool>> {
1730 match variant.fields() {
1731 Fields::Unit => Ok(Some(true)),
1732 Fields::Named(named) => Ok(Some(
1733 named.fields().iter().all(|(_, field)| field.ty().is_none()),
1734 )),
1735 Fields::Unnamed(unnamed) => {
1736 if unnamed.fields().len() != 1 {
1737 return Ok(None);
1738 }
1739
1740 unnamed
1741 .fields()
1742 .iter()
1743 .find_map(|field| field.ty())
1744 .map_or(Ok(None), |ty| {
1745 internal_tag_payload_compatibility(ty, original_types, seen)
1746 })
1747 }
1748 }
1749}
1750
1751fn has_local_phase_difference(dt: &DataType) -> Result<bool> {
1752 match dt {
1753 DataType::Struct(s) => Ok(container_has_local_difference(s.attributes())?
1754 || fields_have_local_difference(s.fields())?),
1755 DataType::Enum(e) => Ok(container_has_local_difference(e.attributes())?
1756 || e.variants()
1757 .iter()
1758 .try_fold(false, |has_difference, (_, variant)| {
1759 if has_difference {
1760 return Ok(true);
1761 }
1762
1763 Ok(variant_has_local_difference(variant)?
1764 || fields_have_local_difference(variant.fields())?)
1765 })?),
1766 DataType::Tuple(tuple) => tuple
1767 .elements()
1768 .iter()
1769 .try_fold(false, |has_difference, ty| {
1770 if has_difference {
1771 return Ok(true);
1772 }
1773
1774 has_local_phase_difference(ty)
1775 }),
1776 DataType::List(list) => has_local_phase_difference(list.ty()),
1777 DataType::Map(map) => Ok(has_local_phase_difference(map.key_ty())?
1778 || has_local_phase_difference(map.value_ty())?),
1779 DataType::Nullable(inner) => has_local_phase_difference(inner),
1780 DataType::Reference(Reference::Opaque(reference)) => {
1781 Ok(reference.downcast_ref::<PhasedTy>().is_some())
1782 }
1783 DataType::Primitive(_)
1784 | DataType::Reference(Reference::Named(_))
1785 | DataType::Reference(Reference::Generic(_)) => Ok(false),
1786 }
1787}
1788
1789fn container_has_local_difference(attrs: &specta::datatype::Attributes) -> Result<bool> {
1790 let Some(conversions) = SerdeContainerAttrs::from_attributes(attrs)? else {
1791 return Ok(false);
1792 };
1793
1794 Ok(conversions.resolved_into.as_ref()
1795 != conversions
1796 .resolved_from
1797 .as_ref()
1798 .or(conversions.resolved_try_from.as_ref())
1799 || conversions.rename_serialize != conversions.rename_deserialize
1800 || conversions.rename_all_serialize != conversions.rename_all_deserialize
1801 || conversions.rename_all_fields_serialize != conversions.rename_all_fields_deserialize
1802 || conversions.variant_identifier
1803 || conversions.field_identifier)
1804}
1805
1806fn fields_have_local_difference(fields: &Fields) -> Result<bool> {
1807 match fields {
1808 Fields::Unit => Ok(false),
1809 Fields::Unnamed(unnamed) => {
1810 unnamed
1811 .fields()
1812 .iter()
1813 .try_fold(false, |has_difference, field| {
1814 if has_difference {
1815 return Ok(true);
1816 }
1817
1818 field.ty().map_or(Ok(false), has_local_phase_difference)
1819 })
1820 }
1821 Fields::Named(named) => {
1822 named
1823 .fields()
1824 .iter()
1825 .try_fold(false, |has_difference, (_, field)| {
1826 if has_difference {
1827 return Ok(true);
1828 }
1829
1830 Ok(field_has_local_difference(field)?
1831 || field.ty().map_or(Ok(false), has_local_phase_difference)?)
1832 })
1833 }
1834 }
1835}
1836
1837fn field_has_local_difference(field: &Field) -> Result<bool> {
1838 Ok(SerdeFieldAttrs::from_attributes(field.attributes())?
1839 .map(|attrs| {
1840 attrs.rename_serialize.as_deref() != attrs.rename_deserialize.as_deref()
1841 || attrs.skip_serializing != attrs.skip_deserializing
1842 || attrs.skip_serializing_if.is_some()
1843 || attrs.has_serialize_with
1844 || attrs.has_deserialize_with
1845 || attrs.has_with
1846 })
1847 .unwrap_or_default())
1848}
1849
1850fn variant_has_local_difference(variant: &Variant) -> Result<bool> {
1851 Ok(SerdeVariantAttrs::from_attributes(variant.attributes())?
1852 .map(|attrs| {
1853 attrs.rename_serialize.as_deref() != attrs.rename_deserialize.as_deref()
1854 || attrs.rename_all_serialize != attrs.rename_all_deserialize
1855 || attrs.skip_serializing != attrs.skip_deserializing
1856 || attrs.has_serialize_with
1857 || attrs.has_deserialize_with
1858 || attrs.has_with
1859 || attrs.other
1860 })
1861 .unwrap_or_default())
1862}
1863
1864fn collect_dependencies(
1865 dt: &DataType,
1866 types: &Types,
1867 deps: &mut HashSet<TypeIdentity>,
1868) -> Result<()> {
1869 match dt {
1870 DataType::Struct(s) => {
1871 collect_conversion_dependencies(s.attributes(), types, deps)?;
1872 collect_fields_dependencies(s.fields(), types, deps)?;
1873 }
1874 DataType::Enum(e) => {
1875 collect_conversion_dependencies(e.attributes(), types, deps)?;
1876 for (_, variant) in e.variants() {
1877 collect_fields_dependencies(variant.fields(), types, deps)?;
1878 }
1879 }
1880 DataType::Tuple(tuple) => {
1881 for ty in tuple.elements() {
1882 collect_dependencies(ty, types, deps)?;
1883 }
1884 }
1885 DataType::List(list) => collect_dependencies(list.ty(), types, deps)?,
1886 DataType::Map(map) => {
1887 collect_dependencies(map.key_ty(), types, deps)?;
1888 collect_dependencies(map.value_ty(), types, deps)?;
1889 }
1890 DataType::Nullable(inner) => collect_dependencies(inner, types, deps)?,
1891 DataType::Reference(Reference::Named(reference)) => {
1892 if let Some(referenced) = reference.get(types) {
1893 deps.insert(TypeIdentity::from_ndt(referenced));
1894 }
1895
1896 for (_, generic) in reference.generics() {
1897 collect_dependencies(generic, types, deps)?;
1898 }
1899 }
1900 DataType::Reference(Reference::Opaque(_)) => {
1901 if let DataType::Reference(Reference::Opaque(reference)) = dt
1902 && let Some(phased) = reference.downcast_ref::<PhasedTy>()
1903 {
1904 collect_dependencies(&phased.serialize, types, deps)?;
1905 collect_dependencies(&phased.deserialize, types, deps)?;
1906 }
1907 }
1908 DataType::Primitive(_) | DataType::Reference(Reference::Generic(_)) => {}
1909 }
1910
1911 Ok(())
1912}
1913
1914fn collect_conversion_dependencies(
1915 attrs: &specta::datatype::Attributes,
1916 types: &Types,
1917 deps: &mut HashSet<TypeIdentity>,
1918) -> Result<()> {
1919 let Some(conversions) = SerdeContainerAttrs::from_attributes(attrs)? else {
1920 return Ok(());
1921 };
1922
1923 for conversion in [
1924 conversions.resolved_into.as_ref(),
1925 conversions.resolved_from.as_ref(),
1926 conversions.resolved_try_from.as_ref(),
1927 ]
1928 .into_iter()
1929 .flatten()
1930 {
1931 collect_dependencies(conversion, types, deps)?;
1932 }
1933
1934 Ok(())
1935}
1936
1937fn collect_fields_dependencies(
1938 fields: &Fields,
1939 types: &Types,
1940 deps: &mut HashSet<TypeIdentity>,
1941) -> Result<()> {
1942 match fields {
1943 Fields::Unit => {}
1944 Fields::Unnamed(unnamed) => {
1945 for field in unnamed.fields() {
1946 if let Some(ty) = field.ty() {
1947 collect_dependencies(ty, types, deps)?;
1948 }
1949 }
1950 }
1951 Fields::Named(named) => {
1952 for (_, field) in named.fields() {
1953 if let Some(ty) = field.ty() {
1954 collect_dependencies(ty, types, deps)?;
1955 }
1956 }
1957 }
1958 }
1959
1960 Ok(())
1961}
1962
1963fn build_from_original(
1964 original: &NamedDataType,
1965 name: impl Into<Cow<'static, str>>,
1966 generics: Vec<(specta::datatype::GenericReference, Cow<'static, str>)>,
1967 ty: DataType,
1968 types: &Types,
1969) -> NamedDataType {
1970 let mut ndt = if original.requires_reference(types) {
1971 NamedDataType::new(name, generics, ty)
1972 } else {
1973 NamedDataType::new_inline(name, generics, ty)
1974 };
1975
1976 ndt.set_docs(original.docs().clone());
1977 ndt.set_location(original.location());
1978 ndt.set_module_path(original.module_path().clone());
1979 ndt.set_deprecated(original.deprecated().cloned());
1980
1981 ndt
1982}
1983
1984fn apply_field_attrs(field: &mut Field, mode: PhaseRewrite, container_default: bool) -> Result<()> {
1985 let mut flatten = field.flatten();
1986 let mut optional = field.optional();
1987 if let Some(attrs) = SerdeFieldAttrs::from_attributes(field.attributes())? {
1988 flatten = attrs.flatten;
1989 if field_is_optional_for_mode(Some(&attrs), container_default, mode) {
1990 optional = true;
1991 }
1992 } else if field_is_optional_for_mode(None, container_default, mode) {
1993 optional = true;
1994 }
1995 field.set_flatten(flatten);
1996 field.set_optional(optional);
1997
1998 Ok(())
1999}
2000
2001fn field_is_optional_for_mode(
2002 attrs: Option<&SerdeFieldAttrs>,
2003 container_default: bool,
2004 mode: PhaseRewrite,
2005) -> bool {
2006 match mode {
2007 PhaseRewrite::Serialize => false,
2008 PhaseRewrite::Deserialize | PhaseRewrite::Unified => {
2009 container_default
2010 || attrs.is_some_and(|attrs| attrs.default || attrs.skip_deserializing)
2011 }
2012 }
2013}
2014
2015#[cfg(test)]
2016mod tests {
2017 use serde::{Deserialize, Serialize};
2018 use specta::{ResolvedTypes, Type, datatype::DataType};
2019
2020 use super::{Phase, Phased, apply_phases, select_phase_datatype};
2021
2022 #[derive(Type, Serialize, Deserialize)]
2023 #[serde(untagged)]
2024 enum OneOrManyString {
2025 One(String),
2026 Many(Vec<String>),
2027 }
2028
2029 #[derive(Type, Serialize, Deserialize)]
2030 struct Filters {
2031 #[specta(type = Phased<Vec<String>, OneOrManyString>)]
2032 tags: Vec<String>,
2033 }
2034
2035 #[derive(Type, Serialize, Deserialize)]
2036 struct FilterList {
2037 items: Vec<Filters>,
2038 }
2039
2040 #[derive(Type, Serialize, Deserialize)]
2041 struct Plain {
2042 name: String,
2043 }
2044
2045 #[test]
2046 fn selects_split_named_reference_for_each_phase() {
2047 let mut types = specta::Types::default();
2048 let dt = Filters::definition(&mut types);
2049 let resolved = apply_phases(types).expect("apply_phases should succeed");
2050
2051 let serialize = select_phase_datatype(&dt, &resolved, Phase::Serialize);
2052 let deserialize = select_phase_datatype(&dt, &resolved, Phase::Deserialize);
2053
2054 assert_named_reference(&serialize, &resolved, "Filters_Serialize");
2055 assert_named_reference(&deserialize, &resolved, "Filters_Deserialize");
2056 }
2057
2058 #[test]
2059 fn rewrites_nested_generics_for_each_phase() {
2060 let mut types = specta::Types::default();
2061 let dt = FilterList::definition(&mut types);
2062 let resolved = apply_phases(types).expect("apply_phases should succeed");
2063
2064 let serialize = select_phase_datatype(&dt, &resolved, Phase::Serialize);
2065 let deserialize = select_phase_datatype(&dt, &resolved, Phase::Deserialize);
2066
2067 assert_named_reference(&serialize, &resolved, "FilterList_Serialize");
2068 assert_named_reference(&deserialize, &resolved, "FilterList_Deserialize");
2069
2070 let serialize_inner = named_field_type(&serialize, &resolved, "items");
2071 let deserialize_inner = named_field_type(&deserialize, &resolved, "items");
2072
2073 assert_named_reference(
2074 first_generic_type(serialize_inner),
2075 &resolved,
2076 "Filters_Serialize",
2077 );
2078 assert_named_reference(
2079 first_generic_type(deserialize_inner),
2080 &resolved,
2081 "Filters_Deserialize",
2082 );
2083 }
2084
2085 #[test]
2086 fn preserves_unsplit_types() {
2087 let mut types = specta::Types::default();
2088 let dt = Plain::definition(&mut types);
2089 let resolved = apply_phases(types).expect("apply_phases should succeed");
2090
2091 let serialize = select_phase_datatype(&dt, &resolved, Phase::Serialize);
2092 let deserialize = select_phase_datatype(&dt, &resolved, Phase::Deserialize);
2093
2094 assert_named_reference(&serialize, &resolved, "Plain");
2095 assert_named_reference(&deserialize, &resolved, "Plain");
2096 }
2097
2098 #[test]
2099 fn resolves_explicit_phased_datatypes_without_named_types() {
2100 let mut types = specta::Types::default();
2101 let dt = <Phased<String, Vec<String>>>::definition(&mut types);
2102 let resolved = apply_phases(types).expect("apply_phases should succeed");
2103
2104 let serialize = select_phase_datatype(&dt, &resolved, Phase::Serialize);
2105 let deserialize = select_phase_datatype(&dt, &resolved, Phase::Deserialize);
2106
2107 assert_named_reference(&serialize, &resolved, "String");
2108 assert_named_reference(first_generic_type(&deserialize), &resolved, "String");
2109 }
2110
2111 fn assert_named_reference(dt: &DataType, types: &ResolvedTypes, expected_name: &str) {
2112 let DataType::Reference(specta::datatype::Reference::Named(reference)) = dt else {
2113 panic!("expected named reference");
2114 };
2115
2116 let actual = reference
2117 .get(types.as_types())
2118 .expect("reference should resolve")
2119 .name();
2120
2121 assert_eq!(actual, expected_name);
2122 }
2123
2124 fn named_field_type<'a>(
2125 dt: &'a DataType,
2126 types: &'a ResolvedTypes,
2127 field_name: &str,
2128 ) -> &'a DataType {
2129 let DataType::Reference(specta::datatype::Reference::Named(reference)) = dt else {
2130 panic!("expected named reference");
2131 };
2132
2133 let named = reference
2134 .get(types.as_types())
2135 .expect("reference should resolve");
2136 let DataType::Struct(strct) = named.ty() else {
2137 panic!("expected struct type");
2138 };
2139 let specta::datatype::Fields::Named(fields) = strct.fields() else {
2140 panic!("expected named fields");
2141 };
2142
2143 fields
2144 .fields()
2145 .iter()
2146 .find_map(|(name, field)| (name == field_name).then(|| field.ty()).flatten())
2147 .expect("field should exist")
2148 }
2149
2150 fn first_generic_type(dt: &DataType) -> &DataType {
2151 let DataType::Reference(specta::datatype::Reference::Named(reference)) = dt else {
2152 panic!("expected named reference with generics");
2153 };
2154
2155 reference
2156 .generics()
2157 .first()
2158 .map(|(_, dt)| dt)
2159 .expect("expected first generic type")
2160 }
2161}