Skip to main content

datafusion_app/tables/
map_table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{any::Any, collections::HashMap, sync::Arc};
19
20use async_trait::async_trait;
21use datafusion::{
22    arrow::{
23        array::{
24            ArrayBuilder, ArrayRef, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
25            LargeStringBuilder, RecordBatch, StringBuilder, UInt16Builder, UInt32Builder,
26            UInt64Builder, UInt8Builder,
27        },
28        datatypes::{DataType, Schema, SchemaRef},
29    },
30    catalog::{Session, TableProvider},
31    common::{internal_err, project_schema, Constraints, DataFusionError, Result},
32    datasource::TableType,
33    execution::SendableRecordBatchStream,
34    physical_expr::{EquivalenceProperties, LexOrdering},
35    physical_plan::{
36        execution_plan::{Boundedness, EmissionType},
37        memory::MemoryStream,
38        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
39    },
40    prelude::Expr,
41    scalar::ScalarValue,
42};
43use indexmap::IndexMap;
44use parking_lot::RwLock;
45
46type ArrayBuilderRef = Box<dyn ArrayBuilder>;
47
48// The first String key is meant to hold primary key and provide O(1) lookup.  The inner HashMap is
49// for holding arbitrary column and value pairs - the key is the column name and we use DataFusions
50// scalar value to provide dynamic typing for the column values.
51type MapData = Arc<RwLock<IndexMap<ScalarValue, HashMap<String, ScalarValue>>>>;
52
53#[derive(Debug)]
54pub struct MapTableConfig {
55    table_name: String,
56    /// Column name of the primary key
57    _primary_key: String,
58}
59
60impl MapTableConfig {
61    pub fn new(table_name: String, primary_key: String) -> Self {
62        Self {
63            table_name,
64            _primary_key: primary_key,
65        }
66    }
67}
68
69/// Table for tracking observability information. Data is held in a IndexMap, which maintains
70/// insertion order, while the app is running and is serialized on app shutdown.
71///
72/// TODO: Add filter pushdown on the primary key and use `get` on that for O(1)
73/// TODO: Add filter pushdown on non primary key and use `binary_search_by` / `range` (whatever
74/// TODO: Add projection pushdown to only read keys from HashMap that are projected
75/// method the underlying map provides) to search values
76#[derive(Debug)]
77pub struct MapTable {
78    schema: Arc<Schema>,
79    constraints: Option<Constraints>,
80    config: MapTableConfig,
81    // TODO: This will be based on a Trait so you can use IndexMap, DashMap, BTreeMap, etc...
82    inner: MapData,
83}
84
85impl MapTable {
86    pub fn try_new(
87        schema: Arc<Schema>,
88        constraints: Option<Constraints>,
89        config: MapTableConfig,
90        data: Option<MapData>,
91    ) -> Result<Self> {
92        let inner = data.unwrap_or(Arc::new(RwLock::new(IndexMap::new())));
93        Ok(Self {
94            schema,
95            constraints,
96            config,
97            inner,
98        })
99    }
100
101    fn try_create_partitions(&self) -> Result<Vec<Vec<RecordBatch>>> {
102        let guard = self.inner.read();
103        let values = guard.values();
104        // We use IndexMap, which has order defined on insertion order to have our builders align
105        // with the order of the fields in the Schema.
106        let mut builders: IndexMap<String, (ArrayBuilderRef, DataType)> = IndexMap::new();
107        for f in &self.schema.fields {
108            let builder = datatype_to_array_builder(f.data_type())?;
109            builders.insert(f.name().clone(), (builder, f.data_type().clone()));
110        }
111
112        for value in values {
113            for (col, val) in value {
114                // Check that the column is in the tables schema
115                if self.schema.fields.find(col).is_some() {
116                    if let Some((builder, builder_datatype)) = builders.get_mut(col) {
117                        try_append_scalar_to_builder(builder, builder_datatype, val)?;
118                    }
119                } else {
120                    return Err(datafusion::error::DataFusionError::External(
121                        format!(
122                            "Column {} for table {} is not in the provided schema",
123                            col, self.config.table_name
124                        )
125                        .into(),
126                    ));
127                }
128            }
129        }
130
131        let arrays: Vec<ArrayRef> = builders.values_mut().map(|(b, _)| b.finish()).collect();
132
133        let batch = RecordBatch::try_new(Arc::clone(&self.schema), arrays)?;
134        Ok(vec![vec![batch]])
135    }
136}
137
138#[async_trait]
139impl TableProvider for MapTable {
140    fn as_any(&self) -> &dyn Any {
141        self
142    }
143
144    fn schema(&self) -> SchemaRef {
145        Arc::clone(&self.schema)
146    }
147
148    fn constraints(&self) -> Option<&Constraints> {
149        self.constraints.as_ref()
150    }
151
152    fn table_type(&self) -> TableType {
153        TableType::Base
154    }
155
156    async fn scan(
157        &self,
158        _state: &dyn Session,
159        projection: Option<&Vec<usize>>,
160        _filters: &[Expr],
161        _limit: Option<usize>,
162    ) -> Result<Arc<dyn ExecutionPlan>> {
163        let partitions = self.try_create_partitions()?;
164        let exec = MapExec::try_new(&partitions, Arc::clone(&self.schema), projection.cloned())?;
165        Ok(Arc::new(exec))
166    }
167}
168
169/// Execution plan for converting Map data into in-memory record batches and then reading from
170/// them
171#[derive(Debug)]
172struct MapExec {
173    /// The partitions to query
174    partitions: Vec<Vec<RecordBatch>>,
175    /// Optional projection
176    projection: Option<Vec<usize>>,
177    /// Schema representing the data before projection
178    _schema: SchemaRef,
179    /// Schema representing the data after the optional projection is applied
180    projected_schema: SchemaRef,
181    // Sort information: one or more equivalent orderings
182    _sort_information: Vec<LexOrdering>,
183    cache: PlanProperties,
184}
185
186impl MapExec {
187    fn try_new(
188        partitions: &[Vec<RecordBatch>],
189        schema: SchemaRef,
190        projection: Option<Vec<usize>>,
191    ) -> Result<Self> {
192        let projected_schema = project_schema(&schema, projection.as_ref())?;
193        let constraints = Constraints::new_unverified(vec![]);
194        let cache =
195            Self::compute_properties(Arc::clone(&projected_schema), &[], constraints, partitions);
196
197        Ok(Self {
198            partitions: partitions.to_vec(),
199            _schema: schema,
200            projected_schema,
201            projection,
202            _sort_information: vec![],
203            cache,
204        })
205    }
206
207    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
208    fn compute_properties(
209        schema: SchemaRef,
210        orderings: &[LexOrdering],
211        constraints: Constraints,
212        partitions: &[Vec<RecordBatch>],
213    ) -> PlanProperties {
214        PlanProperties::new(
215            EquivalenceProperties::new_with_orderings(schema, orderings.iter().cloned())
216                .with_constraints(constraints),
217            Partitioning::UnknownPartitioning(partitions.len()),
218            EmissionType::Incremental,
219            Boundedness::Bounded,
220        )
221    }
222}
223
224impl DisplayAs for MapExec {
225    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
226        match t {
227            DisplayFormatType::Default | DisplayFormatType::Verbose => {
228                write!(
229                    f,
230                    "MapExec: partitions={}, projection={:?}",
231                    self.partitions.len(),
232                    self.projection
233                )
234            }
235            DisplayFormatType::TreeRender => {
236                write!(
237                    f,
238                    "MapExec: partitions={}, projection={:?}",
239                    self.partitions.len(),
240                    self.projection
241                )
242            }
243        }
244    }
245}
246
247impl ExecutionPlan for MapExec {
248    fn name(&self) -> &str {
249        "MapExec"
250    }
251
252    fn as_any(&self) -> &dyn Any {
253        self
254    }
255
256    fn properties(&self) -> &PlanProperties {
257        &self.cache
258    }
259
260    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
261        // This is a leaf node and has no children
262        vec![]
263    }
264
265    fn with_new_children(
266        self: Arc<Self>,
267        children: Vec<Arc<dyn ExecutionPlan>>,
268    ) -> Result<Arc<dyn ExecutionPlan>> {
269        // MapExec has no children
270        if children.is_empty() {
271            Ok(self)
272        } else {
273            internal_err!("Children cannot be replaced in {self:?}")
274        }
275    }
276
277    fn execute(
278        &self,
279        partition: usize,
280        _context: Arc<datafusion::execution::TaskContext>,
281    ) -> Result<SendableRecordBatchStream> {
282        Ok(Box::pin(MemoryStream::try_new(
283            self.partitions[partition].clone(),
284            Arc::clone(&self.projected_schema),
285            self.projection.clone(),
286        )?))
287    }
288}
289
290fn datatype_to_array_builder(datatype: &DataType) -> Result<Box<dyn ArrayBuilder>> {
291    match datatype {
292        DataType::Int8 => Ok(Box::new(Int8Builder::new())),
293        DataType::Int16 => Ok(Box::new(Int16Builder::new())),
294        DataType::Int32 => Ok(Box::new(Int32Builder::new())),
295        DataType::Int64 => Ok(Box::new(Int64Builder::new())),
296        DataType::UInt8 => Ok(Box::new(UInt8Builder::new())),
297        DataType::UInt16 => Ok(Box::new(UInt16Builder::new())),
298        DataType::UInt32 => Ok(Box::new(UInt32Builder::new())),
299        DataType::UInt64 => Ok(Box::new(UInt64Builder::new())),
300        DataType::Utf8 => Ok(Box::new(StringBuilder::new())),
301        DataType::LargeUtf8 => Ok(Box::new(LargeStringBuilder::new())),
302
303        _ => Err(DataFusionError::External(
304            "Unsupported column type when constructing batch from Map".into(),
305        )),
306    }
307}
308
309macro_rules! append_primitive_scalar {
310    ($scalar:expr, $builder:expr, $variant:ident, $builder_type:ty) => {{
311        if let ScalarValue::$variant(val) = $scalar {
312            if let Some(b) = $builder.as_any_mut().downcast_mut::<$builder_type>() {
313                if let Some(x) = val {
314                    b.append_value(*x);
315                } else {
316                    b.append_null();
317                }
318                Ok(())
319            } else {
320                Err(DataFusionError::External(
321                    format!("Failed to downcast builder for {}", stringify!($variant)).into(),
322                ))
323            }
324        } else {
325            // If the scalar is not of the expected variant, do nothing.
326            Ok(())
327        }
328    }};
329}
330
331fn try_append_scalar_to_builder(
332    builder: &mut Box<dyn ArrayBuilder>,
333    builder_datatype: &DataType,
334    scalar: &ScalarValue,
335) -> Result<()> {
336    if builder_datatype == &scalar.data_type() {
337        match scalar {
338            ScalarValue::Int8(_) => append_primitive_scalar!(scalar, builder, Int8, Int8Builder)?,
339            ScalarValue::Int16(_) => {
340                append_primitive_scalar!(scalar, builder, Int16, Int16Builder)?
341            }
342            ScalarValue::Int32(_) => {
343                append_primitive_scalar!(scalar, builder, Int32, Int32Builder)?
344            }
345            ScalarValue::Int64(_) => {
346                append_primitive_scalar!(scalar, builder, Int64, Int64Builder)?
347            }
348            ScalarValue::UInt8(_) => {
349                append_primitive_scalar!(scalar, builder, UInt8, UInt8Builder)?
350            }
351            ScalarValue::UInt16(_) => {
352                append_primitive_scalar!(scalar, builder, UInt16, UInt16Builder)?
353            }
354            ScalarValue::UInt32(_) => {
355                append_primitive_scalar!(scalar, builder, UInt32, UInt32Builder)?
356            }
357            ScalarValue::UInt64(_) => {
358                append_primitive_scalar!(scalar, builder, UInt64, UInt64Builder)?
359            }
360            ScalarValue::Utf8(s) => {
361                if let Some(builder) = builder.as_any_mut().downcast_mut::<StringBuilder>() {
362                    if let Some(s) = s {
363                        builder.append_value(s.clone())
364                    } else {
365                        builder.append_null()
366                    }
367                }
368            }
369            ScalarValue::LargeUtf8(s) => {
370                if let Some(builder) = builder.as_any_mut().downcast_mut::<LargeStringBuilder>() {
371                    if let Some(s) = s {
372                        builder.append_value(s.clone())
373                    } else {
374                        builder.append_null()
375                    }
376                }
377            }
378
379            _ => {
380                return Err(DataFusionError::External(
381                    format!("Unsupported DataType ({}) for conversion", builder_datatype).into(),
382                ))
383            }
384        };
385    } else {
386        return Err(DataFusionError::External(
387            "Array builder and ScalarValue data types dont match".into(),
388        ));
389    };
390    Ok(())
391}
392
393#[cfg(test)]
394mod test {
395    use std::{collections::HashMap, sync::Arc};
396
397    use datafusion::{
398        arrow::datatypes::{DataType, Field, Schema},
399        assert_batches_eq,
400        prelude::{SessionConfig, SessionContext},
401        scalar::ScalarValue,
402    };
403    use indexmap::IndexMap;
404    use parking_lot::RwLock;
405
406    use crate::tables::map_table::{MapTable, MapTableConfig};
407
408    fn setup() -> SessionContext {
409        let mut data: IndexMap<ScalarValue, HashMap<String, ScalarValue>> = IndexMap::new();
410        let ids = vec![1, 2, 3, 4, 5];
411        let vals = vec!["val1", "val2", "val3", "val4", "val5"];
412        for (id, val) in ids.into_iter().zip(vals) {
413            let mut row: HashMap<String, ScalarValue> = HashMap::new();
414            row.insert("id".to_string(), ScalarValue::Int32(Some(id)));
415            row.insert("val".to_string(), ScalarValue::Utf8(Some(val.to_string())));
416            data.insert(ScalarValue::Int32(Some(id)), row);
417        }
418
419        let fields = vec![
420            Field::new("id", DataType::Int32, false),
421            Field::new("val", DataType::Utf8, false),
422        ];
423        let schema = Schema::new(fields);
424        let config = MapTableConfig::new("test".to_string(), "id".to_string());
425        let table = MapTable::try_new(
426            Arc::new(schema),
427            None,
428            config,
429            Some(Arc::new(RwLock::new(data))),
430        )
431        .unwrap();
432        let config = SessionConfig::new().with_target_partitions(4);
433        let ctx = SessionContext::new_with_config(config);
434        ctx.register_table("test", Arc::new(table)).unwrap();
435        ctx
436    }
437
438    #[tokio::test]
439    async fn test_map_table_plans_correctly() {
440        // TODO UPDATE ROOT KEY, WHICH IS THE PRIMARY KEY, TO BE OF TYPE SCALARVALUE
441        let ctx = setup();
442        let batches = ctx
443            .sql("EXPLAIN SELECT * FROM test")
444            .await
445            .unwrap()
446            .collect()
447            .await
448            .unwrap();
449
450        let expected = [
451            "+---------------+--------------------------------------------------+",
452            "| plan_type     | plan                                             |",
453            "+---------------+--------------------------------------------------+",
454            "| logical_plan  | TableScan: test projection=[id, val]             |",
455            "| physical_plan | CooperativeExec                                  |",
456            "|               |   MapExec: partitions=1, projection=Some([0, 1]) |",
457            "|               |                                                  |",
458            "+---------------+--------------------------------------------------+",
459        ];
460
461        assert_batches_eq!(expected, &batches);
462    }
463
464    #[tokio::test]
465    async fn test_select_star_from_map_table() {
466        let ctx = setup();
467        let batches = ctx
468            .sql("SELECT * FROM test")
469            .await
470            .unwrap()
471            .collect()
472            .await
473            .unwrap();
474
475        let expected = [
476            "+----+------+",
477            "| id | val  |",
478            "+----+------+",
479            "| 1  | val1 |",
480            "| 2  | val2 |",
481            "| 3  | val3 |",
482            "| 4  | val4 |",
483            "| 5  | val5 |",
484            "+----+------+",
485        ];
486
487        assert_batches_eq!(expected, &batches);
488    }
489
490    #[tokio::test]
491    async fn test_select_star_with_filter_from_map_table() {
492        let ctx = setup();
493        // Check it includes the expected result
494        let batches = ctx
495            .sql("SELECT * FROM test WHERE id = 1")
496            .await
497            .unwrap()
498            .collect()
499            .await
500            .unwrap();
501
502        let expected = [
503            "+----+------+",
504            "| id | val  |",
505            "+----+------+",
506            "| 1  | val1 |",
507            "+----+------+",
508        ];
509
510        assert_batches_eq!(expected, &batches);
511
512        // Check it excludes the expected result
513        let batches = ctx
514            .sql("SELECT * FROM test WHERE id = 6")
515            .await
516            .unwrap()
517            .collect()
518            .await
519            .unwrap();
520
521        let expected = ["++", "++"];
522
523        assert_batches_eq!(expected, &batches);
524
525        let batches = ctx
526            .sql("EXPLAIN SELECT * FROM test WHERE id = 2")
527            .await
528            .unwrap()
529            .collect()
530            .await
531            .unwrap();
532
533        let expected = [
534             "+---------------+--------------------------------------------------------------------------+",
535    "| plan_type     | plan                                                                     |",
536    "+---------------+--------------------------------------------------------------------------+",
537    "| logical_plan  | Filter: test.id = Int32(2)                                               |",
538    "|               |   TableScan: test projection=[id, val]                                   |",
539    "| physical_plan | CoalesceBatchesExec: target_batch_size=8192                              |",
540    "|               |   FilterExec: id@0 = 2                                                   |",
541    "|               |     RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |",
542    "|               |       CooperativeExec                                                    |",
543    "|               |         MapExec: partitions=1, projection=Some([0, 1])                   |",
544    "|               |                                                                          |",
545    "+---------------+--------------------------------------------------------------------------+",
546        ];
547
548        assert_batches_eq!(expected, &batches);
549    }
550
551    #[tokio::test]
552    async fn test_select_star_with_projection_from_map_table() {
553        let ctx = setup();
554        // Check it includes the expected result
555        let batches = ctx
556            .sql("SELECT val FROM test WHERE id = 1")
557            .await
558            .unwrap()
559            .collect()
560            .await
561            .unwrap();
562
563        let expected = ["+------+", "| val  |", "+------+", "| val1 |", "+------+"];
564
565        assert_batches_eq!(expected, &batches);
566
567        let batches = ctx
568            .sql("SELECT id * 2 FROM test")
569            .await
570            .unwrap()
571            .collect()
572            .await
573            .unwrap();
574
575        let expected = [
576            "+--------------------+",
577            "| test.id * Int64(2) |",
578            "+--------------------+",
579            "| 2                  |",
580            "| 4                  |",
581            "| 6                  |",
582            "| 8                  |",
583            "| 10                 |",
584            "+--------------------+",
585        ];
586
587        assert_batches_eq!(expected, &batches);
588    }
589
590    #[tokio::test]
591    async fn test_select_star_with_sort_from_map_table() {
592        let ctx = setup();
593        // Check it includes the expected result
594        let batches = ctx
595            .sql("SELECT * FROM test ORDER BY id DESC")
596            .await
597            .unwrap()
598            .collect()
599            .await
600            .unwrap();
601
602        let expected = [
603            "+----+------+",
604            "| id | val  |",
605            "+----+------+",
606            "| 5  | val5 |",
607            "| 4  | val4 |",
608            "| 3  | val3 |",
609            "| 2  | val2 |",
610            "| 1  | val1 |",
611            "+----+------+",
612        ];
613
614        assert_batches_eq!(expected, &batches);
615
616        let batches = ctx
617            .sql("EXPLAIN SELECT * FROM test ORDER BY id DESC")
618            .await
619            .unwrap()
620            .collect()
621            .await
622            .unwrap();
623
624        let expected = [
625            "+---------------+-----------------------------------------------------------+",
626            "| plan_type     | plan                                                      |",
627            "+---------------+-----------------------------------------------------------+",
628            "| logical_plan  | Sort: test.id DESC NULLS FIRST                            |",
629            "|               |   TableScan: test projection=[id, val]                    |",
630            "| physical_plan | SortExec: expr=[id@0 DESC], preserve_partitioning=[false] |",
631            "|               |   CooperativeExec                                         |",
632            "|               |     MapExec: partitions=1, projection=Some([0, 1])        |",
633            "|               |                                                           |",
634            "+---------------+-----------------------------------------------------------+",
635        ];
636
637        assert_batches_eq!(expected, &batches);
638    }
639
640    #[tokio::test]
641    async fn test_select_star_with_limit_from_map_table() {
642        let ctx = setup();
643        // Check it includes the expected result
644        let batches = ctx
645            .sql("SELECT * FROM test LIMIT 2")
646            .await
647            .unwrap()
648            .collect()
649            .await
650            .unwrap();
651
652        let expected = [
653            "+----+------+",
654            "| id | val  |",
655            "+----+------+",
656            "| 1  | val1 |",
657            "| 2  | val2 |",
658            "+----+------+",
659        ];
660
661        assert_batches_eq!(expected, &batches);
662
663        let batches = ctx
664            .sql("EXPLAIN SELECT * FROM test LIMIT 2")
665            .await
666            .unwrap()
667            .collect()
668            .await
669            .unwrap();
670
671        let expected = [
672            "+---------------+----------------------------------------------------+",
673            "| plan_type     | plan                                               |",
674            "+---------------+----------------------------------------------------+",
675            "| logical_plan  | Limit: skip=0, fetch=2                             |",
676            "|               |   TableScan: test projection=[id, val], fetch=2    |",
677            "| physical_plan | GlobalLimitExec: skip=0, fetch=2                   |",
678            "|               |   CooperativeExec                                  |",
679            "|               |     MapExec: partitions=1, projection=Some([0, 1]) |",
680            "|               |                                                    |",
681            "+---------------+----------------------------------------------------+",
682        ];
683
684        assert_batches_eq!(expected, &batches);
685    }
686}