Skip to main content

reifydb_macro_impl/
parse.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4//! Token parsing for derive macro input.
5
6use proc_macro2::{Delimiter, Group, Ident, TokenStream, TokenTree};
7
8use crate::generate::compile_error;
9
10/// Parsed struct information.
11pub struct ParsedStruct {
12	pub name: Ident,
13	pub fields: Vec<ParsedField>,
14	pub crate_path: String,
15}
16
17/// Parsed field information.
18pub struct ParsedField {
19	pub name: Ident,
20	pub ty: Vec<TokenTree>,
21	pub attrs: FieldAttrs,
22}
23
24/// Field attributes from #[frame(...)].
25#[derive(Default)]
26pub struct FieldAttrs {
27	pub column_name: Option<String>,
28	pub optional: bool,
29	pub coerce: bool,
30	pub skip: bool,
31}
32
33/// Parse a derive macro input into a ParsedStruct.
34pub fn parse_struct(input: TokenStream) -> Result<ParsedStruct, TokenStream> {
35	parse_struct_with_crate(input, "reifydb_type")
36}
37
38/// Parse a derive macro input into a ParsedStruct with a specific crate path.
39pub fn parse_struct_with_crate(input: TokenStream, crate_path: &str) -> Result<ParsedStruct, TokenStream> {
40	let tokens: Vec<TokenTree> = input.into_iter().collect();
41	let mut iter = tokens.iter().peekable();
42	let crate_path = crate_path.to_string();
43
44	// Skip attributes on the struct itself
45	while let Some(TokenTree::Punct(p)) = iter.peek() {
46		if p.as_char() == '#' {
47			iter.next(); // #
48			if let Some(TokenTree::Group(_)) = iter.peek() {
49				iter.next(); // [...]
50			}
51		} else {
52			break;
53		}
54	}
55
56	// Skip visibility (pub, pub(crate), etc.)
57	if let Some(TokenTree::Ident(i)) = iter.peek()
58		&& *i == "pub"
59	{
60		iter.next();
61		// Handle pub(crate), pub(super), etc.
62		if let Some(TokenTree::Group(g)) = iter.peek()
63			&& g.delimiter() == Delimiter::Parenthesis
64		{
65			iter.next();
66		}
67	}
68
69	// Expect "struct"
70	match iter.next() {
71		Some(TokenTree::Ident(i)) if *i == "struct" => {}
72		_ => return Err(compile_error("FromFrame can only be derived for structs")),
73	}
74
75	// Get struct name
76	let name = match iter.next() {
77		Some(TokenTree::Ident(i)) => i.clone(),
78		_ => return Err(compile_error("expected struct name")),
79	};
80
81	// Find the fields group (skip generics if present)
82	let fields_group = loop {
83		match iter.next() {
84			Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => {
85				break g.clone();
86			}
87			Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
88				// Skip generics - for now we don't support generic structs
89				let mut depth = 1;
90				while depth > 0 {
91					match iter.next() {
92						Some(TokenTree::Punct(p)) if p.as_char() == '<' => depth += 1,
93						Some(TokenTree::Punct(p)) if p.as_char() == '>' => depth -= 1,
94						None => {
95							return Err(compile_error("unclosed generic parameters"));
96						}
97						_ => {}
98					}
99				}
100			}
101			Some(TokenTree::Ident(i)) if *i == "where" => {
102				// Skip where clause until we hit the brace
103				continue;
104			}
105			None => return Err(compile_error("expected struct body")),
106			_ => continue,
107		}
108	};
109
110	// Parse fields
111	let fields = parse_fields(fields_group)?;
112
113	// Check for duplicate column aliases
114	for (i, field) in fields.iter().enumerate() {
115		if field.attrs.skip {
116			continue;
117		}
118		let col_name = field.column_name();
119		for other in fields.iter().skip(i + 1) {
120			if other.attrs.skip {
121				continue;
122			}
123			if other.column_name() == col_name {
124				return Err(compile_error(&format!(
125					"duplicate column alias '{}': used by both '{}' and '{}'",
126					col_name,
127					field.safe_name(),
128					other.safe_name()
129				)));
130			}
131		}
132	}
133
134	Ok(ParsedStruct {
135		name,
136		fields,
137		crate_path,
138	})
139}
140
141/// Parse the fields from a struct body group.
142fn parse_fields(group: Group) -> Result<Vec<ParsedField>, TokenStream> {
143	let tokens: Vec<TokenTree> = group.stream().into_iter().collect();
144	let mut fields = Vec::new();
145	let mut iter = tokens.iter().peekable();
146
147	while iter.peek().is_some() {
148		// Collect attributes
149		let mut attrs_tokens = Vec::new();
150		while let Some(TokenTree::Punct(p)) = iter.peek() {
151			if p.as_char() == '#' {
152				iter.next(); // #
153				if let Some(TokenTree::Group(g)) = iter.next() {
154					attrs_tokens.push(g.clone());
155				}
156			} else {
157				break;
158			}
159		}
160
161		// Skip visibility
162		if let Some(TokenTree::Ident(i)) = iter.peek()
163			&& *i == "pub"
164		{
165			iter.next();
166			if let Some(TokenTree::Group(g)) = iter.peek()
167				&& g.delimiter() == Delimiter::Parenthesis
168			{
169				iter.next();
170			}
171		}
172
173		// Get field name
174		let field_name = match iter.next() {
175			Some(TokenTree::Ident(i)) => i.clone(),
176			None => break, // End of fields
177			_ => continue, // Skip unexpected tokens
178		};
179
180		// Expect colon
181		match iter.next() {
182			Some(TokenTree::Punct(p)) if p.as_char() == ':' => {}
183			_ => return Err(compile_error("expected ':' after field name")),
184		}
185
186		// Collect type tokens until comma or end
187		let mut ty_tokens = Vec::new();
188		let mut depth = 0;
189		loop {
190			match iter.peek() {
191				Some(TokenTree::Punct(p)) if p.as_char() == ',' && depth == 0 => {
192					iter.next(); // consume comma
193					break;
194				}
195				Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
196					depth += 1;
197					ty_tokens.push(iter.next().unwrap().clone());
198				}
199				Some(TokenTree::Punct(p)) if p.as_char() == '>' => {
200					depth -= 1;
201					ty_tokens.push(iter.next().unwrap().clone());
202				}
203				Some(t) => {
204					ty_tokens.push((*t).clone());
205					iter.next();
206				}
207				None => break,
208			}
209		}
210
211		if ty_tokens.is_empty() {
212			return Err(compile_error("expected field type"));
213		}
214
215		let attrs = parse_field_attrs(&attrs_tokens);
216
217		fields.push(ParsedField {
218			name: field_name,
219			ty: ty_tokens,
220			attrs,
221		});
222	}
223
224	Ok(fields)
225}
226
227/// Parse #[frame(...)] attributes from a list of attribute groups.
228fn parse_field_attrs(attr_groups: &[Group]) -> FieldAttrs {
229	let mut result = FieldAttrs::default();
230
231	for group in attr_groups {
232		let tokens: Vec<TokenTree> = group.stream().into_iter().collect();
233		let mut iter = tokens.iter().peekable();
234
235		// Check if this is a #[frame(...)] attribute
236		match iter.next() {
237			Some(TokenTree::Ident(i)) if *i == "frame" => {}
238			_ => continue,
239		}
240
241		// Get the inner group
242		let inner = match iter.next() {
243			Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => g,
244			_ => continue,
245		};
246
247		// Parse inner tokens
248		let inner_tokens: Vec<TokenTree> = inner.stream().into_iter().collect();
249		let mut inner_iter = inner_tokens.iter().peekable();
250
251		while inner_iter.peek().is_some() {
252			// Get attribute name
253			let attr_name = match inner_iter.next() {
254				Some(TokenTree::Ident(i)) => i.to_string(),
255				_ => continue,
256			};
257
258			match attr_name.as_str() {
259				"column" => {
260					// Expect = "value"
261					if let Some(TokenTree::Punct(p)) = inner_iter.next()
262						&& p.as_char() == '=' && let Some(TokenTree::Literal(lit)) =
263						inner_iter.next()
264					{
265						let s = lit.to_string();
266						// Remove quotes
267						if s.starts_with('"') && s.ends_with('"') {
268							result.column_name = Some(s[1..s.len() - 1].to_string());
269						}
270					}
271				}
272				"optional" => result.optional = true,
273				"coerce" => result.coerce = true,
274				"skip" => result.skip = true,
275				_ => {}
276			}
277
278			// Skip comma if present
279			if let Some(TokenTree::Punct(p)) = inner_iter.peek()
280				&& p.as_char() == ','
281			{
282				inner_iter.next();
283			}
284		}
285	}
286
287	result
288}
289
290impl ParsedField {
291	/// Get the column name, using the field name if not explicitly set.
292	/// Strips r# prefix from raw identifiers.
293	pub fn column_name(&self) -> String {
294		if let Some(ref name) = self.attrs.column_name {
295			name.clone()
296		} else {
297			self.safe_name()
298		}
299	}
300
301	/// Get a safe variable name (strips r# prefix from raw identifiers).
302	pub fn safe_name(&self) -> String {
303		let name = self.name.to_string();
304		name.strip_prefix("r#").unwrap_or(&name).to_string()
305	}
306}