1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![no_std]
5
6use core::{borrow::Borrow, str::FromStr, iter::Peekable};
7
8extern crate alloc;
9use alloc::{
10 vec,
11 vec::Vec,
12 string::{String, ToString},
13 format,
14};
15
16extern crate proc_macro;
17use proc_macro::{Delimiter, Spacing, Punct, TokenTree, TokenStream};
18
19fn take_angle_expression(
21 iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>,
22) -> TokenStream {
23 {
24 let Some(peeked) = iter.peek() else { return TokenStream::default() };
25 let TokenTree::Punct(punct) = peeked.borrow() else { return TokenStream::default() };
26 if punct.as_char() != '<' {
27 return TokenStream::default();
28 }
29 }
30
31 let mut result = vec![];
32 let mut count = 0;
33 loop {
34 let item = iter.next().expect("`TokenTree` unexpectedly terminated when taking `< ... >`");
35 result.push(item.borrow().clone());
36 if let TokenTree::Punct(punct) = item.borrow() {
37 let punct = punct.as_char();
38 if punct == '<' {
39 count += 1;
40 }
41 if punct == '>' {
42 count -= 1;
43 }
44 if count == 0 {
45 break;
46 }
47 }
48 }
49 TokenStream::from_iter(result)
50}
51
52fn skip_comma_delimited(iter: &mut Peekable<impl Iterator<Item: Borrow<TokenTree>>>) {
54 loop {
55 take_angle_expression(iter);
56 let Some(item) = iter.next() else { return };
57 if let TokenTree::Punct(punct) = item.borrow() {
58 if punct.as_char() == ',' {
59 return;
60 }
61 }
62 }
63}
64
65struct Struct {
66 generic_bounds: String,
67 generics: String,
68 name: String,
69 fields: Vec<(String, String)>,
70}
71
72fn parse_struct(object: TokenStream) -> Struct {
74 let mut object = object.into_iter().peekable();
75
76 loop {
77 match object.peek() {
78 Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
79 let _ = object.next().expect("peeked but not present");
80 let TokenTree::Group(_) = object.next().expect("`#` but no `[ ... ]`") else {
81 panic!("`#` not followed by a `TokenTree::Group` for its `[ ... ]`")
82 };
83 }
84 _ => break,
85 }
86 }
87
88 match object.next() {
89 Some(TokenTree::Ident(ident)) if ident.to_string() == "struct" => {}
90 _ => panic!("`JsonDeserialize` wasn't applied to a `struct`"),
91 }
92 let name = match object.next() {
93 Some(TokenTree::Ident(ident)) => ident.to_string(),
94 _ => panic!("`JsonDeserialize` wasn't applied to a `struct` with a name"),
95 };
96
97 let generic_bounds_tree = take_angle_expression(&mut object);
98
99 let mut generics_tree = vec![];
100 {
101 let mut iter = generic_bounds_tree.clone().into_iter().peekable();
102 while let Some(component) = iter.next() {
103 if let TokenTree::Punct(punct) = &component {
105 if punct.as_char() == ':' {
106 skip_comma_delimited(&mut iter);
108 generics_tree.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
110 continue;
111 }
112 }
113 generics_tree.push(component);
115 }
116 }
117 if let Some(last) = generics_tree.last() {
120 match last {
121 TokenTree::Punct(punct) if punct.as_char() == '>' => {}
122 _ => generics_tree.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))),
123 }
124 }
125
126 let generic_bounds = generic_bounds_tree.to_string();
127 let generics = TokenStream::from_iter(generics_tree).to_string();
128
129 let Some(TokenTree::Group(struct_body)) = object.next() else {
131 panic!("`struct`'s name was not followed by its body");
132 };
133 if struct_body.delimiter() != Delimiter::Brace {
134 panic!("`JsonDeserialize` derivation applied to `struct` with anonymous fields");
135 }
136
137 let mut fields = vec![];
138
139 let mut struct_body = struct_body.stream().into_iter().peekable();
140 while struct_body.peek().is_some() {
142 let mut serialization_field_name = None;
144 let mut field_name = None;
145 let mut skip = false;
146 for item in &mut struct_body {
147 if let TokenTree::Group(group) = &item {
149 if group.delimiter() == Delimiter::Bracket {
150 let mut iter = group.stream().into_iter();
151 let ident = iter.next().and_then(|ident| match ident {
152 TokenTree::Ident(ident) => Some(ident.to_string()),
153 _ => None,
154 });
155 match ident.as_deref() {
156 Some("skip") => skip = true,
157 Some("key") => {
158 let TokenTree::Group(group) = iter.next().expect("`key` attribute without arguments")
159 else {
160 panic!("`key` attribute not followed with `(...)`")
161 };
162 assert_eq!(
163 group.delimiter(),
164 Delimiter::Parenthesis,
165 "`key` attribute with a non-parentheses group"
166 );
167 assert_eq!(
168 group.stream().into_iter().count(),
169 1,
170 "`key` attribute with multiple tokens within parentheses"
171 );
172 let TokenTree::Literal(literal) = group.stream().into_iter().next().unwrap() else {
173 panic!("`key` attribute with a non-literal argument")
174 };
175 let literal = literal.to_string();
176 assert_eq!(literal.chars().next().unwrap(), '"', "literal wasn't a string literal");
177 assert_eq!(literal.chars().last().unwrap(), '"', "literal wasn't a string literal");
178 serialization_field_name =
179 Some(literal.trim_start_matches('"').trim_end_matches('"').to_string());
180 }
181 _ => {}
182 }
183 }
184 }
185
186 if let TokenTree::Ident(ident) = item {
187 let ident = ident.to_string();
188 if ident == "pub" {
190 continue;
191 }
192 field_name = Some(ident);
193 serialization_field_name = serialization_field_name.or(field_name.clone());
195 break;
196 }
197 }
198 let field_name = field_name.expect("couldn't find the name of the field within the `struct`");
199 let serialization_field_name =
200 serialization_field_name.expect("`field_name` but no `serialization_field_name`?");
201
202 if !skip {
203 fields.push((field_name, serialization_field_name));
204 }
205
206 skip_comma_delimited(&mut struct_body);
208 }
209
210 Struct { generic_bounds, generics, name, fields }
211}
212
213#[proc_macro_derive(JsonDeserialize, attributes(key, skip))]
225pub fn derive_json_deserialize(object: TokenStream) -> TokenStream {
226 let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
227
228 let mut largest_key = 0;
229 let mut fields_deserialization = String::new();
230 for (field_name, serialization_field_name) in &fields {
231 largest_key = largest_key.max(serialization_field_name.len());
232
233 let mut serialization_field_name_array = "&[".to_string();
234 for char in serialization_field_name.chars() {
235 serialization_field_name_array.push('\'');
236 serialization_field_name_array.push_str(&char.escape_unicode().to_string());
237 serialization_field_name_array.push('\'');
238 serialization_field_name_array.push(',');
239 }
240 serialization_field_name_array.push(']');
241
242 fields_deserialization.push_str(&format!(
243 r#"
244 {serialization_field_name_array} => {{
245 result.{field_name} = core_json_traits::JsonDeserialize::deserialize(value)?
246 }},
247 "#
248 ));
249 }
250
251 TokenStream::from_str(&format!(
252 r#"
253 impl{generic_bounds} core_json_traits::JsonDeserialize for {name}{generics}
254 where Self: core::default::Default {{
255 fn deserialize<
256 'bytes,
257 'parent,
258 B: core_json_traits::BytesLike<'bytes>,
259 S: core_json_traits::Stack,
260 >(
261 value: core_json_traits::Value<'bytes, 'parent, B, S>,
262 ) -> Result<Self, core_json_traits::JsonError<'bytes, B, S>> {{
263 use core::default::Default;
264
265 let mut result = Self::default();
266 if {largest_key} == 0 {{
267 return Ok(result);
268 }}
269
270 let mut key_chars = ['\0'; {largest_key}];
271 let mut object = value.fields()?;
272 'serialized_field: while let Some(field) = object.next() {{
273 let (mut key, value) = field?;
274
275 let key = {{
276 let mut key_len = 0;
277 while let Some(key_char) = key.next() {{
278 key_chars[key_len] = match key_char {{
279 Ok(key_char) => key_char,
280 /*
281 This occurs when the key specifies an invalid UTF codepoint, which is technically
282 allowed by RFC 8259. While it means we can't interpret the key, it also means
283 this isn't a field we're looking for.
284
285 Continue to the next serialized field accordingly.
286 */
287 Err(core_json_traits::JsonError::InvalidValue) => continue 'serialized_field,
288 // Propagate all other errors.
289 Err(e) => Err(e)?,
290 }};
291 key_len += 1;
292 if key_len == {largest_key} {{
293 break;
294 }}
295 }}
296 match key.next() {{
297 None => {{}},
298 // This key is larger than our largest key
299 Some(Ok(_)) => continue,
300 Some(Err(e)) => Err(e)?,
301 }}
302 &key_chars[.. key_len]
303 }};
304
305 match key {{
306 {fields_deserialization}
307 // Skip unknown fields
308 _ => {{}}
309 }}
310 }}
311
312 Ok(result)
313 }}
314 }}
315 impl{generic_bounds} core_json_traits::JsonStructure for {name}{generics}
316 where Self: core::default::Default {{}}
317 "#
318 ))
319 .expect("typo in implementation of `JsonDeserialize`")
320}
321
322#[proc_macro_derive(JsonSerialize, attributes(key, skip))]
330pub fn derive_json_serialize(object: TokenStream) -> TokenStream {
331 let Struct { generic_bounds, generics, name, fields } = parse_struct(object);
332
333 let mut fields_serialization = String::new();
334 for (i, (field_name, serialization_field_name)) in fields.iter().enumerate() {
335 let comma = if (i + 1) == fields.len() { "" } else { r#".chain(core::iter::once(','))"# };
336
337 fields_serialization.push_str(&format!(
338 r#"
339 .chain("{serialization_field_name}".serialize())
340 .chain(core::iter::once(':'))
341 .chain(core_json_traits::JsonSerialize::serialize(&self.{field_name}))
342 {comma}
343 "#
344 ));
345 }
346
347 TokenStream::from_str(&format!(
348 r#"
349 impl{generic_bounds} core_json_traits::JsonSerialize for {name}{generics} {{
350 fn serialize(&self) -> impl Iterator<Item = char> {{
351 core::iter::once('{{')
352 {fields_serialization}
353 .chain(core::iter::once('}}'))
354 }}
355 }}
356 "#
357 ))
358 .expect("typo in implementation of `JsonSerialize`")
359}