datafusion_catalog/memory/
table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`MemTable`] for querying `Vec<RecordBatch>` by DataFusion.
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::TableProvider;
26
27use arrow::array::{
28    Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array,
29};
30use arrow::compute::kernels::zip::zip;
31use arrow::compute::{and, filter_record_batch};
32use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
33use arrow::record_batch::RecordBatch;
34use datafusion_common::error::Result;
35use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err};
36use datafusion_common_runtime::JoinSet;
37use datafusion_datasource::memory::{MemSink, MemorySourceConfig};
38use datafusion_datasource::sink::DataSinkExec;
39use datafusion_datasource::source::DataSourceExec;
40use datafusion_expr::dml::InsertOp;
41use datafusion_expr::{Expr, SortExpr, TableType};
42use datafusion_physical_expr::{
43    LexOrdering, create_physical_expr, create_physical_sort_exprs,
44};
45use datafusion_physical_plan::repartition::RepartitionExec;
46use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
47use datafusion_physical_plan::{
48    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
49    PlanProperties, common,
50};
51use datafusion_session::Session;
52
53use async_trait::async_trait;
54use futures::StreamExt;
55use log::debug;
56use parking_lot::Mutex;
57use tokio::sync::RwLock;
58
59// backward compatibility
60pub use datafusion_datasource::memory::PartitionData;
61
62/// In-memory data source for presenting a `Vec<RecordBatch>` as a
63/// data source that can be queried by DataFusion. This allows data to
64/// be pre-loaded into memory and then repeatedly queried without
65/// incurring additional file I/O overhead.
66#[derive(Debug)]
67pub struct MemTable {
68    schema: SchemaRef,
69    // batches used to be pub(crate), but it's needed to be public for the tests
70    pub batches: Vec<PartitionData>,
71    constraints: Constraints,
72    column_defaults: HashMap<String, Expr>,
73    /// Optional pre-known sort order(s). Must be `SortExpr`s.
74    /// inserting data into this table removes the order
75    pub sort_order: Arc<Mutex<Vec<Vec<SortExpr>>>>,
76}
77
78impl MemTable {
79    /// Create a new in-memory table from the provided schema and record batches.
80    ///
81    /// Requires at least one partition. To construct an empty `MemTable`, pass
82    /// `vec![vec![]]` as the `partitions` argument, this represents one partition with
83    /// no batches.
84    pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
85        if partitions.is_empty() {
86            return plan_err!("No partitions provided, expected at least one partition");
87        }
88
89        for batches in partitions.iter().flatten() {
90            let batches_schema = batches.schema();
91            if !schema.contains(&batches_schema) {
92                debug!(
93                    "mem table schema does not contain batches schema. \
94                        Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
95                );
96                return plan_err!("Mismatch between schema and batches");
97            }
98        }
99
100        Ok(Self {
101            schema,
102            batches: partitions
103                .into_iter()
104                .map(|e| Arc::new(RwLock::new(e)))
105                .collect::<Vec<_>>(),
106            constraints: Constraints::default(),
107            column_defaults: HashMap::new(),
108            sort_order: Arc::new(Mutex::new(vec![])),
109        })
110    }
111
112    /// Assign constraints
113    pub fn with_constraints(mut self, constraints: Constraints) -> Self {
114        self.constraints = constraints;
115        self
116    }
117
118    /// Assign column defaults
119    pub fn with_column_defaults(
120        mut self,
121        column_defaults: HashMap<String, Expr>,
122    ) -> Self {
123        self.column_defaults = column_defaults;
124        self
125    }
126
127    /// Specify an optional pre-known sort order(s). Must be `SortExpr`s.
128    ///
129    /// If the data is not sorted by this order, DataFusion may produce
130    /// incorrect results.
131    ///
132    /// DataFusion may take advantage of this ordering to omit sorts
133    /// or use more efficient algorithms.
134    ///
135    /// Note that multiple sort orders are supported, if some are known to be
136    /// equivalent,
137    pub fn with_sort_order(self, mut sort_order: Vec<Vec<SortExpr>>) -> Self {
138        std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
139        self
140    }
141
142    /// Create a mem table by reading from another data source
143    pub async fn load(
144        t: Arc<dyn TableProvider>,
145        output_partitions: Option<usize>,
146        state: &dyn Session,
147    ) -> Result<Self> {
148        let schema = t.schema();
149        let constraints = t.constraints();
150        let exec = t.scan(state, None, &[], None).await?;
151        let partition_count = exec.output_partitioning().partition_count();
152
153        let mut join_set = JoinSet::new();
154
155        for part_idx in 0..partition_count {
156            let task = state.task_ctx();
157            let exec = Arc::clone(&exec);
158            join_set.spawn(async move {
159                let stream = exec.execute(part_idx, task)?;
160                common::collect(stream).await
161            });
162        }
163
164        let mut data: Vec<Vec<RecordBatch>> =
165            Vec::with_capacity(exec.output_partitioning().partition_count());
166
167        while let Some(result) = join_set.join_next().await {
168            match result {
169                Ok(res) => data.push(res?),
170                Err(e) => {
171                    if e.is_panic() {
172                        std::panic::resume_unwind(e.into_panic());
173                    } else {
174                        unreachable!();
175                    }
176                }
177            }
178        }
179
180        let mut exec = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
181            &data,
182            Arc::clone(&schema),
183            None,
184        )?));
185        if let Some(cons) = constraints {
186            exec = exec.with_constraints(cons.clone());
187        }
188
189        if let Some(num_partitions) = output_partitions {
190            let exec = RepartitionExec::try_new(
191                Arc::new(exec),
192                Partitioning::RoundRobinBatch(num_partitions),
193            )?;
194
195            // execute and collect results
196            let mut output_partitions = vec![];
197            for i in 0..exec.properties().output_partitioning().partition_count() {
198                // execute this *output* partition and collect all batches
199                let task_ctx = state.task_ctx();
200                let mut stream = exec.execute(i, task_ctx)?;
201                let mut batches = vec![];
202                while let Some(result) = stream.next().await {
203                    batches.push(result?);
204                }
205                output_partitions.push(batches);
206            }
207
208            return MemTable::try_new(Arc::clone(&schema), output_partitions);
209        }
210        MemTable::try_new(Arc::clone(&schema), data)
211    }
212}
213
214#[async_trait]
215impl TableProvider for MemTable {
216    fn as_any(&self) -> &dyn Any {
217        self
218    }
219
220    fn schema(&self) -> SchemaRef {
221        Arc::clone(&self.schema)
222    }
223
224    fn constraints(&self) -> Option<&Constraints> {
225        Some(&self.constraints)
226    }
227
228    fn table_type(&self) -> TableType {
229        TableType::Base
230    }
231
232    async fn scan(
233        &self,
234        state: &dyn Session,
235        projection: Option<&Vec<usize>>,
236        _filters: &[Expr],
237        _limit: Option<usize>,
238    ) -> Result<Arc<dyn ExecutionPlan>> {
239        let mut partitions = vec![];
240        for arc_inner_vec in self.batches.iter() {
241            let inner_vec = arc_inner_vec.read().await;
242            partitions.push(inner_vec.clone())
243        }
244
245        let mut source =
246            MemorySourceConfig::try_new(&partitions, self.schema(), projection.cloned())?;
247
248        let show_sizes = state.config_options().explain.show_sizes;
249        source = source.with_show_sizes(show_sizes);
250
251        // add sort information if present
252        let sort_order = self.sort_order.lock();
253        if !sort_order.is_empty() {
254            let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
255
256            let eqp = state.execution_props();
257            let mut file_sort_order = vec![];
258            for sort_exprs in sort_order.iter() {
259                let physical_exprs =
260                    create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?;
261                file_sort_order.extend(LexOrdering::new(physical_exprs));
262            }
263            source = source.try_with_sort_information(file_sort_order)?;
264        }
265
266        Ok(DataSourceExec::from_data_source(source))
267    }
268
269    /// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
270    ///
271    /// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
272    ///
273    /// # Arguments
274    ///
275    /// * `state` - The [`SessionState`] containing the context for executing the plan.
276    /// * `input` - The [`ExecutionPlan`] to execute and insert.
277    ///
278    /// # Returns
279    ///
280    /// * A plan that returns the number of rows written.
281    ///
282    /// [`SessionState`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html
283    async fn insert_into(
284        &self,
285        _state: &dyn Session,
286        input: Arc<dyn ExecutionPlan>,
287        insert_op: InsertOp,
288    ) -> Result<Arc<dyn ExecutionPlan>> {
289        // If we are inserting into the table, any sort order may be messed up so reset it here
290        *self.sort_order.lock() = vec![];
291
292        // Create a physical plan from the logical plan.
293        // Check that the schema of the plan matches the schema of this table.
294        self.schema()
295            .logically_equivalent_names_and_types(&input.schema())?;
296
297        if insert_op != InsertOp::Append {
298            return not_impl_err!("{insert_op} not implemented for MemoryTable yet");
299        }
300        let sink = MemSink::try_new(self.batches.clone(), Arc::clone(&self.schema))?;
301        Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None)))
302    }
303
304    fn get_column_default(&self, column: &str) -> Option<&Expr> {
305        self.column_defaults.get(column)
306    }
307
308    async fn delete_from(
309        &self,
310        state: &dyn Session,
311        filters: Vec<Expr>,
312    ) -> Result<Arc<dyn ExecutionPlan>> {
313        // Early exit if table has no partitions
314        if self.batches.is_empty() {
315            return Ok(Arc::new(DmlResultExec::new(0)));
316        }
317
318        *self.sort_order.lock() = vec![];
319
320        let mut total_deleted: u64 = 0;
321        let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
322
323        for partition_data in &self.batches {
324            let mut partition = partition_data.write().await;
325            let mut new_batches = Vec::with_capacity(partition.len());
326
327            for batch in partition.iter() {
328                if batch.num_rows() == 0 {
329                    continue;
330                }
331
332                // Evaluate filters - None means "match all rows"
333                let filter_mask = evaluate_filters_to_mask(
334                    &filters,
335                    batch,
336                    &df_schema,
337                    state.execution_props(),
338                )?;
339
340                let (delete_count, keep_mask) = match filter_mask {
341                    Some(mask) => {
342                        // Count rows where mask is true (will be deleted)
343                        let count = mask.iter().filter(|v| v == &Some(true)).count();
344                        // Keep rows where predicate is false or NULL (SQL three-valued logic)
345                        let keep: BooleanArray =
346                            mask.iter().map(|v| Some(v != Some(true))).collect();
347                        (count, keep)
348                    }
349                    None => {
350                        // No filters = delete all rows
351                        (
352                            batch.num_rows(),
353                            BooleanArray::from(vec![false; batch.num_rows()]),
354                        )
355                    }
356                };
357
358                total_deleted += delete_count as u64;
359
360                let filtered_batch = filter_record_batch(batch, &keep_mask)?;
361                if filtered_batch.num_rows() > 0 {
362                    new_batches.push(filtered_batch);
363                }
364            }
365
366            *partition = new_batches;
367        }
368
369        Ok(Arc::new(DmlResultExec::new(total_deleted)))
370    }
371
372    async fn update(
373        &self,
374        state: &dyn Session,
375        assignments: Vec<(String, Expr)>,
376        filters: Vec<Expr>,
377    ) -> Result<Arc<dyn ExecutionPlan>> {
378        // Early exit if table has no partitions
379        if self.batches.is_empty() {
380            return Ok(Arc::new(DmlResultExec::new(0)));
381        }
382
383        // Validate column names upfront with clear error messages
384        let available_columns: Vec<&str> = self
385            .schema
386            .fields()
387            .iter()
388            .map(|f| f.name().as_str())
389            .collect();
390        for (column_name, _) in &assignments {
391            if self.schema.field_with_name(column_name).is_err() {
392                return plan_err!(
393                    "UPDATE failed: column '{}' does not exist. Available columns: {}",
394                    column_name,
395                    available_columns.join(", ")
396                );
397            }
398        }
399
400        let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
401
402        // Create physical expressions for assignments upfront (outside batch loop)
403        let physical_assignments: HashMap<
404            String,
405            Arc<dyn datafusion_physical_plan::PhysicalExpr>,
406        > = assignments
407            .iter()
408            .map(|(name, expr)| {
409                let physical_expr =
410                    create_physical_expr(expr, &df_schema, state.execution_props())?;
411                Ok((name.clone(), physical_expr))
412            })
413            .collect::<Result<_>>()?;
414
415        *self.sort_order.lock() = vec![];
416
417        let mut total_updated: u64 = 0;
418
419        for partition_data in &self.batches {
420            let mut partition = partition_data.write().await;
421            let mut new_batches = Vec::with_capacity(partition.len());
422
423            for batch in partition.iter() {
424                if batch.num_rows() == 0 {
425                    continue;
426                }
427
428                // Evaluate filters - None means "match all rows"
429                let filter_mask = evaluate_filters_to_mask(
430                    &filters,
431                    batch,
432                    &df_schema,
433                    state.execution_props(),
434                )?;
435
436                let (update_count, update_mask) = match filter_mask {
437                    Some(mask) => {
438                        // Count rows where mask is true (will be updated)
439                        let count = mask.iter().filter(|v| v == &Some(true)).count();
440                        // Normalize mask: only true (not NULL) triggers update
441                        let normalized: BooleanArray =
442                            mask.iter().map(|v| Some(v == Some(true))).collect();
443                        (count, normalized)
444                    }
445                    None => {
446                        // No filters = update all rows
447                        (
448                            batch.num_rows(),
449                            BooleanArray::from(vec![true; batch.num_rows()]),
450                        )
451                    }
452                };
453
454                total_updated += update_count as u64;
455
456                if update_count == 0 {
457                    new_batches.push(batch.clone());
458                    continue;
459                }
460
461                let mut new_columns: Vec<ArrayRef> =
462                    Vec::with_capacity(batch.num_columns());
463
464                for field in self.schema.fields() {
465                    let column_name = field.name();
466                    let original_column =
467                        batch.column_by_name(column_name).ok_or_else(|| {
468                            datafusion_common::DataFusionError::Internal(format!(
469                                "Column '{column_name}' not found in batch"
470                            ))
471                        })?;
472
473                    let new_column = if let Some(physical_expr) =
474                        physical_assignments.get(column_name.as_str())
475                    {
476                        // Use evaluate_selection to only evaluate on matching rows.
477                        // This avoids errors (e.g., divide-by-zero) on rows that won't
478                        // be updated. The result is scattered back with nulls for
479                        // non-matching rows, which zip() will replace with originals.
480                        let new_values =
481                            physical_expr.evaluate_selection(batch, &update_mask)?;
482                        let new_array = new_values.into_array(batch.num_rows())?;
483
484                        // Convert to &dyn Array which implements Datum
485                        let new_arr: &dyn Array = new_array.as_ref();
486                        let orig_arr: &dyn Array = original_column.as_ref();
487                        zip(&update_mask, &new_arr, &orig_arr)?
488                    } else {
489                        Arc::clone(original_column)
490                    };
491
492                    new_columns.push(new_column);
493                }
494
495                let updated_batch =
496                    ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?;
497                new_batches.push(updated_batch);
498            }
499
500            *partition = new_batches;
501        }
502
503        Ok(Arc::new(DmlResultExec::new(total_updated)))
504    }
505}
506
507/// Evaluate filter expressions against a batch and return a combined boolean mask.
508/// Returns None if filters is empty (meaning "match all rows").
509/// The returned mask has true for rows that match the filter predicates.
510fn evaluate_filters_to_mask(
511    filters: &[Expr],
512    batch: &RecordBatch,
513    df_schema: &DFSchema,
514    execution_props: &datafusion_expr::execution_props::ExecutionProps,
515) -> Result<Option<BooleanArray>> {
516    if filters.is_empty() {
517        return Ok(None);
518    }
519
520    let mut combined_mask: Option<BooleanArray> = None;
521
522    for filter_expr in filters {
523        let physical_expr =
524            create_physical_expr(filter_expr, df_schema, execution_props)?;
525
526        let result = physical_expr.evaluate(batch)?;
527        let array = result.into_array(batch.num_rows())?;
528        let bool_array = array
529            .as_any()
530            .downcast_ref::<BooleanArray>()
531            .ok_or_else(|| {
532                datafusion_common::DataFusionError::Internal(
533                    "Filter did not evaluate to boolean".to_string(),
534                )
535            })?
536            .clone();
537
538        combined_mask = Some(match combined_mask {
539            Some(existing) => and(&existing, &bool_array)?,
540            None => bool_array,
541        });
542    }
543
544    Ok(combined_mask)
545}
546
547/// Returns a single row with the count of affected rows.
548#[derive(Debug)]
549struct DmlResultExec {
550    rows_affected: u64,
551    schema: SchemaRef,
552    properties: PlanProperties,
553}
554
555impl DmlResultExec {
556    fn new(rows_affected: u64) -> Self {
557        let schema = Arc::new(Schema::new(vec![Field::new(
558            "count",
559            DataType::UInt64,
560            false,
561        )]));
562
563        let properties = PlanProperties::new(
564            datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)),
565            Partitioning::UnknownPartitioning(1),
566            datafusion_physical_plan::execution_plan::EmissionType::Final,
567            datafusion_physical_plan::execution_plan::Boundedness::Bounded,
568        );
569
570        Self {
571            rows_affected,
572            schema,
573            properties,
574        }
575    }
576}
577
578impl DisplayAs for DmlResultExec {
579    fn fmt_as(
580        &self,
581        t: DisplayFormatType,
582        f: &mut std::fmt::Formatter,
583    ) -> std::fmt::Result {
584        match t {
585            DisplayFormatType::Default
586            | DisplayFormatType::Verbose
587            | DisplayFormatType::TreeRender => {
588                write!(f, "DmlResultExec: rows_affected={}", self.rows_affected)
589            }
590        }
591    }
592}
593
594impl ExecutionPlan for DmlResultExec {
595    fn name(&self) -> &str {
596        "DmlResultExec"
597    }
598
599    fn as_any(&self) -> &dyn Any {
600        self
601    }
602
603    fn schema(&self) -> SchemaRef {
604        Arc::clone(&self.schema)
605    }
606
607    fn properties(&self) -> &PlanProperties {
608        &self.properties
609    }
610
611    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
612        vec![]
613    }
614
615    fn with_new_children(
616        self: Arc<Self>,
617        _children: Vec<Arc<dyn ExecutionPlan>>,
618    ) -> Result<Arc<dyn ExecutionPlan>> {
619        Ok(self)
620    }
621
622    fn execute(
623        &self,
624        _partition: usize,
625        _context: Arc<datafusion_execution::TaskContext>,
626    ) -> Result<datafusion_execution::SendableRecordBatchStream> {
627        // Create a single batch with the count
628        let count_array = UInt64Array::from(vec![self.rows_affected]);
629        let batch = ArrowRecordBatch::try_new(
630            Arc::clone(&self.schema),
631            vec![Arc::new(count_array) as ArrayRef],
632        )?;
633
634        // Create a stream that yields just this one batch
635        let stream = futures::stream::iter(vec![Ok(batch)]);
636        Ok(Box::pin(RecordBatchStreamAdapter::new(
637            Arc::clone(&self.schema),
638            stream,
639        )))
640    }
641}