core_json_derive/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10  string::{String, ToString},
11  vec, format,
12};
13
14extern crate proc_macro;
15use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
16
17// `<` will not open a group, so we use this to take all items within a `< ... >` expression.
18fn take_angle_expression(
19  iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
20) -> TokenStream {
21  {
22    let Some(peeked) = iter.peek() else { return TokenStream::default() };
23    let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
24    if punct.as_char() != '<' {
25      return TokenStream::default();
26    }
27  }
28
29  let mut result = vec![];
30  let mut count = 0;
31  loop {
32    let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
33    result.push(item.borrow().clone());
34    if let TokenTree::Punct(punct) = item.borrow() {
35      let punct = punct.as_char();
36      if punct == '<' {
37        count += 1;
38      }
39      if punct == '>' {
40        count -= 1;
41      }
42      if count == 0 {
43        break;
44      }
45    }
46  }
47  TokenStream::from_iter(result)
48}
49
50// Advance the iterator past the next `,` on this depth, if there is one.
51fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
52  loop {
53    take_angle_expression(iter);
54    let Some(item) = iter.next() else { return };
55    if let TokenTree::Punct(punct) = item.borrow() {
56      if punct.as_char() == ',' {
57        return;
58      }
59    }
60  }
61}
62
63/// Derive an implementation of the `JsonDeserialize` trait.
64///
65/// This _requires_ the `struct` derived for implement `Default`. Fields which aren't present in
66/// the serialization will be left to their `Default` initialization. If you wish to detect if a
67/// field was omitted, please wrap it in `Option`.
68///
69/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
70#[proc_macro_derive(JsonDeserialize)]
71pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
72  let generic_bounds;
73  let generics;
74  let object_name;
75  let mut largest_key = 0;
76  let mut all_fields = String::new();
77  {
78    let mut object = object.clone().into_iter().peekable();
79
80    loop {
81      match object.peek() {
82        Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
83          let _ = object.next().expect("peeked but not present");
84          let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
85            panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
86          };
87        }
88        _ => break,
89      }
90    }
91
92    match object.next() {
93      Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => {}
94      _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
95    }
96    object_name = match object.next() {
97      Some(TokenTree::Ident(ident)) => ident.to_string(),
98      _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
99    };
100
101    let generic_bounds_tree = take_angle_expression(&mut object);
102
103    let mut generics_tree = vec![];
104    {
105      let mut iter = generic_bounds_tree.clone().into_iter().peekable();
106      while let Some(component) = iter.next() {
107        // Take until the next colon, used to mark trait bounds
108        if let TokenTree::Punct(punct) = &component {
109          if punct.as_char() == ':' {
110            // Skip the actual bounds
111            skip_comma_delimited(&mut iter);
112            // Add our own comma delimiter and move to the next item
113            generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
114            continue;
115          }
116        }
117        // Push this component as it isn't part of the bounds
118        generics_tree.push(component);
119      }
120    }
121    // Ensure this is terminated, which it won't be if the last item had bounds yet didn't have a
122    // trailing comma
123    if let Some(last) = generics_tree.last() {
124      match last {
125        TokenTree::Punct(punct) if punct.as_char() == '>' => {}
126        _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
127      }
128    }
129
130    generic_bounds = generic_bounds_tree.to_string();
131    generics = TokenStream::from_iter(generics_tree).to_string();
132
133    // This presumably means we don't support `struct`'s defined with `where` bounds
134    let Some(TokenTree::Group(struct_body)) = object.next() else {
135      panic!("`struct`'s name was not followed by its body");
136    };
137    if struct_body.delimiter() != Delimiter::Brace {
138      panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
139    }
140    let mut struct_body = struct_body.stream().into_iter().peekable();
141    // Read each field within this `struct`'s body
142    while struct_body.peek().is_some() {
143      // Access the field name
144      let mut field_name = None;
145      // This loop will ignore attributes successfully
146      for item in &mut struct_body {
147        if let TokenTree::Ident(ident) = item {
148          let ident = ident.to_string();
149          // Skip the access modifier
150          if ident == "pub" {
151            continue;
152          }
153          field_name = Some(ident);
154          break;
155        }
156      }
157      let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
158      largest_key = largest_key.max(field_name.len());
159
160      for char in field_name.chars() {
161        if !(char.is_ascii_alphanumeric() || (char == '_')) {
162          panic!("character in name of field wasn't supported (A-Za-z0-9_)");
163        }
164      }
165
166      all_fields.push_str(&format!(
167        r#"
168        b"{field_name}" => {{
169          result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
170        }},
171      "#
172      ));
173
174      // Advance to the next field
175      skip_comma_delimited(&mut struct_body);
176    }
177  }
178
179  TokenStream::from_str(&format!(
180    r#"
181    impl{generic_bounds} core_json_traits::JsonDeserialize for {object_name}{generics}
182      where Self: core::default::Default {{
183      fn deserialize<
184        'bytes,
185        'parent,
186        B: core_json_traits::BytesLike<'bytes>,
187        S: core_json_traits::Stack,
188      >(
189        value: core_json_traits::Value<'bytes, 'parent, B, S>,
190      ) -> Result<Self, core_json_traits::JsonError<'bytes, B, S>> {{
191        use core::default::Default;
192
193        let mut result = Self::default();
194
195        let mut key_bytes = [0; {largest_key}];
196        let mut object = value.fields()?;
197        while let Some(field) = object.next() {{
198          let (mut key, value) = field?;
199
200          if key.len() > {largest_key} {{
201            continue;
202          }}
203          let key = {{
204            let key_len = key.len();
205            key
206              .consume()
207              .read_into_slice(&mut key_bytes[.. key_len])
208              .map_err(core_json_traits::JsonError::BytesError)?;
209            &key_bytes[.. key_len]
210          }};
211
212          match key {{
213            {all_fields}
214            // Skip unknown fields
215            _ => {{}}
216          }}
217        }}
218
219        Ok(result)
220      }}
221    }}
222    impl{generic_bounds} core_json_traits::JsonObject for {object_name}{generics}
223      where Self: core::default::Default {{}}
224    "#
225  ))
226  .expect("typo in implementation of `JsonDeserialize`")
227}