Skip to main content

datafusion_catalog/memory/
table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`MemTable`] for querying `Vec<RecordBatch>` by DataFusion.
19
20use std::collections::HashMap;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::TableProvider;
25
26use arrow::array::{
27    Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array,
28};
29use arrow::compute::kernels::zip::zip;
30use arrow::compute::{and, filter_record_batch};
31use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
32use arrow::record_batch::RecordBatch;
33use datafusion_common::error::Result;
34use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err};
35use datafusion_common_runtime::JoinSet;
36use datafusion_datasource::memory::{MemSink, MemorySourceConfig};
37use datafusion_datasource::sink::DataSinkExec;
38use datafusion_datasource::source::DataSourceExec;
39use datafusion_expr::dml::InsertOp;
40use datafusion_expr::{Expr, SortExpr, TableType};
41use datafusion_physical_expr::{
42    LexOrdering, PhysicalExpr, create_physical_expr, create_physical_sort_exprs,
43};
44use datafusion_physical_plan::repartition::RepartitionExec;
45use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
46use datafusion_physical_plan::{
47    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
48    PlanProperties, common,
49};
50use datafusion_session::Session;
51
52use async_trait::async_trait;
53use futures::StreamExt;
54use log::debug;
55use parking_lot::Mutex;
56use tokio::sync::RwLock;
57
58// backward compatibility
59pub use datafusion_datasource::memory::PartitionData;
60
61/// In-memory data source for presenting a `Vec<RecordBatch>` as a
62/// data source that can be queried by DataFusion. This allows data to
63/// be pre-loaded into memory and then repeatedly queried without
64/// incurring additional file I/O overhead.
65#[derive(Debug)]
66pub struct MemTable {
67    schema: SchemaRef,
68    // batches used to be pub(crate), but it's needed to be public for the tests
69    pub batches: Vec<PartitionData>,
70    constraints: Constraints,
71    column_defaults: HashMap<String, Expr>,
72    /// Optional pre-known sort order(s). Must be `SortExpr`s.
73    /// inserting data into this table removes the order
74    pub sort_order: Arc<Mutex<Vec<Vec<SortExpr>>>>,
75}
76
77impl MemTable {
78    /// Create a new in-memory table from the provided schema and record batches.
79    ///
80    /// Requires at least one partition. To construct an empty `MemTable`, pass
81    /// `vec![vec![]]` as the `partitions` argument, this represents one partition with
82    /// no batches.
83    pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
84        if partitions.is_empty() {
85            return plan_err!("No partitions provided, expected at least one partition");
86        }
87
88        for batches in partitions.iter().flatten() {
89            let batches_schema = batches.schema();
90            if !schema.contains(&batches_schema) {
91                debug!(
92                    "mem table schema does not contain batches schema. \
93                        Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
94                );
95                return plan_err!("Mismatch between schema and batches");
96            }
97        }
98
99        Ok(Self {
100            schema,
101            batches: partitions
102                .into_iter()
103                .map(|e| Arc::new(RwLock::new(e)))
104                .collect::<Vec<_>>(),
105            constraints: Constraints::default(),
106            column_defaults: HashMap::new(),
107            sort_order: Arc::new(Mutex::new(vec![])),
108        })
109    }
110
111    /// Assign constraints
112    pub fn with_constraints(mut self, constraints: Constraints) -> Self {
113        self.constraints = constraints;
114        self
115    }
116
117    /// Assign column defaults
118    pub fn with_column_defaults(
119        mut self,
120        column_defaults: HashMap<String, Expr>,
121    ) -> Self {
122        self.column_defaults = column_defaults;
123        self
124    }
125
126    /// Specify an optional pre-known sort order(s). Must be `SortExpr`s.
127    ///
128    /// If the data is not sorted by this order, DataFusion may produce
129    /// incorrect results.
130    ///
131    /// DataFusion may take advantage of this ordering to omit sorts
132    /// or use more efficient algorithms.
133    ///
134    /// Note that multiple sort orders are supported, if some are known to be
135    /// equivalent,
136    pub fn with_sort_order(self, mut sort_order: Vec<Vec<SortExpr>>) -> Self {
137        std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
138        self
139    }
140
141    /// Create a mem table by reading from another data source
142    pub async fn load(
143        t: Arc<dyn TableProvider>,
144        output_partitions: Option<usize>,
145        state: &dyn Session,
146    ) -> Result<Self> {
147        let schema = t.schema();
148        let constraints = t.constraints();
149        let exec = t.scan(state, None, &[], None).await?;
150        let partition_count = exec.output_partitioning().partition_count();
151
152        let mut join_set = JoinSet::new();
153
154        for part_idx in 0..partition_count {
155            let task = state.task_ctx();
156            let exec = Arc::clone(&exec);
157            join_set.spawn(async move {
158                let stream = exec.execute(part_idx, task)?;
159                common::collect(stream).await
160            });
161        }
162
163        let mut data: Vec<Vec<RecordBatch>> =
164            Vec::with_capacity(exec.output_partitioning().partition_count());
165
166        while let Some(result) = join_set.join_next().await {
167            match result {
168                Ok(res) => data.push(res?),
169                Err(e) => {
170                    if e.is_panic() {
171                        std::panic::resume_unwind(e.into_panic());
172                    } else {
173                        unreachable!();
174                    }
175                }
176            }
177        }
178
179        let mut exec = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
180            &data,
181            Arc::clone(&schema),
182            None,
183        )?));
184        if let Some(cons) = constraints {
185            exec = exec.with_constraints(cons.clone());
186        }
187
188        if let Some(num_partitions) = output_partitions {
189            let exec = RepartitionExec::try_new(
190                Arc::new(exec),
191                Partitioning::RoundRobinBatch(num_partitions),
192            )?;
193
194            // execute and collect results
195            let mut output_partitions = vec![];
196            for i in 0..exec.properties().output_partitioning().partition_count() {
197                // execute this *output* partition and collect all batches
198                let task_ctx = state.task_ctx();
199                let mut stream = exec.execute(i, task_ctx)?;
200                let mut batches = vec![];
201                while let Some(result) = stream.next().await {
202                    batches.push(result?);
203                }
204                output_partitions.push(batches);
205            }
206
207            return MemTable::try_new(Arc::clone(&schema), output_partitions);
208        }
209        MemTable::try_new(Arc::clone(&schema), data)
210    }
211}
212
213#[async_trait]
214impl TableProvider for MemTable {
215    fn schema(&self) -> SchemaRef {
216        Arc::clone(&self.schema)
217    }
218
219    fn constraints(&self) -> Option<&Constraints> {
220        Some(&self.constraints)
221    }
222
223    fn table_type(&self) -> TableType {
224        TableType::Base
225    }
226
227    async fn scan(
228        &self,
229        state: &dyn Session,
230        projection: Option<&Vec<usize>>,
231        _filters: &[Expr],
232        _limit: Option<usize>,
233    ) -> Result<Arc<dyn ExecutionPlan>> {
234        let mut partitions = vec![];
235        for arc_inner_vec in self.batches.iter() {
236            let inner_vec = arc_inner_vec.read().await;
237            partitions.push(inner_vec.clone())
238        }
239
240        let mut source =
241            MemorySourceConfig::try_new(&partitions, self.schema(), projection.cloned())?;
242
243        let show_sizes = state.config_options().explain.show_sizes;
244        source = source.with_show_sizes(show_sizes);
245
246        // add sort information if present
247        let sort_order = self.sort_order.lock();
248        if !sort_order.is_empty() {
249            let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
250
251            let eqp = state.execution_props();
252            let mut file_sort_order = vec![];
253            for sort_exprs in sort_order.iter() {
254                let physical_exprs =
255                    create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?;
256                file_sort_order.extend(LexOrdering::new(physical_exprs));
257            }
258            source = source.try_with_sort_information(file_sort_order)?;
259        }
260
261        Ok(DataSourceExec::from_data_source(source))
262    }
263
264    /// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
265    ///
266    /// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
267    ///
268    /// # Arguments
269    ///
270    /// * `state` - The [`SessionState`] containing the context for executing the plan.
271    /// * `input` - The [`ExecutionPlan`] to execute and insert.
272    ///
273    /// # Returns
274    ///
275    /// * A plan that returns the number of rows written.
276    ///
277    /// [`SessionState`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html
278    async fn insert_into(
279        &self,
280        _state: &dyn Session,
281        input: Arc<dyn ExecutionPlan>,
282        insert_op: InsertOp,
283    ) -> Result<Arc<dyn ExecutionPlan>> {
284        // If we are inserting into the table, any sort order may be messed up so reset it here
285        *self.sort_order.lock() = vec![];
286
287        // Create a physical plan from the logical plan.
288        // Check that the schema of the plan matches the schema of this table.
289        self.schema()
290            .logically_equivalent_names_and_types(&input.schema())?;
291
292        if insert_op != InsertOp::Append {
293            return not_impl_err!("{insert_op} not implemented for MemoryTable yet");
294        }
295        let sink = MemSink::try_new(self.batches.clone(), Arc::clone(&self.schema))?;
296        Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None)))
297    }
298
299    fn get_column_default(&self, column: &str) -> Option<&Expr> {
300        self.column_defaults.get(column)
301    }
302
303    async fn delete_from(
304        &self,
305        state: &dyn Session,
306        filters: Vec<Expr>,
307    ) -> Result<Arc<dyn ExecutionPlan>> {
308        // Early exit if table has no partitions
309        if self.batches.is_empty() {
310            return Ok(Arc::new(DmlResultExec::new(0)));
311        }
312
313        *self.sort_order.lock() = vec![];
314
315        let mut total_deleted: u64 = 0;
316        let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
317
318        for partition_data in &self.batches {
319            let mut partition = partition_data.write().await;
320            let mut new_batches = Vec::with_capacity(partition.len());
321
322            for batch in partition.iter() {
323                if batch.num_rows() == 0 {
324                    continue;
325                }
326
327                // Evaluate filters - None means "match all rows"
328                let filter_mask = evaluate_filters_to_mask(
329                    &filters,
330                    batch,
331                    &df_schema,
332                    state.execution_props(),
333                )?;
334
335                let (delete_count, keep_mask) = match filter_mask {
336                    Some(mask) => {
337                        // Count rows where mask is true (will be deleted)
338                        let count = mask.iter().filter(|v| v == &Some(true)).count();
339                        // Keep rows where predicate is false or NULL (SQL three-valued logic)
340                        let keep: BooleanArray =
341                            mask.iter().map(|v| Some(v != Some(true))).collect();
342                        (count, keep)
343                    }
344                    None => {
345                        // No filters = delete all rows
346                        (
347                            batch.num_rows(),
348                            BooleanArray::from(vec![false; batch.num_rows()]),
349                        )
350                    }
351                };
352
353                total_deleted += delete_count as u64;
354
355                let filtered_batch = filter_record_batch(batch, &keep_mask)?;
356                if filtered_batch.num_rows() > 0 {
357                    new_batches.push(filtered_batch);
358                }
359            }
360
361            *partition = new_batches;
362        }
363
364        Ok(Arc::new(DmlResultExec::new(total_deleted)))
365    }
366
367    async fn update(
368        &self,
369        state: &dyn Session,
370        assignments: Vec<(String, Expr)>,
371        filters: Vec<Expr>,
372    ) -> Result<Arc<dyn ExecutionPlan>> {
373        // Early exit if table has no partitions
374        if self.batches.is_empty() {
375            return Ok(Arc::new(DmlResultExec::new(0)));
376        }
377
378        // Validate column names upfront with clear error messages
379        let available_columns: Vec<&str> = self
380            .schema
381            .fields()
382            .iter()
383            .map(|f| f.name().as_str())
384            .collect();
385        for (column_name, _) in &assignments {
386            if self.schema.field_with_name(column_name).is_err() {
387                return plan_err!(
388                    "UPDATE failed: column '{}' does not exist. Available columns: {}",
389                    column_name,
390                    available_columns.join(", ")
391                );
392            }
393        }
394
395        let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
396
397        // Create physical expressions for assignments upfront (outside batch loop)
398        let physical_assignments: HashMap<String, Arc<dyn PhysicalExpr>> = assignments
399            .iter()
400            .map(|(name, expr)| {
401                let physical_expr =
402                    create_physical_expr(expr, &df_schema, state.execution_props())?;
403                Ok((name.clone(), physical_expr))
404            })
405            .collect::<Result<_>>()?;
406
407        *self.sort_order.lock() = vec![];
408
409        let mut total_updated: u64 = 0;
410
411        for partition_data in &self.batches {
412            let mut partition = partition_data.write().await;
413            let mut new_batches = Vec::with_capacity(partition.len());
414
415            for batch in partition.iter() {
416                if batch.num_rows() == 0 {
417                    continue;
418                }
419
420                // Evaluate filters - None means "match all rows"
421                let filter_mask = evaluate_filters_to_mask(
422                    &filters,
423                    batch,
424                    &df_schema,
425                    state.execution_props(),
426                )?;
427
428                let (update_count, update_mask) = match filter_mask {
429                    Some(mask) => {
430                        // Count rows where mask is true (will be updated)
431                        let count = mask.iter().filter(|v| v == &Some(true)).count();
432                        // Normalize mask: only true (not NULL) triggers update
433                        let normalized: BooleanArray =
434                            mask.iter().map(|v| Some(v == Some(true))).collect();
435                        (count, normalized)
436                    }
437                    None => {
438                        // No filters = update all rows
439                        (
440                            batch.num_rows(),
441                            BooleanArray::from(vec![true; batch.num_rows()]),
442                        )
443                    }
444                };
445
446                total_updated += update_count as u64;
447
448                if update_count == 0 {
449                    new_batches.push(batch.clone());
450                    continue;
451                }
452
453                let mut new_columns: Vec<ArrayRef> =
454                    Vec::with_capacity(batch.num_columns());
455
456                for field in self.schema.fields() {
457                    let column_name = field.name();
458                    let original_column =
459                        batch.column_by_name(column_name).ok_or_else(|| {
460                            datafusion_common::DataFusionError::Internal(format!(
461                                "Column '{column_name}' not found in batch"
462                            ))
463                        })?;
464
465                    let new_column = if let Some(physical_expr) =
466                        physical_assignments.get(column_name.as_str())
467                    {
468                        // Use evaluate_selection to only evaluate on matching rows.
469                        // This avoids errors (e.g., divide-by-zero) on rows that won't
470                        // be updated. The result is scattered back with nulls for
471                        // non-matching rows, which zip() will replace with originals.
472                        let new_values =
473                            physical_expr.evaluate_selection(batch, &update_mask)?;
474                        let new_array = new_values.into_array(batch.num_rows())?;
475
476                        // Convert to &dyn Array which implements Datum
477                        let new_arr: &dyn Array = new_array.as_ref();
478                        let orig_arr: &dyn Array = original_column.as_ref();
479                        zip(&update_mask, &new_arr, &orig_arr)?
480                    } else {
481                        Arc::clone(original_column)
482                    };
483
484                    new_columns.push(new_column);
485                }
486
487                let updated_batch =
488                    ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?;
489                new_batches.push(updated_batch);
490            }
491
492            *partition = new_batches;
493        }
494
495        Ok(Arc::new(DmlResultExec::new(total_updated)))
496    }
497}
498
499/// Evaluate filter expressions against a batch and return a combined boolean mask.
500/// Returns None if filters is empty (meaning "match all rows").
501/// The returned mask has true for rows that match the filter predicates.
502fn evaluate_filters_to_mask(
503    filters: &[Expr],
504    batch: &RecordBatch,
505    df_schema: &DFSchema,
506    execution_props: &datafusion_expr::execution_props::ExecutionProps,
507) -> Result<Option<BooleanArray>> {
508    if filters.is_empty() {
509        return Ok(None);
510    }
511
512    let mut combined_mask: Option<BooleanArray> = None;
513
514    for filter_expr in filters {
515        let physical_expr =
516            create_physical_expr(filter_expr, df_schema, execution_props)?;
517
518        let result = physical_expr.evaluate(batch)?;
519        let array = result.into_array(batch.num_rows())?;
520        let bool_array = array
521            .as_any()
522            .downcast_ref::<BooleanArray>()
523            .ok_or_else(|| {
524                datafusion_common::DataFusionError::Internal(
525                    "Filter did not evaluate to boolean".to_string(),
526                )
527            })?
528            .clone();
529
530        combined_mask = Some(match combined_mask {
531            Some(existing) => and(&existing, &bool_array)?,
532            None => bool_array,
533        });
534    }
535
536    Ok(combined_mask)
537}
538
539/// Returns a single row with the count of affected rows.
540#[derive(Debug)]
541struct DmlResultExec {
542    rows_affected: u64,
543    schema: SchemaRef,
544    properties: Arc<PlanProperties>,
545}
546
547impl DmlResultExec {
548    fn new(rows_affected: u64) -> Self {
549        let schema = Arc::new(Schema::new(vec![Field::new(
550            "count",
551            DataType::UInt64,
552            false,
553        )]));
554
555        let properties = PlanProperties::new(
556            datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)),
557            Partitioning::UnknownPartitioning(1),
558            datafusion_physical_plan::execution_plan::EmissionType::Final,
559            datafusion_physical_plan::execution_plan::Boundedness::Bounded,
560        );
561
562        Self {
563            rows_affected,
564            schema,
565            properties: Arc::new(properties),
566        }
567    }
568}
569
570impl DisplayAs for DmlResultExec {
571    fn fmt_as(
572        &self,
573        t: DisplayFormatType,
574        f: &mut std::fmt::Formatter,
575    ) -> std::fmt::Result {
576        match t {
577            DisplayFormatType::Default
578            | DisplayFormatType::Verbose
579            | DisplayFormatType::TreeRender => {
580                write!(f, "DmlResultExec: rows_affected={}", self.rows_affected)
581            }
582        }
583    }
584}
585
586impl ExecutionPlan for DmlResultExec {
587    fn name(&self) -> &str {
588        "DmlResultExec"
589    }
590
591    fn schema(&self) -> SchemaRef {
592        Arc::clone(&self.schema)
593    }
594
595    fn properties(&self) -> &Arc<PlanProperties> {
596        &self.properties
597    }
598
599    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
600        vec![]
601    }
602
603    fn with_new_children(
604        self: Arc<Self>,
605        _children: Vec<Arc<dyn ExecutionPlan>>,
606    ) -> Result<Arc<dyn ExecutionPlan>> {
607        Ok(self)
608    }
609
610    fn execute(
611        &self,
612        _partition: usize,
613        _context: Arc<datafusion_execution::TaskContext>,
614    ) -> Result<datafusion_execution::SendableRecordBatchStream> {
615        // Create a single batch with the count
616        let count_array = UInt64Array::from(vec![self.rows_affected]);
617        let batch = ArrowRecordBatch::try_new(
618            Arc::clone(&self.schema),
619            vec![Arc::new(count_array) as ArrayRef],
620        )?;
621
622        // Create a stream that yields just this one batch
623        let stream = futures::stream::iter(vec![Ok(batch)]);
624        Ok(Box::pin(RecordBatchStreamAdapter::new(
625            Arc::clone(&self.schema),
626            stream,
627        )))
628    }
629}