1use std::{any::Any, collections::HashMap, sync::Arc};
19
20use async_trait::async_trait;
21use datafusion::{
22 arrow::{
23 array::{
24 ArrayBuilder, ArrayRef, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
25 LargeStringBuilder, RecordBatch, StringBuilder, UInt16Builder, UInt32Builder,
26 UInt64Builder, UInt8Builder,
27 },
28 datatypes::{DataType, Schema, SchemaRef},
29 },
30 catalog::{Session, TableProvider},
31 common::{internal_err, project_schema, Constraints, DataFusionError, Result},
32 datasource::TableType,
33 execution::SendableRecordBatchStream,
34 physical_expr::{EquivalenceProperties, LexOrdering},
35 physical_plan::{
36 execution_plan::{Boundedness, EmissionType},
37 memory::MemoryStream,
38 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
39 },
40 prelude::Expr,
41 scalar::ScalarValue,
42};
43use indexmap::IndexMap;
44use parking_lot::RwLock;
45
46type ArrayBuilderRef = Box<dyn ArrayBuilder>;
47
48type MapData = Arc<RwLock<IndexMap<ScalarValue, HashMap<String, ScalarValue>>>>;
52
53#[derive(Debug)]
54pub struct MapTableConfig {
55 table_name: String,
56 _primary_key: String,
58}
59
60impl MapTableConfig {
61 pub fn new(table_name: String, primary_key: String) -> Self {
62 Self {
63 table_name,
64 _primary_key: primary_key,
65 }
66 }
67}
68
69#[derive(Debug)]
77pub struct MapTable {
78 schema: Arc<Schema>,
79 constraints: Option<Constraints>,
80 config: MapTableConfig,
81 inner: MapData,
83}
84
85impl MapTable {
86 pub fn try_new(
87 schema: Arc<Schema>,
88 constraints: Option<Constraints>,
89 config: MapTableConfig,
90 data: Option<MapData>,
91 ) -> Result<Self> {
92 let inner = data.unwrap_or(Arc::new(RwLock::new(IndexMap::new())));
93 Ok(Self {
94 schema,
95 constraints,
96 config,
97 inner,
98 })
99 }
100
101 fn try_create_partitions(&self) -> Result<Vec<Vec<RecordBatch>>> {
102 let guard = self.inner.read();
103 let values = guard.values();
104 let mut builders: IndexMap<String, (ArrayBuilderRef, DataType)> = IndexMap::new();
107 for f in &self.schema.fields {
108 let builder = datatype_to_array_builder(f.data_type())?;
109 builders.insert(f.name().clone(), (builder, f.data_type().clone()));
110 }
111
112 for value in values {
113 for (col, val) in value {
114 if self.schema.fields.find(col).is_some() {
116 if let Some((builder, builder_datatype)) = builders.get_mut(col) {
117 try_append_scalar_to_builder(builder, builder_datatype, val)?;
118 }
119 } else {
120 return Err(datafusion::error::DataFusionError::External(
121 format!(
122 "Column {} for table {} is not in the provided schema",
123 col, self.config.table_name
124 )
125 .into(),
126 ));
127 }
128 }
129 }
130
131 let arrays: Vec<ArrayRef> = builders.values_mut().map(|(b, _)| b.finish()).collect();
132
133 let batch = RecordBatch::try_new(Arc::clone(&self.schema), arrays)?;
134 Ok(vec![vec![batch]])
135 }
136}
137
138#[async_trait]
139impl TableProvider for MapTable {
140 fn as_any(&self) -> &dyn Any {
141 self
142 }
143
144 fn schema(&self) -> SchemaRef {
145 Arc::clone(&self.schema)
146 }
147
148 fn constraints(&self) -> Option<&Constraints> {
149 self.constraints.as_ref()
150 }
151
152 fn table_type(&self) -> TableType {
153 TableType::Base
154 }
155
156 async fn scan(
157 &self,
158 _state: &dyn Session,
159 projection: Option<&Vec<usize>>,
160 _filters: &[Expr],
161 _limit: Option<usize>,
162 ) -> Result<Arc<dyn ExecutionPlan>> {
163 let partitions = self.try_create_partitions()?;
164 let exec = MapExec::try_new(&partitions, Arc::clone(&self.schema), projection.cloned())?;
165 Ok(Arc::new(exec))
166 }
167}
168
169#[derive(Debug)]
172struct MapExec {
173 partitions: Vec<Vec<RecordBatch>>,
175 projection: Option<Vec<usize>>,
177 _schema: SchemaRef,
179 projected_schema: SchemaRef,
181 _sort_information: Vec<LexOrdering>,
183 cache: PlanProperties,
184}
185
186impl MapExec {
187 fn try_new(
188 partitions: &[Vec<RecordBatch>],
189 schema: SchemaRef,
190 projection: Option<Vec<usize>>,
191 ) -> Result<Self> {
192 let projected_schema = project_schema(&schema, projection.as_ref())?;
193 let constraints = Constraints::new_unverified(vec![]);
194 let cache =
195 Self::compute_properties(Arc::clone(&projected_schema), &[], constraints, partitions);
196
197 Ok(Self {
198 partitions: partitions.to_vec(),
199 _schema: schema,
200 projected_schema,
201 projection,
202 _sort_information: vec![],
203 cache,
204 })
205 }
206
207 fn compute_properties(
209 schema: SchemaRef,
210 orderings: &[LexOrdering],
211 constraints: Constraints,
212 partitions: &[Vec<RecordBatch>],
213 ) -> PlanProperties {
214 PlanProperties::new(
215 EquivalenceProperties::new_with_orderings(schema, orderings.iter().cloned())
216 .with_constraints(constraints),
217 Partitioning::UnknownPartitioning(partitions.len()),
218 EmissionType::Incremental,
219 Boundedness::Bounded,
220 )
221 }
222}
223
224impl DisplayAs for MapExec {
225 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
226 match t {
227 DisplayFormatType::Default | DisplayFormatType::Verbose => {
228 write!(
229 f,
230 "MapExec: partitions={}, projection={:?}",
231 self.partitions.len(),
232 self.projection
233 )
234 }
235 DisplayFormatType::TreeRender => {
236 write!(
237 f,
238 "MapExec: partitions={}, projection={:?}",
239 self.partitions.len(),
240 self.projection
241 )
242 }
243 }
244 }
245}
246
247impl ExecutionPlan for MapExec {
248 fn name(&self) -> &str {
249 "MapExec"
250 }
251
252 fn as_any(&self) -> &dyn Any {
253 self
254 }
255
256 fn properties(&self) -> &PlanProperties {
257 &self.cache
258 }
259
260 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
261 vec![]
263 }
264
265 fn with_new_children(
266 self: Arc<Self>,
267 children: Vec<Arc<dyn ExecutionPlan>>,
268 ) -> Result<Arc<dyn ExecutionPlan>> {
269 if children.is_empty() {
271 Ok(self)
272 } else {
273 internal_err!("Children cannot be replaced in {self:?}")
274 }
275 }
276
277 fn execute(
278 &self,
279 partition: usize,
280 _context: Arc<datafusion::execution::TaskContext>,
281 ) -> Result<SendableRecordBatchStream> {
282 Ok(Box::pin(MemoryStream::try_new(
283 self.partitions[partition].clone(),
284 Arc::clone(&self.projected_schema),
285 self.projection.clone(),
286 )?))
287 }
288}
289
290fn datatype_to_array_builder(datatype: &DataType) -> Result<Box<dyn ArrayBuilder>> {
291 match datatype {
292 DataType::Int8 => Ok(Box::new(Int8Builder::new())),
293 DataType::Int16 => Ok(Box::new(Int16Builder::new())),
294 DataType::Int32 => Ok(Box::new(Int32Builder::new())),
295 DataType::Int64 => Ok(Box::new(Int64Builder::new())),
296 DataType::UInt8 => Ok(Box::new(UInt8Builder::new())),
297 DataType::UInt16 => Ok(Box::new(UInt16Builder::new())),
298 DataType::UInt32 => Ok(Box::new(UInt32Builder::new())),
299 DataType::UInt64 => Ok(Box::new(UInt64Builder::new())),
300 DataType::Utf8 => Ok(Box::new(StringBuilder::new())),
301 DataType::LargeUtf8 => Ok(Box::new(LargeStringBuilder::new())),
302
303 _ => Err(DataFusionError::External(
304 "Unsupported column type when constructing batch from Map".into(),
305 )),
306 }
307}
308
309macro_rules! append_primitive_scalar {
310 ($scalar:expr, $builder:expr, $variant:ident, $builder_type:ty) => {{
311 if let ScalarValue::$variant(val) = $scalar {
312 if let Some(b) = $builder.as_any_mut().downcast_mut::<$builder_type>() {
313 if let Some(x) = val {
314 b.append_value(*x);
315 } else {
316 b.append_null();
317 }
318 Ok(())
319 } else {
320 Err(DataFusionError::External(
321 format!("Failed to downcast builder for {}", stringify!($variant)).into(),
322 ))
323 }
324 } else {
325 Ok(())
327 }
328 }};
329}
330
331fn try_append_scalar_to_builder(
332 builder: &mut Box<dyn ArrayBuilder>,
333 builder_datatype: &DataType,
334 scalar: &ScalarValue,
335) -> Result<()> {
336 if builder_datatype == &scalar.data_type() {
337 match scalar {
338 ScalarValue::Int8(_) => append_primitive_scalar!(scalar, builder, Int8, Int8Builder)?,
339 ScalarValue::Int16(_) => {
340 append_primitive_scalar!(scalar, builder, Int16, Int16Builder)?
341 }
342 ScalarValue::Int32(_) => {
343 append_primitive_scalar!(scalar, builder, Int32, Int32Builder)?
344 }
345 ScalarValue::Int64(_) => {
346 append_primitive_scalar!(scalar, builder, Int64, Int64Builder)?
347 }
348 ScalarValue::UInt8(_) => {
349 append_primitive_scalar!(scalar, builder, UInt8, UInt8Builder)?
350 }
351 ScalarValue::UInt16(_) => {
352 append_primitive_scalar!(scalar, builder, UInt16, UInt16Builder)?
353 }
354 ScalarValue::UInt32(_) => {
355 append_primitive_scalar!(scalar, builder, UInt32, UInt32Builder)?
356 }
357 ScalarValue::UInt64(_) => {
358 append_primitive_scalar!(scalar, builder, UInt64, UInt64Builder)?
359 }
360 ScalarValue::Utf8(s) => {
361 if let Some(builder) = builder.as_any_mut().downcast_mut::<StringBuilder>() {
362 if let Some(s) = s {
363 builder.append_value(s.clone())
364 } else {
365 builder.append_null()
366 }
367 }
368 }
369 ScalarValue::LargeUtf8(s) => {
370 if let Some(builder) = builder.as_any_mut().downcast_mut::<LargeStringBuilder>() {
371 if let Some(s) = s {
372 builder.append_value(s.clone())
373 } else {
374 builder.append_null()
375 }
376 }
377 }
378
379 _ => {
380 return Err(DataFusionError::External(
381 format!("Unsupported DataType ({}) for conversion", builder_datatype).into(),
382 ))
383 }
384 };
385 } else {
386 return Err(DataFusionError::External(
387 "Array builder and ScalarValue data types dont match".into(),
388 ));
389 };
390 Ok(())
391}
392
393#[cfg(test)]
394mod test {
395 use std::{collections::HashMap, sync::Arc};
396
397 use datafusion::{
398 arrow::datatypes::{DataType, Field, Schema},
399 assert_batches_eq,
400 prelude::{SessionConfig, SessionContext},
401 scalar::ScalarValue,
402 };
403 use indexmap::IndexMap;
404 use parking_lot::RwLock;
405
406 use crate::tables::map_table::{MapTable, MapTableConfig};
407
408 fn setup() -> SessionContext {
409 let mut data: IndexMap<ScalarValue, HashMap<String, ScalarValue>> = IndexMap::new();
410 let ids = vec![1, 2, 3, 4, 5];
411 let vals = vec!["val1", "val2", "val3", "val4", "val5"];
412 for (id, val) in ids.into_iter().zip(vals) {
413 let mut row: HashMap<String, ScalarValue> = HashMap::new();
414 row.insert("id".to_string(), ScalarValue::Int32(Some(id)));
415 row.insert("val".to_string(), ScalarValue::Utf8(Some(val.to_string())));
416 data.insert(ScalarValue::Int32(Some(id)), row);
417 }
418
419 let fields = vec![
420 Field::new("id", DataType::Int32, false),
421 Field::new("val", DataType::Utf8, false),
422 ];
423 let schema = Schema::new(fields);
424 let config = MapTableConfig::new("test".to_string(), "id".to_string());
425 let table = MapTable::try_new(
426 Arc::new(schema),
427 None,
428 config,
429 Some(Arc::new(RwLock::new(data))),
430 )
431 .unwrap();
432 let config = SessionConfig::new().with_target_partitions(4);
433 let ctx = SessionContext::new_with_config(config);
434 ctx.register_table("test", Arc::new(table)).unwrap();
435 ctx
436 }
437
438 #[tokio::test]
439 async fn test_map_table_plans_correctly() {
440 let ctx = setup();
442 let batches = ctx
443 .sql("EXPLAIN SELECT * FROM test")
444 .await
445 .unwrap()
446 .collect()
447 .await
448 .unwrap();
449
450 let expected = [
451 "+---------------+--------------------------------------------------+",
452 "| plan_type | plan |",
453 "+---------------+--------------------------------------------------+",
454 "| logical_plan | TableScan: test projection=[id, val] |",
455 "| physical_plan | CooperativeExec |",
456 "| | MapExec: partitions=1, projection=Some([0, 1]) |",
457 "| | |",
458 "+---------------+--------------------------------------------------+",
459 ];
460
461 assert_batches_eq!(expected, &batches);
462 }
463
464 #[tokio::test]
465 async fn test_select_star_from_map_table() {
466 let ctx = setup();
467 let batches = ctx
468 .sql("SELECT * FROM test")
469 .await
470 .unwrap()
471 .collect()
472 .await
473 .unwrap();
474
475 let expected = [
476 "+----+------+",
477 "| id | val |",
478 "+----+------+",
479 "| 1 | val1 |",
480 "| 2 | val2 |",
481 "| 3 | val3 |",
482 "| 4 | val4 |",
483 "| 5 | val5 |",
484 "+----+------+",
485 ];
486
487 assert_batches_eq!(expected, &batches);
488 }
489
490 #[tokio::test]
491 async fn test_select_star_with_filter_from_map_table() {
492 let ctx = setup();
493 let batches = ctx
495 .sql("SELECT * FROM test WHERE id = 1")
496 .await
497 .unwrap()
498 .collect()
499 .await
500 .unwrap();
501
502 let expected = [
503 "+----+------+",
504 "| id | val |",
505 "+----+------+",
506 "| 1 | val1 |",
507 "+----+------+",
508 ];
509
510 assert_batches_eq!(expected, &batches);
511
512 let batches = ctx
514 .sql("SELECT * FROM test WHERE id = 6")
515 .await
516 .unwrap()
517 .collect()
518 .await
519 .unwrap();
520
521 let expected = ["++", "++"];
522
523 assert_batches_eq!(expected, &batches);
524
525 let batches = ctx
526 .sql("EXPLAIN SELECT * FROM test WHERE id = 2")
527 .await
528 .unwrap()
529 .collect()
530 .await
531 .unwrap();
532
533 let expected = [
534 "+---------------+--------------------------------------------------------------------------+",
535 "| plan_type | plan |",
536 "+---------------+--------------------------------------------------------------------------+",
537 "| logical_plan | Filter: test.id = Int32(2) |",
538 "| | TableScan: test projection=[id, val] |",
539 "| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |",
540 "| | FilterExec: id@0 = 2 |",
541 "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |",
542 "| | CooperativeExec |",
543 "| | MapExec: partitions=1, projection=Some([0, 1]) |",
544 "| | |",
545 "+---------------+--------------------------------------------------------------------------+",
546 ];
547
548 assert_batches_eq!(expected, &batches);
549 }
550
551 #[tokio::test]
552 async fn test_select_star_with_projection_from_map_table() {
553 let ctx = setup();
554 let batches = ctx
556 .sql("SELECT val FROM test WHERE id = 1")
557 .await
558 .unwrap()
559 .collect()
560 .await
561 .unwrap();
562
563 let expected = ["+------+", "| val |", "+------+", "| val1 |", "+------+"];
564
565 assert_batches_eq!(expected, &batches);
566
567 let batches = ctx
568 .sql("SELECT id * 2 FROM test")
569 .await
570 .unwrap()
571 .collect()
572 .await
573 .unwrap();
574
575 let expected = [
576 "+--------------------+",
577 "| test.id * Int64(2) |",
578 "+--------------------+",
579 "| 2 |",
580 "| 4 |",
581 "| 6 |",
582 "| 8 |",
583 "| 10 |",
584 "+--------------------+",
585 ];
586
587 assert_batches_eq!(expected, &batches);
588 }
589
590 #[tokio::test]
591 async fn test_select_star_with_sort_from_map_table() {
592 let ctx = setup();
593 let batches = ctx
595 .sql("SELECT * FROM test ORDER BY id DESC")
596 .await
597 .unwrap()
598 .collect()
599 .await
600 .unwrap();
601
602 let expected = [
603 "+----+------+",
604 "| id | val |",
605 "+----+------+",
606 "| 5 | val5 |",
607 "| 4 | val4 |",
608 "| 3 | val3 |",
609 "| 2 | val2 |",
610 "| 1 | val1 |",
611 "+----+------+",
612 ];
613
614 assert_batches_eq!(expected, &batches);
615
616 let batches = ctx
617 .sql("EXPLAIN SELECT * FROM test ORDER BY id DESC")
618 .await
619 .unwrap()
620 .collect()
621 .await
622 .unwrap();
623
624 let expected = [
625 "+---------------+-----------------------------------------------------------+",
626 "| plan_type | plan |",
627 "+---------------+-----------------------------------------------------------+",
628 "| logical_plan | Sort: test.id DESC NULLS FIRST |",
629 "| | TableScan: test projection=[id, val] |",
630 "| physical_plan | SortExec: expr=[id@0 DESC], preserve_partitioning=[false] |",
631 "| | CooperativeExec |",
632 "| | MapExec: partitions=1, projection=Some([0, 1]) |",
633 "| | |",
634 "+---------------+-----------------------------------------------------------+",
635 ];
636
637 assert_batches_eq!(expected, &batches);
638 }
639
640 #[tokio::test]
641 async fn test_select_star_with_limit_from_map_table() {
642 let ctx = setup();
643 let batches = ctx
645 .sql("SELECT * FROM test LIMIT 2")
646 .await
647 .unwrap()
648 .collect()
649 .await
650 .unwrap();
651
652 let expected = [
653 "+----+------+",
654 "| id | val |",
655 "+----+------+",
656 "| 1 | val1 |",
657 "| 2 | val2 |",
658 "+----+------+",
659 ];
660
661 assert_batches_eq!(expected, &batches);
662
663 let batches = ctx
664 .sql("EXPLAIN SELECT * FROM test LIMIT 2")
665 .await
666 .unwrap()
667 .collect()
668 .await
669 .unwrap();
670
671 let expected = [
672 "+---------------+----------------------------------------------------+",
673 "| plan_type | plan |",
674 "+---------------+----------------------------------------------------+",
675 "| logical_plan | Limit: skip=0, fetch=2 |",
676 "| | TableScan: test projection=[id, val], fetch=2 |",
677 "| physical_plan | GlobalLimitExec: skip=0, fetch=2 |",
678 "| | CooperativeExec |",
679 "| | MapExec: partitions=1, projection=Some([0, 1]) |",
680 "| | |",
681 "+---------------+----------------------------------------------------+",
682 ];
683
684 assert_batches_eq!(expected, &batches);
685 }
686}