reifydb_macro_impl/
parse.rs1use proc_macro2::{Delimiter, Group, Ident, TokenStream, TokenTree};
7
8use crate::generate::compile_error;
9
10pub struct ParsedStruct {
12 pub name: Ident,
13 pub fields: Vec<ParsedField>,
14 pub crate_path: String,
15}
16
17pub struct ParsedField {
19 pub name: Ident,
20 pub ty: Vec<TokenTree>,
21 pub attrs: FieldAttrs,
22}
23
24#[derive(Default)]
26pub struct FieldAttrs {
27 pub column_name: Option<String>,
28 pub optional: bool,
29 pub coerce: bool,
30 pub skip: bool,
31}
32
33pub fn parse_struct(input: TokenStream) -> Result<ParsedStruct, TokenStream> {
35 parse_struct_with_crate(input, "reifydb_type")
36}
37
38pub fn parse_struct_with_crate(input: TokenStream, crate_path: &str) -> Result<ParsedStruct, TokenStream> {
40 let tokens: Vec<TokenTree> = input.into_iter().collect();
41 let mut iter = tokens.iter().peekable();
42 let crate_path = crate_path.to_string();
43
44 while let Some(TokenTree::Punct(p)) = iter.peek() {
46 if p.as_char() == '#' {
47 iter.next(); if let Some(TokenTree::Group(_)) = iter.peek() {
49 iter.next(); }
51 } else {
52 break;
53 }
54 }
55
56 if let Some(TokenTree::Ident(i)) = iter.peek()
58 && *i == "pub"
59 {
60 iter.next();
61 if let Some(TokenTree::Group(g)) = iter.peek()
63 && g.delimiter() == Delimiter::Parenthesis
64 {
65 iter.next();
66 }
67 }
68
69 match iter.next() {
71 Some(TokenTree::Ident(i)) if *i == "struct" => {}
72 _ => return Err(compile_error("FromFrame can only be derived for structs")),
73 }
74
75 let name = match iter.next() {
77 Some(TokenTree::Ident(i)) => i.clone(),
78 _ => return Err(compile_error("expected struct name")),
79 };
80
81 let fields_group = loop {
83 match iter.next() {
84 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => {
85 break g.clone();
86 }
87 Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
88 let mut depth = 1;
90 while depth > 0 {
91 match iter.next() {
92 Some(TokenTree::Punct(p)) if p.as_char() == '<' => depth += 1,
93 Some(TokenTree::Punct(p)) if p.as_char() == '>' => depth -= 1,
94 None => {
95 return Err(compile_error("unclosed generic parameters"));
96 }
97 _ => {}
98 }
99 }
100 }
101 Some(TokenTree::Ident(i)) if *i == "where" => {
102 continue;
104 }
105 None => return Err(compile_error("expected struct body")),
106 _ => continue,
107 }
108 };
109
110 let fields = parse_fields(fields_group)?;
112
113 for (i, field) in fields.iter().enumerate() {
115 if field.attrs.skip {
116 continue;
117 }
118 let col_name = field.column_name();
119 for other in fields.iter().skip(i + 1) {
120 if other.attrs.skip {
121 continue;
122 }
123 if other.column_name() == col_name {
124 return Err(compile_error(&format!(
125 "duplicate column alias '{}': used by both '{}' and '{}'",
126 col_name,
127 field.safe_name(),
128 other.safe_name()
129 )));
130 }
131 }
132 }
133
134 Ok(ParsedStruct {
135 name,
136 fields,
137 crate_path,
138 })
139}
140
141fn parse_fields(group: Group) -> Result<Vec<ParsedField>, TokenStream> {
143 let tokens: Vec<TokenTree> = group.stream().into_iter().collect();
144 let mut fields = Vec::new();
145 let mut iter = tokens.iter().peekable();
146
147 while iter.peek().is_some() {
148 let mut attrs_tokens = Vec::new();
150 while let Some(TokenTree::Punct(p)) = iter.peek() {
151 if p.as_char() == '#' {
152 iter.next(); if let Some(TokenTree::Group(g)) = iter.next() {
154 attrs_tokens.push(g.clone());
155 }
156 } else {
157 break;
158 }
159 }
160
161 if let Some(TokenTree::Ident(i)) = iter.peek()
163 && *i == "pub"
164 {
165 iter.next();
166 if let Some(TokenTree::Group(g)) = iter.peek()
167 && g.delimiter() == Delimiter::Parenthesis
168 {
169 iter.next();
170 }
171 }
172
173 let field_name = match iter.next() {
175 Some(TokenTree::Ident(i)) => i.clone(),
176 None => break, _ => continue, };
179
180 match iter.next() {
182 Some(TokenTree::Punct(p)) if p.as_char() == ':' => {}
183 _ => return Err(compile_error("expected ':' after field name")),
184 }
185
186 let mut ty_tokens = Vec::new();
188 let mut depth = 0;
189 loop {
190 match iter.peek() {
191 Some(TokenTree::Punct(p)) if p.as_char() == ',' && depth == 0 => {
192 iter.next(); break;
194 }
195 Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
196 depth += 1;
197 ty_tokens.push(iter.next().unwrap().clone());
198 }
199 Some(TokenTree::Punct(p)) if p.as_char() == '>' => {
200 depth -= 1;
201 ty_tokens.push(iter.next().unwrap().clone());
202 }
203 Some(t) => {
204 ty_tokens.push((*t).clone());
205 iter.next();
206 }
207 None => break,
208 }
209 }
210
211 if ty_tokens.is_empty() {
212 return Err(compile_error("expected field type"));
213 }
214
215 let attrs = parse_field_attrs(&attrs_tokens);
216
217 fields.push(ParsedField {
218 name: field_name,
219 ty: ty_tokens,
220 attrs,
221 });
222 }
223
224 Ok(fields)
225}
226
227fn parse_field_attrs(attr_groups: &[Group]) -> FieldAttrs {
229 let mut result = FieldAttrs::default();
230
231 for group in attr_groups {
232 let tokens: Vec<TokenTree> = group.stream().into_iter().collect();
233 let mut iter = tokens.iter().peekable();
234
235 match iter.next() {
237 Some(TokenTree::Ident(i)) if *i == "frame" => {}
238 _ => continue,
239 }
240
241 let inner = match iter.next() {
243 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => g,
244 _ => continue,
245 };
246
247 let inner_tokens: Vec<TokenTree> = inner.stream().into_iter().collect();
249 let mut inner_iter = inner_tokens.iter().peekable();
250
251 while inner_iter.peek().is_some() {
252 let attr_name = match inner_iter.next() {
254 Some(TokenTree::Ident(i)) => i.to_string(),
255 _ => continue,
256 };
257
258 match attr_name.as_str() {
259 "column" => {
260 if let Some(TokenTree::Punct(p)) = inner_iter.next()
262 && p.as_char() == '=' && let Some(TokenTree::Literal(lit)) =
263 inner_iter.next()
264 {
265 let s = lit.to_string();
266 if s.starts_with('"') && s.ends_with('"') {
268 result.column_name = Some(s[1..s.len() - 1].to_string());
269 }
270 }
271 }
272 "optional" => result.optional = true,
273 "coerce" => result.coerce = true,
274 "skip" => result.skip = true,
275 _ => {}
276 }
277
278 if let Some(TokenTree::Punct(p)) = inner_iter.peek()
280 && p.as_char() == ','
281 {
282 inner_iter.next();
283 }
284 }
285 }
286
287 result
288}
289
290impl ParsedField {
291 pub fn column_name(&self) -> String {
294 if let Some(ref name) = self.attrs.column_name {
295 name.clone()
296 } else {
297 self.safe_name()
298 }
299 }
300
301 pub fn safe_name(&self) -> String {
303 let name = self.name.to_string();
304 name.strip_prefix("r#").unwrap_or(&name).to_string()
305 }
306}