1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10 vec,
11 vec::Vec,
12 string::{String, ToString},
13 format,
14};
15
16extern crate proc_macro;
17use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
18
19fn take_angle_expression(
21 iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
22) -> TokenStream {
23 {
24 let Some(peeked) = iter.peek() else { return TokenStream::default() };
25 let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
26 if punct.as_char() != '<' {
27 return TokenStream::default();
28 }
29 }
30
31 let mut result = vec![];
32 let mut count = 0;
33 loop {
34 let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
35 result.push(item.borrow().clone());
36 if let TokenTree::Punct(punct) = item.borrow() {
37 let punct = punct.as_char();
38 if punct == '<' {
39 count += 1;
40 }
41 if punct == '>' {
42 count -= 1;
43 }
44 if count == 0 {
45 break;
46 }
47 }
48 }
49 TokenStream::from_iter(result)
50}
51
52fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
54 loop {
55 take_angle_expression(iter);
56 let Some(item) = iter.next() else { return };
57 if let TokenTree::Punct(punct) = item.borrow() {
58 if punct.as_char() == ',' {
59 return;
60 }
61 }
62 }
63}
64
65struct Struct {
66 generic_bounds: String,
67 generics: String,
68 name: String,
69 fields: Vec<(String, String)>,
70}
71
72fn parse_struct(object: TokenStream) -> Struct {
74 let mut object = object.into_iter().peekable();
75
76 loop {
77 match object.peek() {
78 Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
79 let _ = object.next().expect("peeked but not present");
80 let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
81 panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
82 };
83 }
84 _ => break,
85 }
86 }
87
88 {
89 let mut vis = false;
90 loop {
91 match object.next() {
92 Some(TokenTree::Ident(ident)) if ident.to_string() == "pub" => {
94 if vis {
95 panic!("multiple visibility modifers found");
96 }
97 vis = true;
98 }
99 Some(TokenTree::Group(group)) if vis && (group.delimiter() == Delimiter::Parenthesis) => {}
101 Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => break,
102 _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
103 }
104 }
105 }
106 let name = match object.next() {
107 Some(TokenTree::Ident(ident)) => ident.to_string(),
108 _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
109 };
110
111 let generic_bounds_tree = take_angle_expression(&mut object);
112
113 let mut generics_tree = vec![];
114 {
115 let mut iter = generic_bounds_tree.clone().into_iter().peekable();
116 while let Some(component) = iter.next() {
117 if let TokenTree::Punct(punct) = &component {
119 if punct.as_char() == ':' {
120 skip_comma_delimited(&mut iter);
122 generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
124 continue;
125 }
126 }
127 generics_tree.push(component);
129 }
130 }
131 if let Some(last) = generics_tree.last() {
134 match last {
135 TokenTree::Punct(punct) if punct.as_char() == '>' => {}
136 _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
137 }
138 }
139
140 let generic_bounds = generic_bounds_tree.to_string();
141 let generics = TokenStream::from_iter(generics_tree).to_string();
142
143 let Some(TokenTree::Group(struct_body)) = object.next() else {
145 panic!("`struct`'s name was not followed by its body");
146 };
147 if struct_body.delimiter() != Delimiter::Brace {
148 panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
149 }
150
151 let mut fields = vec![];
152
153 let mut struct_body = struct_body.stream().into_iter().peekable();
154 while struct_body.peek().is_some() {
156 let mut serialization_field_name = None;
158 let mut field_name = None;
159 let mut skip = false;
160 for item in &mut struct_body {
161 if let TokenTree::Group(group) = &item {
163 if group.delimiter() == Delimiter::Bracket {
164 let mut iter = group.stream().into_iter();
165 let ident = iter.next().and_then(|ident| match ident {
166 TokenTree::Ident(ident) => Some(ident.to_string()),
167 _ => None,
168 });
169 match ident.as_deref() {
170 Some("skip") => skip = true,
171 Some("key") => {
172 let TokenTree::Group(group) = iter.next().expect("`key` attribute without arguments")
173 else {
174 panic!("`key` attribute not followed with `(...)`")
175 };
176 assert_eq!(
177 group.delimiter(),
178 Delimiter::Parenthesis,
179 "`key` attribute with a non-parentheses group"
180 );
181 assert_eq!(
182 group.stream().into_iter().count(),
183 1,
184 "`key` attribute with multiple tokens within parentheses"
185 );
186 let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
187 panic!("`key` attribute with a non-literal argument")
188 };
189 let literal = literal.to_string();
190 assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
191 assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
192 serialization_field_name =
193 Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
194 }
195 _ => {}
196 }
197 }
198 }
199
200 if let TokenTree::Ident(ident) = item {
201 let ident = ident.to_string();
202 if ident == "pub" {
204 continue;
205 }
206 field_name = Some(ident);
207 serialization_field_name = serialization_field_name.or(field_name.clone());
209 break;
210 }
211 }
212 let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
213 let serialization_field_name =
214 serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
215
216 if !skip {
217 fields.push((field_name, serialization_field_name));
218 }
219
220 skip_comma_delimited(&mut struct_body);
222 }
223
224 Struct { generic_bounds, generics, name, fields }
225}
226
227#[proc_macro_derive(JsonDeserialize, attributes(key, skip))]
239pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
240 let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
241
242 let mut largest_key = 0;
243 let mut fields_deserialization = String::new();
244 for (field_name, serialization_field_name) in &fields {
245 largest_key = largest_key.max(serialization_field_name.len());
246
247 let mut serialization_field_name_array = "&[".to_string();
248 for char in serialization_field_name.chars() {
249 serialization_field_name_array.push('\'');
250 serialization_field_name_array.push_str(&char.escape_unicode().to_string());
251 serialization_field_name_array.push('\'');
252 serialization_field_name_array.push(',');
253 }
254 serialization_field_name_array.push(']');
255
256 fields_deserialization.push_str(&format!(
257 r#"
258 {serialization_field_name_array} => {{
259 result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
260 }},
261 "#
262 ));
263 }
264
265 TokenStream::from_str(&format!(
266 r#"
267 impl{generic_bounds} core_json_traits::JsonDeserialize for {name}{generics}
268 where Self: core::default::Default {{
269 fn deserialize<
270 'read,
271 'parent,
272 B: core_json_traits::Read<'read>,
273 S: core_json_traits::Stack,
274 >(
275 value: core_json_traits::Value<'read, 'parent, B, S>,
276 ) -> Result<Self, core_json_traits::JsonError<'read, B, S>> {{
277 use core::default::Default;
278
279 let mut result = Self::default();
280 if {largest_key} == 0 {{
281 return Ok(result);
282 }}
283
284 let mut key_chars = ['\0'; {largest_key}];
285 let mut object = value.fields()?;
286 'serialized_field: while let Some(field) = object.next() {{
287 let mut field = field?;
288
289 let key = {{
290 let key = field.key();
291 let mut key_len = 0;
292 while let Some(key_char) = key.next() {{
293 key_chars[key_len] = match key_char {{
294 Ok(key_char) => key_char,
295 /*
296 This occurs when the key specifies an invalid UTF codepoint, which is technically
297 allowed by RFC 8259. While it means we can't interpret the key, it also means
298 this isn't a field we're looking for.
299
300 Continue to the next serialized field accordingly.
301 */
302 Err(core_json_traits::JsonError::InvalidValue) => continue 'serialized_field,
303 // Propagate all other errors.
304 Err(e) => Err(e)?,
305 }};
306 key_len += 1;
307 if key_len == {largest_key} {{
308 break;
309 }}
310 }}
311 match key.next() {{
312 None => {{}},
313 // This key is larger than our largest key
314 Some(Ok(_)) => continue,
315 Some(Err(e)) => Err(e)?,
316 }}
317 &key_chars[.. key_len]
318 }};
319 let value = field.value();
320
321 match key {{
322 {fields_deserialization}
323 // Skip unknown fields
324 _ => {{}}
325 }}
326 }}
327
328 Ok(result)
329 }}
330 }}
331 impl{generic_bounds} core_json_traits::JsonStructure for {name}{generics}
332 where Self: core::default::Default {{}}
333 "#
334 ))
335 .expect("typo in implementation of `JsonDeserialize`")
336}
337
338#[proc_macro_derive(JsonSerialize, attributes(key, skip))]
346pub fn derive_json_serialize(object: TokenStream) -> TokenStream {
347 let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
348
349 let mut fields_serialization = String::new();
350 for (i, (field_name, serialization_field_name)) in fields.iter().enumerate() {
351 let comma = if (i + 1) == fields.len() { "" } else { r#".chain(core::iter::once(','))"# };
352
353 fields_serialization.push_str(&format!(
354 r#"
355 .chain("{serialization_field_name}".serialize())
356 .chain(core::iter::once(':'))
357 .chain(core_json_traits::JsonSerialize::serialize(&self.{field_name}))
358 {comma}
359 "#
360 ));
361 }
362
363 TokenStream::from_str(&format!(
364 r#"
365 impl{generic_bounds} core_json_traits::JsonSerialize for {name}{generics} {{
366 fn serialize(&self) -> impl Iterator<Item = char> {{
367 core::iter::once('{{')
368 {fields_serialization}
369 .chain(core::iter::once('}}'))
370 }}
371 }}
372 "#
373 ))
374 .expect("typo in implementation of `JsonSerialize`")
375}