reifydb_engine/vm/volcano/
extend.rs1use std::{mem, sync::Arc};
5
6use reifydb_core::{
7 error::diagnostic::query::extend_duplicate_column,
8 interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
9 value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
10};
11use reifydb_extension::transform::{Transform, context::TransformContext};
12use reifydb_rql::expression::{Expression, name::display_label};
13use reifydb_transaction::transaction::Transaction;
14use reifydb_type::{fragment::Fragment, return_error, util::cowvec::CowVec};
15use tracing::instrument;
16
17use super::NoopNode;
18use crate::{
19 Result,
20 expression::{
21 cast::cast_column_data,
22 compile::{CompiledExpr, compile_expression},
23 context::{CompileContext, EvalContext},
24 },
25 vm::volcano::{
26 query::{QueryContext, QueryNode},
27 udf::{UdfEvalNode, evaluate_udfs_no_input, strip_udf_columns},
28 },
29};
30
31pub(crate) struct ExtendNode {
32 input: Box<dyn QueryNode>,
33 expressions: Vec<Expression>,
34 udf_names: Vec<String>,
35 headers: Option<ColumnHeaders>,
36 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
37}
38
39impl ExtendNode {
40 pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
41 Self {
42 input,
43 expressions,
44 udf_names: Vec::new(),
45 headers: None,
46 context: None,
47 }
48 }
49}
50
51impl QueryNode for ExtendNode {
52 #[instrument(name = "volcano::extend::initialize", level = "trace", skip_all)]
53 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
54 let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
55 mem::replace(&mut self.input, Box::new(NoopNode)),
56 &self.expressions,
57 &ctx.symbols,
58 );
59 self.input = input;
60 self.expressions = expressions;
61 self.udf_names = udf_names;
62
63 let compile_ctx = CompileContext {
64 symbols: &ctx.symbols,
65 };
66 let compiled = self
67 .expressions
68 .iter()
69 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
70 .collect();
71 self.context = Some((Arc::new(ctx.clone()), compiled));
72 self.input.initialize(rx, ctx)?;
73 Ok(())
74 }
75
76 #[instrument(name = "volcano::extend::next", level = "trace", skip_all)]
77 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78 debug_assert!(self.context.is_some(), "ExtendNode::next() called before initialize()");
79
80 if let Some(columns) = self.input.next(rx, ctx)? {
81 let stored_ctx = &self.context.as_ref().unwrap().0;
82 let transform_ctx = TransformContext {
83 routines: &ctx.services.routines,
84 runtime_context: &stored_ctx.services.runtime_context,
85 params: &stored_ctx.params,
86 };
87 let result = self.apply(&transform_ctx, columns)?;
88
89 if self.headers.is_none() {
90 let mut all_headers = if let Some(input_headers) = self.input.headers() {
91 input_headers.columns.clone()
92 } else {
93 let input_column_count = result.len() - self.expressions.len();
94 result.iter().take(input_column_count).map(|c| c.name().clone()).collect()
95 };
96
97 let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
98 all_headers.extend(new_names);
99
100 self.headers = Some(ColumnHeaders {
101 columns: all_headers,
102 });
103 }
104
105 let mut result = result;
106 strip_udf_columns(&mut result, &self.udf_names);
107 return Ok(Some(result));
108 }
109 if self.headers.is_none()
110 && let Some(input_headers) = self.input.headers()
111 {
112 let mut all_headers = input_headers.columns.clone();
113 let new_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
114
115 for new_name in &new_names {
116 for existing_name in &all_headers {
117 if new_name.text() == existing_name.text() {
118 return_error!(extend_duplicate_column(new_name.text()));
119 }
120 }
121 }
122 for i in 0..new_names.len() {
123 for j in (i + 1)..new_names.len() {
124 if new_names[i].text() == new_names[j].text() {
125 return_error!(extend_duplicate_column(new_names[i].text()));
126 }
127 }
128 }
129
130 all_headers.extend(new_names);
131 self.headers = Some(ColumnHeaders {
132 columns: all_headers,
133 });
134 }
135 Ok(None)
136 }
137
138 fn headers(&self) -> Option<ColumnHeaders> {
139 self.headers.clone().or(self.input.headers())
140 }
141}
142
143impl Transform for ExtendNode {
144 fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
145 let (stored_ctx, compiled) =
146 self.context.as_ref().expect("ExtendNode::apply() called before initialize()");
147
148 let row_count = input.row_count();
149 let row_numbers = input.row_numbers.to_vec();
150 let created_at = input.created_at.clone();
151 let updated_at = input.updated_at.clone();
152
153 let existing_names: Vec<Fragment> = input.iter().map(|c| c.name().clone()).collect();
154
155 let session = EvalContext::from_transform(ctx, stored_ctx);
156 let mut new_columns: Vec<ColumnWithName> = input
157 .names
158 .iter()
159 .zip(input.columns.iter())
160 .map(|(name, data)| ColumnWithName::new(name.clone(), data.clone()))
161 .collect();
162
163 let mut new_names = Vec::with_capacity(compiled.len());
164 for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
165 let mut exec_ctx = session.with_eval(Columns::new(new_columns.clone()), row_count);
166
167 if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
168 let alias_name = alias_expr.alias.name();
169 if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
170 let column_ident = Fragment::internal(&table_column.name);
171 let resolved_column =
172 ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
173 exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
174 }
175 }
176
177 let mut column = compiled_expr.execute(&exec_ctx)?;
178
179 if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
180 && column.data.get_type() != target_type
181 {
182 let data =
183 cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
184 column = ColumnWithName {
185 name: column.name,
186 data,
187 };
188 }
189
190 new_columns.push(column);
191 new_names.push(display_label(expr));
192 }
193
194 for new_name in &new_names {
195 for existing_name in &existing_names {
196 if new_name.text() == existing_name.text() {
197 return_error!(extend_duplicate_column(new_name.text()));
198 }
199 }
200 }
201
202 for i in 0..new_names.len() {
203 for j in (i + 1)..new_names.len() {
204 if new_names[i].text() == new_names[j].text() {
205 return_error!(extend_duplicate_column(new_names[i].text()));
206 }
207 }
208 }
209
210 let mut names_vec = Vec::with_capacity(new_columns.len());
211 let mut buffers_vec = Vec::with_capacity(new_columns.len());
212 for c in new_columns {
213 names_vec.push(c.name);
214 buffers_vec.push(c.data);
215 }
216 Ok(Columns {
217 row_numbers: CowVec::new(row_numbers),
218 created_at,
219 updated_at,
220 columns: CowVec::new(buffers_vec),
221 names: CowVec::new(names_vec),
222 })
223 }
224}
225
226pub(crate) struct ExtendWithoutInputNode {
227 expressions: Vec<Expression>,
228 headers: Option<ColumnHeaders>,
229
230 udf_columns: Option<Columns>,
231 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
232}
233
234impl ExtendWithoutInputNode {
235 pub fn new(expressions: Vec<Expression>) -> Self {
236 Self {
237 expressions,
238 headers: None,
239 udf_columns: None,
240 context: None,
241 }
242 }
243}
244
245impl QueryNode for ExtendWithoutInputNode {
246 #[instrument(name = "volcano::extend::noinput::initialize", level = "trace", skip_all)]
247 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
248 if let Some((rewritten, udf_cols)) = evaluate_udfs_no_input(&self.expressions, ctx, rx)? {
249 self.expressions = rewritten;
250 self.udf_columns = Some(udf_cols);
251 }
252
253 let compile_ctx = CompileContext {
254 symbols: &ctx.symbols,
255 };
256 let compiled = self
257 .expressions
258 .iter()
259 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
260 .collect();
261 self.context = Some((Arc::new(ctx.clone()), compiled));
262 Ok(())
263 }
264
265 #[instrument(name = "volcano::extend::noinput::next", level = "trace", skip_all)]
266 fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
267 debug_assert!(self.context.is_some(), "ExtendWithoutInputNode::next() called before initialize()");
268 let (stored_ctx, compiled) = self.context.as_ref().unwrap();
269
270 if self.headers.is_some() {
271 return Ok(None);
272 }
273
274 let session = EvalContext::from_query(stored_ctx);
275 let mut new_columns = Vec::with_capacity(self.expressions.len());
276
277 for compiled_expr in compiled {
278 let exec_ctx = match &self.udf_columns {
279 Some(udf_cols) => session.with_eval(udf_cols.clone(), 1),
280 None => session.with_eval_empty(),
281 };
282
283 let column = compiled_expr.execute(&exec_ctx)?;
284 new_columns.push(column);
285 }
286
287 let column_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
288
289 for i in 0..column_names.len() {
290 for j in (i + 1)..column_names.len() {
291 if column_names[i].text() == column_names[j].text() {
292 return_error!(extend_duplicate_column(column_names[i].text()));
293 }
294 }
295 }
296
297 self.headers = Some(ColumnHeaders {
298 columns: column_names,
299 });
300
301 Ok(Some(Columns::new(new_columns)))
302 }
303
304 fn headers(&self) -> Option<ColumnHeaders> {
305 self.headers.clone()
306 }
307}