reifydb_macro_impl/
parse.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT, see license.md file
3
4//! Token parsing for derive macro input.
5
6use proc_macro2::{Delimiter, Group, Ident, TokenStream, TokenTree};
7
8use crate::generate::compile_error;
9
10/// Parsed struct information.
11pub struct ParsedStruct {
12	pub name: Ident,
13	pub fields: Vec<ParsedField>,
14	pub crate_path: String,
15}
16
17/// Parsed field information.
18pub struct ParsedField {
19	pub name: Ident,
20	pub ty: Vec<TokenTree>,
21	pub attrs: FieldAttrs,
22}
23
24/// Field attributes from #[frame(...)].
25#[derive(Default)]
26pub struct FieldAttrs {
27	pub column_name: Option<String>,
28	pub optional: bool,
29	pub coerce: bool,
30	pub skip: bool,
31}
32
33/// Parse a derive macro input into a ParsedStruct.
34pub fn parse_struct(input: TokenStream) -> Result<ParsedStruct, TokenStream> {
35	parse_struct_with_crate(input, "reifydb_type")
36}
37
38/// Parse a derive macro input into a ParsedStruct with a specific crate path.
39pub fn parse_struct_with_crate(input: TokenStream, crate_path: &str) -> Result<ParsedStruct, TokenStream> {
40	let tokens: Vec<TokenTree> = input.into_iter().collect();
41	let mut iter = tokens.iter().peekable();
42	let crate_path = crate_path.to_string();
43
44	// Skip attributes on the struct itself
45	while let Some(TokenTree::Punct(p)) = iter.peek() {
46		if p.as_char() == '#' {
47			iter.next(); // #
48			if let Some(TokenTree::Group(_)) = iter.peek() {
49				iter.next(); // [...]
50			}
51		} else {
52			break;
53		}
54	}
55
56	// Skip visibility (pub, pub(crate), etc.)
57	if let Some(TokenTree::Ident(i)) = iter.peek() {
58		if i.to_string() == "pub" {
59			iter.next();
60			// Handle pub(crate), pub(super), etc.
61			if let Some(TokenTree::Group(g)) = iter.peek() {
62				if g.delimiter() == Delimiter::Parenthesis {
63					iter.next();
64				}
65			}
66		}
67	}
68
69	// Expect "struct"
70	match iter.next() {
71		Some(TokenTree::Ident(i)) if i.to_string() == "struct" => {}
72		_ => return Err(compile_error("FromFrame can only be derived for structs")),
73	}
74
75	// Get struct name
76	let name = match iter.next() {
77		Some(TokenTree::Ident(i)) => i.clone(),
78		_ => return Err(compile_error("expected struct name")),
79	};
80
81	// Find the fields group (skip generics if present)
82	let fields_group = loop {
83		match iter.next() {
84			Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => {
85				break g.clone();
86			}
87			Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
88				// Skip generics - for now we don't support generic structs
89				let mut depth = 1;
90				while depth > 0 {
91					match iter.next() {
92						Some(TokenTree::Punct(p)) if p.as_char() == '<' => depth += 1,
93						Some(TokenTree::Punct(p)) if p.as_char() == '>' => depth -= 1,
94						None => {
95							return Err(compile_error("unclosed generic parameters"));
96						}
97						_ => {}
98					}
99				}
100			}
101			Some(TokenTree::Ident(i)) if i.to_string() == "where" => {
102				// Skip where clause until we hit the brace
103				continue;
104			}
105			None => return Err(compile_error("expected struct body")),
106			_ => continue,
107		}
108	};
109
110	// Parse fields
111	let fields = parse_fields(fields_group)?;
112
113	// Check for duplicate column aliases
114	for (i, field) in fields.iter().enumerate() {
115		if field.attrs.skip {
116			continue;
117		}
118		let col_name = field.column_name();
119		for other in fields.iter().skip(i + 1) {
120			if other.attrs.skip {
121				continue;
122			}
123			if other.column_name() == col_name {
124				return Err(compile_error(&format!(
125					"duplicate column alias '{}': used by both '{}' and '{}'",
126					col_name,
127					field.safe_name(),
128					other.safe_name()
129				)));
130			}
131		}
132	}
133
134	Ok(ParsedStruct {
135		name,
136		fields,
137		crate_path,
138	})
139}
140
141/// Parse the fields from a struct body group.
142fn parse_fields(group: Group) -> Result<Vec<ParsedField>, TokenStream> {
143	let tokens: Vec<TokenTree> = group.stream().into_iter().collect();
144	let mut fields = Vec::new();
145	let mut iter = tokens.iter().peekable();
146
147	while iter.peek().is_some() {
148		// Collect attributes
149		let mut attrs_tokens = Vec::new();
150		while let Some(TokenTree::Punct(p)) = iter.peek() {
151			if p.as_char() == '#' {
152				iter.next(); // #
153				if let Some(TokenTree::Group(g)) = iter.next() {
154					attrs_tokens.push(g.clone());
155				}
156			} else {
157				break;
158			}
159		}
160
161		// Skip visibility
162		if let Some(TokenTree::Ident(i)) = iter.peek() {
163			if i.to_string() == "pub" {
164				iter.next();
165				if let Some(TokenTree::Group(g)) = iter.peek() {
166					if g.delimiter() == Delimiter::Parenthesis {
167						iter.next();
168					}
169				}
170			}
171		}
172
173		// Get field name
174		let field_name = match iter.next() {
175			Some(TokenTree::Ident(i)) => i.clone(),
176			None => break, // End of fields
177			_ => continue, // Skip unexpected tokens
178		};
179
180		// Expect colon
181		match iter.next() {
182			Some(TokenTree::Punct(p)) if p.as_char() == ':' => {}
183			_ => return Err(compile_error("expected ':' after field name")),
184		}
185
186		// Collect type tokens until comma or end
187		let mut ty_tokens = Vec::new();
188		let mut depth = 0;
189		loop {
190			match iter.peek() {
191				Some(TokenTree::Punct(p)) if p.as_char() == ',' && depth == 0 => {
192					iter.next(); // consume comma
193					break;
194				}
195				Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
196					depth += 1;
197					ty_tokens.push(iter.next().unwrap().clone());
198				}
199				Some(TokenTree::Punct(p)) if p.as_char() == '>' => {
200					depth -= 1;
201					ty_tokens.push(iter.next().unwrap().clone());
202				}
203				Some(t) => {
204					ty_tokens.push((*t).clone());
205					iter.next();
206				}
207				None => break,
208			}
209		}
210
211		if ty_tokens.is_empty() {
212			return Err(compile_error("expected field type"));
213		}
214
215		let attrs = parse_field_attrs(&attrs_tokens);
216
217		fields.push(ParsedField {
218			name: field_name,
219			ty: ty_tokens,
220			attrs,
221		});
222	}
223
224	Ok(fields)
225}
226
227/// Parse #[frame(...)] attributes from a list of attribute groups.
228fn parse_field_attrs(attr_groups: &[Group]) -> FieldAttrs {
229	let mut result = FieldAttrs::default();
230
231	for group in attr_groups {
232		let tokens: Vec<TokenTree> = group.stream().into_iter().collect();
233		let mut iter = tokens.iter().peekable();
234
235		// Check if this is a #[frame(...)] attribute
236		match iter.next() {
237			Some(TokenTree::Ident(i)) if i.to_string() == "frame" => {}
238			_ => continue,
239		}
240
241		// Get the inner group
242		let inner = match iter.next() {
243			Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => g,
244			_ => continue,
245		};
246
247		// Parse inner tokens
248		let inner_tokens: Vec<TokenTree> = inner.stream().into_iter().collect();
249		let mut inner_iter = inner_tokens.iter().peekable();
250
251		while inner_iter.peek().is_some() {
252			// Get attribute name
253			let attr_name = match inner_iter.next() {
254				Some(TokenTree::Ident(i)) => i.to_string(),
255				_ => continue,
256			};
257
258			match attr_name.as_str() {
259				"column" => {
260					// Expect = "value"
261					if let Some(TokenTree::Punct(p)) = inner_iter.next() {
262						if p.as_char() == '=' {
263							if let Some(TokenTree::Literal(lit)) = inner_iter.next() {
264								let s = lit.to_string();
265								// Remove quotes
266								if s.starts_with('"') && s.ends_with('"') {
267									result.column_name =
268										Some(s[1..s.len() - 1].to_string());
269								}
270							}
271						}
272					}
273				}
274				"optional" => result.optional = true,
275				"coerce" => result.coerce = true,
276				"skip" => result.skip = true,
277				_ => {}
278			}
279
280			// Skip comma if present
281			if let Some(TokenTree::Punct(p)) = inner_iter.peek() {
282				if p.as_char() == ',' {
283					inner_iter.next();
284				}
285			}
286		}
287	}
288
289	result
290}
291
292impl ParsedField {
293	/// Get the column name, using the field name if not explicitly set.
294	/// Strips r# prefix from raw identifiers.
295	pub fn column_name(&self) -> String {
296		if let Some(ref name) = self.attrs.column_name {
297			name.clone()
298		} else {
299			self.safe_name()
300		}
301	}
302
303	/// Get a safe variable name (strips r# prefix from raw identifiers).
304	pub fn safe_name(&self) -> String {
305		let name = self.name.to_string();
306		name.strip_prefix("r#").unwrap_or(&name).to_string()
307	}
308}