reifydb-engine 0.4.12

Query execution and processing engine for ReifyDB
Documentation
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2025 ReifyDB

use std::{collections::HashMap, sync::Arc};

use reifydb_core::{
	encoded::{row::EncodedRow, shape::RowShape},
	error::diagnostic::{
		catalog::{namespace_not_found, table_not_found},
		index::primary_key_violation,
	},
	interface::{
		catalog::{
			id::IndexId,
			policy::{DataOp, PolicyTargetType},
		},
		resolved::{ResolvedColumn, ResolvedNamespace, ResolvedShape, ResolvedTable},
	},
	internal_error,
	key::{EncodableKey, index_entry::IndexEntryKey},
	value::column::columns::Columns,
};
use reifydb_rql::nodes::InsertTableNode;
use reifydb_transaction::transaction::Transaction;
use reifydb_type::{
	fragment::Fragment,
	params::Params,
	return_error,
	value::{Value, identity::IdentityId, row_number::RowNumber, r#type::Type},
};
use tracing::instrument;

use super::{
	primary_key,
	returning::{decode_rows_to_columns, evaluate_returning},
	shape::get_or_create_table_shape,
};
use crate::{
	Result,
	policy::PolicyEvaluator,
	transaction::operation::{dictionary::DictionaryOperations, table::TableOperations},
	vm::{
		instruction::dml::coerce::coerce_value_to_column_type,
		services::Services,
		stack::SymbolTable,
		volcano::{
			compile::compile,
			query::{QueryContext, QueryNode},
		},
	},
};

#[instrument(name = "mutate::table::insert", level = "trace", skip_all)]
pub(crate) fn insert_table(
	services: &Arc<Services>,
	txn: &mut Transaction<'_>,
	plan: InsertTableNode,
	symbols: &mut SymbolTable,
) -> Result<Columns> {
	let namespace_name = plan.target.namespace().name();

	let Some(namespace) = services.catalog.find_namespace_by_name(txn, namespace_name)? else {
		return_error!(namespace_not_found(Fragment::internal(namespace_name), namespace_name));
	};

	let table_name = plan.target.name();
	let Some(table) = services.catalog.find_table_by_name(txn, namespace.id(), table_name)? else {
		let fragment = plan.target.identifier().clone();
		return_error!(table_not_found(fragment.clone(), namespace_name, table_name,));
	};

	// Get or create shape with proper field names and constraints
	let shape = get_or_create_table_shape(&services.catalog, &table, txn)?;

	// Create resolved source for the table
	let namespace_ident = Fragment::internal(namespace.name());
	let resolved_namespace = ResolvedNamespace::new(namespace_ident, namespace.clone());

	let table_ident = Fragment::internal(table.name.clone());
	let resolved_table = ResolvedTable::new(table_ident, resolved_namespace, table.clone());
	let resolved_source = Some(ResolvedShape::Table(resolved_table));

	let execution_context = Arc::new(QueryContext {
		services: services.clone(),
		source: resolved_source,
		batch_size: 1024,
		params: Params::None,
		symbols: symbols.clone(),
		identity: IdentityId::root(),
	});

	let mut input_node = compile(*plan.input, txn, execution_context.clone());

	// Initialize the operator before execution
	input_node.initialize(txn, &execution_context)?;

	// PASS 1: Validate and encode all rows first, before allocating any row numbers
	// This ensures we only allocate row numbers for valid rows (fail-fast on validation errors)
	let mut validated_rows: Vec<EncodedRow> = Vec::new();
	let mut mutable_context = (*execution_context).clone();

	while let Some(columns) = input_node.next(txn, &mut mutable_context)? {
		// Enforce write policies before processing rows
		PolicyEvaluator::new(services, symbols).enforce_write_policies(
			txn,
			namespace_name,
			table_name,
			DataOp::Insert,
			&columns,
			PolicyTargetType::Table,
		)?;

		let row_count = columns.row_count();

		let mut column_map: HashMap<&str, usize> = HashMap::new();
		for (idx, col) in columns.iter().enumerate() {
			column_map.insert(col.name().text(), idx);
		}

		for row_numberx in 0..row_count {
			let mut row = shape.allocate();

			// For each table column, find if it exists in the input columns
			for (table_idx, table_column) in table.columns.iter().enumerate() {
				let mut value = if let Some(&input_idx) = column_map.get(table_column.name.as_str()) {
					columns[input_idx].data().get_value(row_numberx)
				} else {
					Value::none()
				};

				// Handle auto-increment columns
				if table_column.auto_increment && matches!(value, Value::None { .. }) {
					value = services.catalog.column_sequence_next_value(
						txn,
						table.id,
						table_column.id,
					)?;
				}

				// Create ResolvedColumn for this column
				let column_ident = column_map
					.get(table_column.name.as_str())
					.map(|&idx| columns[idx].name().clone())
					.unwrap_or_else(|| Fragment::internal(table_column.name.clone()));
				let resolved_column = ResolvedColumn::new(
					column_ident.clone(),
					execution_context.source.clone().unwrap(),
					table_column.clone(),
				);

				value = coerce_value_to_column_type(
					value,
					table_column.constraint.get_type(),
					resolved_column,
					&execution_context,
				)?;

				// Validate the value against the column's constraint
				if let Err(mut e) = table_column.constraint.validate(&value) {
					e.0.fragment = column_ident.clone();
					return Err(e);
				}

				// Dictionary encoding: if column has a dictionary binding, encode the value
				let value = if let Some(dict_id) = table_column.dictionary_id {
					let dictionary =
						services.catalog.find_dictionary(txn, dict_id)?.ok_or_else(|| {
							internal_error!(
								"Dictionary {:?} not found for column {}",
								dict_id,
								table_column.name
							)
						})?;
					let entry_id = txn.insert_into_dictionary(&dictionary, &value)?;
					entry_id.to_value()
				} else {
					value
				};

				shape.set_value(&mut row, table_idx, &value);
			}

			let now_nanos = services.runtime_context.clock.now_nanos();
			row.set_timestamps(now_nanos, now_nanos);

			validated_rows.push(row);
		}
	}

	// BATCH ALLOCATION: Now that all rows are validated, allocate row numbers in one batch
	let total_rows = validated_rows.len();
	if total_rows == 0 {
		// No rows to insert, return early
		return Ok(Columns::single_row([
			("namespace", Value::Utf8(namespace.name().to_string())),
			("table", Value::Utf8(table.name)),
			("inserted", Value::Uint8(0)),
		]));
	}

	let row_numbers = services.catalog.next_row_number_batch(txn, table.id, total_rows as u64)?;

	assert_eq!(row_numbers.len(), validated_rows.len());

	// Hoist loop-invariant computations out of PASS 2
	let pk_def = primary_key::get_primary_key(&services.catalog, txn, &table)?;
	let row_number_shape = if pk_def.is_some() {
		Some(RowShape::testing(&[Type::Uint8]))
	} else {
		None
	};

	// PASS 2: Insert all validated rows using the pre-allocated row numbers
	let mut returned_rows: Vec<(RowNumber, EncodedRow)> = if plan.returning.is_some() {
		Vec::with_capacity(total_rows)
	} else {
		Vec::new()
	};

	for (row, &row_number) in validated_rows.iter().zip(row_numbers.iter()) {
		// Insert the row directly into storage
		let stored_row = txn.insert_table(&table, &shape, row.clone(), row_number)?;

		if plan.returning.is_some() {
			returned_rows.push((row_number, stored_row));
		}

		// Store primary key index entry if table has one
		if let Some(ref pk_def) = pk_def {
			let index_key = primary_key::encode_primary_key(pk_def, row, &table, &shape)?;

			// Check if primary key already exists
			let index_entry_key =
				IndexEntryKey::new(table.id, IndexId::primary(pk_def.id), index_key.clone());
			if txn.contains_key(&index_entry_key.encode())? {
				let key_columns = pk_def.columns.iter().map(|c| c.name.clone()).collect();
				return_error!(primary_key_violation(
					plan.target.identifier().clone(),
					table.name.clone(),
					key_columns,
				));
			}

			// Store the index entry with the row number as value
			let rns = row_number_shape.as_ref().unwrap();
			let mut row_number_encoded = rns.allocate();
			rns.set_u64(&mut row_number_encoded, 0, u64::from(row_number));

			txn.set(&index_entry_key.encode(), row_number_encoded)?;
		}
	}

	// If RETURNING clause is present, evaluate expressions against inserted rows
	if let Some(returning_exprs) = &plan.returning {
		let columns = decode_rows_to_columns(&shape, &returned_rows);
		return evaluate_returning(services, symbols, returning_exprs, columns);
	}

	// Return summary columns
	Ok(Columns::single_row([
		("namespace", Value::Utf8(namespace.name().to_string())),
		("table", Value::Utf8(table.name)),
		("inserted", Value::Uint8(total_rows as u64)),
	]))
}