datafusion_catalog/memory/
table.rs1use std::collections::HashMap;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::TableProvider;
25
26use arrow::array::{
27 Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array,
28};
29use arrow::compute::kernels::zip::zip;
30use arrow::compute::{and, filter_record_batch};
31use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
32use arrow::record_batch::RecordBatch;
33use datafusion_common::error::Result;
34use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err};
35use datafusion_common_runtime::JoinSet;
36use datafusion_datasource::memory::{MemSink, MemorySourceConfig};
37use datafusion_datasource::sink::DataSinkExec;
38use datafusion_datasource::source::DataSourceExec;
39use datafusion_expr::dml::InsertOp;
40use datafusion_expr::{Expr, SortExpr, TableType};
41use datafusion_physical_expr::{
42 LexOrdering, PhysicalExpr, create_physical_expr, create_physical_sort_exprs,
43};
44use datafusion_physical_plan::repartition::RepartitionExec;
45use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
46use datafusion_physical_plan::{
47 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
48 PlanProperties, common,
49};
50use datafusion_session::Session;
51
52use async_trait::async_trait;
53use futures::StreamExt;
54use log::debug;
55use parking_lot::Mutex;
56use tokio::sync::RwLock;
57
58pub use datafusion_datasource::memory::PartitionData;
60
61#[derive(Debug)]
66pub struct MemTable {
67 schema: SchemaRef,
68 pub batches: Vec<PartitionData>,
70 constraints: Constraints,
71 column_defaults: HashMap<String, Expr>,
72 pub sort_order: Arc<Mutex<Vec<Vec<SortExpr>>>>,
75}
76
77impl MemTable {
78 pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
84 if partitions.is_empty() {
85 return plan_err!("No partitions provided, expected at least one partition");
86 }
87
88 for batches in partitions.iter().flatten() {
89 let batches_schema = batches.schema();
90 if !schema.contains(&batches_schema) {
91 debug!(
92 "mem table schema does not contain batches schema. \
93 Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
94 );
95 return plan_err!("Mismatch between schema and batches");
96 }
97 }
98
99 Ok(Self {
100 schema,
101 batches: partitions
102 .into_iter()
103 .map(|e| Arc::new(RwLock::new(e)))
104 .collect::<Vec<_>>(),
105 constraints: Constraints::default(),
106 column_defaults: HashMap::new(),
107 sort_order: Arc::new(Mutex::new(vec![])),
108 })
109 }
110
111 pub fn with_constraints(mut self, constraints: Constraints) -> Self {
113 self.constraints = constraints;
114 self
115 }
116
117 pub fn with_column_defaults(
119 mut self,
120 column_defaults: HashMap<String, Expr>,
121 ) -> Self {
122 self.column_defaults = column_defaults;
123 self
124 }
125
126 pub fn with_sort_order(self, mut sort_order: Vec<Vec<SortExpr>>) -> Self {
137 std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
138 self
139 }
140
141 pub async fn load(
143 t: Arc<dyn TableProvider>,
144 output_partitions: Option<usize>,
145 state: &dyn Session,
146 ) -> Result<Self> {
147 let schema = t.schema();
148 let constraints = t.constraints();
149 let exec = t.scan(state, None, &[], None).await?;
150 let partition_count = exec.output_partitioning().partition_count();
151
152 let mut join_set = JoinSet::new();
153
154 for part_idx in 0..partition_count {
155 let task = state.task_ctx();
156 let exec = Arc::clone(&exec);
157 join_set.spawn(async move {
158 let stream = exec.execute(part_idx, task)?;
159 common::collect(stream).await
160 });
161 }
162
163 let mut data: Vec<Vec<RecordBatch>> =
164 Vec::with_capacity(exec.output_partitioning().partition_count());
165
166 while let Some(result) = join_set.join_next().await {
167 match result {
168 Ok(res) => data.push(res?),
169 Err(e) => {
170 if e.is_panic() {
171 std::panic::resume_unwind(e.into_panic());
172 } else {
173 unreachable!();
174 }
175 }
176 }
177 }
178
179 let mut exec = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
180 &data,
181 Arc::clone(&schema),
182 None,
183 )?));
184 if let Some(cons) = constraints {
185 exec = exec.with_constraints(cons.clone());
186 }
187
188 if let Some(num_partitions) = output_partitions {
189 let exec = RepartitionExec::try_new(
190 Arc::new(exec),
191 Partitioning::RoundRobinBatch(num_partitions),
192 )?;
193
194 let mut output_partitions = vec![];
196 for i in 0..exec.properties().output_partitioning().partition_count() {
197 let task_ctx = state.task_ctx();
199 let mut stream = exec.execute(i, task_ctx)?;
200 let mut batches = vec![];
201 while let Some(result) = stream.next().await {
202 batches.push(result?);
203 }
204 output_partitions.push(batches);
205 }
206
207 return MemTable::try_new(Arc::clone(&schema), output_partitions);
208 }
209 MemTable::try_new(Arc::clone(&schema), data)
210 }
211}
212
213#[async_trait]
214impl TableProvider for MemTable {
215 fn schema(&self) -> SchemaRef {
216 Arc::clone(&self.schema)
217 }
218
219 fn constraints(&self) -> Option<&Constraints> {
220 Some(&self.constraints)
221 }
222
223 fn table_type(&self) -> TableType {
224 TableType::Base
225 }
226
227 async fn scan(
228 &self,
229 state: &dyn Session,
230 projection: Option<&Vec<usize>>,
231 _filters: &[Expr],
232 _limit: Option<usize>,
233 ) -> Result<Arc<dyn ExecutionPlan>> {
234 let mut partitions = vec![];
235 for arc_inner_vec in self.batches.iter() {
236 let inner_vec = arc_inner_vec.read().await;
237 partitions.push(inner_vec.clone())
238 }
239
240 let mut source =
241 MemorySourceConfig::try_new(&partitions, self.schema(), projection.cloned())?;
242
243 let show_sizes = state.config_options().explain.show_sizes;
244 source = source.with_show_sizes(show_sizes);
245
246 let sort_order = self.sort_order.lock();
248 if !sort_order.is_empty() {
249 let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
250
251 let eqp = state.execution_props();
252 let mut file_sort_order = vec![];
253 for sort_exprs in sort_order.iter() {
254 let physical_exprs =
255 create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?;
256 file_sort_order.extend(LexOrdering::new(physical_exprs));
257 }
258 source = source.try_with_sort_information(file_sort_order)?;
259 }
260
261 Ok(DataSourceExec::from_data_source(source))
262 }
263
264 async fn insert_into(
279 &self,
280 _state: &dyn Session,
281 input: Arc<dyn ExecutionPlan>,
282 insert_op: InsertOp,
283 ) -> Result<Arc<dyn ExecutionPlan>> {
284 *self.sort_order.lock() = vec![];
286
287 self.schema()
290 .logically_equivalent_names_and_types(&input.schema())?;
291
292 if insert_op != InsertOp::Append {
293 return not_impl_err!("{insert_op} not implemented for MemoryTable yet");
294 }
295 let sink = MemSink::try_new(self.batches.clone(), Arc::clone(&self.schema))?;
296 Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None)))
297 }
298
299 fn get_column_default(&self, column: &str) -> Option<&Expr> {
300 self.column_defaults.get(column)
301 }
302
303 async fn delete_from(
304 &self,
305 state: &dyn Session,
306 filters: Vec<Expr>,
307 ) -> Result<Arc<dyn ExecutionPlan>> {
308 if self.batches.is_empty() {
310 return Ok(Arc::new(DmlResultExec::new(0)));
311 }
312
313 *self.sort_order.lock() = vec![];
314
315 let mut total_deleted: u64 = 0;
316 let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
317
318 for partition_data in &self.batches {
319 let mut partition = partition_data.write().await;
320 let mut new_batches = Vec::with_capacity(partition.len());
321
322 for batch in partition.iter() {
323 if batch.num_rows() == 0 {
324 continue;
325 }
326
327 let filter_mask = evaluate_filters_to_mask(
329 &filters,
330 batch,
331 &df_schema,
332 state.execution_props(),
333 )?;
334
335 let (delete_count, keep_mask) = match filter_mask {
336 Some(mask) => {
337 let count = mask.iter().filter(|v| v == &Some(true)).count();
339 let keep: BooleanArray =
341 mask.iter().map(|v| Some(v != Some(true))).collect();
342 (count, keep)
343 }
344 None => {
345 (
347 batch.num_rows(),
348 BooleanArray::from(vec![false; batch.num_rows()]),
349 )
350 }
351 };
352
353 total_deleted += delete_count as u64;
354
355 let filtered_batch = filter_record_batch(batch, &keep_mask)?;
356 if filtered_batch.num_rows() > 0 {
357 new_batches.push(filtered_batch);
358 }
359 }
360
361 *partition = new_batches;
362 }
363
364 Ok(Arc::new(DmlResultExec::new(total_deleted)))
365 }
366
367 async fn update(
368 &self,
369 state: &dyn Session,
370 assignments: Vec<(String, Expr)>,
371 filters: Vec<Expr>,
372 ) -> Result<Arc<dyn ExecutionPlan>> {
373 if self.batches.is_empty() {
375 return Ok(Arc::new(DmlResultExec::new(0)));
376 }
377
378 let available_columns: Vec<&str> = self
380 .schema
381 .fields()
382 .iter()
383 .map(|f| f.name().as_str())
384 .collect();
385 for (column_name, _) in &assignments {
386 if self.schema.field_with_name(column_name).is_err() {
387 return plan_err!(
388 "UPDATE failed: column '{}' does not exist. Available columns: {}",
389 column_name,
390 available_columns.join(", ")
391 );
392 }
393 }
394
395 let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
396
397 let physical_assignments: HashMap<String, Arc<dyn PhysicalExpr>> = assignments
399 .iter()
400 .map(|(name, expr)| {
401 let physical_expr =
402 create_physical_expr(expr, &df_schema, state.execution_props())?;
403 Ok((name.clone(), physical_expr))
404 })
405 .collect::<Result<_>>()?;
406
407 *self.sort_order.lock() = vec![];
408
409 let mut total_updated: u64 = 0;
410
411 for partition_data in &self.batches {
412 let mut partition = partition_data.write().await;
413 let mut new_batches = Vec::with_capacity(partition.len());
414
415 for batch in partition.iter() {
416 if batch.num_rows() == 0 {
417 continue;
418 }
419
420 let filter_mask = evaluate_filters_to_mask(
422 &filters,
423 batch,
424 &df_schema,
425 state.execution_props(),
426 )?;
427
428 let (update_count, update_mask) = match filter_mask {
429 Some(mask) => {
430 let count = mask.iter().filter(|v| v == &Some(true)).count();
432 let normalized: BooleanArray =
434 mask.iter().map(|v| Some(v == Some(true))).collect();
435 (count, normalized)
436 }
437 None => {
438 (
440 batch.num_rows(),
441 BooleanArray::from(vec![true; batch.num_rows()]),
442 )
443 }
444 };
445
446 total_updated += update_count as u64;
447
448 if update_count == 0 {
449 new_batches.push(batch.clone());
450 continue;
451 }
452
453 let mut new_columns: Vec<ArrayRef> =
454 Vec::with_capacity(batch.num_columns());
455
456 for field in self.schema.fields() {
457 let column_name = field.name();
458 let original_column =
459 batch.column_by_name(column_name).ok_or_else(|| {
460 datafusion_common::DataFusionError::Internal(format!(
461 "Column '{column_name}' not found in batch"
462 ))
463 })?;
464
465 let new_column = if let Some(physical_expr) =
466 physical_assignments.get(column_name.as_str())
467 {
468 let new_values =
473 physical_expr.evaluate_selection(batch, &update_mask)?;
474 let new_array = new_values.into_array(batch.num_rows())?;
475
476 let new_arr: &dyn Array = new_array.as_ref();
478 let orig_arr: &dyn Array = original_column.as_ref();
479 zip(&update_mask, &new_arr, &orig_arr)?
480 } else {
481 Arc::clone(original_column)
482 };
483
484 new_columns.push(new_column);
485 }
486
487 let updated_batch =
488 ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?;
489 new_batches.push(updated_batch);
490 }
491
492 *partition = new_batches;
493 }
494
495 Ok(Arc::new(DmlResultExec::new(total_updated)))
496 }
497}
498
499fn evaluate_filters_to_mask(
503 filters: &[Expr],
504 batch: &RecordBatch,
505 df_schema: &DFSchema,
506 execution_props: &datafusion_expr::execution_props::ExecutionProps,
507) -> Result<Option<BooleanArray>> {
508 if filters.is_empty() {
509 return Ok(None);
510 }
511
512 let mut combined_mask: Option<BooleanArray> = None;
513
514 for filter_expr in filters {
515 let physical_expr =
516 create_physical_expr(filter_expr, df_schema, execution_props)?;
517
518 let result = physical_expr.evaluate(batch)?;
519 let array = result.into_array(batch.num_rows())?;
520 let bool_array = array
521 .as_any()
522 .downcast_ref::<BooleanArray>()
523 .ok_or_else(|| {
524 datafusion_common::DataFusionError::Internal(
525 "Filter did not evaluate to boolean".to_string(),
526 )
527 })?
528 .clone();
529
530 combined_mask = Some(match combined_mask {
531 Some(existing) => and(&existing, &bool_array)?,
532 None => bool_array,
533 });
534 }
535
536 Ok(combined_mask)
537}
538
539#[derive(Debug)]
541struct DmlResultExec {
542 rows_affected: u64,
543 schema: SchemaRef,
544 properties: Arc<PlanProperties>,
545}
546
547impl DmlResultExec {
548 fn new(rows_affected: u64) -> Self {
549 let schema = Arc::new(Schema::new(vec![Field::new(
550 "count",
551 DataType::UInt64,
552 false,
553 )]));
554
555 let properties = PlanProperties::new(
556 datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)),
557 Partitioning::UnknownPartitioning(1),
558 datafusion_physical_plan::execution_plan::EmissionType::Final,
559 datafusion_physical_plan::execution_plan::Boundedness::Bounded,
560 );
561
562 Self {
563 rows_affected,
564 schema,
565 properties: Arc::new(properties),
566 }
567 }
568}
569
570impl DisplayAs for DmlResultExec {
571 fn fmt_as(
572 &self,
573 t: DisplayFormatType,
574 f: &mut std::fmt::Formatter,
575 ) -> std::fmt::Result {
576 match t {
577 DisplayFormatType::Default
578 | DisplayFormatType::Verbose
579 | DisplayFormatType::TreeRender => {
580 write!(f, "DmlResultExec: rows_affected={}", self.rows_affected)
581 }
582 }
583 }
584}
585
586impl ExecutionPlan for DmlResultExec {
587 fn name(&self) -> &str {
588 "DmlResultExec"
589 }
590
591 fn schema(&self) -> SchemaRef {
592 Arc::clone(&self.schema)
593 }
594
595 fn properties(&self) -> &Arc<PlanProperties> {
596 &self.properties
597 }
598
599 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
600 vec![]
601 }
602
603 fn with_new_children(
604 self: Arc<Self>,
605 _children: Vec<Arc<dyn ExecutionPlan>>,
606 ) -> Result<Arc<dyn ExecutionPlan>> {
607 Ok(self)
608 }
609
610 fn execute(
611 &self,
612 _partition: usize,
613 _context: Arc<datafusion_execution::TaskContext>,
614 ) -> Result<datafusion_execution::SendableRecordBatchStream> {
615 let count_array = UInt64Array::from(vec![self.rows_affected]);
617 let batch = ArrowRecordBatch::try_new(
618 Arc::clone(&self.schema),
619 vec![Arc::new(count_array) as ArrayRef],
620 )?;
621
622 let stream = futures::stream::iter(vec![Ok(batch)]);
624 Ok(Box::pin(RecordBatchStreamAdapter::new(
625 Arc::clone(&self.schema),
626 stream,
627 )))
628 }
629}