core_json_derive/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10  vec,
11  vec::Vec,
12  string::{String, ToString},
13  format,
14};
15
16extern crate proc_macro;
17use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
18
19// `<` will not open a group, so we use this to take all items within a `< ... >` expression.
20fn take_angle_expression(
21  iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
22) -> TokenStream {
23  {
24    let Some(peeked) = iter.peek() else { return TokenStream::default() };
25    let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
26    if punct.as_char() != '<' {
27      return TokenStream::default();
28    }
29  }
30
31  let mut result = vec![];
32  let mut count = 0;
33  loop {
34    let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
35    result.push(item.borrow().clone());
36    if let TokenTree::Punct(punct) = item.borrow() {
37      let punct = punct.as_char();
38      if punct == '<' {
39        count += 1;
40      }
41      if punct == '>' {
42        count -= 1;
43      }
44      if count == 0 {
45        break;
46      }
47    }
48  }
49  TokenStream::from_iter(result)
50}
51
52// Advance the iterator past the next `,` on this depth, if there is one.
53fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
54  loop {
55    take_angle_expression(iter);
56    let Some(item) = iter.next() else { return };
57    if let TokenTree::Punct(punct) = item.borrow() {
58      if punct.as_char() == ',' {
59        return;
60      }
61    }
62  }
63}
64
65struct Struct {
66  generic_bounds: String,
67  generics: String,
68  name: String,
69  fields: Vec<(String, String)>,
70}
71
72// This is somewhat comparable to `syn::Generics`, especially its `split_for_impl` method.
73fn parse_struct(object: TokenStream) -> Struct {
74  let mut object = object.into_iter().peekable();
75
76  loop {
77    match object.peek() {
78      Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
79        let _ = object.next().expect("peeked but not present");
80        let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
81          panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
82        };
83      }
84      _ => break,
85    }
86  }
87
88  match object.next() {
89    Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => {}
90    _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
91  }
92  let name = match object.next() {
93    Some(TokenTree::Ident(ident)) => ident.to_string(),
94    _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
95  };
96
97  let generic_bounds_tree = take_angle_expression(&mut object);
98
99  let mut generics_tree = vec![];
100  {
101    let mut iter = generic_bounds_tree.clone().into_iter().peekable();
102    while let Some(component) = iter.next() {
103      // Take until the next colon, used to mark trait bounds
104      if let TokenTree::Punct(punct) = &component {
105        if punct.as_char() == ':' {
106          // Skip the actual bounds
107          skip_comma_delimited(&mut iter);
108          // Add our own comma delimiter and move to the next item
109          generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
110          continue;
111        }
112      }
113      // Push this component as it isn't part of the bounds
114      generics_tree.push(component);
115    }
116  }
117  // Ensure this is terminated, which it won't be if the last item had bounds yet didn't have a
118  // trailing comma
119  if let Some(last) = generics_tree.last() {
120    match last {
121      TokenTree::Punct(punct) if punct.as_char() == '>' => {}
122      _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
123    }
124  }
125
126  let generic_bounds = generic_bounds_tree.to_string();
127  let generics = TokenStream::from_iter(generics_tree).to_string();
128
129  // This presumably means we don't support `struct`s defined with `where` bounds
130  let Some(TokenTree::Group(struct_body)) = object.next() else {
131    panic!("`struct`'s name was not followed by its body");
132  };
133  if struct_body.delimiter() != Delimiter::Brace {
134    panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
135  }
136
137  let mut fields = vec![];
138
139  let mut struct_body = struct_body.stream().into_iter().peekable();
140  // Read each field within this `struct`'s body
141  while struct_body.peek().is_some() {
142    // Access the field name
143    let mut serialization_field_name = None;
144    let mut field_name = None;
145    let mut skip = false;
146    for item in &mut struct_body {
147      // Handle the `key` attribute
148      if let TokenTree::Group(group) = &item {
149        if group.delimiter() == Delimiter::Bracket {
150          let mut iter = group.stream().into_iter();
151          let ident = iter.next().and_then(|ident| match ident {
152            TokenTree::Ident(ident) => Some(ident.to_string()),
153            _ => None,
154          });
155          match ident.as_deref() {
156            Some("skip") => skip = true,
157            Some("key") => {
158              let TokenTree::Group(group) = iter.next().expect("`key` attribute without arguments")
159              else {
160                panic!("`key` attribute not followed with `(...)`")
161              };
162              assert_eq!(
163                group.delimiter(),
164                Delimiter::Parenthesis,
165                "`key` attribute with a non-parentheses group"
166              );
167              assert_eq!(
168                group.stream().into_iter().count(),
169                1,
170                "`key` attribute with multiple tokens within parentheses"
171              );
172              let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
173                panic!("`key` attribute with a non-literal argument")
174              };
175              let literal = literal.to_string();
176              assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
177              assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
178              serialization_field_name =
179                Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
180            }
181            _ => {}
182          }
183        }
184      }
185
186      if let TokenTree::Ident(ident) = item {
187        let ident = ident.to_string();
188        // Skip the access modifier
189        if ident == "pub" {
190          continue;
191        }
192        field_name = Some(ident);
193        // Use the field's actual name within the serialization, if not renamed
194        serialization_field_name = serialization_field_name.or(field_name.clone());
195        break;
196      }
197    }
198    let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
199    let serialization_field_name =
200      serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
201
202    if !skip {
203      fields.push((field_name, serialization_field_name));
204    }
205
206    // Advance to the next field
207    skip_comma_delimited(&mut struct_body);
208  }
209
210  Struct { generic_bounds, generics, name, fields }
211}
212
213/// Derive an implementation of the `JsonDeserialize` trait.
214///
215/// This _requires_ the `struct` derived for implement `Default`. Fields which aren't present in
216/// the serialization will be left to their `Default` initialization. If you wish to detect if a
217/// field was omitted, please wrap it in `Option`.
218///
219/// Fields may deserialized from a distinct key using the `key` attribute, accepting a string
220/// literal for the key to deserialize from (`key("key")`). Fields may be omitted from
221/// deserialization with the `skip` attribute.
222///
223/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
224#[proc_macro_derive(JsonDeserialize, attributes(key, skip))]
225pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
226  let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
227
228  let mut largest_key = 0;
229  let mut fields_deserialization = String::new();
230  for (field_name, serialization_field_name) in &fields {
231    largest_key = largest_key.max(serialization_field_name.len());
232
233    let mut serialization_field_name_array = "&[".to_string();
234    for char in serialization_field_name.chars() {
235      serialization_field_name_array.push('\'');
236      serialization_field_name_array.push_str(&char.escape_unicode().to_string());
237      serialization_field_name_array.push('\'');
238      serialization_field_name_array.push(',');
239    }
240    serialization_field_name_array.push(']');
241
242    fields_deserialization.push_str(&format!(
243      r#"
244      {serialization_field_name_array} => {{
245        result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
246      }},
247      "#
248    ));
249  }
250
251  TokenStream::from_str(&format!(
252    r#"
253    impl{generic_bounds} core_json_traits::JsonDeserialize for {name}{generics}
254      where Self: core::default::Default {{
255      fn deserialize<
256        'bytes,
257        'parent,
258        B: core_json_traits::BytesLike<'bytes>,
259        S: core_json_traits::Stack,
260      >(
261        value: core_json_traits::Value<'bytes, 'parent, B, S>,
262      ) -> Result<Self, core_json_traits::JsonError<'bytes, B, S>> {{
263        use core::default::Default;
264
265        let mut result = Self::default();
266        if {largest_key} == 0 {{
267          return Ok(result);
268        }}
269
270        let mut key_chars = ['\0'; {largest_key}];
271        let mut object = value.fields()?;
272        'serialized_field: while let Some(field) = object.next() {{
273          let (mut key, value) = field?;
274
275          let key = {{
276            let mut key_len = 0;
277            while let Some(key_char) = key.next() {{
278              key_chars[key_len] = match key_char {{
279                Ok(key_char) => key_char,
280                /*
281                  This occurs when the key specifies an invalid UTF codepoint, which is technically
282                  allowed by RFC 8259. While it means we can't interpret the key, it also means
283                  this isn't a field we're looking for.
284
285                  Continue to the next serialized field accordingly.
286                */
287                Err(core_json_traits::JsonError::InvalidValue) => continue 'serialized_field,
288                // Propagate all other errors.
289                Err(e) => Err(e)?,
290              }};
291              key_len += 1;
292              if key_len == {largest_key} {{
293                break;
294              }}
295            }}
296            match key.next() {{
297              None => {{}},
298              // This key is larger than our largest key
299              Some(Ok(_)) => continue,
300              Some(Err(e)) => Err(e)?,
301            }}
302            &key_chars[.. key_len]
303          }};
304
305          match key {{
306            {fields_deserialization}
307            // Skip unknown fields
308            _ => {{}}
309          }}
310        }}
311
312        Ok(result)
313      }}
314    }}
315    impl{generic_bounds} core_json_traits::JsonStructure for {name}{generics}
316      where Self: core::default::Default {{}}
317    "#
318  ))
319  .expect("typo in implementation of `JsonDeserialize`")
320}
321
322/// Derive an implementation of the `JsonSerialize` trait.
323///
324/// Fields may serialized with a distinct name using the `key` attribute, accepting a string
325/// literal for the key to serialize as (`key("key")`). Fields may be omitted from serialization
326/// with the `skip` attribute.
327///
328/// As a procedural macro, this will panic causing a compile-time error on any unexpected input.
329#[proc_macro_derive(JsonSerialize, attributes(key, skip))]
330pub fn derive_json_serialize(object: TokenStream) -> TokenStream {
331  let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
332
333  let mut fields_serialization = String::new();
334  for (i, (field_name, serialization_field_name)) in fields.iter().enumerate() {
335    let comma = if (i + 1) == fields.len() { "" } else { r#".chain(core::iter::once(','))"# };
336
337    fields_serialization.push_str(&format!(
338      r#"
339      .chain("{serialization_field_name}".serialize())
340      .chain(core::iter::once(':'))
341      .chain(core_json_traits::JsonSerialize::serialize(&self.{field_name}))
342      {comma}
343      "#
344    ));
345  }
346
347  TokenStream::from_str(&format!(
348    r#"
349    impl{generic_bounds} core_json_traits::JsonSerialize for {name}{generics} {{
350      fn serialize(&self) -> impl Iterator<Item = char> {{
351        core::iter::once('{{')
352        {fields_serialization}
353        .chain(core::iter::once('}}'))
354      }}
355    }}
356    "#
357  ))
358  .expect("typo in implementation of `JsonSerialize`")
359}