Skip to main content

reifydb_routine/procedure/rql/
ast.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::sync::LazyLock;
5
6use bumpalo::Bump;
7use reifydb_core::value::column::{ColumnWithName, columns::Columns};
8use reifydb_rql::{
9	ast::{
10		ast::{Ast, AstAlter, AstCreate, AstFrom, AstJoin},
11		identifier::UnqualifiedIdentifier,
12		parse_str,
13	},
14	bump::BumpFragment,
15	token::token::{Token, TokenKind},
16};
17use reifydb_type::value::r#type::Type;
18
19use crate::{
20	procedure::rql::extract_query,
21	routine::{Routine, RoutineInfo, context::ProcedureContext, error::RoutineError},
22};
23
24static INFO: LazyLock<RoutineInfo> = LazyLock::new(|| RoutineInfo::new("rql::ast"));
25
26pub struct RqlAst;
27
28impl Default for RqlAst {
29	fn default() -> Self {
30		Self::new()
31	}
32}
33
34impl RqlAst {
35	pub fn new() -> Self {
36		Self
37	}
38}
39
40impl<'a, 'tx> Routine<ProcedureContext<'a, 'tx>> for RqlAst {
41	fn info(&self) -> &RoutineInfo {
42		&INFO
43	}
44
45	fn return_type(&self, _input_types: &[Type]) -> Type {
46		Type::Any
47	}
48
49	fn attaches_row_metadata(&self) -> bool {
50		false
51	}
52
53	fn execute(&self, ctx: &mut ProcedureContext<'a, 'tx>, _args: &Columns) -> Result<Columns, RoutineError> {
54		let query = extract_query(ctx.params, "rql::ast")?;
55
56		let bump = Bump::new();
57		let statements = parse_str(&bump, query.as_str())?;
58
59		let mut walker = AstWalker::default();
60		for statement in &statements {
61			for node in &statement.nodes {
62				walker.walk(node, 0);
63			}
64		}
65
66		Ok(walker.into_columns())
67	}
68}
69
70#[derive(Default)]
71struct AstWalker {
72	idx: Vec<i32>,
73	depth: Vec<i32>,
74	kind: Vec<String>,
75	detail: Vec<String>,
76}
77
78impl AstWalker {
79	fn emit(&mut self, depth: i32, kind: &str, detail: String) {
80		let next = self.idx.len() as i32;
81		self.idx.push(next);
82		self.depth.push(depth);
83		self.kind.push(kind.to_string());
84		self.detail.push(detail);
85	}
86
87	fn into_columns(self) -> Columns {
88		Columns::new(vec![
89			ColumnWithName::int4("idx", self.idx),
90			ColumnWithName::int4("depth", self.depth),
91			ColumnWithName::utf8("kind", self.kind),
92			ColumnWithName::utf8("detail", self.detail),
93		])
94	}
95
96	fn walk(&mut self, ast: &Ast<'_>, depth: i32) {
97		let token = ast.token();
98		let fragment = &token.fragment;
99		let kind = ast_kind(ast);
100		let description = ast_description(ast, kind);
101		let detail = format!(
102			"{} @ line {}, column {} - \"{}\"",
103			description,
104			fragment.line().0,
105			fragment.column().0,
106			fragment.text()
107		);
108		self.emit(depth, kind, detail);
109
110		let mut owned_children: Vec<Ast<'_>> = vec![];
111		let mut ref_children: Vec<&Ast<'_>> = vec![];
112
113		match ast {
114			Ast::Tuple(t) => {
115				for node in &t.nodes {
116					ref_children.push(node);
117				}
118			}
119			Ast::Prefix(p) => ref_children.push(&p.node),
120			Ast::Cast(c) => {
121				for node in &c.tuple.nodes {
122					ref_children.push(node);
123				}
124			}
125			Ast::Filter(f) => ref_children.push(&f.node),
126			Ast::Gate(f) => ref_children.push(&f.node),
127			Ast::From(from) => match from {
128				AstFrom::Source {
129					source,
130					index_name,
131					..
132				} => {
133					let source_token = Token {
134						kind: TokenKind::Identifier,
135						fragment: source.name,
136					};
137					owned_children.push(Ast::Identifier(UnqualifiedIdentifier::new(source_token)));
138
139					if let Some(index) = index_name {
140						let index_token = Token {
141							kind: TokenKind::Identifier,
142							fragment: *index,
143						};
144						owned_children
145							.push(Ast::Identifier(UnqualifiedIdentifier::new(index_token)));
146					}
147				}
148				AstFrom::Inline {
149					list: query,
150					..
151				} => {
152					for node in &query.nodes {
153						ref_children.push(node);
154					}
155				}
156				AstFrom::Generator(generator_func) => {
157					for node in &generator_func.nodes {
158						ref_children.push(node);
159					}
160				}
161				AstFrom::Variable {
162					variable,
163					..
164				} => {
165					let variable_token = Token {
166						kind: TokenKind::Variable,
167						fragment: variable.token.fragment,
168					};
169					owned_children
170						.push(Ast::Identifier(UnqualifiedIdentifier::new(variable_token)));
171				}
172				AstFrom::Environment {
173					..
174				} => {
175					let env_token = Token {
176						kind: TokenKind::Variable,
177						fragment: BumpFragment::Internal {
178							text: "env",
179						},
180					};
181					owned_children.push(Ast::Identifier(UnqualifiedIdentifier::new(env_token)));
182				}
183			},
184			Ast::Aggregate(a) => {
185				if !a.map.is_empty() {
186					self.emit(depth + 1, "AggregateMap", "Aggregate Map".to_string());
187					for child in a.map.iter() {
188						self.walk(child, depth + 2);
189					}
190				}
191				if !a.by.is_empty() {
192					self.emit(depth + 1, "AggregateBy", "Aggregate By".to_string());
193					for child in a.by.iter() {
194						self.walk(child, depth + 2);
195					}
196				} else if a.map.is_empty() {
197					self.emit(depth + 1, "AggregateBy", "Aggregate By".to_string());
198				}
199				return;
200			}
201			Ast::Insert(_) => {
202				return;
203			}
204			Ast::Join(AstJoin::LeftJoin {
205				with,
206				using_clause,
207				..
208			}) => {
209				for node in &with.statement.nodes {
210					ref_children.push(node);
211				}
212				for pair in &using_clause.pairs {
213					ref_children.push(&pair.first);
214					ref_children.push(&pair.second);
215				}
216			}
217			Ast::Map(s) => {
218				for node in &s.nodes {
219					ref_children.push(node);
220				}
221			}
222			Ast::Generator(s) => {
223				for node in &s.nodes {
224					ref_children.push(node);
225				}
226			}
227			Ast::Sort(_o) => {}
228			Ast::Inline(r) => {
229				for field in &r.keyed_values {
230					owned_children.push(Ast::Identifier(field.key));
231					ref_children.push(&field.value);
232				}
233			}
234			Ast::Infix(i) => {
235				ref_children.push(&i.left);
236				ref_children.push(&i.right);
237			}
238			Ast::Alter(_) => {
239				return;
240			}
241			Ast::Patch(p) => {
242				for node in &p.assignments {
243					ref_children.push(node);
244				}
245			}
246			Ast::Assert(a) => {
247				if let Some(ref node) = a.node {
248					ref_children.push(node);
249				}
250			}
251			Ast::SubQuery(sq) => {
252				for node in &sq.statement.nodes {
253					ref_children.push(node);
254				}
255			}
256			_ => {}
257		}
258
259		if !owned_children.is_empty() && !ref_children.is_empty() {
260			let total = owned_children.len() + ref_children.len();
261			let mut oi = 0;
262			let mut ri = 0;
263			for i in 0..total {
264				if i % 2 == 0 && oi < owned_children.len() {
265					self.walk(&owned_children[oi], depth + 1);
266					oi += 1;
267				} else if ri < ref_children.len() {
268					self.walk(ref_children[ri], depth + 1);
269					ri += 1;
270				} else if oi < owned_children.len() {
271					self.walk(&owned_children[oi], depth + 1);
272					oi += 1;
273				}
274			}
275		} else if !owned_children.is_empty() {
276			for child in owned_children.iter() {
277				self.walk(child, depth + 1);
278			}
279		} else {
280			for child in ref_children.iter() {
281				self.walk(child, depth + 1);
282			}
283		}
284	}
285}
286
287fn ast_kind(ast: &Ast<'_>) -> &'static str {
288	match ast {
289		Ast::Aggregate(_) => "Aggregate",
290		Ast::Between(_) => "Between",
291		Ast::Block(_) => "Block",
292		Ast::Break(_) => "Break",
293		Ast::CallFunction(_) => "CallFunction",
294		Ast::Continue(_) => "Continue",
295		Ast::Inline(_) => "Row",
296		Ast::Cast(_) => "Cast",
297		Ast::Create(_) => "Create",
298		Ast::Alter(_) => "Alter",
299		Ast::Drop(_) => "Drop",
300		Ast::Describe(_) => "Describe",
301		Ast::Filter(_) => "Filter",
302		Ast::Gate(_) => "Gate",
303		Ast::For(_) => "For",
304		Ast::From(_) => "From",
305		Ast::Identifier(_) => "Identifier",
306		Ast::If(_) => "If",
307		Ast::Infix(_) => "Infix",
308		Ast::Let(_) => "Let",
309		Ast::Loop(_) => "Loop",
310		Ast::Delete(_) => "Delete",
311		Ast::Insert(_) => "Insert",
312		Ast::Update(_) => "Update",
313		Ast::Join(_) => "Join",
314		Ast::List(_) => "List",
315		Ast::Literal(_) => "Literal",
316		Ast::Nop => "Nop",
317		Ast::Sort(_) => "Sort",
318		Ast::Prefix(_) => "Prefix",
319		Ast::Map(_) => "Map",
320		Ast::Generator(_) => "Generator",
321		Ast::Extend(_) => "Extend",
322		Ast::Patch(_) => "Patch",
323		Ast::Take(_) => "Take",
324		Ast::Tuple(_) => "Tuple",
325		Ast::While(_) => "While",
326		Ast::Wildcard(_) => "Wildcard",
327		Ast::Variable(_) => "Variable",
328		Ast::Distinct(_) => "Distinct",
329		Ast::Apply(_) => "Apply",
330		Ast::Call(_) => "Call",
331		Ast::SubQuery(_) => "SubQuery",
332		Ast::Window(_) => "Window",
333		Ast::StatementExpression(_) => "StatementExpression",
334		Ast::Environment(_) => "Environment",
335		Ast::Rownum(_) => "Rownum",
336		Ast::SystemColumn(_) => "SystemColumn",
337		Ast::DefFunction(_) => "DefFunction",
338		Ast::Return(_) => "Return",
339		Ast::Append(_) => "Append",
340		Ast::Assert(_) => "Assert",
341		Ast::SumTypeConstructor(_) => "SumTypeConstructor",
342		Ast::IsVariant(_) => "IsVariant",
343		Ast::Match(_) => "Match",
344		Ast::Closure(_) => "Closure",
345		Ast::Dispatch(_) => "Dispatch",
346		Ast::Grant(_) => "Grant",
347		Ast::Revoke(_) => "Revoke",
348		Ast::Identity(_) => "Identity",
349		Ast::Require(_) => "Require",
350		Ast::Migrate(_) => "Migrate",
351		Ast::RollbackMigration(_) => "RollbackMigration",
352		Ast::RunTests(_) => "RunTests",
353	}
354}
355
356fn ast_description(ast: &Ast<'_>, kind: &str) -> String {
357	match ast {
358		Ast::Inline(r) => {
359			let field_names: Vec<&str> = r.keyed_values.iter().map(|f| f.key.text()).collect();
360			format!("{} ({} fields: {})", kind, r.keyed_values.len(), field_names.join(", "))
361		}
362		Ast::Alter(alter) => match alter {
363			AstAlter::Sequence(s) => {
364				let namespace = s
365					.sequence
366					.namespace
367					.first()
368					.map(|sch| format!("{}.", sch.text()))
369					.unwrap_or_default();
370				format!("ALTER SEQUENCE {}{}.{}", namespace, s.sequence.name.text(), s.column.text())
371			}
372			AstAlter::Policy(sp) => {
373				format!("ALTER {:?} POLICY {}", sp.target_type, sp.name.text())
374			}
375			AstAlter::Table(t) => {
376				let namespace =
377					t.table.namespace.first().map(|s| format!("{}.", s.text())).unwrap_or_default();
378				format!("ALTER TABLE {}{}", namespace, t.table.name.text())
379			}
380			AstAlter::RemoteNamespace(ns) => {
381				format!(
382					"ALTER REMOTE NAMESPACE {}",
383					ns.namespace.segments.iter().map(|s| s.text()).collect::<Vec<_>>().join("::")
384				)
385			}
386		},
387		Ast::Create(create) => match create {
388			AstCreate::PrimaryKey(pk) => {
389				let namespace =
390					pk.table.namespace
391						.first()
392						.map(|s| format!("{}::", s.text()))
393						.unwrap_or_default();
394				format!("CREATE PRIMARY KEY ON {}{}", namespace, pk.table.name.text())
395			}
396			AstCreate::ColumnProperty(p) => {
397				format!("CREATE COLUMN POLICY ON {}", p.column.name.text())
398			}
399			_ => kind.to_string(),
400		},
401		_ => kind.to_string(),
402	}
403}