Skip to main content

reifydb_engine/vm/volcano/
extend.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{mem, sync::Arc};
5
6use reifydb_core::{
7	error::diagnostic::query::extend_duplicate_column,
8	interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
9	value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
10};
11use reifydb_extension::transform::{Transform, context::TransformContext};
12use reifydb_rql::expression::{Expression, name::display_label};
13use reifydb_transaction::transaction::Transaction;
14use reifydb_type::{fragment::Fragment, return_error, util::cowvec::CowVec};
15use tracing::instrument;
16
17use super::NoopNode;
18use crate::{
19	Result,
20	expression::{
21		cast::cast_column_data,
22		compile::{CompiledExpr, compile_expression},
23		context::{CompileContext, EvalContext},
24	},
25	vm::volcano::{
26		query::{QueryContext, QueryNode},
27		udf::{UdfEvalNode, evaluate_udfs_no_input, strip_udf_columns},
28	},
29};
30
31pub(crate) struct ExtendNode {
32	input: Box<dyn QueryNode>,
33	expressions: Vec<Expression>,
34	udf_names: Vec<String>,
35	headers: Option<ColumnHeaders>,
36	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
37}
38
39impl ExtendNode {
40	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
41		Self {
42			input,
43			expressions,
44			udf_names: Vec::new(),
45			headers: None,
46			context: None,
47		}
48	}
49}
50
51impl QueryNode for ExtendNode {
52	#[instrument(name = "volcano::extend::initialize", level = "trace", skip_all)]
53	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
54		let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
55			mem::replace(&mut self.input, Box::new(NoopNode)),
56			&self.expressions,
57			&ctx.symbols,
58		);
59		self.input = input;
60		self.expressions = expressions;
61		self.udf_names = udf_names;
62
63		let compile_ctx = CompileContext {
64			symbols: &ctx.symbols,
65		};
66		let compiled = self
67			.expressions
68			.iter()
69			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
70			.collect();
71		self.context = Some((Arc::new(ctx.clone()), compiled));
72		self.input.initialize(rx, ctx)?;
73		Ok(())
74	}
75
76	#[instrument(name = "volcano::extend::next", level = "trace", skip_all)]
77	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78		debug_assert!(self.context.is_some(), "ExtendNode::next() called before initialize()");
79
80		if let Some(columns) = self.input.next(rx, ctx)? {
81			let stored_ctx = &self.context.as_ref().unwrap().0;
82			let transform_ctx = TransformContext {
83				routines: &ctx.services.routines,
84				runtime_context: &stored_ctx.services.runtime_context,
85				params: &stored_ctx.params,
86			};
87			let result = self.apply(&transform_ctx, columns)?;
88
89			if self.headers.is_none() {
90				let mut all_headers = if let Some(input_headers) = self.input.headers() {
91					input_headers.columns.clone()
92				} else {
93					let input_column_count = result.len() - self.expressions.len();
94					result.iter().take(input_column_count).map(|c| c.name().clone()).collect()
95				};
96
97				let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
98				all_headers.extend(new_names);
99
100				self.headers = Some(ColumnHeaders {
101					columns: all_headers,
102				});
103			}
104
105			let mut result = result;
106			strip_udf_columns(&mut result, &self.udf_names);
107			return Ok(Some(result));
108		}
109		if self.headers.is_none()
110			&& let Some(input_headers) = self.input.headers()
111		{
112			let mut all_headers = input_headers.columns.clone();
113			let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
114
115			for new_name in &new_names {
116				for existing_name in &all_headers {
117					if new_name.text() == existing_name.text() {
118						return_error!(extend_duplicate_column(new_name.text()));
119					}
120				}
121			}
122			for i in 0..new_names.len() {
123				for j in (i + 1)..new_names.len() {
124					if new_names[i].text() == new_names[j].text() {
125						return_error!(extend_duplicate_column(new_names[i].text()));
126					}
127				}
128			}
129
130			all_headers.extend(new_names);
131			self.headers = Some(ColumnHeaders {
132				columns: all_headers,
133			});
134		}
135		Ok(None)
136	}
137
138	fn headers(&self) -> Option<ColumnHeaders> {
139		self.headers.clone().or(self.input.headers())
140	}
141}
142
143impl Transform for ExtendNode {
144	fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
145		let (stored_ctx, compiled) =
146			self.context.as_ref().expect("ExtendNode::apply() called before initialize()");
147
148		let row_count = input.row_count();
149		let row_numbers = input.row_numbers.to_vec();
150		let created_at = input.created_at.clone();
151		let updated_at = input.updated_at.clone();
152
153		// Collect existing column names for duplicate checking
154		let existing_names: Vec<Fragment> = input.iter().map(|c| c.name().clone()).collect();
155
156		let session = EvalContext::from_transform(ctx, stored_ctx);
157		let mut new_columns: Vec<ColumnWithName> = input
158			.names
159			.iter()
160			.zip(input.columns.iter())
161			.map(|(name, data)| ColumnWithName::new(name.clone(), data.clone()))
162			.collect();
163
164		let mut new_names = Vec::with_capacity(compiled.len());
165		for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
166			let mut exec_ctx = session.with_eval(Columns::new(new_columns.clone()), row_count);
167
168			if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
169				let alias_name = alias_expr.alias.name();
170				if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
171					let column_ident = Fragment::internal(&table_column.name);
172					let resolved_column =
173						ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
174					exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
175				}
176			}
177
178			let mut column = compiled_expr.execute(&exec_ctx)?;
179
180			if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
181				&& column.data.get_type() != target_type
182			{
183				let data =
184					cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
185				column = ColumnWithName {
186					name: column.name,
187					data,
188				};
189			}
190
191			new_columns.push(column);
192			new_names.push(display_label(expr));
193		}
194
195		// Validate no duplicate column names against existing columns
196		for new_name in &new_names {
197			for existing_name in &existing_names {
198				if new_name.text() == existing_name.text() {
199					return_error!(extend_duplicate_column(new_name.text()));
200				}
201			}
202		}
203
204		// Validate no duplicates within new columns
205		for i in 0..new_names.len() {
206			for j in (i + 1)..new_names.len() {
207				if new_names[i].text() == new_names[j].text() {
208					return_error!(extend_duplicate_column(new_names[i].text()));
209				}
210			}
211		}
212
213		let mut names_vec = Vec::with_capacity(new_columns.len());
214		let mut buffers_vec = Vec::with_capacity(new_columns.len());
215		for c in new_columns {
216			names_vec.push(c.name);
217			buffers_vec.push(c.data);
218		}
219		Ok(Columns {
220			row_numbers: CowVec::new(row_numbers),
221			created_at,
222			updated_at,
223			columns: CowVec::new(buffers_vec),
224			names: CowVec::new(names_vec),
225		})
226	}
227}
228
229pub(crate) struct ExtendWithoutInputNode {
230	expressions: Vec<Expression>,
231	headers: Option<ColumnHeaders>,
232	/// When UDFs are present, stores the pre-computed UDF result columns.
233	udf_columns: Option<Columns>,
234	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
235}
236
237impl ExtendWithoutInputNode {
238	pub fn new(expressions: Vec<Expression>) -> Self {
239		Self {
240			expressions,
241			headers: None,
242			udf_columns: None,
243			context: None,
244		}
245	}
246}
247
248impl QueryNode for ExtendWithoutInputNode {
249	#[instrument(name = "volcano::extend::noinput::initialize", level = "trace", skip_all)]
250	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
251		// Extract and evaluate UDFs if present
252		if let Some((rewritten, udf_cols)) = evaluate_udfs_no_input(&self.expressions, ctx, rx)? {
253			self.expressions = rewritten;
254			self.udf_columns = Some(udf_cols);
255		}
256
257		let compile_ctx = CompileContext {
258			symbols: &ctx.symbols,
259		};
260		let compiled = self
261			.expressions
262			.iter()
263			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
264			.collect();
265		self.context = Some((Arc::new(ctx.clone()), compiled));
266		Ok(())
267	}
268
269	#[instrument(name = "volcano::extend::noinput::next", level = "trace", skip_all)]
270	fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
271		debug_assert!(self.context.is_some(), "ExtendWithoutInputNode::next() called before initialize()");
272		let (stored_ctx, compiled) = self.context.as_ref().unwrap();
273
274		if self.headers.is_some() {
275			return Ok(None);
276		}
277
278		let session = EvalContext::from_query(stored_ctx);
279		let mut new_columns = Vec::with_capacity(self.expressions.len());
280
281		for compiled_expr in compiled {
282			// If we have UDF result columns, include them so __udf_N column refs resolve
283			let exec_ctx = match &self.udf_columns {
284				Some(udf_cols) => session.with_eval(udf_cols.clone(), 1),
285				None => session.with_eval_empty(),
286			};
287
288			let column = compiled_expr.execute(&exec_ctx)?;
289			new_columns.push(column);
290		}
291
292		let column_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
293
294		// Check for duplicate column names within the new columns
295		for i in 0..column_names.len() {
296			for j in (i + 1)..column_names.len() {
297				if column_names[i].text() == column_names[j].text() {
298					return_error!(extend_duplicate_column(column_names[i].text()));
299				}
300			}
301		}
302
303		self.headers = Some(ColumnHeaders {
304			columns: column_names,
305		});
306
307		Ok(Some(Columns::new(new_columns)))
308	}
309
310	fn headers(&self) -> Option<ColumnHeaders> {
311		self.headers.clone()
312	}
313}