Skip to main content

reifydb_engine/vm/volcano/
patch.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{mem, sync::Arc};
5
6use reifydb_core::{
7	interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
8	value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
9};
10use reifydb_extension::transform::{Transform, context::TransformContext};
11use reifydb_rql::expression::{Expression, name::display_label};
12use reifydb_transaction::transaction::Transaction;
13use reifydb_type::{fragment::Fragment, util::cowvec::CowVec};
14use tracing::instrument;
15
16use super::NoopNode;
17use crate::{
18	Result,
19	expression::{
20		cast::cast_column_data,
21		compile::{CompiledExpr, compile_expression},
22		context::{CompileContext, EvalContext},
23	},
24	vm::volcano::{
25		query::{QueryContext, QueryNode},
26		udf::{UdfEvalNode, strip_udf_columns},
27	},
28};
29
30/// PatchNode merges assignment values with original row values.
31/// Unlike ExtendNode which adds new columns, PatchNode replaces
32/// columns that have matching names in the assignments.
33pub(crate) struct PatchNode {
34	input: Box<dyn QueryNode>,
35	expressions: Vec<Expression>,
36	udf_names: Vec<String>,
37	headers: Option<ColumnHeaders>,
38	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
39}
40
41impl PatchNode {
42	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
43		Self {
44			input,
45			expressions,
46			udf_names: Vec::new(),
47			headers: None,
48			context: None,
49		}
50	}
51}
52
53impl QueryNode for PatchNode {
54	#[instrument(name = "volcano::patch::initialize", level = "trace", skip_all)]
55	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
56		let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
57			mem::replace(&mut self.input, Box::new(NoopNode)),
58			&self.expressions,
59			&ctx.symbols,
60		);
61		self.input = input;
62		self.expressions = expressions;
63		self.udf_names = udf_names;
64
65		let compile_ctx = CompileContext {
66			symbols: &ctx.symbols,
67		};
68		let compiled = self
69			.expressions
70			.iter()
71			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
72			.collect();
73		self.context = Some((Arc::new(ctx.clone()), compiled));
74		self.input.initialize(rx, ctx)?;
75		Ok(())
76	}
77
78	#[instrument(name = "volcano::patch::next", level = "trace", skip_all)]
79	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
80		debug_assert!(self.context.is_some(), "PatchNode::next() called before initialize()");
81
82		if let Some(columns) = self.input.next(rx, ctx)? {
83			let stored_ctx = &self.context.as_ref().unwrap().0;
84			let transform_ctx = TransformContext {
85				routines: &ctx.services.routines,
86				runtime_context: &stored_ctx.services.runtime_context,
87				params: &stored_ctx.params,
88			};
89			let result = self.apply(&transform_ctx, columns)?;
90
91			if self.headers.is_none() {
92				let result_headers: Vec<Fragment> = result.iter().map(|c| c.name().clone()).collect();
93				self.headers = Some(ColumnHeaders {
94					columns: result_headers,
95				});
96			}
97
98			let mut result = result;
99			strip_udf_columns(&mut result, &self.udf_names);
100			Ok(Some(result))
101		} else {
102			Ok(None)
103		}
104	}
105
106	fn headers(&self) -> Option<ColumnHeaders> {
107		if let Some(ref headers) = self.headers {
108			return Some(headers.clone());
109		}
110
111		let input_headers = self.input.headers()?;
112		let patch_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
113
114		let mut result = Vec::new();
115		for col in &input_headers.columns {
116			if let Some(patch_idx) = patch_names.iter().position(|n| n.text() == col.text()) {
117				result.push(patch_names[patch_idx].clone());
118			} else {
119				result.push(col.clone());
120			}
121		}
122
123		for patch_name in &patch_names {
124			if !result.iter().any(|h| h.text() == patch_name.text()) {
125				result.push(patch_name.clone());
126			}
127		}
128
129		Some(ColumnHeaders {
130			columns: result,
131		})
132	}
133}
134
135impl Transform for PatchNode {
136	fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
137		let (stored_ctx, compiled) =
138			self.context.as_ref().expect("PatchNode::apply() called before initialize()");
139
140		let row_count = input.row_count();
141		let row_numbers = input.row_numbers.to_vec();
142		let created_at = input.created_at.clone();
143		let updated_at = input.updated_at.clone();
144
145		let patch_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
146
147		let session = EvalContext::from_transform(ctx, stored_ctx);
148		let mut patch_columns = Vec::with_capacity(self.expressions.len());
149		for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
150			let mut exec_ctx = session.with_eval(input.clone(), row_count);
151
152			if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
153				let alias_name = alias_expr.alias.name();
154
155				if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
156					let column_ident = Fragment::internal(&table_column.name);
157					let resolved_column =
158						ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
159					exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
160				}
161			}
162
163			let mut column = compiled_expr.execute(&exec_ctx)?;
164
165			if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
166				&& column.data.get_type() != target_type
167			{
168				let data =
169					cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
170				column = ColumnWithName {
171					name: column.name,
172					data,
173				};
174			}
175
176			patch_columns.push(column);
177		}
178
179		let mut result_columns: Vec<ColumnWithName> = Vec::new();
180
181		for (original_name, original_data) in input.names.iter().zip(input.columns.iter()) {
182			let original_name_text = original_name.text();
183
184			if let Some(patch_idx) = patch_names.iter().position(|n| n.text() == original_name_text) {
185				result_columns.push(patch_columns[patch_idx].clone());
186			} else {
187				result_columns.push(ColumnWithName::new(original_name.clone(), original_data.clone()));
188			}
189		}
190
191		for (patch_idx, patch_name) in patch_names.iter().enumerate() {
192			if !result_columns.iter().any(|c| c.name().text() == patch_name.text()) {
193				result_columns.push(patch_columns[patch_idx].clone());
194			}
195		}
196
197		let mut names_vec = Vec::with_capacity(result_columns.len());
198		let mut buffers_vec = Vec::with_capacity(result_columns.len());
199		for c in result_columns {
200			names_vec.push(c.name);
201			buffers_vec.push(c.data);
202		}
203		Ok(Columns {
204			row_numbers: CowVec::new(row_numbers),
205			created_at,
206			updated_at,
207			columns: CowVec::new(buffers_vec),
208			names: CowVec::new(names_vec),
209		})
210	}
211}