reifydb_engine/vm/volcano/
extend.rs1use std::{mem, sync::Arc};
5
6use reifydb_core::{
7 error::diagnostic::query::extend_duplicate_column,
8 interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
9 value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
10};
11use reifydb_extension::transform::{Transform, context::TransformContext};
12use reifydb_rql::expression::{Expression, name::display_label};
13use reifydb_transaction::transaction::Transaction;
14use reifydb_type::{fragment::Fragment, return_error, util::cowvec::CowVec};
15use tracing::instrument;
16
17use super::NoopNode;
18use crate::{
19 Result,
20 expression::{
21 cast::cast_column_data,
22 compile::{CompiledExpr, compile_expression},
23 context::{CompileContext, EvalContext},
24 },
25 vm::volcano::{
26 query::{QueryContext, QueryNode},
27 udf::{UdfEvalNode, evaluate_udfs_no_input, strip_udf_columns},
28 },
29};
30
31pub(crate) struct ExtendNode {
32 input: Box<dyn QueryNode>,
33 expressions: Vec<Expression>,
34 udf_names: Vec<String>,
35 headers: Option<ColumnHeaders>,
36 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
37}
38
39impl ExtendNode {
40 pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
41 Self {
42 input,
43 expressions,
44 udf_names: Vec::new(),
45 headers: None,
46 context: None,
47 }
48 }
49}
50
51impl QueryNode for ExtendNode {
52 #[instrument(name = "volcano::extend::initialize", level = "trace", skip_all)]
53 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
54 let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
55 mem::replace(&mut self.input, Box::new(NoopNode)),
56 &self.expressions,
57 &ctx.symbols,
58 );
59 self.input = input;
60 self.expressions = expressions;
61 self.udf_names = udf_names;
62
63 let compile_ctx = CompileContext {
64 symbols: &ctx.symbols,
65 };
66 let compiled = self
67 .expressions
68 .iter()
69 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
70 .collect();
71 self.context = Some((Arc::new(ctx.clone()), compiled));
72 self.input.initialize(rx, ctx)?;
73 Ok(())
74 }
75
76 #[instrument(name = "volcano::extend::next", level = "trace", skip_all)]
77 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78 debug_assert!(self.context.is_some(), "ExtendNode::next() called before initialize()");
79
80 if let Some(columns) = self.input.next(rx, ctx)? {
81 let stored_ctx = &self.context.as_ref().unwrap().0;
82 let transform_ctx = TransformContext {
83 routines: &ctx.services.routines,
84 runtime_context: &stored_ctx.services.runtime_context,
85 params: &stored_ctx.params,
86 };
87 let result = self.apply(&transform_ctx, columns)?;
88
89 if self.headers.is_none() {
90 let mut all_headers = if let Some(input_headers) = self.input.headers() {
91 input_headers.columns.clone()
92 } else {
93 let input_column_count = result.len() - self.expressions.len();
94 result.iter().take(input_column_count).map(|c| c.name().clone()).collect()
95 };
96
97 let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
98 all_headers.extend(new_names);
99
100 self.headers = Some(ColumnHeaders {
101 columns: all_headers,
102 });
103 }
104
105 let mut result = result;
106 strip_udf_columns(&mut result, &self.udf_names);
107 return Ok(Some(result));
108 }
109 if self.headers.is_none()
110 && let Some(input_headers) = self.input.headers()
111 {
112 let mut all_headers = input_headers.columns.clone();
113 let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
114
115 for new_name in &new_names {
116 for existing_name in &all_headers {
117 if new_name.text() == existing_name.text() {
118 return_error!(extend_duplicate_column(new_name.text()));
119 }
120 }
121 }
122 for i in 0..new_names.len() {
123 for j in (i + 1)..new_names.len() {
124 if new_names[i].text() == new_names[j].text() {
125 return_error!(extend_duplicate_column(new_names[i].text()));
126 }
127 }
128 }
129
130 all_headers.extend(new_names);
131 self.headers = Some(ColumnHeaders {
132 columns: all_headers,
133 });
134 }
135 Ok(None)
136 }
137
138 fn headers(&self) -> Option<ColumnHeaders> {
139 self.headers.clone().or(self.input.headers())
140 }
141}
142
143impl Transform for ExtendNode {
144 fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
145 let (stored_ctx, compiled) =
146 self.context.as_ref().expect("ExtendNode::apply() called before initialize()");
147
148 let row_count = input.row_count();
149 let row_numbers = input.row_numbers.to_vec();
150 let created_at = input.created_at.clone();
151 let updated_at = input.updated_at.clone();
152
153 let existing_names: Vec<Fragment> = input.iter().map(|c| c.name().clone()).collect();
155
156 let session = EvalContext::from_transform(ctx, stored_ctx);
157 let mut new_columns: Vec<ColumnWithName> = input
158 .names
159 .iter()
160 .zip(input.columns.iter())
161 .map(|(name, data)| ColumnWithName::new(name.clone(), data.clone()))
162 .collect();
163
164 let mut new_names = Vec::with_capacity(compiled.len());
165 for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
166 let mut exec_ctx = session.with_eval(Columns::new(new_columns.clone()), row_count);
167
168 if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
169 let alias_name = alias_expr.alias.name();
170 if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
171 let column_ident = Fragment::internal(&table_column.name);
172 let resolved_column =
173 ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
174 exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
175 }
176 }
177
178 let mut column = compiled_expr.execute(&exec_ctx)?;
179
180 if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
181 && column.data.get_type() != target_type
182 {
183 let data =
184 cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
185 column = ColumnWithName {
186 name: column.name,
187 data,
188 };
189 }
190
191 new_columns.push(column);
192 new_names.push(display_label(expr));
193 }
194
195 for new_name in &new_names {
197 for existing_name in &existing_names {
198 if new_name.text() == existing_name.text() {
199 return_error!(extend_duplicate_column(new_name.text()));
200 }
201 }
202 }
203
204 for i in 0..new_names.len() {
206 for j in (i + 1)..new_names.len() {
207 if new_names[i].text() == new_names[j].text() {
208 return_error!(extend_duplicate_column(new_names[i].text()));
209 }
210 }
211 }
212
213 let mut names_vec = Vec::with_capacity(new_columns.len());
214 let mut buffers_vec = Vec::with_capacity(new_columns.len());
215 for c in new_columns {
216 names_vec.push(c.name);
217 buffers_vec.push(c.data);
218 }
219 Ok(Columns {
220 row_numbers: CowVec::new(row_numbers),
221 created_at,
222 updated_at,
223 columns: CowVec::new(buffers_vec),
224 names: CowVec::new(names_vec),
225 })
226 }
227}
228
229pub(crate) struct ExtendWithoutInputNode {
230 expressions: Vec<Expression>,
231 headers: Option<ColumnHeaders>,
232 udf_columns: Option<Columns>,
234 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
235}
236
237impl ExtendWithoutInputNode {
238 pub fn new(expressions: Vec<Expression>) -> Self {
239 Self {
240 expressions,
241 headers: None,
242 udf_columns: None,
243 context: None,
244 }
245 }
246}
247
248impl QueryNode for ExtendWithoutInputNode {
249 #[instrument(name = "volcano::extend::noinput::initialize", level = "trace", skip_all)]
250 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
251 if let Some((rewritten, udf_cols)) = evaluate_udfs_no_input(&self.expressions, ctx, rx)? {
253 self.expressions = rewritten;
254 self.udf_columns = Some(udf_cols);
255 }
256
257 let compile_ctx = CompileContext {
258 symbols: &ctx.symbols,
259 };
260 let compiled = self
261 .expressions
262 .iter()
263 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
264 .collect();
265 self.context = Some((Arc::new(ctx.clone()), compiled));
266 Ok(())
267 }
268
269 #[instrument(name = "volcano::extend::noinput::next", level = "trace", skip_all)]
270 fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
271 debug_assert!(self.context.is_some(), "ExtendWithoutInputNode::next() called before initialize()");
272 let (stored_ctx, compiled) = self.context.as_ref().unwrap();
273
274 if self.headers.is_some() {
275 return Ok(None);
276 }
277
278 let session = EvalContext::from_query(stored_ctx);
279 let mut new_columns = Vec::with_capacity(self.expressions.len());
280
281 for compiled_expr in compiled {
282 let exec_ctx = match &self.udf_columns {
284 Some(udf_cols) => session.with_eval(udf_cols.clone(), 1),
285 None => session.with_eval_empty(),
286 };
287
288 let column = compiled_expr.execute(&exec_ctx)?;
289 new_columns.push(column);
290 }
291
292 let column_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
293
294 for i in 0..column_names.len() {
296 for j in (i + 1)..column_names.len() {
297 if column_names[i].text() == column_names[j].text() {
298 return_error!(extend_duplicate_column(column_names[i].text()));
299 }
300 }
301 }
302
303 self.headers = Some(ColumnHeaders {
304 columns: column_names,
305 });
306
307 Ok(Some(Columns::new(new_columns)))
308 }
309
310 fn headers(&self) -> Option<ColumnHeaders> {
311 self.headers.clone()
312 }
313}