Skip to main content

reifydb_engine/vm/volcano/
inline.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	collections::{BTreeSet, HashMap, HashSet},
6	mem,
7	sync::Arc,
8};
9
10use reifydb_core::{
11	interface::{catalog::sumtype::SumType, evaluate::TargetColumn, resolved::ResolvedShape},
12	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
13};
14use reifydb_rql::expression::{AliasExpression, ConstantExpression, Expression, IdentExpression};
15use reifydb_transaction::transaction::Transaction;
16use reifydb_type::{
17	fragment::Fragment,
18	value::{Value, constraint::Constraint, r#type::Type},
19};
20
21use crate::{
22	Result,
23	expression::{cast::cast_column_data, context::EvalContext, eval::evaluate},
24	vm::volcano::query::{QueryContext, QueryNode},
25};
26
27pub(crate) struct InlineDataNode {
28	rows: Vec<Vec<AliasExpression>>,
29	headers: Option<ColumnHeaders>,
30	context: Option<Arc<QueryContext>>,
31	executed: bool,
32}
33
34impl InlineDataNode {
35	pub fn new(rows: Vec<Vec<AliasExpression>>, context: Arc<QueryContext>) -> Self {
36		let cloned_context = context.clone();
37		let headers = cloned_context.source.as_ref().map(|source| {
38			let mut layout = Self::create_columns_layout_from_source(source);
39
40			if matches!(source, ResolvedShape::Series(_)) {
41				let existing: HashSet<String> =
42					layout.columns.iter().map(|c| c.text().to_string()).collect();
43				for row in &rows {
44					for alias in row {
45						let name = alias.alias.0.text().to_string();
46						if !existing.contains(&name) {
47							layout.columns.push(Fragment::internal(&name));
48						}
49					}
50				}
51			}
52			layout
53		});
54
55		Self {
56			rows,
57			headers,
58			context: Some(context),
59			executed: false,
60		}
61	}
62
63	fn create_columns_layout_from_source(source: &ResolvedShape) -> ColumnHeaders {
64		ColumnHeaders {
65			columns: source.columns().iter().map(|col| Fragment::internal(&col.name)).collect(),
66		}
67	}
68
69	fn expand_sumtype_constructors<'a>(&mut self, txn: &mut Transaction<'a>) -> Result<()> {
70		let Some(ctx) = self.context.as_ref().cloned() else {
71			return Ok(());
72		};
73		if !rows_need_sumtype_expansion(&self.rows) {
74			return Ok(());
75		}
76		for row in &mut self.rows {
77			let original = mem::take(row);
78			let mut expanded = Vec::with_capacity(original.len());
79			for alias_expr in original {
80				match alias_expr.expression.as_ref() {
81					Expression::SumTypeConstructor(_) => {
82						expand_sumtype_ctor(&ctx, txn, alias_expr, &mut expanded)?;
83					}
84					Expression::Column(_) => {
85						expand_unit_variant_column(&ctx, txn, alias_expr, &mut expanded)?;
86					}
87					_ => expanded.push(alias_expr),
88				}
89			}
90			*row = expanded;
91		}
92		Ok(())
93	}
94}
95
96#[inline]
97fn rows_need_sumtype_expansion(rows: &[Vec<AliasExpression>]) -> bool {
98	for row in rows {
99		for alias_expr in row {
100			if matches!(
101				alias_expr.expression.as_ref(),
102				Expression::SumTypeConstructor(_) | Expression::Column(_)
103			) {
104				return true;
105			}
106		}
107	}
108	false
109}
110
111fn expand_sumtype_ctor<'a>(
112	ctx: &Arc<QueryContext>,
113	txn: &mut Transaction<'a>,
114	alias_expr: AliasExpression,
115	expanded: &mut Vec<AliasExpression>,
116) -> Result<()> {
117	let col_name = alias_expr.alias.0.text().to_string();
118	let fragment = alias_expr.fragment.clone();
119
120	let Expression::SumTypeConstructor(ctor) = *alias_expr.expression else {
121		unreachable!()
122	};
123
124	let is_unresolved = ctor.namespace.text() == ctor.variant_name.text()
125		&& ctor.sumtype_name.text() == ctor.variant_name.text();
126
127	let sumtype = if is_unresolved {
128		resolve_unresolved_sumtype(ctx, txn, &col_name)?
129	} else {
130		let ns_name = ctor.namespace.text();
131		let ns = ctx.services.catalog.find_namespace_by_name(txn, ns_name)?.unwrap();
132		let sumtype_name = ctor.sumtype_name.text();
133		ctx.services.catalog.find_sumtype_by_name(txn, ns.id(), sumtype_name)?.unwrap()
134	};
135
136	let variant_name_lower = ctor.variant_name.text().to_lowercase();
137	let variant = sumtype.variants.iter().find(|v| v.name == variant_name_lower).unwrap();
138
139	expanded.push(AliasExpression {
140		alias: IdentExpression(Fragment::internal(format!("{}_tag", col_name))),
141		expression: Box::new(Expression::Constant(ConstantExpression::Number {
142			fragment: Fragment::internal(variant.tag.to_string()),
143		})),
144		fragment: fragment.clone(),
145	});
146
147	for (field_name, field_expr) in ctor.columns {
148		let phys_col_name = format!("{}_{}_{}", col_name, variant_name_lower, field_name.text().to_lowercase());
149		expanded.push(AliasExpression {
150			alias: IdentExpression(Fragment::internal(phys_col_name)),
151			expression: Box::new(field_expr),
152			fragment: fragment.clone(),
153		});
154	}
155
156	Ok(())
157}
158
159#[inline]
160fn resolve_unresolved_sumtype<'a>(
161	ctx: &Arc<QueryContext>,
162	txn: &mut Transaction<'a>,
163	col_name: &str,
164) -> Result<SumType> {
165	let tag_col_name = format!("{}_tag", col_name);
166	let source = ctx.source.as_ref().expect("source required for unresolved sumtype");
167
168	if let Some(tag_col) = source.columns().iter().find(|c| c.name == tag_col_name) {
169		let Some(Constraint::SumType(id)) = tag_col.constraint.constraint() else {
170			panic!("expected SumType constraint on tag column")
171		};
172		ctx.services.catalog.get_sumtype(txn, *id)
173	} else if let ResolvedShape::Series(series) = source {
174		let tag_id = series.def().tag.expect("series tag expected");
175		ctx.services.catalog.get_sumtype(txn, tag_id)
176	} else {
177		panic!("tag column not found: {}", tag_col_name)
178	}
179}
180
181fn expand_unit_variant_column<'a>(
182	ctx: &Arc<QueryContext>,
183	txn: &mut Transaction<'a>,
184	alias_expr: AliasExpression,
185	expanded: &mut Vec<AliasExpression>,
186) -> Result<()> {
187	let col_name = alias_expr.alias.0.text().to_string();
188
189	let resolved = if let Some(source) = ctx.source.as_ref() {
190		let Expression::Column(col) = alias_expr.expression.as_ref() else {
191			unreachable!()
192		};
193		try_resolve_unit_variant(ctx, txn, source, &col_name, col.0.name.text())?
194	} else {
195		None
196	};
197
198	let Some((sumtype, tag)) = resolved else {
199		expanded.push(alias_expr);
200		return Ok(());
201	};
202
203	let fragment = alias_expr.fragment.clone();
204	expanded.push(AliasExpression {
205		alias: IdentExpression(Fragment::internal(format!("{}_tag", col_name))),
206		expression: Box::new(Expression::Constant(ConstantExpression::Number {
207			fragment: Fragment::internal(tag.to_string()),
208		})),
209		fragment: fragment.clone(),
210	});
211	for v in &sumtype.variants {
212		for field in &v.fields {
213			let phys_col_name =
214				format!("{}_{}_{}", col_name, v.name.to_lowercase(), field.name.to_lowercase());
215			expanded.push(AliasExpression {
216				alias: IdentExpression(Fragment::internal(phys_col_name)),
217				expression: Box::new(Expression::Constant(ConstantExpression::None {
218					fragment: fragment.clone(),
219				})),
220				fragment: fragment.clone(),
221			});
222		}
223	}
224	Ok(())
225}
226
227#[inline]
228fn try_resolve_unit_variant<'a>(
229	ctx: &Arc<QueryContext>,
230	txn: &mut Transaction<'a>,
231	source: &ResolvedShape,
232	col_name: &str,
233	alias_text: &str,
234) -> Result<Option<(SumType, u8)>> {
235	let tag_col_name = format!("{}_tag", col_name);
236
237	if let Some(tag_col) = source.columns().iter().find(|c| c.name == tag_col_name) {
238		let Some(Constraint::SumType(id)) = tag_col.constraint.constraint() else {
239			return Ok(None);
240		};
241		let sumtype = ctx.services.catalog.get_sumtype(txn, *id)?;
242		let variant_name_lower = alias_text.to_lowercase();
243		let maybe_tag =
244			sumtype.variants.iter().find(|v| v.name.to_lowercase() == variant_name_lower).map(|v| v.tag);
245		return Ok(maybe_tag.map(|tag| (sumtype, tag)));
246	}
247
248	if let ResolvedShape::Series(series) = source
249		&& let Some(tag_id) = series.def().tag
250	{
251		let sumtype = ctx.services.catalog.get_sumtype(txn, tag_id)?;
252		let variant_name_lower = alias_text.to_lowercase();
253		let maybe_tag =
254			sumtype.variants.iter().find(|v| v.name.to_lowercase() == variant_name_lower).map(|v| v.tag);
255		return Ok(maybe_tag.map(|tag| (sumtype, tag)));
256	}
257
258	Ok(None)
259}
260
261impl QueryNode for InlineDataNode {
262	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, _ctx: &QueryContext) -> Result<()> {
263		self.expand_sumtype_constructors(rx)?;
264		Ok(())
265	}
266
267	fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
268		debug_assert!(self.context.is_some(), "InlineDataNode::next() called before initialize()");
269		let stored_ctx = self.context.as_ref().unwrap().clone();
270
271		if self.executed {
272			return Ok(None);
273		}
274
275		self.executed = true;
276
277		if self.rows.is_empty() {
278			let columns = Columns::empty();
279			if self.headers.is_none() {
280				self.headers = Some(ColumnHeaders::from_columns(&columns));
281			}
282			return Ok(Some(columns));
283		}
284
285		if self.headers.is_some() {
286			self.next_with_source(&stored_ctx)
287		} else {
288			self.next_infer_namespace(&stored_ctx)
289		}
290	}
291
292	fn headers(&self) -> Option<ColumnHeaders> {
293		self.headers.clone()
294	}
295}
296
297impl InlineDataNode {
298	fn find_optimal_integer_type(column: &ColumnBuffer) -> Type {
299		let mut min_val = i128::MAX;
300		let mut max_val = i128::MIN;
301		let mut has_values = false;
302
303		for value in column.iter() {
304			match value {
305				Value::Int16(v) => {
306					has_values = true;
307					min_val = min_val.min(v);
308					max_val = max_val.max(v);
309				}
310				Value::None {
311					..
312				} => {}
313				_ => {
314					return Type::Int16;
315				}
316			}
317		}
318
319		if !has_values {
320			return Type::Int1;
321		}
322
323		if min_val >= i8::MIN as i128 && max_val <= i8::MAX as i128 {
324			Type::Int1
325		} else if min_val >= i16::MIN as i128 && max_val <= i16::MAX as i128 {
326			Type::Int2
327		} else if min_val >= i32::MIN as i128 && max_val <= i32::MAX as i128 {
328			Type::Int4
329		} else if min_val >= i64::MIN as i128 && max_val <= i64::MAX as i128 {
330			Type::Int8
331		} else {
332			Type::Int16
333		}
334	}
335
336	fn next_infer_namespace(&mut self, ctx: &QueryContext) -> Result<Option<Columns>> {
337		let mut all_columns: BTreeSet<String> = BTreeSet::new();
338
339		for row in &self.rows {
340			for keyed_expr in row {
341				let column_name = keyed_expr.alias.0.text().to_string();
342				all_columns.insert(column_name);
343			}
344		}
345
346		let mut rows_data: Vec<HashMap<String, &AliasExpression>> = Vec::new();
347
348		for row in &self.rows {
349			let mut row_map: HashMap<String, &AliasExpression> = HashMap::new();
350			for alias_expr in row {
351				let column_name = alias_expr.alias.0.text().to_string();
352				row_map.insert(column_name, alias_expr);
353			}
354			rows_data.push(row_map);
355		}
356
357		let session = EvalContext::from_query(ctx);
358
359		let mut columns = Vec::new();
360
361		for column_name in all_columns {
362			let mut all_values = Vec::new();
363			let mut first_value_type: Option<Type> = None;
364			let mut column_fragment: Option<Fragment> = None;
365
366			for row_data in &rows_data {
367				if let Some(alias_expr) = row_data.get(&column_name) {
368					if column_fragment.is_none() {
369						column_fragment = Some(alias_expr.fragment.clone());
370					}
371					let eval_ctx = session.with_eval_empty();
372
373					let evaluated = evaluate(&eval_ctx, &alias_expr.expression)?;
374
375					let mut iter = evaluated.data().iter();
376					if let Some(value) = iter.next() {
377						if first_value_type.is_none() && !matches!(value, Value::None { .. }) {
378							first_value_type = Some(value.get_type());
379						}
380						all_values.push(value);
381					} else {
382						all_values.push(Value::none());
383					}
384				} else {
385					all_values.push(Value::none());
386				}
387			}
388
389			let wide_type = if let Some(ref fvt) = first_value_type {
390				if fvt.is_integer() {
391					Some(Type::Int16)
392				} else if fvt.is_floating_point() {
393					Some(Type::Float8)
394				} else if *fvt == Type::Utf8 {
395					Some(Type::Utf8)
396				} else if *fvt == Type::Boolean {
397					Some(Type::Boolean)
398				} else {
399					None
400				}
401			} else {
402				None
403			};
404
405			let mut column_data = if wide_type.is_none() {
406				ColumnBuffer::none_typed(Type::Boolean, all_values.len())
407			} else {
408				let mut data = ColumnBuffer::with_capacity(wide_type.clone().unwrap(), 0);
409
410				for value in &all_values {
411					if matches!(value, Value::None { .. }) {
412						data.push_none();
413					} else if wide_type.as_ref().is_some_and(|wt| value.get_type() == *wt) {
414						data.push_value(value.clone());
415					} else {
416						let temp_data = ColumnBuffer::from(value.clone());
417						let eval_ctx = session.with_eval_empty();
418
419						match cast_column_data(
420							&eval_ctx,
421							&temp_data,
422							wide_type.clone().unwrap(),
423							Fragment::none,
424						) {
425							Ok(casted) => {
426								if let Some(casted_value) = casted.iter().next() {
427									data.push_value(casted_value);
428								} else {
429									data.push_none();
430								}
431							}
432							Err(_) => {
433								data.push_none();
434							}
435						}
436					}
437				}
438
439				data
440			};
441
442			if wide_type == Some(Type::Int16) {
443				let optimal_type = Self::find_optimal_integer_type(&column_data);
444				if optimal_type != Type::Int16 {
445					let eval_ctx = session.with_eval(Columns::empty(), column_data.len());
446
447					if let Ok(demoted) =
448						cast_column_data(&eval_ctx, &column_data, optimal_type, || {
449							Fragment::none()
450						}) {
451						column_data = demoted;
452					}
453				}
454			}
455
456			columns.push(ColumnWithName::new(
457				column_fragment.unwrap_or_else(|| Fragment::internal(column_name)),
458				column_data,
459			));
460		}
461
462		let columns = Columns::new(columns);
463		self.headers = Some(ColumnHeaders::from_columns(&columns));
464
465		Ok(Some(columns))
466	}
467
468	fn next_with_source(&mut self, ctx: &QueryContext) -> Result<Option<Columns>> {
469		let source = ctx.source.as_ref().unwrap();
470		let headers = self.headers.as_ref().unwrap();
471		let session = EvalContext::from_query(ctx);
472
473		let mut rows_data: Vec<HashMap<String, &AliasExpression>> = Vec::new();
474
475		for row in &self.rows {
476			let mut row_map: HashMap<String, &AliasExpression> = HashMap::new();
477			for alias_expr in row {
478				let column_name = alias_expr.alias.0.text().to_string();
479				row_map.insert(column_name, alias_expr);
480			}
481			rows_data.push(row_map);
482		}
483
484		let mut columns = Vec::new();
485
486		for column_name in &headers.columns {
487			let table_column = source.columns().iter().find(|col| col.name == column_name.text());
488
489			let mut column_data = if let Some(tc) = table_column {
490				ColumnBuffer::none_typed(tc.constraint.get_type(), 0)
491			} else {
492				ColumnBuffer::with_capacity(Type::Int16, 0)
493			};
494			let mut column_fragment: Option<Fragment> = None;
495
496			for row_data in &rows_data {
497				if let Some(alias_expr) = row_data.get(column_name.text()) {
498					if column_fragment.is_none() {
499						column_fragment = Some(alias_expr.fragment.clone());
500					}
501					let mut eval_ctx = session.with_eval_empty();
502					eval_ctx.target = table_column.map(|tc| TargetColumn::Partial {
503						source_name: Some(source.identifier().text().to_string()),
504						column_name: Some(tc.name.clone()),
505						column_type: tc.constraint.get_type(),
506						properties: tc
507							.properties
508							.iter()
509							.map(|cp| cp.property.clone())
510							.collect(),
511					});
512
513					let evaluated = evaluate(&eval_ctx, &alias_expr.expression)?;
514
515					let eval_len = evaluated.data().len();
516					if table_column.is_some() {
517						if eval_len == 1 {
518							column_data.extend(evaluated.data().clone())?;
519						} else if eval_len == 0 {
520							column_data.push_value(Value::none());
521						} else {
522							let first_value =
523								evaluated.data().iter().next().unwrap_or(Value::none());
524							column_data.push_value(first_value);
525						}
526					} else {
527						let value = if eval_len > 0 {
528							evaluated.data().iter().next().unwrap_or(Value::none())
529						} else {
530							Value::none()
531						};
532						match &value {
533							Value::None {
534								..
535							} => column_data.push_none(),
536							Value::Int16(_) => column_data.push_value(value),
537							_ => {
538								let temp = ColumnBuffer::from(value.clone());
539								match cast_column_data(
540									&eval_ctx,
541									&temp,
542									Type::Int16,
543									Fragment::none,
544								) {
545									Ok(casted) => {
546										if let Some(v) = casted.iter().next() {
547											column_data.push_value(v);
548										} else {
549											column_data.push_none();
550										}
551									}
552									Err(_) => column_data.push_value(value),
553								}
554							}
555						}
556					}
557				} else {
558					column_data.push_value(Value::none());
559				}
560			}
561
562			if table_column.is_none() {
563				let optimal_type = Self::find_optimal_integer_type(&column_data);
564				if optimal_type != Type::Int16 {
565					let eval_ctx = session.with_eval(Columns::empty(), column_data.len());
566					if let Ok(demoted) =
567						cast_column_data(&eval_ctx, &column_data, optimal_type, || {
568							Fragment::none()
569						}) {
570						column_data = demoted;
571					}
572				}
573			}
574
575			columns.push(ColumnWithName::new(
576				column_fragment
577					.map(|f| f.with_text(column_name.text()))
578					.unwrap_or_else(|| column_name.clone()),
579				column_data,
580			));
581		}
582
583		let columns = Columns::new(columns);
584
585		Ok(Some(columns))
586	}
587}