1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10 string::{String, ToString},
11 vec, format,
12};
13
14extern crate proc_macro;
15use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
16
17fn take_angle_expression(
19 iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
20) -> TokenStream {
21 {
22 let Some(peeked) = iter.peek() else { return TokenStream::default() };
23 let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
24 if punct.as_char() != '<' {
25 return TokenStream::default();
26 }
27 }
28
29 let mut result = vec![];
30 let mut count = 0;
31 loop {
32 let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
33 result.push(item.borrow().clone());
34 if let TokenTree::Punct(punct) = item.borrow() {
35 let punct = punct.as_char();
36 if punct == '<' {
37 count += 1;
38 }
39 if punct == '>' {
40 count -= 1;
41 }
42 if count == 0 {
43 break;
44 }
45 }
46 }
47 TokenStream::from_iter(result)
48}
49
50fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
52 loop {
53 take_angle_expression(iter);
54 let Some(item) = iter.next() else { return };
55 if let TokenTree::Punct(punct) = item.borrow() {
56 if punct.as_char() == ',' {
57 return;
58 }
59 }
60 }
61}
62
63#[proc_macro_derive(JsonDeserialize, attributes(rename))]
71pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
72 let generic_bounds;
73 let generics;
74 let object_name;
75 let mut largest_key = 0;
76 let mut all_fields = String::new();
77 {
78 let mut object = object.clone().into_iter().peekable();
79
80 loop {
81 match object.peek() {
82 Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
83 let _ = object.next().expect("peeked but not present");
84 let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
85 panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
86 };
87 }
88 _ => break,
89 }
90 }
91
92 match object.next() {
93 Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => {}
94 _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
95 }
96 object_name = match object.next() {
97 Some(TokenTree::Ident(ident)) => ident.to_string(),
98 _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
99 };
100
101 let generic_bounds_tree = take_angle_expression(&mut object);
102
103 let mut generics_tree = vec![];
104 {
105 let mut iter = generic_bounds_tree.clone().into_iter().peekable();
106 while let Some(component) = iter.next() {
107 if let TokenTree::Punct(punct) = &component {
109 if punct.as_char() == ':' {
110 skip_comma_delimited(&mut iter);
112 generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
114 continue;
115 }
116 }
117 generics_tree.push(component);
119 }
120 }
121 if let Some(last) = generics_tree.last() {
124 match last {
125 TokenTree::Punct(punct) if punct.as_char() == '>' => {}
126 _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
127 }
128 }
129
130 generic_bounds = generic_bounds_tree.to_string();
131 generics = TokenStream::from_iter(generics_tree).to_string();
132
133 let Some(TokenTree::Group(struct_body)) = object.next() else {
135 panic!("`struct`'s name was not followed by its body");
136 };
137 if struct_body.delimiter() != Delimiter::Brace {
138 panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
139 }
140 let mut struct_body = struct_body.stream().into_iter().peekable();
141 while struct_body.peek().is_some() {
143 let mut serialization_field_name = None;
145 let mut field_name = None;
146 for item in &mut struct_body {
147 if let TokenTree::Group(group) = &item {
149 if group.delimiter() == Delimiter::Bracket {
150 let mut iter = group.stream().into_iter();
151 if iter.next().and_then(|ident| match ident {
152 TokenTree::Ident(ident) => Some(ident.to_string()),
153 _ => None,
154 }) == Some("rename".to_string())
155 {
156 let TokenTree::Group(group) =
157 iter.next().expect("`rename` attribute without arguments")
158 else {
159 panic!("`rename` attribute not followed with `(...)`")
160 };
161 assert_eq!(
162 group.delimiter(),
163 Delimiter::Parenthesis,
164 "`rename` attribute with a non-parentheses group"
165 );
166 assert_eq!(
167 group.stream().into_iter().count(),
168 1,
169 "`rename` attribute with multiple tokens within parentheses"
170 );
171 let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
172 panic!("`rename` attribute with a non-literal argument")
173 };
174 let literal = literal.to_string();
175 assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
176 assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
177 serialization_field_name =
178 Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
179 }
180 }
181 }
182
183 if let TokenTree::Ident(ident) = item {
184 let ident = ident.to_string();
185 if ident == "pub" {
187 continue;
188 }
189 field_name = Some(ident);
190 serialization_field_name = serialization_field_name.or(field_name.clone());
192 break;
193 }
194 }
195 let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
196 let serialization_field_name =
197 serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
198 largest_key = largest_key.max(serialization_field_name.len());
199
200 let mut serialization_field_name_array = "&[".to_string();
201 for char in serialization_field_name.chars() {
202 serialization_field_name_array.push('\'');
203 serialization_field_name_array.push_str(&char.escape_unicode().to_string());
204 serialization_field_name_array.push('\'');
205 serialization_field_name_array.push(',');
206 }
207 serialization_field_name_array.push(']');
208
209 all_fields.push_str(&format!(
210 r#"
211 {serialization_field_name_array} => {{
212 result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
213 }},
214 "#
215 ));
216
217 skip_comma_delimited(&mut struct_body);
219 }
220 }
221
222 TokenStream::from_str(&format!(
223 r#"
224 impl{generic_bounds} core_json_traits::JsonDeserialize for {object_name}{generics}
225 where Self: core::default::Default {{
226 fn deserialize<
227 'bytes,
228 'parent,
229 B: core_json_traits::BytesLike<'bytes>,
230 S: core_json_traits::Stack,
231 >(
232 value: core_json_traits::Value<'bytes, 'parent, B, S>,
233 ) -> Result<Self, core_json_traits::JsonError<'bytes, B, S>> {{
234 use core::default::Default;
235
236 let mut result = Self::default();
237 if {largest_key} == 0 {{
238 return Ok(result);
239 }}
240
241 let mut key_chars = ['\0'; {largest_key}];
242 let mut object = value.fields()?;
243 'serialized_field: while let Some(field) = object.next() {{
244 let (mut key, value) = field?;
245
246 let key = {{
247 let mut key_len = 0;
248 while let Some(key_char) = key.next() {{
249 key_chars[key_len] = match key_char {{
250 Ok(key_char) => key_char,
251 /*
252 This occurs when the key specifies an invalid UTF codepoint, which is technically
253 allowed by RFC 8259. While it means we can't interpret the key, it also means
254 this isn't a field we're looking for.
255
256 Continue to the next serialized field accordingly.
257 */
258 Err(core_json_traits::JsonError::InvalidValue) => continue 'serialized_field,
259 // Propagate all other errors.
260 Err(e) => Err(e)?,
261 }};
262 key_len += 1;
263 if key_len == {largest_key} {{
264 break;
265 }}
266 }}
267 match key.next() {{
268 None => {{}},
269 // This key is larger than our largest key
270 Some(Ok(_)) => continue,
271 Some(Err(e)) => Err(e)?,
272 }}
273 &key_chars[.. key_len]
274 }};
275
276 match key {{
277 {all_fields}
278 // Skip unknown fields
279 _ => {{}}
280 }}
281 }}
282
283 Ok(result)
284 }}
285 }}
286 impl{generic_bounds} core_json_traits::JsonStructure for {object_name}{generics}
287 where Self: core::default::Default {{}}
288 "#
289 ))
290 .expect("typo in implementation of `JsonDeserialize`")
291}