Skip to main content

reifydb_engine/vm/volcano/
extend.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{mem, sync::Arc};
5
6use reifydb_core::{
7	error::diagnostic::query::extend_duplicate_column,
8	interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
9	value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
10};
11use reifydb_extension::transform::{Transform, context::TransformContext};
12use reifydb_rql::expression::{Expression, name::display_label};
13use reifydb_transaction::transaction::Transaction;
14use reifydb_type::{fragment::Fragment, return_error, util::cowvec::CowVec};
15use tracing::instrument;
16
17use super::NoopNode;
18use crate::{
19	Result,
20	expression::{
21		cast::cast_column_data,
22		compile::{CompiledExpr, compile_expression},
23		context::{CompileContext, EvalContext},
24	},
25	vm::volcano::{
26		query::{QueryContext, QueryNode},
27		udf::{UdfEvalNode, evaluate_udfs_no_input, strip_udf_columns},
28	},
29};
30
31pub(crate) struct ExtendNode {
32	input: Box<dyn QueryNode>,
33	expressions: Vec<Expression>,
34	udf_names: Vec<String>,
35	headers: Option<ColumnHeaders>,
36	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
37}
38
39impl ExtendNode {
40	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
41		Self {
42			input,
43			expressions,
44			udf_names: Vec::new(),
45			headers: None,
46			context: None,
47		}
48	}
49}
50
51impl QueryNode for ExtendNode {
52	#[instrument(name = "volcano::extend::initialize", level = "trace", skip_all)]
53	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
54		let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
55			mem::replace(&mut self.input, Box::new(NoopNode)),
56			&self.expressions,
57			&ctx.symbols,
58		);
59		self.input = input;
60		self.expressions = expressions;
61		self.udf_names = udf_names;
62
63		let compile_ctx = CompileContext {
64			symbols: &ctx.symbols,
65		};
66		let compiled = self
67			.expressions
68			.iter()
69			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
70			.collect();
71		self.context = Some((Arc::new(ctx.clone()), compiled));
72		self.input.initialize(rx, ctx)?;
73		Ok(())
74	}
75
76	#[instrument(name = "volcano::extend::next", level = "trace", skip_all)]
77	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78		debug_assert!(self.context.is_some(), "ExtendNode::next() called before initialize()");
79
80		if let Some(columns) = self.input.next(rx, ctx)? {
81			let stored_ctx = &self.context.as_ref().unwrap().0;
82			let transform_ctx = TransformContext {
83				routines: &ctx.services.routines,
84				runtime_context: &stored_ctx.services.runtime_context,
85				params: &stored_ctx.params,
86			};
87			let result = self.apply(&transform_ctx, columns)?;
88
89			if self.headers.is_none() {
90				let mut all_headers = if let Some(input_headers) = self.input.headers() {
91					input_headers.columns.clone()
92				} else {
93					let input_column_count = result.len() - self.expressions.len();
94					result.iter().take(input_column_count).map(|c| c.name().clone()).collect()
95				};
96
97				let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
98				all_headers.extend(new_names);
99
100				self.headers = Some(ColumnHeaders {
101					columns: all_headers,
102				});
103			}
104
105			let mut result = result;
106			strip_udf_columns(&mut result, &self.udf_names);
107			return Ok(Some(result));
108		}
109		if self.headers.is_none()
110			&& let Some(input_headers) = self.input.headers()
111		{
112			let mut all_headers = input_headers.columns.clone();
113			let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
114
115			for new_name in &new_names {
116				for existing_name in &all_headers {
117					if new_name.text() == existing_name.text() {
118						return_error!(extend_duplicate_column(new_name.text()));
119					}
120				}
121			}
122			for i in 0..new_names.len() {
123				for j in (i + 1)..new_names.len() {
124					if new_names[i].text() == new_names[j].text() {
125						return_error!(extend_duplicate_column(new_names[i].text()));
126					}
127				}
128			}
129
130			all_headers.extend(new_names);
131			self.headers = Some(ColumnHeaders {
132				columns: all_headers,
133			});
134		}
135		Ok(None)
136	}
137
138	fn headers(&self) -> Option<ColumnHeaders> {
139		self.headers.clone().or(self.input.headers())
140	}
141}
142
143impl Transform for ExtendNode {
144	fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
145		let (stored_ctx, compiled) =
146			self.context.as_ref().expect("ExtendNode::apply() called before initialize()");
147
148		let row_count = input.row_count();
149		let row_numbers = input.row_numbers.to_vec();
150		let created_at = input.created_at.clone();
151		let updated_at = input.updated_at.clone();
152
153		let existing_names: Vec<Fragment> = input.iter().map(|c| c.name().clone()).collect();
154
155		let session = EvalContext::from_transform(ctx, stored_ctx);
156		let mut new_columns: Vec<ColumnWithName> = input
157			.names
158			.iter()
159			.zip(input.columns.iter())
160			.map(|(name, data)| ColumnWithName::new(name.clone(), data.clone()))
161			.collect();
162
163		let mut new_names = Vec::with_capacity(compiled.len());
164		for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
165			let mut exec_ctx = session.with_eval(Columns::new(new_columns.clone()), row_count);
166
167			if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
168				let alias_name = alias_expr.alias.name();
169				if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
170					let column_ident = Fragment::internal(&table_column.name);
171					let resolved_column =
172						ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
173					exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
174				}
175			}
176
177			let mut column = compiled_expr.execute(&exec_ctx)?;
178
179			if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
180				&& column.data.get_type() != target_type
181			{
182				let data =
183					cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
184				column = ColumnWithName {
185					name: column.name,
186					data,
187				};
188			}
189
190			new_columns.push(column);
191			new_names.push(display_label(expr));
192		}
193
194		for new_name in &new_names {
195			for existing_name in &existing_names {
196				if new_name.text() == existing_name.text() {
197					return_error!(extend_duplicate_column(new_name.text()));
198				}
199			}
200		}
201
202		for i in 0..new_names.len() {
203			for j in (i + 1)..new_names.len() {
204				if new_names[i].text() == new_names[j].text() {
205					return_error!(extend_duplicate_column(new_names[i].text()));
206				}
207			}
208		}
209
210		let mut names_vec = Vec::with_capacity(new_columns.len());
211		let mut buffers_vec = Vec::with_capacity(new_columns.len());
212		for c in new_columns {
213			names_vec.push(c.name);
214			buffers_vec.push(c.data);
215		}
216		Ok(Columns {
217			row_numbers: CowVec::new(row_numbers),
218			created_at,
219			updated_at,
220			columns: CowVec::new(buffers_vec),
221			names: CowVec::new(names_vec),
222		})
223	}
224}
225
226pub(crate) struct ExtendWithoutInputNode {
227	expressions: Vec<Expression>,
228	headers: Option<ColumnHeaders>,
229
230	udf_columns: Option<Columns>,
231	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
232}
233
234impl ExtendWithoutInputNode {
235	pub fn new(expressions: Vec<Expression>) -> Self {
236		Self {
237			expressions,
238			headers: None,
239			udf_columns: None,
240			context: None,
241		}
242	}
243}
244
245impl QueryNode for ExtendWithoutInputNode {
246	#[instrument(name = "volcano::extend::noinput::initialize", level = "trace", skip_all)]
247	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
248		if let Some((rewritten, udf_cols)) = evaluate_udfs_no_input(&self.expressions, ctx, rx)? {
249			self.expressions = rewritten;
250			self.udf_columns = Some(udf_cols);
251		}
252
253		let compile_ctx = CompileContext {
254			symbols: &ctx.symbols,
255		};
256		let compiled = self
257			.expressions
258			.iter()
259			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
260			.collect();
261		self.context = Some((Arc::new(ctx.clone()), compiled));
262		Ok(())
263	}
264
265	#[instrument(name = "volcano::extend::noinput::next", level = "trace", skip_all)]
266	fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
267		debug_assert!(self.context.is_some(), "ExtendWithoutInputNode::next() called before initialize()");
268		let (stored_ctx, compiled) = self.context.as_ref().unwrap();
269
270		if self.headers.is_some() {
271			return Ok(None);
272		}
273
274		let session = EvalContext::from_query(stored_ctx);
275		let mut new_columns = Vec::with_capacity(self.expressions.len());
276
277		for compiled_expr in compiled {
278			let exec_ctx = match &self.udf_columns {
279				Some(udf_cols) => session.with_eval(udf_cols.clone(), 1),
280				None => session.with_eval_empty(),
281			};
282
283			let column = compiled_expr.execute(&exec_ctx)?;
284			new_columns.push(column);
285		}
286
287		let column_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
288
289		for i in 0..column_names.len() {
290			for j in (i + 1)..column_names.len() {
291				if column_names[i].text() == column_names[j].text() {
292					return_error!(extend_duplicate_column(column_names[i].text()));
293				}
294			}
295		}
296
297		self.headers = Some(ColumnHeaders {
298			columns: column_names,
299		});
300
301		Ok(Some(Columns::new(new_columns)))
302	}
303
304	fn headers(&self) -> Option<ColumnHeaders> {
305		self.headers.clone()
306	}
307}