core_json_derive/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10  string::{String, ToString},
11  vec, format,
12};
13
14extern crate proc_macro;
15use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
16
17// `<` will not open a group, so we use this to take all items within a `< ... >` expression.
18fn take_angle_expression(
19  iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
20) -> TokenStream {
21  {
22    let Some(peeked) = iter.peek() else { return TokenStream::default() };
23    let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
24    if punct.as_char() != '<' {
25      return TokenStream::default();
26    }
27  }
28
29  let mut result = vec![];
30  let mut count = 0;
31  loop {
32    let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
33    result.push(item.borrow().clone());
34    if let TokenTree::Punct(punct) = item.borrow() {
35      let punct = punct.as_char();
36      if punct == '<' {
37        count += 1;
38      }
39      if punct == '>' {
40        count -= 1;
41      }
42      if count == 0 {
43        break;
44      }
45    }
46  }
47  TokenStream::from_iter(result)
48}
49
50// Advance the iterator past the next `,` on this depth, if there is one.
51fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
52  loop {
53    take_angle_expression(iter);
54    let Some(item) = iter.next() else { return };
55    if let TokenTree::Punct(punct) = item.borrow() {
56      if punct.as_char() == ',' {
57        return;
58      }
59    }
60  }
61}
62
63/// Derive an implementation of the `JsonDeserialize` trait.
64///
65/// This _requires_ the `struct` derived for implement `Default`. Fields which aren't present in
66/// the serialization will be left to their `Default` initialization. If you wish to detect if a
67/// field was omitted, please wrap it in `Option`.
68///
69/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
70#[proc_macro_derive(JsonDeserialize, attributes(rename))]
71pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
72  let generic_bounds;
73  let generics;
74  let object_name;
75  let mut largest_key = 0;
76  let mut all_fields = String::new();
77  {
78    let mut object = object.clone().into_iter().peekable();
79
80    loop {
81      match object.peek() {
82        Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
83          let _ = object.next().expect("peeked but not present");
84          let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
85            panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
86          };
87        }
88        _ => break,
89      }
90    }
91
92    match object.next() {
93      Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => {}
94      _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
95    }
96    object_name = match object.next() {
97      Some(TokenTree::Ident(ident)) => ident.to_string(),
98      _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
99    };
100
101    let generic_bounds_tree = take_angle_expression(&mut object);
102
103    let mut generics_tree = vec![];
104    {
105      let mut iter = generic_bounds_tree.clone().into_iter().peekable();
106      while let Some(component) = iter.next() {
107        // Take until the next colon, used to mark trait bounds
108        if let TokenTree::Punct(punct) = &component {
109          if punct.as_char() == ':' {
110            // Skip the actual bounds
111            skip_comma_delimited(&mut iter);
112            // Add our own comma delimiter and move to the next item
113            generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
114            continue;
115          }
116        }
117        // Push this component as it isn't part of the bounds
118        generics_tree.push(component);
119      }
120    }
121    // Ensure this is terminated, which it won't be if the last item had bounds yet didn't have a
122    // trailing comma
123    if let Some(last) = generics_tree.last() {
124      match last {
125        TokenTree::Punct(punct) if punct.as_char() == '>' => {}
126        _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
127      }
128    }
129
130    generic_bounds = generic_bounds_tree.to_string();
131    generics = TokenStream::from_iter(generics_tree).to_string();
132
133    // This presumably means we don't support `struct`'s defined with `where` bounds
134    let Some(TokenTree::Group(struct_body)) = object.next() else {
135      panic!("`struct`'s name was not followed by its body");
136    };
137    if struct_body.delimiter() != Delimiter::Brace {
138      panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
139    }
140    let mut struct_body = struct_body.stream().into_iter().peekable();
141    // Read each field within this `struct`'s body
142    while struct_body.peek().is_some() {
143      // Access the field name
144      let mut serialization_field_name = None;
145      let mut field_name = None;
146      for item in &mut struct_body {
147        // Hanlde the `rename` attribute
148        if let TokenTree::Group(group) = &item {
149          if group.delimiter() == Delimiter::Bracket {
150            let mut iter = group.stream().into_iter();
151            if iter.next().and_then(|ident| match ident {
152              TokenTree::Ident(ident) => Some(ident.to_string()),
153              _ => None,
154            }) == Some("rename".to_string())
155            {
156              let TokenTree::Group(group) =
157                iter.next().expect("`rename` attribute without arguments")
158              else {
159                panic!("`rename` attribute not followed with `(...)`")
160              };
161              assert_eq!(
162                group.delimiter(),
163                Delimiter::Parenthesis,
164                "`rename` attribute with a non-parentheses group"
165              );
166              assert_eq!(
167                group.stream().into_iter().count(),
168                1,
169                "`rename` attribute with multiple tokens within parentheses"
170              );
171              let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
172                panic!("`rename` attribute with a non-literal argument")
173              };
174              let literal = literal.to_string();
175              assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
176              assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
177              serialization_field_name =
178                Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
179            }
180          }
181        }
182
183        if let TokenTree::Ident(ident) = item {
184          let ident = ident.to_string();
185          // Skip the access modifier
186          if ident == "pub" {
187            continue;
188          }
189          field_name = Some(ident);
190          // Use the field's actual name within the serialization, if not renamed
191          serialization_field_name = serialization_field_name.or(field_name.clone());
192          break;
193        }
194      }
195      let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
196      let serialization_field_name =
197        serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
198      largest_key = largest_key.max(serialization_field_name.len());
199
200      for char in serialization_field_name.chars() {
201        if !(char.is_ascii_alphanumeric() || (char == '_')) {
202          panic!(
203            "character in name of field wasn't supported (`[A-Za-z0-9_]+` required): `{char}`"
204          );
205        }
206      }
207
208      all_fields.push_str(&format!(
209        r#"
210        b"{serialization_field_name}" => {{
211          result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
212        }},
213      "#
214      ));
215
216      // Advance to the next field
217      skip_comma_delimited(&mut struct_body);
218    }
219  }
220
221  TokenStream::from_str(&format!(
222    r#"
223    impl{generic_bounds} core_json_traits::JsonDeserialize for {object_name}{generics}
224      where Self: core::default::Default {{
225      fn deserialize<
226        'bytes,
227        'parent,
228        B: core_json_traits::BytesLike<'bytes>,
229        S: core_json_traits::Stack,
230      >(
231        value: core_json_traits::Value<'bytes, 'parent, B, S>,
232      ) -> Result<Self, core_json_traits::JsonError<'bytes, B, S>> {{
233        use core::default::Default;
234
235        let mut result = Self::default();
236
237        let mut key_bytes = [0; {largest_key}];
238        let mut object = value.fields()?;
239        while let Some(field) = object.next() {{
240          let (mut key, value) = field?;
241
242          if key.len() > {largest_key} {{
243            continue;
244          }}
245          let key = {{
246            let key_len = key.len();
247            key
248              .consume()
249              .read_into_slice(&mut key_bytes[.. key_len])
250              .map_err(core_json_traits::JsonError::BytesError)?;
251            &key_bytes[.. key_len]
252          }};
253
254          match key {{
255            {all_fields}
256            // Skip unknown fields
257            _ => {{}}
258          }}
259        }}
260
261        Ok(result)
262      }}
263    }}
264    impl{generic_bounds} core_json_traits::JsonObject for {object_name}{generics}
265      where Self: core::default::Default {{}}
266    "#
267  ))
268  .expect("typo in implementation of `JsonDeserialize`")
269}