core_json_derive/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10  vec,
11  vec::Vec,
12  string::{String, ToString},
13  format,
14};
15
16extern crate proc_macro;
17use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
18
19// `<` will not open a group, so we use this to take all items within a `< ... >` expression.
20fn take_angle_expression(
21  iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
22) -> TokenStream {
23  {
24    let Some(peeked) = iter.peek() else { return TokenStream::default() };
25    let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
26    if punct.as_char() != '<' {
27      return TokenStream::default();
28    }
29  }
30
31  let mut result = vec![];
32  let mut count = 0;
33  loop {
34    let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
35    result.push(item.borrow().clone());
36    if let TokenTree::Punct(punct) = item.borrow() {
37      let punct = punct.as_char();
38      if punct == '<' {
39        count += 1;
40      }
41      if punct == '>' {
42        count -= 1;
43      }
44      if count == 0 {
45        break;
46      }
47    }
48  }
49  TokenStream::from_iter(result)
50}
51
52// Advance the iterator past the next `,` on this depth, if there is one.
53fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
54  loop {
55    take_angle_expression(iter);
56    let Some(item) = iter.next() else { return };
57    if let TokenTree::Punct(punct) = item.borrow() {
58      if punct.as_char() == ',' {
59        return;
60      }
61    }
62  }
63}
64
65struct Struct {
66  generic_bounds: String,
67  generics: String,
68  name: String,
69  fields: Vec<(String, String)>,
70}
71
72// This is somewhat comparable to `syn::Generics`, especially its `split_for_impl` method.
73fn parse_struct(object: TokenStream) -> Struct {
74  let mut object = object.into_iter().peekable();
75
76  loop {
77    match object.peek() {
78      Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
79        let _ = object.next().expect("peeked but not present");
80        let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
81          panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
82        };
83      }
84      _ => break,
85    }
86  }
87
88  {
89    let mut vis = false;
90    loop {
91      match object.next() {
92        // Skip visibility modifiers
93        Some(TokenTree::Ident(ident)) if ident.to_string() == "pub" => {
94          if vis {
95            panic!("multiple visibility modifers found");
96          }
97          vis = true;
98        }
99        // This _technically_ allows multiple/invalid arguments re: visibility
100        Some(TokenTree::Group(group)) if vis && (group.delimiter() == Delimiter::Parenthesis) => {}
101        Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => break,
102        _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
103      }
104    }
105  }
106  let name = match object.next() {
107    Some(TokenTree::Ident(ident)) => ident.to_string(),
108    _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
109  };
110
111  let generic_bounds_tree = take_angle_expression(&mut object);
112
113  let mut generics_tree = vec![];
114  {
115    let mut iter = generic_bounds_tree.clone().into_iter().peekable();
116    while let Some(component) = iter.next() {
117      // Take until the next colon, used to mark trait bounds
118      if let TokenTree::Punct(punct) = &component {
119        if punct.as_char() == ':' {
120          // Skip the actual bounds
121          skip_comma_delimited(&mut iter);
122          // Add our own comma delimiter and move to the next item
123          generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
124          continue;
125        }
126      }
127      // Push this component as it isn't part of the bounds
128      generics_tree.push(component);
129    }
130  }
131  // Ensure this is terminated, which it won't be if the last item had bounds yet didn't have a
132  // trailing comma
133  if let Some(last) = generics_tree.last() {
134    match last {
135      TokenTree::Punct(punct) if punct.as_char() == '>' => {}
136      _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
137    }
138  }
139
140  let generic_bounds = generic_bounds_tree.to_string();
141  let generics = TokenStream::from_iter(generics_tree).to_string();
142
143  // This presumably means we don't support `struct`s defined with `where` bounds
144  let Some(TokenTree::Group(struct_body)) = object.next() else {
145    panic!("`struct`'s name was not followed by its body");
146  };
147  if struct_body.delimiter() != Delimiter::Brace {
148    panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
149  }
150
151  let mut fields = vec![];
152
153  let mut struct_body = struct_body.stream().into_iter().peekable();
154  // Read each field within this `struct`'s body
155  while struct_body.peek().is_some() {
156    // Access the field name
157    let mut serialization_field_name = None;
158    let mut field_name = None;
159    let mut skip = false;
160    for item in &mut struct_body {
161      // Handle the `key` attribute
162      if let TokenTree::Group(group) = &item {
163        if group.delimiter() == Delimiter::Bracket {
164          let mut iter = group.stream().into_iter();
165          let ident = iter.next().and_then(|ident| match ident {
166            TokenTree::Ident(ident) => Some(ident.to_string()),
167            _ => None,
168          });
169          match ident.as_deref() {
170            Some("skip") => skip = true,
171            Some("key") => {
172              let TokenTree::Group(group) = iter.next().expect("`key` attribute without arguments")
173              else {
174                panic!("`key` attribute not followed with `(...)`")
175              };
176              assert_eq!(
177                group.delimiter(),
178                Delimiter::Parenthesis,
179                "`key` attribute with a non-parentheses group"
180              );
181              assert_eq!(
182                group.stream().into_iter().count(),
183                1,
184                "`key` attribute with multiple tokens within parentheses"
185              );
186              let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
187                panic!("`key` attribute with a non-literal argument")
188              };
189              let literal = literal.to_string();
190              assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
191              assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
192              serialization_field_name =
193                Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
194            }
195            _ => {}
196          }
197        }
198      }
199
200      if let TokenTree::Ident(ident) = item {
201        let ident = ident.to_string();
202        // Skip the access modifier
203        if ident == "pub" {
204          continue;
205        }
206        field_name = Some(ident);
207        // Use the field's actual name within the serialization, if not renamed
208        serialization_field_name = serialization_field_name.or(field_name.clone());
209        break;
210      }
211    }
212    let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
213    let serialization_field_name =
214      serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
215
216    if !skip {
217      fields.push((field_name, serialization_field_name));
218    }
219
220    // Advance to the next field
221    skip_comma_delimited(&mut struct_body);
222  }
223
224  Struct { generic_bounds, generics, name, fields }
225}
226
227/// Derive an implementation of the `JsonDeserialize` trait.
228///
229/// This _requires_ the `struct` derived for implement `Default`. Fields which aren't present in
230/// the serialization will be left to their `Default` initialization. If you wish to detect if a
231/// field was omitted, please wrap it in `Option`.
232///
233/// Fields may deserialized from a distinct key using the `key` attribute, accepting a string
234/// literal for the key to deserialize from (`key("key")`). Fields may be omitted from
235/// deserialization with the `skip` attribute.
236///
237/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
238#[proc_macro_derive(JsonDeserialize, attributes(key, skip))]
239pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
240  let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
241
242  let mut largest_key = 0;
243  let mut fields_deserialization = String::new();
244  for (field_name, serialization_field_name) in &fields {
245    largest_key = largest_key.max(serialization_field_name.len());
246
247    let mut serialization_field_name_array = "&[".to_string();
248    for char in serialization_field_name.chars() {
249      serialization_field_name_array.push('\'');
250      serialization_field_name_array.push_str(&char.escape_unicode().to_string());
251      serialization_field_name_array.push('\'');
252      serialization_field_name_array.push(',');
253    }
254    serialization_field_name_array.push(']');
255
256    fields_deserialization.push_str(&format!(
257      r#"
258      {serialization_field_name_array} => {{
259        result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
260      }},
261      "#
262    ));
263  }
264
265  TokenStream::from_str(&format!(
266    r#"
267    impl{generic_bounds} core_json_traits::JsonDeserialize for {name}{generics}
268      where Self: core::default::Default {{
269      fn deserialize<
270        'read,
271        'parent,
272        B: core_json_traits::Read<'read>,
273        S: core_json_traits::Stack,
274      >(
275        value: core_json_traits::Value<'read, 'parent, B, S>,
276      ) -> Result<Self, core_json_traits::JsonError<'read, B, S>> {{
277        use core::default::Default;
278
279        let mut result = Self::default();
280        if {largest_key} == 0 {{
281          return Ok(result);
282        }}
283
284        let mut key_chars = ['\0'; {largest_key}];
285        let mut object = value.fields()?;
286        'serialized_field: while let Some(field) = object.next() {{
287          let mut field = field?;
288
289          let key = {{
290            let key = field.key();
291            let mut key_len = 0;
292            while let Some(key_char) = key.next() {{
293              key_chars[key_len] = match key_char {{
294                Ok(key_char) => key_char,
295                /*
296                  This occurs when the key specifies an invalid UTF codepoint, which is technically
297                  allowed by RFC 8259. While it means we can't interpret the key, it also means
298                  this isn't a field we're looking for.
299
300                  Continue to the next serialized field accordingly.
301                */
302                Err(core_json_traits::JsonError::InvalidValue) => continue 'serialized_field,
303                // Propagate all other errors.
304                Err(e) => Err(e)?,
305              }};
306              key_len += 1;
307              if key_len == {largest_key} {{
308                break;
309              }}
310            }}
311            match key.next() {{
312              None => {{}},
313              // This key is larger than our largest key
314              Some(Ok(_)) => continue,
315              Some(Err(e)) => Err(e)?,
316            }}
317            &key_chars[.. key_len]
318          }};
319          let value = field.value();
320
321          match key {{
322            {fields_deserialization}
323            // Skip unknown fields
324            _ => {{}}
325          }}
326        }}
327
328        Ok(result)
329      }}
330    }}
331    impl{generic_bounds} core_json_traits::JsonStructure for {name}{generics}
332      where Self: core::default::Default {{}}
333    "#
334  ))
335  .expect("typo in implementation of `JsonDeserialize`")
336}
337
338/// Derive an implementation of the `JsonSerialize` trait.
339///
340/// Fields may serialized with a distinct name using the `key` attribute, accepting a string
341/// literal for the key to serialize as (`key("key")`). Fields may be omitted from serialization
342/// with the `skip` attribute.
343///
344/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
345#[proc_macro_derive(JsonSerialize, attributes(key, skip))]
346pub fn derive_json_serialize(object: TokenStream) -> TokenStream {
347  let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
348
349  let mut fields_serialization = String::new();
350  for (i, (field_name, serialization_field_name)) in fields.iter().enumerate() {
351    let comma = if (i + 1) == fields.len() { "" } else { r#".chain(core::iter::once(','))"# };
352
353    fields_serialization.push_str(&format!(
354      r#"
355      .chain("{serialization_field_name}".serialize())
356      .chain(core::iter::once(':'))
357      .chain(core_json_traits::JsonSerialize::serialize(&self.{field_name}))
358      {comma}
359      "#
360    ));
361  }
362
363  TokenStream::from_str(&format!(
364    r#"
365    impl{generic_bounds} core_json_traits::JsonSerialize for {name}{generics} {{
366      fn serialize(&self) -> impl Iterator<Item = char> {{
367        core::iter::once('{{')
368        {fields_serialization}
369        .chain(core::iter::once('}}'))
370      }}
371    }}
372    "#
373  ))
374  .expect("typo in implementation of `JsonSerialize`")
375}