reifydb_engine/vm/volcano/
patch.rs1use std::{mem, sync::Arc};
5
6use reifydb_core::{
7 interface::{evaluate::TargetColumn, resolved::ResolvedColumn},
8 value::column::{ColumnWithName, columns::Columns, headers::ColumnHeaders},
9};
10use reifydb_extension::transform::{Transform, context::TransformContext};
11use reifydb_rql::expression::{Expression, name::display_label};
12use reifydb_transaction::transaction::Transaction;
13use reifydb_type::{fragment::Fragment, util::cowvec::CowVec};
14use tracing::instrument;
15
16use super::NoopNode;
17use crate::{
18 Result,
19 expression::{
20 cast::cast_column_data,
21 compile::{CompiledExpr, compile_expression},
22 context::{CompileContext, EvalContext},
23 },
24 vm::volcano::{
25 query::{QueryContext, QueryNode},
26 udf::{UdfEvalNode, strip_udf_columns},
27 },
28};
29
30pub(crate) struct PatchNode {
31 input: Box<dyn QueryNode>,
32 expressions: Vec<Expression>,
33 udf_names: Vec<String>,
34 headers: Option<ColumnHeaders>,
35 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
36}
37
38impl PatchNode {
39 pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
40 Self {
41 input,
42 expressions,
43 udf_names: Vec::new(),
44 headers: None,
45 context: None,
46 }
47 }
48}
49
50impl QueryNode for PatchNode {
51 #[instrument(name = "volcano::patch::initialize", level = "trace", skip_all)]
52 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
53 let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
54 mem::replace(&mut self.input, Box::new(NoopNode)),
55 &self.expressions,
56 &ctx.symbols,
57 );
58 self.input = input;
59 self.expressions = expressions;
60 self.udf_names = udf_names;
61
62 let compile_ctx = CompileContext {
63 symbols: &ctx.symbols,
64 };
65 let compiled = self
66 .expressions
67 .iter()
68 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
69 .collect();
70 self.context = Some((Arc::new(ctx.clone()), compiled));
71 self.input.initialize(rx, ctx)?;
72 Ok(())
73 }
74
75 #[instrument(name = "volcano::patch::next", level = "trace", skip_all)]
76 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
77 debug_assert!(self.context.is_some(), "PatchNode::next() called before initialize()");
78
79 if let Some(columns) = self.input.next(rx, ctx)? {
80 let stored_ctx = &self.context.as_ref().unwrap().0;
81 let transform_ctx = TransformContext {
82 routines: &ctx.services.routines,
83 runtime_context: &stored_ctx.services.runtime_context,
84 params: &stored_ctx.params,
85 };
86 let result = self.apply(&transform_ctx, columns)?;
87
88 if self.headers.is_none() {
89 let result_headers: Vec<Fragment> = result.iter().map(|c| c.name().clone()).collect();
90 self.headers = Some(ColumnHeaders {
91 columns: result_headers,
92 });
93 }
94
95 let mut result = result;
96 strip_udf_columns(&mut result, &self.udf_names);
97 Ok(Some(result))
98 } else {
99 Ok(None)
100 }
101 }
102
103 fn headers(&self) -> Option<ColumnHeaders> {
104 if let Some(ref headers) = self.headers {
105 return Some(headers.clone());
106 }
107
108 let input_headers = self.input.headers()?;
109 let patch_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
110
111 let mut result = Vec::new();
112 for col in &input_headers.columns {
113 if let Some(patch_idx) = patch_names.iter().position(|n| n.text() == col.text()) {
114 result.push(patch_names[patch_idx].clone());
115 } else {
116 result.push(col.clone());
117 }
118 }
119
120 for patch_name in &patch_names {
121 if !result.iter().any(|h| h.text() == patch_name.text()) {
122 result.push(patch_name.clone());
123 }
124 }
125
126 Some(ColumnHeaders {
127 columns: result,
128 })
129 }
130}
131
132impl Transform for PatchNode {
133 fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
134 let (stored_ctx, compiled) =
135 self.context.as_ref().expect("PatchNode::apply() called before initialize()");
136
137 let row_count = input.row_count();
138 let row_numbers = input.row_numbers.to_vec();
139 let created_at = input.created_at.clone();
140 let updated_at = input.updated_at.clone();
141
142 let patch_names: Vec<Fragment> = self.expressions.iter().map(display_label).collect();
143
144 let session = EvalContext::from_transform(ctx, stored_ctx);
145 let mut patch_columns = Vec::with_capacity(self.expressions.len());
146 for (expr, compiled_expr) in self.expressions.iter().zip(compiled.iter()) {
147 let mut exec_ctx = session.with_eval(input.clone(), row_count);
148
149 if let (Expression::Alias(alias_expr), Some(source)) = (expr, &stored_ctx.source) {
150 let alias_name = alias_expr.alias.name();
151
152 if let Some(table_column) = source.columns().iter().find(|col| col.name == alias_name) {
153 let column_ident = Fragment::internal(&table_column.name);
154 let resolved_column =
155 ResolvedColumn::new(column_ident, source.clone(), table_column.clone());
156 exec_ctx.target = Some(TargetColumn::Resolved(resolved_column));
157 }
158 }
159
160 let mut column = compiled_expr.execute(&exec_ctx)?;
161
162 if let Some(target_type) = exec_ctx.target.as_ref().map(|t| t.column_type())
163 && column.data.get_type() != target_type
164 {
165 let data =
166 cast_column_data(&exec_ctx, &column.data, target_type, &expr.lazy_fragment())?;
167 column = ColumnWithName {
168 name: column.name,
169 data,
170 };
171 }
172
173 patch_columns.push(column);
174 }
175
176 let mut result_columns: Vec<ColumnWithName> = Vec::new();
177
178 for (original_name, original_data) in input.names.iter().zip(input.columns.iter()) {
179 let original_name_text = original_name.text();
180
181 if let Some(patch_idx) = patch_names.iter().position(|n| n.text() == original_name_text) {
182 result_columns.push(patch_columns[patch_idx].clone());
183 } else {
184 result_columns.push(ColumnWithName::new(original_name.clone(), original_data.clone()));
185 }
186 }
187
188 for (patch_idx, patch_name) in patch_names.iter().enumerate() {
189 if !result_columns.iter().any(|c| c.name().text() == patch_name.text()) {
190 result_columns.push(patch_columns[patch_idx].clone());
191 }
192 }
193
194 let mut names_vec = Vec::with_capacity(result_columns.len());
195 let mut buffers_vec = Vec::with_capacity(result_columns.len());
196 for c in result_columns {
197 names_vec.push(c.name);
198 buffers_vec.push(c.data);
199 }
200 Ok(Columns {
201 row_numbers: CowVec::new(row_numbers),
202 created_at,
203 updated_at,
204 columns: CowVec::new(buffers_vec),
205 names: CowVec::new(names_vec),
206 })
207 }
208}