datafusion_catalog/memory/
table.rs1use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::TableProvider;
26
27use arrow::array::{
28 Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array,
29};
30use arrow::compute::kernels::zip::zip;
31use arrow::compute::{and, filter_record_batch};
32use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
33use arrow::record_batch::RecordBatch;
34use datafusion_common::error::Result;
35use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err};
36use datafusion_common_runtime::JoinSet;
37use datafusion_datasource::memory::{MemSink, MemorySourceConfig};
38use datafusion_datasource::sink::DataSinkExec;
39use datafusion_datasource::source::DataSourceExec;
40use datafusion_expr::dml::InsertOp;
41use datafusion_expr::{Expr, SortExpr, TableType};
42use datafusion_physical_expr::{
43 LexOrdering, create_physical_expr, create_physical_sort_exprs,
44};
45use datafusion_physical_plan::repartition::RepartitionExec;
46use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
47use datafusion_physical_plan::{
48 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
49 PlanProperties, common,
50};
51use datafusion_session::Session;
52
53use async_trait::async_trait;
54use futures::StreamExt;
55use log::debug;
56use parking_lot::Mutex;
57use tokio::sync::RwLock;
58
59pub use datafusion_datasource::memory::PartitionData;
61
62#[derive(Debug)]
67pub struct MemTable {
68 schema: SchemaRef,
69 pub batches: Vec<PartitionData>,
71 constraints: Constraints,
72 column_defaults: HashMap<String, Expr>,
73 pub sort_order: Arc<Mutex<Vec<Vec<SortExpr>>>>,
76}
77
78impl MemTable {
79 pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
85 if partitions.is_empty() {
86 return plan_err!("No partitions provided, expected at least one partition");
87 }
88
89 for batches in partitions.iter().flatten() {
90 let batches_schema = batches.schema();
91 if !schema.contains(&batches_schema) {
92 debug!(
93 "mem table schema does not contain batches schema. \
94 Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
95 );
96 return plan_err!("Mismatch between schema and batches");
97 }
98 }
99
100 Ok(Self {
101 schema,
102 batches: partitions
103 .into_iter()
104 .map(|e| Arc::new(RwLock::new(e)))
105 .collect::<Vec<_>>(),
106 constraints: Constraints::default(),
107 column_defaults: HashMap::new(),
108 sort_order: Arc::new(Mutex::new(vec![])),
109 })
110 }
111
112 pub fn with_constraints(mut self, constraints: Constraints) -> Self {
114 self.constraints = constraints;
115 self
116 }
117
118 pub fn with_column_defaults(
120 mut self,
121 column_defaults: HashMap<String, Expr>,
122 ) -> Self {
123 self.column_defaults = column_defaults;
124 self
125 }
126
127 pub fn with_sort_order(self, mut sort_order: Vec<Vec<SortExpr>>) -> Self {
138 std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
139 self
140 }
141
142 pub async fn load(
144 t: Arc<dyn TableProvider>,
145 output_partitions: Option<usize>,
146 state: &dyn Session,
147 ) -> Result<Self> {
148 let schema = t.schema();
149 let constraints = t.constraints();
150 let exec = t.scan(state, None, &[], None).await?;
151 let partition_count = exec.output_partitioning().partition_count();
152
153 let mut join_set = JoinSet::new();
154
155 for part_idx in 0..partition_count {
156 let task = state.task_ctx();
157 let exec = Arc::clone(&exec);
158 join_set.spawn(async move {
159 let stream = exec.execute(part_idx, task)?;
160 common::collect(stream).await
161 });
162 }
163
164 let mut data: Vec<Vec<RecordBatch>> =
165 Vec::with_capacity(exec.output_partitioning().partition_count());
166
167 while let Some(result) = join_set.join_next().await {
168 match result {
169 Ok(res) => data.push(res?),
170 Err(e) => {
171 if e.is_panic() {
172 std::panic::resume_unwind(e.into_panic());
173 } else {
174 unreachable!();
175 }
176 }
177 }
178 }
179
180 let mut exec = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
181 &data,
182 Arc::clone(&schema),
183 None,
184 )?));
185 if let Some(cons) = constraints {
186 exec = exec.with_constraints(cons.clone());
187 }
188
189 if let Some(num_partitions) = output_partitions {
190 let exec = RepartitionExec::try_new(
191 Arc::new(exec),
192 Partitioning::RoundRobinBatch(num_partitions),
193 )?;
194
195 let mut output_partitions = vec![];
197 for i in 0..exec.properties().output_partitioning().partition_count() {
198 let task_ctx = state.task_ctx();
200 let mut stream = exec.execute(i, task_ctx)?;
201 let mut batches = vec![];
202 while let Some(result) = stream.next().await {
203 batches.push(result?);
204 }
205 output_partitions.push(batches);
206 }
207
208 return MemTable::try_new(Arc::clone(&schema), output_partitions);
209 }
210 MemTable::try_new(Arc::clone(&schema), data)
211 }
212}
213
214#[async_trait]
215impl TableProvider for MemTable {
216 fn as_any(&self) -> &dyn Any {
217 self
218 }
219
220 fn schema(&self) -> SchemaRef {
221 Arc::clone(&self.schema)
222 }
223
224 fn constraints(&self) -> Option<&Constraints> {
225 Some(&self.constraints)
226 }
227
228 fn table_type(&self) -> TableType {
229 TableType::Base
230 }
231
232 async fn scan(
233 &self,
234 state: &dyn Session,
235 projection: Option<&Vec<usize>>,
236 _filters: &[Expr],
237 _limit: Option<usize>,
238 ) -> Result<Arc<dyn ExecutionPlan>> {
239 let mut partitions = vec![];
240 for arc_inner_vec in self.batches.iter() {
241 let inner_vec = arc_inner_vec.read().await;
242 partitions.push(inner_vec.clone())
243 }
244
245 let mut source =
246 MemorySourceConfig::try_new(&partitions, self.schema(), projection.cloned())?;
247
248 let show_sizes = state.config_options().explain.show_sizes;
249 source = source.with_show_sizes(show_sizes);
250
251 let sort_order = self.sort_order.lock();
253 if !sort_order.is_empty() {
254 let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
255
256 let eqp = state.execution_props();
257 let mut file_sort_order = vec![];
258 for sort_exprs in sort_order.iter() {
259 let physical_exprs =
260 create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?;
261 file_sort_order.extend(LexOrdering::new(physical_exprs));
262 }
263 source = source.try_with_sort_information(file_sort_order)?;
264 }
265
266 Ok(DataSourceExec::from_data_source(source))
267 }
268
269 async fn insert_into(
284 &self,
285 _state: &dyn Session,
286 input: Arc<dyn ExecutionPlan>,
287 insert_op: InsertOp,
288 ) -> Result<Arc<dyn ExecutionPlan>> {
289 *self.sort_order.lock() = vec![];
291
292 self.schema()
295 .logically_equivalent_names_and_types(&input.schema())?;
296
297 if insert_op != InsertOp::Append {
298 return not_impl_err!("{insert_op} not implemented for MemoryTable yet");
299 }
300 let sink = MemSink::try_new(self.batches.clone(), Arc::clone(&self.schema))?;
301 Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None)))
302 }
303
304 fn get_column_default(&self, column: &str) -> Option<&Expr> {
305 self.column_defaults.get(column)
306 }
307
308 async fn delete_from(
309 &self,
310 state: &dyn Session,
311 filters: Vec<Expr>,
312 ) -> Result<Arc<dyn ExecutionPlan>> {
313 if self.batches.is_empty() {
315 return Ok(Arc::new(DmlResultExec::new(0)));
316 }
317
318 *self.sort_order.lock() = vec![];
319
320 let mut total_deleted: u64 = 0;
321 let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
322
323 for partition_data in &self.batches {
324 let mut partition = partition_data.write().await;
325 let mut new_batches = Vec::with_capacity(partition.len());
326
327 for batch in partition.iter() {
328 if batch.num_rows() == 0 {
329 continue;
330 }
331
332 let filter_mask = evaluate_filters_to_mask(
334 &filters,
335 batch,
336 &df_schema,
337 state.execution_props(),
338 )?;
339
340 let (delete_count, keep_mask) = match filter_mask {
341 Some(mask) => {
342 let count = mask.iter().filter(|v| v == &Some(true)).count();
344 let keep: BooleanArray =
346 mask.iter().map(|v| Some(v != Some(true))).collect();
347 (count, keep)
348 }
349 None => {
350 (
352 batch.num_rows(),
353 BooleanArray::from(vec![false; batch.num_rows()]),
354 )
355 }
356 };
357
358 total_deleted += delete_count as u64;
359
360 let filtered_batch = filter_record_batch(batch, &keep_mask)?;
361 if filtered_batch.num_rows() > 0 {
362 new_batches.push(filtered_batch);
363 }
364 }
365
366 *partition = new_batches;
367 }
368
369 Ok(Arc::new(DmlResultExec::new(total_deleted)))
370 }
371
372 async fn update(
373 &self,
374 state: &dyn Session,
375 assignments: Vec<(String, Expr)>,
376 filters: Vec<Expr>,
377 ) -> Result<Arc<dyn ExecutionPlan>> {
378 if self.batches.is_empty() {
380 return Ok(Arc::new(DmlResultExec::new(0)));
381 }
382
383 let available_columns: Vec<&str> = self
385 .schema
386 .fields()
387 .iter()
388 .map(|f| f.name().as_str())
389 .collect();
390 for (column_name, _) in &assignments {
391 if self.schema.field_with_name(column_name).is_err() {
392 return plan_err!(
393 "UPDATE failed: column '{}' does not exist. Available columns: {}",
394 column_name,
395 available_columns.join(", ")
396 );
397 }
398 }
399
400 let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
401
402 let physical_assignments: HashMap<
404 String,
405 Arc<dyn datafusion_physical_plan::PhysicalExpr>,
406 > = assignments
407 .iter()
408 .map(|(name, expr)| {
409 let physical_expr =
410 create_physical_expr(expr, &df_schema, state.execution_props())?;
411 Ok((name.clone(), physical_expr))
412 })
413 .collect::<Result<_>>()?;
414
415 *self.sort_order.lock() = vec![];
416
417 let mut total_updated: u64 = 0;
418
419 for partition_data in &self.batches {
420 let mut partition = partition_data.write().await;
421 let mut new_batches = Vec::with_capacity(partition.len());
422
423 for batch in partition.iter() {
424 if batch.num_rows() == 0 {
425 continue;
426 }
427
428 let filter_mask = evaluate_filters_to_mask(
430 &filters,
431 batch,
432 &df_schema,
433 state.execution_props(),
434 )?;
435
436 let (update_count, update_mask) = match filter_mask {
437 Some(mask) => {
438 let count = mask.iter().filter(|v| v == &Some(true)).count();
440 let normalized: BooleanArray =
442 mask.iter().map(|v| Some(v == Some(true))).collect();
443 (count, normalized)
444 }
445 None => {
446 (
448 batch.num_rows(),
449 BooleanArray::from(vec![true; batch.num_rows()]),
450 )
451 }
452 };
453
454 total_updated += update_count as u64;
455
456 if update_count == 0 {
457 new_batches.push(batch.clone());
458 continue;
459 }
460
461 let mut new_columns: Vec<ArrayRef> =
462 Vec::with_capacity(batch.num_columns());
463
464 for field in self.schema.fields() {
465 let column_name = field.name();
466 let original_column =
467 batch.column_by_name(column_name).ok_or_else(|| {
468 datafusion_common::DataFusionError::Internal(format!(
469 "Column '{column_name}' not found in batch"
470 ))
471 })?;
472
473 let new_column = if let Some(physical_expr) =
474 physical_assignments.get(column_name.as_str())
475 {
476 let new_values =
481 physical_expr.evaluate_selection(batch, &update_mask)?;
482 let new_array = new_values.into_array(batch.num_rows())?;
483
484 let new_arr: &dyn Array = new_array.as_ref();
486 let orig_arr: &dyn Array = original_column.as_ref();
487 zip(&update_mask, &new_arr, &orig_arr)?
488 } else {
489 Arc::clone(original_column)
490 };
491
492 new_columns.push(new_column);
493 }
494
495 let updated_batch =
496 ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?;
497 new_batches.push(updated_batch);
498 }
499
500 *partition = new_batches;
501 }
502
503 Ok(Arc::new(DmlResultExec::new(total_updated)))
504 }
505}
506
507fn evaluate_filters_to_mask(
511 filters: &[Expr],
512 batch: &RecordBatch,
513 df_schema: &DFSchema,
514 execution_props: &datafusion_expr::execution_props::ExecutionProps,
515) -> Result<Option<BooleanArray>> {
516 if filters.is_empty() {
517 return Ok(None);
518 }
519
520 let mut combined_mask: Option<BooleanArray> = None;
521
522 for filter_expr in filters {
523 let physical_expr =
524 create_physical_expr(filter_expr, df_schema, execution_props)?;
525
526 let result = physical_expr.evaluate(batch)?;
527 let array = result.into_array(batch.num_rows())?;
528 let bool_array = array
529 .as_any()
530 .downcast_ref::<BooleanArray>()
531 .ok_or_else(|| {
532 datafusion_common::DataFusionError::Internal(
533 "Filter did not evaluate to boolean".to_string(),
534 )
535 })?
536 .clone();
537
538 combined_mask = Some(match combined_mask {
539 Some(existing) => and(&existing, &bool_array)?,
540 None => bool_array,
541 });
542 }
543
544 Ok(combined_mask)
545}
546
547#[derive(Debug)]
549struct DmlResultExec {
550 rows_affected: u64,
551 schema: SchemaRef,
552 properties: PlanProperties,
553}
554
555impl DmlResultExec {
556 fn new(rows_affected: u64) -> Self {
557 let schema = Arc::new(Schema::new(vec![Field::new(
558 "count",
559 DataType::UInt64,
560 false,
561 )]));
562
563 let properties = PlanProperties::new(
564 datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)),
565 Partitioning::UnknownPartitioning(1),
566 datafusion_physical_plan::execution_plan::EmissionType::Final,
567 datafusion_physical_plan::execution_plan::Boundedness::Bounded,
568 );
569
570 Self {
571 rows_affected,
572 schema,
573 properties,
574 }
575 }
576}
577
578impl DisplayAs for DmlResultExec {
579 fn fmt_as(
580 &self,
581 t: DisplayFormatType,
582 f: &mut std::fmt::Formatter,
583 ) -> std::fmt::Result {
584 match t {
585 DisplayFormatType::Default
586 | DisplayFormatType::Verbose
587 | DisplayFormatType::TreeRender => {
588 write!(f, "DmlResultExec: rows_affected={}", self.rows_affected)
589 }
590 }
591 }
592}
593
594impl ExecutionPlan for DmlResultExec {
595 fn name(&self) -> &str {
596 "DmlResultExec"
597 }
598
599 fn as_any(&self) -> &dyn Any {
600 self
601 }
602
603 fn schema(&self) -> SchemaRef {
604 Arc::clone(&self.schema)
605 }
606
607 fn properties(&self) -> &PlanProperties {
608 &self.properties
609 }
610
611 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
612 vec![]
613 }
614
615 fn with_new_children(
616 self: Arc<Self>,
617 _children: Vec<Arc<dyn ExecutionPlan>>,
618 ) -> Result<Arc<dyn ExecutionPlan>> {
619 Ok(self)
620 }
621
622 fn execute(
623 &self,
624 _partition: usize,
625 _context: Arc<datafusion_execution::TaskContext>,
626 ) -> Result<datafusion_execution::SendableRecordBatchStream> {
627 let count_array = UInt64Array::from(vec![self.rows_affected]);
629 let batch = ArrowRecordBatch::try_new(
630 Arc::clone(&self.schema),
631 vec![Arc::new(count_array) as ArrayRef],
632 )?;
633
634 let stream = futures::stream::iter(vec![Ok(batch)]);
636 Ok(Box::pin(RecordBatchStreamAdapter::new(
637 Arc::clone(&self.schema),
638 stream,
639 )))
640 }
641}