core_json_derive/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10  string::{String, ToString},
11  vec, format,
12};
13
14extern crate proc_macro;
15use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
16
17// `<` will not open a group, so we use this to take all items within a `< ... >` expression.
18fn take_angle_expression(
19  iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
20) -> TokenStream {
21  {
22    let Some(peeked) = iter.peek() else { return TokenStream::default() };
23    let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
24    if punct.as_char() != '<' {
25      return TokenStream::default();
26    }
27  }
28
29  let mut result = vec![];
30  let mut count = 0;
31  loop {
32    let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
33    result.push(item.borrow().clone());
34    if let TokenTree::Punct(punct) = item.borrow() {
35      let punct = punct.as_char();
36      if punct == '<' {
37        count += 1;
38      }
39      if punct == '>' {
40        count -= 1;
41      }
42      if count == 0 {
43        break;
44      }
45    }
46  }
47  TokenStream::from_iter(result)
48}
49
50// Advance the iterator past the next `,` on this depth, if there is one.
51fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
52  loop {
53    take_angle_expression(iter);
54    let Some(item) = iter.next() else { return };
55    if let TokenTree::Punct(punct) = item.borrow() {
56      if punct.as_char() == ',' {
57        return;
58      }
59    }
60  }
61}
62
63/// Derive an implementation of the `JsonDeserialize` trait.
64///
65/// This _requires_ the `struct` derived for implement `Default`. Fields which aren't present in
66/// the serialization will be left to their `Default` initialization. If you wish to detect if a
67/// field was omitted, please wrap it in `Option`.
68///
69/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
70#[proc_macro_derive(JsonDeserialize, attributes(rename))]
71pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
72  let generic_bounds;
73  let generics;
74  let object_name;
75  let mut largest_key = 0;
76  let mut all_fields = String::new();
77  {
78    let mut object = object.clone().into_iter().peekable();
79
80    loop {
81      match object.peek() {
82        Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
83          let _ = object.next().expect("peeked but not present");
84          let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
85            panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
86          };
87        }
88        _ => break,
89      }
90    }
91
92    match object.next() {
93      Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => {}
94      _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
95    }
96    object_name = match object.next() {
97      Some(TokenTree::Ident(ident)) => ident.to_string(),
98      _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
99    };
100
101    let generic_bounds_tree = take_angle_expression(&mut object);
102
103    let mut generics_tree = vec![];
104    {
105      let mut iter = generic_bounds_tree.clone().into_iter().peekable();
106      while let Some(component) = iter.next() {
107        // Take until the next colon, used to mark trait bounds
108        if let TokenTree::Punct(punct) = &component {
109          if punct.as_char() == ':' {
110            // Skip the actual bounds
111            skip_comma_delimited(&mut iter);
112            // Add our own comma delimiter and move to the next item
113            generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
114            continue;
115          }
116        }
117        // Push this component as it isn't part of the bounds
118        generics_tree.push(component);
119      }
120    }
121    // Ensure this is terminated, which it won't be if the last item had bounds yet didn't have a
122    // trailing comma
123    if let Some(last) = generics_tree.last() {
124      match last {
125        TokenTree::Punct(punct) if punct.as_char() == '>' => {}
126        _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
127      }
128    }
129
130    generic_bounds = generic_bounds_tree.to_string();
131    generics = TokenStream::from_iter(generics_tree).to_string();
132
133    // This presumably means we don't support `struct`'s defined with `where` bounds
134    let Some(TokenTree::Group(struct_body)) = object.next() else {
135      panic!("`struct`'s name was not followed by its body");
136    };
137    if struct_body.delimiter() != Delimiter::Brace {
138      panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
139    }
140    let mut struct_body = struct_body.stream().into_iter().peekable();
141    // Read each field within this `struct`'s body
142    while struct_body.peek().is_some() {
143      // Access the field name
144      let mut serialization_field_name = None;
145      let mut field_name = None;
146      for item in &mut struct_body {
147        // Hanlde the `rename` attribute
148        if let TokenTree::Group(group) = &item {
149          if group.delimiter() == Delimiter::Bracket {
150            let mut iter = group.stream().into_iter();
151            if iter.next().and_then(|ident| match ident {
152              TokenTree::Ident(ident) => Some(ident.to_string()),
153              _ => None,
154            }) == Some("rename".to_string())
155            {
156              let TokenTree::Group(group) =
157                iter.next().expect("`rename` attribute without arguments")
158              else {
159                panic!("`rename` attribute not followed with `(...)`")
160              };
161              assert_eq!(
162                group.delimiter(),
163                Delimiter::Parenthesis,
164                "`rename` attribute with a non-parentheses group"
165              );
166              assert_eq!(
167                group.stream().into_iter().count(),
168                1,
169                "`rename` attribute with multiple tokens within parentheses"
170              );
171              let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
172                panic!("`rename` attribute with a non-literal argument")
173              };
174              let literal = literal.to_string();
175              assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
176              assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
177              serialization_field_name =
178                Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
179            }
180          }
181        }
182
183        if let TokenTree::Ident(ident) = item {
184          let ident = ident.to_string();
185          // Skip the access modifier
186          if ident == "pub" {
187            continue;
188          }
189          field_name = Some(ident);
190          // Use the field's actual name within the serialization, if not renamed
191          serialization_field_name = serialization_field_name.or(field_name.clone());
192          break;
193        }
194      }
195      let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
196      let serialization_field_name =
197        serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
198      largest_key = largest_key.max(serialization_field_name.len());
199
200      let mut serialization_field_name_array = "&[".to_string();
201      for char in serialization_field_name.chars() {
202        serialization_field_name_array.push('\'');
203        serialization_field_name_array.push_str(&char.escape_unicode().to_string());
204        serialization_field_name_array.push('\'');
205        serialization_field_name_array.push(',');
206      }
207      serialization_field_name_array.push(']');
208
209      all_fields.push_str(&format!(
210        r#"
211        {serialization_field_name_array} => {{
212          result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
213        }},
214      "#
215      ));
216
217      // Advance to the next field
218      skip_comma_delimited(&mut struct_body);
219    }
220  }
221
222  TokenStream::from_str(&format!(
223    r#"
224    impl{generic_bounds} core_json_traits::JsonDeserialize for {object_name}{generics}
225      where Self: core::default::Default {{
226      fn deserialize<
227        'bytes,
228        'parent,
229        B: core_json_traits::BytesLike<'bytes>,
230        S: core_json_traits::Stack,
231      >(
232        value: core_json_traits::Value<'bytes, 'parent, B, S>,
233      ) -> Result<Self, core_json_traits::JsonError<'bytes, B, S>> {{
234        use core::default::Default;
235
236        let mut result = Self::default();
237        if {largest_key} == 0 {{
238          return Ok(result);
239        }}
240
241        let mut key_chars = ['\0'; {largest_key}];
242        let mut object = value.fields()?;
243        'serialized_field: while let Some(field) = object.next() {{
244          let (mut key, value) = field?;
245
246          let key = {{
247            let mut key_len = 0;
248            while let Some(key_char) = key.next() {{
249              key_chars[key_len] = match key_char {{
250                Ok(key_char) => key_char,
251                /*
252                  This occurs when the key specifies an invalid UTF codepoint, which is technically
253                  allowed by RFC 8259. While it means we can't interpret the key, it also means
254                  this isn't a field we're looking for.
255
256                  Continue to the next serialized field accordingly.
257                */
258                Err(core_json_traits::JsonError::InvalidValue) => continue 'serialized_field,
259                // Propagate all other errors.
260                Err(e) => Err(e)?,
261              }};
262              key_len += 1;
263              if key_len == {largest_key} {{
264                break;
265              }}
266            }}
267            match key.next() {{
268              None => {{}},
269              // This key is larger than our largest key
270              Some(Ok(_)) => continue,
271              Some(Err(e)) => Err(e)?,
272            }}
273            &key_chars[.. key_len]
274          }};
275
276          match key {{
277            {all_fields}
278            // Skip unknown fields
279            _ => {{}}
280          }}
281        }}
282
283        Ok(result)
284      }}
285    }}
286    impl{generic_bounds} core_json_traits::JsonStructure for {object_name}{generics}
287      where Self: core::default::Default {{}}
288    "#
289  ))
290  .expect("typo in implementation of `JsonDeserialize`")
291}