Skip to main content

specta_util/
remapper.rs

1use specta::{
2    Types,
3    datatype::{DataType, Fields, NamedReferenceType, Reference},
4};
5
6/// Recursively replaces [`DataType`]s within a [`DataType`] structure from a set of remap rules.
7///
8/// `Remapper` is useful when a type should be represented differently for export
9/// without changing the original Rust type or derive output. It performs [`DataType`]
10/// equality checks while walking the [`DataType`] structure applying the transformations.
11///
12/// Rules are applied in the order they are registered. For each visited
13/// [`DataType`], every matching rule is applied, with each rule seeing the
14/// result of the previous matching rule.
15///
16/// <div class="warning">
17///
18/// **WARNING:** This is an advanced API!
19///
20/// It needs to be used carefully as it can easily break the safety of the end to end type contract.
21///
22/// You must ensure you have Serde applying the same transformations to the runtime data for it to be sound.
23///
24/// </div>
25///
26/// # Examples
27///
28/// Remap `u32` to `str` and `i32` to `bool`:
29///
30/// ```rust
31/// use specta::{Types, datatype::{DataType, Field, List, NamedDataType, Primitive, Struct}};
32/// use specta_util::Remapper;
33///
34/// let remapper = Remapper::new()
35///     .rule(Primitive::u32.into(), Primitive::str.into())
36///     .rule(Primitive::i32.into(), Primitive::bool.into());
37///
38/// // For a single `DataType`
39/// assert_eq!(
40///     remapper.remap_dt(DataType::List(List::new(Primitive::u32.into()))),
41///     DataType::List(List::new(Primitive::str.into()))
42/// );
43///
44/// // For a whole collection of types
45/// # #[allow(unused)]
46/// let types: Types = remapper.remap_types(Types::default());
47/// ```
48#[derive(Debug, Clone, Default)]
49pub struct Remapper {
50    rules: Vec<(DataType, DataType)>,
51}
52
53impl Remapper {
54    /// Creates a remapper with no rules.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Registers a rule that replaces exact matches of `from` with `to`.
60    ///
61    /// Rules are checked in the order they are registered.
62    pub fn rule(mut self, from: DataType, to: DataType) -> Self {
63        self.rules.push((from, to));
64        self
65    }
66
67    /// Applies the remap operation to a datatype, returning the remapped datatype.
68    pub fn remap_dt(&self, mut dt: DataType) -> DataType {
69        self.remap_internal(&mut dt);
70        dt
71    }
72
73    /// Applies the remap operation to every datatype in a [`Types`] collection, returning the remapped collection.
74    pub fn remap_types(&self, types: Types) -> Types {
75        types.map(|mut ndt| {
76            ndt.generics.to_mut().iter_mut().for_each(|generic| {
77                if let Some(dt) = &mut generic.default {
78                    self.remap_internal(dt);
79                }
80            });
81            if let Some(dt) = &mut ndt.ty {
82                self.remap_internal(dt);
83            }
84            ndt
85        })
86    }
87
88    fn remap_internal(&self, dt: &mut DataType) {
89        self.remap_rules(dt);
90
91        match dt {
92            DataType::Primitive(_) | DataType::Generic(_) => {}
93            DataType::List(list) => self.remap_internal(&mut list.ty),
94            DataType::Map(map) => {
95                self.remap_internal(map.key_ty_mut());
96                self.remap_internal(map.value_ty_mut());
97            }
98            DataType::Struct(s) => self.remap_fields(&mut s.fields),
99            DataType::Enum(e) => {
100                for (_, variant) in &mut e.variants {
101                    self.remap_fields(&mut variant.fields);
102                }
103            }
104            DataType::Tuple(tuple) => {
105                for dt in &mut tuple.elements {
106                    self.remap_internal(dt);
107                }
108            }
109            DataType::Nullable(dt) => self.remap_internal(dt),
110            DataType::Intersection(dts) => {
111                for dt in dts {
112                    self.remap_internal(dt);
113                }
114            }
115            DataType::Reference(r) => self.remap_reference(r),
116        }
117    }
118
119    fn remap_rules(&self, dt: &mut DataType) {
120        for (from, to) in &self.rules {
121            if *dt == *from {
122                *dt = to.clone();
123            }
124        }
125    }
126
127    fn remap_fields(&self, fields: &mut Fields) {
128        match fields {
129            Fields::Unit => {}
130            Fields::Unnamed(fields) => {
131                for field in &mut fields.fields {
132                    if let Some(dt) = &mut field.ty {
133                        self.remap_internal(dt);
134                    }
135                }
136            }
137            Fields::Named(fields) => {
138                for (_, field) in &mut fields.fields {
139                    if let Some(dt) = &mut field.ty {
140                        self.remap_internal(dt);
141                    }
142                }
143            }
144        }
145    }
146
147    fn remap_reference(&self, reference: &mut Reference) {
148        let Reference::Named(reference) = reference else {
149            return;
150        };
151
152        match &mut reference.inner {
153            NamedReferenceType::Recursive(_) => {}
154            NamedReferenceType::Inline { dt, .. } => self.remap_internal(dt),
155            NamedReferenceType::Reference { generics, .. } => {
156                for (_, dt) in generics {
157                    self.remap_internal(dt);
158                }
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use specta::{
167        Types,
168        datatype::{DataType, Field, List, NamedDataType, Primitive, Struct, Tuple},
169    };
170
171    use super::Remapper;
172
173    #[test]
174    fn remaps_multiple_rules_in_one_crawl() {
175        let dt = DataType::Tuple(Tuple::new(vec![
176            Primitive::u32.into(),
177            Primitive::i32.into(),
178        ]));
179
180        let remapped = Remapper::new()
181            .rule(Primitive::u32.into(), Primitive::str.into())
182            .rule(Primitive::i32.into(), Primitive::bool.into())
183            .remap_dt(dt);
184
185        assert_eq!(
186            remapped,
187            DataType::Tuple(Tuple::new(vec![
188                Primitive::str.into(),
189                Primitive::bool.into()
190            ]))
191        );
192    }
193
194    #[test]
195    fn rules_are_piped_in_registration_order() {
196        let remapped = Remapper::new()
197            .rule(Primitive::u32.into(), Primitive::i32.into())
198            .rule(Primitive::i32.into(), Primitive::bool.into())
199            .remap_dt(Primitive::u32.into());
200
201        assert_eq!(remapped, Primitive::bool.into());
202    }
203
204    #[test]
205    fn replacement_is_recrawled() {
206        let remapped = Remapper::new()
207            .rule(
208                Primitive::u32.into(),
209                DataType::List(List::new(Primitive::i32.into())),
210            )
211            .rule(Primitive::i32.into(), Primitive::bool.into())
212            .remap_dt(Primitive::u32.into());
213
214        assert_eq!(remapped, DataType::List(List::new(Primitive::bool.into())));
215    }
216
217    #[test]
218    fn remaps_named_type_bodies() {
219        let mut types = Types::default();
220        NamedDataType::new("User", &mut types, |_, ty| {
221            ty.ty = Some(
222                Struct::named()
223                    .field("id", Field::new(Primitive::u32.into()))
224                    .field("active", Field::new(Primitive::i32.into()))
225                    .build(),
226            );
227        });
228
229        let types = Remapper::new()
230            .rule(Primitive::u32.into(), Primitive::str.into())
231            .rule(Primitive::i32.into(), Primitive::bool.into())
232            .remap_types(types);
233
234        let debug = format!("{types:?}");
235        assert!(debug.contains("Primitive(str)"));
236        assert!(debug.contains("Primitive(bool)"));
237    }
238}