Skip to main content

proof_of_sql/base/database/
data_accessor_impl.rs

1#[cfg(feature = "arrow")]
2use crate::base::database::{ArrayRefExt, ArrowArrayToColumnConversionError};
3use crate::base::{
4    database::{Column, DataAccessor, MetadataAccessor, TableRef},
5    scalar::Scalar,
6    IndexMap,
7};
8#[cfg(feature = "arrow")]
9use arrow::array::RecordBatch;
10use bumpalo::Bump;
11use sqlparser::ast::Ident;
12
13/// The canonical implementation for the `DataAccessor` trait
14pub struct DataAccessorImpl<'a, S: Scalar> {
15    data_lookup: IndexMap<TableRef, TableDataAccessor<'a, S>>,
16}
17
18impl<'a, S: Scalar> DataAccessorImpl<'a, S> {
19    /// Creates a new instance of `DataAccessorImpl`
20    #[must_use]
21    pub fn new(data_lookup: IndexMap<TableRef, TableDataAccessor<'a, S>>) -> Self {
22        Self { data_lookup }
23    }
24}
25
26/// An intermediate type for use by `DataAccessorImpl`
27pub struct TableDataAccessor<'a, S: Scalar> {
28    offset: usize,
29    table_data: IndexMap<Ident, Column<'a, S>>,
30}
31
32impl<'a, S: Scalar> TableDataAccessor<'a, S> {
33    /// Creates a new instance of `TableDataAccessor`
34    #[must_use]
35    pub fn new(offset: usize, table_data: IndexMap<Ident, Column<'a, S>>) -> Self {
36        Self { offset, table_data }
37    }
38
39    /// Creates a new instance of `TableDataAccessor` using a `RecordBatch`
40    #[cfg(feature = "arrow")]
41    pub fn try_from_record_batch(
42        record_batch: &'a RecordBatch,
43        offset: usize,
44        alloc: &'a Bump,
45    ) -> Result<Self, ArrowArrayToColumnConversionError> {
46        let range = 0..record_batch.num_rows();
47        let columns = record_batch
48            .schema()
49            .fields()
50            .iter()
51            .zip(record_batch.columns())
52            .map(|(f, col)| {
53                col.to_column::<S>(alloc, &range, None)
54                    .map(|col| (f.name().as_str().into(), col))
55            })
56            // Use collect to transform Iterator<Result<T, E>> into Result<Collection<T>, E>
57            .collect::<Result<IndexMap<_, _>, _>>()?;
58        Ok(Self {
59            offset,
60            table_data: columns,
61        })
62    }
63}
64
65impl<S: Scalar> MetadataAccessor for DataAccessorImpl<'_, S> {
66    fn get_length(&self, table_ref: &TableRef) -> usize {
67        self.data_lookup
68            .get(table_ref)
69            .expect("table does not exist")
70            .table_data
71            .len()
72    }
73
74    fn get_offset(&self, table_ref: &TableRef) -> usize {
75        self.data_lookup
76            .get(table_ref)
77            .expect("table does not exist")
78            .offset
79    }
80}
81
82impl<S: Scalar> DataAccessor<S> for DataAccessorImpl<'_, S> {
83    fn get_column(&self, table_ref: &TableRef, column_id: &Ident) -> Column<'_, S> {
84        *self
85            .data_lookup
86            .get(table_ref)
87            .expect("table does not exist")
88            .table_data
89            .get(column_id)
90            .expect("column does not exist")
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use crate::base::{
97        database::{
98            Column, DataAccessor, DataAccessorImpl, MetadataAccessor, TableDataAccessor, TableRef,
99        },
100        scalar::test_scalar::TestScalar,
101    };
102    use alloc::sync::Arc;
103    #[cfg(feature = "arrow")]
104    use arrow::array::{ArrayRef, BooleanArray, RecordBatch};
105    use bumpalo::Bump;
106    use sqlparser::ast::Ident;
107    use std::str::FromStr;
108
109    #[test]
110    fn we_can_get_offset_and_length() {
111        let column_id = Ident::from("test");
112        let column = Column::<TestScalar>::BigInt(&[3i64]);
113        let table_data_accessor =
114            TableDataAccessor::new(2, [(column_id.clone(), column)].into_iter().collect());
115        let table_ref = TableRef::from_names(Some("test"), "table");
116        let data_accessor = DataAccessorImpl::new(
117            [(table_ref.clone(), table_data_accessor)]
118                .into_iter()
119                .collect(),
120        );
121        assert_eq!(data_accessor.get_length(&table_ref), 1);
122        assert_eq!(data_accessor.get_offset(&table_ref), 2);
123        assert_eq!(data_accessor.get_column(&table_ref, &column_id), column);
124    }
125
126    #[cfg(feature = "arrow")]
127    #[test]
128    fn we_can_get_data_accessor_from_record_batch() {
129        let rb = RecordBatch::try_from_iter([(
130            "BOOLS",
131            Arc::new(BooleanArray::from(vec![true, false])) as ArrayRef,
132        )])
133        .unwrap();
134
135        let alloc = Bump::new();
136        let table_ref = TableRef::from_str("test.table").unwrap();
137        let table_data_accessor =
138            TableDataAccessor::<TestScalar>::try_from_record_batch(&rb, 1, &alloc).unwrap();
139        let data_accessor_impl = DataAccessorImpl::new(
140            [(table_ref.clone(), table_data_accessor)]
141                .into_iter()
142                .collect(),
143        );
144
145        assert_eq!(data_accessor_impl.get_length(&table_ref), 1);
146        assert_eq!(data_accessor_impl.get_offset(&table_ref), 1);
147        assert_eq!(
148            data_accessor_impl.get_column(&table_ref, &Ident::new("BOOLS")),
149            Column::Boolean(&[true, false])
150        );
151    }
152}