reifydb_engine/vm/volcano/
patch.rs1use std::{mem, sync::Arc};
5
6use reifydb_core::{
7 interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
8 value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
9};
10use reifydb_extension::transform::{Transform, context::TransformContext};
11use reifydb_rql::expression::{Expression, name::display_label};
12use reifydb_transaction::transaction::Transaction;
13use reifydb_type::{fragment::Fragment, util::cowvec::CowVec};
14use tracing::instrument;
15
16use super::NoopNode;
17use crate::{
18 Result,
19 expression::{
20 cast::cast_column_data,
21 compile::{CompiledExpr, compile_expression},
22 context::{CompileContext, EvalContext},
23 },
24 vm::volcano::{
25 query::{QueryContext, QueryNode},
26 udf::{UdfEvalNode, strip_udf_columns},
27 },
28};
29
30pub(crate) struct PatchNode {
34 input: Box<dyn QueryNode>,
35 expressions: Vec<Expression>,
36 udf_names: Vec<String>,
37 headers: Option<ColumnHeaders>,
38 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
39}
40
41impl PatchNode {
42 pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
43 Self {
44 input,
45 expressions,
46 udf_names: Vec::new(),
47 headers: None,
48 context: None,
49 }
50 }
51}
52
53impl QueryNode for PatchNode {
54 #[instrument(name = "volcano::patch::initialize", level = "trace", skip_all)]
55 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
56 let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
57 mem::replace(&mut self.input, Box::new(NoopNode)),
58 &self.expressions,
59 &ctx.symbols,
60 );
61 self.input = input;
62 self.expressions = expressions;
63 self.udf_names = udf_names;
64
65 let compile_ctx = CompileContext {
66 symbols: &ctx.symbols,
67 };
68 let compiled = self
69 .expressions
70 .iter()
71 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
72 .collect();
73 self.context = Some((Arc::new(ctx.clone()), compiled));
74 self.input.initialize(rx, ctx)?;
75 Ok(())
76 }
77
78 #[instrument(name = "volcano::patch::next", level = "trace", skip_all)]
79 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
80 debug_assert!(self.context.is_some(), "PatchNode::next() called before initialize()");
81
82 if let Some(columns) = self.input.next(rx, ctx)? {
83 let stored_ctx = &self.context.as_ref().unwrap().0;
84 let transform_ctx = TransformContext {
85 routines: &ctx.services.routines,
86 runtime_context: &stored_ctx.services.runtime_context,
87 params: &stored_ctx.params,
88 };
89 let result = self.apply(&transform_ctx, columns)?;
90
91 if self.headers.is_none() {
92 let result_headers: Vec<Fragment> = result.iter().map(|c| c.name().clone()).collect();
93 self.headers = Some(ColumnHeaders {
94 columns: result_headers,
95 });
96 }
97
98 let mut result = result;
99 strip_udf_columns(&mut result, &self.udf_names);
100 Ok(Some(result))
101 } else {
102 Ok(None)
103 }
104 }
105
106 fn headers(&self) -> Option<ColumnHeaders> {
107 if let Some(ref headers) = self.headers {
108 return Some(headers.clone());
109 }
110
111 let input_headers = self.input.headers()?;
112 let patch_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
113
114 let mut result = Vec::new();
115 for col in &input_headers.columns {
116 if let Some(patch_idx) = patch_names.iter().position(|n| n.text() == col.text()) {
117 result.push(patch_names[patch_idx].clone());
118 } else {
119 result.push(col.clone());
120 }
121 }
122
123 for patch_name in &patch_names {
124 if !result.iter().any(|h| h.text() == patch_name.text()) {
125 result.push(patch_name.clone());
126 }
127 }
128
129 Some(ColumnHeaders {
130 columns: result,
131 })
132 }
133}
134
135impl Transform for PatchNode {
136 fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
137 let (stored_ctx, compiled) =
138 self.context.as_ref().expect("PatchNode::apply() called before initialize()");
139
140 let row_count = input.row_count();
141 let row_numbers = input.row_numbers.to_vec();
142 let created_at = input.created_at.clone();
143 let updated_at = input.updated_at.clone();
144
145 let patch_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
146
147 let session = EvalContext::from_transform(ctx, stored_ctx);
148 let mut patch_columns = Vec::with_capacity(self.expressions.len());
149 for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
150 let mut exec_ctx = session.with_eval(input.clone(), row_count);
151
152 if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
153 let alias_name = alias_expr.alias.name();
154
155 if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
156 let column_ident = Fragment::internal(&table_column.name);
157 let resolved_column =
158 ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
159 exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
160 }
161 }
162
163 let mut column = compiled_expr.execute(&exec_ctx)?;
164
165 if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
166 && column.data.get_type() != target_type
167 {
168 let data =
169 cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
170 column = ColumnWithName {
171 name: column.name,
172 data,
173 };
174 }
175
176 patch_columns.push(column);
177 }
178
179 let mut result_columns: Vec<ColumnWithName> = Vec::new();
180
181 for (original_name, original_data) in input.names.iter().zip(input.columns.iter()) {
182 let original_name_text = original_name.text();
183
184 if let Some(patch_idx) = patch_names.iter().position(|n| n.text() == original_name_text) {
185 result_columns.push(patch_columns[patch_idx].clone());
186 } else {
187 result_columns.push(ColumnWithName::new(original_name.clone(), original_data.clone()));
188 }
189 }
190
191 for (patch_idx, patch_name) in patch_names.iter().enumerate() {
192 if !result_columns.iter().any(|c| c.name().text() == patch_name.text()) {
193 result_columns.push(patch_columns[patch_idx].clone());
194 }
195 }
196
197 let mut names_vec = Vec::with_capacity(result_columns.len());
198 let mut buffers_vec = Vec::with_capacity(result_columns.len());
199 for c in result_columns {
200 names_vec.push(c.name);
201 buffers_vec.push(c.data);
202 }
203 Ok(Columns {
204 row_numbers: CowVec::new(row_numbers),
205 created_at,
206 updated_at,
207 columns: CowVec::new(buffers_vec),
208 names: CowVec::new(names_vec),
209 })
210 }
211}