proof_of_sql/base/database/
owned_table.rs

1use super::{ColumnField, OwnedColumn, Table};
2use crate::base::{
3    database::ColumnCoercionError, map::IndexMap, polynomial::compute_evaluation_vector,
4    scalar::Scalar,
5};
6use alloc::{vec, vec::Vec};
7use itertools::{EitherOrBoth, Itertools};
8use serde::{Deserialize, Serialize};
9use snafu::Snafu;
10use sqlparser::ast::Ident;
11
12/// An error that occurs when working with tables.
13#[derive(Snafu, Debug, PartialEq, Eq)]
14pub enum OwnedTableError {
15    /// The columns have different lengths.
16    #[snafu(display("Columns have different lengths"))]
17    ColumnLengthMismatch,
18}
19
20/// Errors that can occur when coercing a table.
21#[derive(Snafu, Debug, PartialEq, Eq)]
22pub(crate) enum TableCoercionError {
23    #[snafu(transparent)]
24    ColumnCoercionError { source: ColumnCoercionError },
25    /// Name mismatch between column and field.
26    #[snafu(display("Name mismatch between column and field"))]
27    NameMismatch,
28    /// Column count mismatch.
29    #[snafu(display("Column count mismatch"))]
30    ColumnCountMismatch,
31}
32
33/// A table of data, with schema included. This is simply a map from `Ident` to `OwnedColumn`,
34/// where columns order matters.
35/// This is primarily used as an internal result that is used before
36/// converting to the final result in either Arrow format or JSON.
37/// This is the analog of an arrow [`RecordBatch`](arrow::record_batch::RecordBatch).
38#[derive(Debug, Clone, Eq, Serialize, Deserialize)]
39pub struct OwnedTable<S: Scalar> {
40    table: IndexMap<Ident, OwnedColumn<S>>,
41}
42impl<S: Scalar> OwnedTable<S> {
43    /// Creates a new [`OwnedTable`].
44    pub fn try_new(table: IndexMap<Ident, OwnedColumn<S>>) -> Result<Self, OwnedTableError> {
45        if table.is_empty() {
46            return Ok(Self { table });
47        }
48        let num_rows = table[0].len();
49        if table.values().any(|column| column.len() != num_rows) {
50            Err(OwnedTableError::ColumnLengthMismatch)
51        } else {
52            Ok(Self { table })
53        }
54    }
55    /// Creates a new [`OwnedTable`].
56    pub fn try_from_iter<T: IntoIterator<Item = (Ident, OwnedColumn<S>)>>(
57        iter: T,
58    ) -> Result<Self, OwnedTableError> {
59        Self::try_new(IndexMap::from_iter(iter))
60    }
61
62    #[allow(
63        clippy::missing_panics_doc,
64        reason = "Mapping from one table to another should not result in column mismatch"
65    )]
66    /// Attempts to coerce the columns of the table to match the provided fields.
67    ///
68    /// # Arguments
69    ///
70    /// * `fields` - An iterator of `ColumnField` items that specify the desired schema.
71    ///
72    /// # Errors
73    ///
74    /// Returns a `TableCoercionError` if:
75    /// * The number of columns in the table does not match the number of fields.
76    /// * The name of a column does not match the name of the corresponding field.
77    /// * A column cannot be coerced to the type specified by the corresponding field.
78    pub(crate) fn try_coerce_with_fields<T: IntoIterator<Item = ColumnField>>(
79        self,
80        fields: T,
81    ) -> Result<Self, TableCoercionError> {
82        self.into_inner()
83            .into_iter()
84            .zip_longest(fields)
85            .map(|p| match p {
86                EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => {
87                    Err(TableCoercionError::ColumnCountMismatch)
88                }
89                EitherOrBoth::Both((name, column), field) if name == field.name() => Ok((
90                    name,
91                    column.try_coerce_scalar_to_numeric(field.data_type())?,
92                )),
93                EitherOrBoth::Both(_, _) => Err(TableCoercionError::NameMismatch),
94            })
95            .process_results(|iter| {
96                Self::try_from_iter(iter).expect("Columns should have the same length")
97            })
98    }
99
100    /// Number of columns in the table.
101    #[must_use]
102    pub fn num_columns(&self) -> usize {
103        self.table.len()
104    }
105    /// Number of rows in the table.
106    #[must_use]
107    pub fn num_rows(&self) -> usize {
108        if self.table.is_empty() {
109            0
110        } else {
111            self.table[0].len()
112        }
113    }
114    /// Whether the table has no columns.
115    #[must_use]
116    pub fn is_empty(&self) -> bool {
117        self.table.is_empty()
118    }
119    /// Returns the columns of this table as an `IndexMap`
120    #[must_use]
121    pub fn into_inner(self) -> IndexMap<Ident, OwnedColumn<S>> {
122        self.table
123    }
124    /// Returns the columns of this table as an `IndexMap`
125    #[must_use]
126    pub fn inner_table(&self) -> &IndexMap<Ident, OwnedColumn<S>> {
127        &self.table
128    }
129    /// Returns the columns of this table as an Iterator
130    pub fn column_names(&self) -> impl Iterator<Item = &Ident> {
131        self.table.keys()
132    }
133    /// Returns the column with the given position.
134    #[must_use]
135    pub fn column_by_index(&self, index: usize) -> Option<&OwnedColumn<S>> {
136        self.table.get_index(index).map(|(_, v)| v)
137    }
138
139    pub(crate) fn mle_evaluations(&self, evaluation_point: &[S]) -> Vec<S> {
140        let mut evaluation_vector = vec![S::ZERO; self.num_rows()];
141        compute_evaluation_vector(&mut evaluation_vector, evaluation_point);
142        self.table
143            .values()
144            .map(|column| column.inner_product(&evaluation_vector))
145            .collect()
146    }
147}
148
149// Note: we modify the default PartialEq for IndexMap to also check for column ordering.
150// This is to align with the behaviour of a `RecordBatch`.
151impl<S: Scalar> PartialEq for OwnedTable<S> {
152    fn eq(&self, other: &Self) -> bool {
153        self.table == other.table
154            && self
155                .table
156                .keys()
157                .zip(other.table.keys())
158                .all(|(a, b)| a == b)
159    }
160}
161
162#[cfg(test)]
163impl<S: Scalar> core::ops::Index<&str> for OwnedTable<S> {
164    type Output = OwnedColumn<S>;
165    fn index(&self, index: &str) -> &Self::Output {
166        self.table.get(&Ident::new(index)).unwrap()
167    }
168}
169
170impl<'a, S: Scalar> From<&Table<'a, S>> for OwnedTable<S> {
171    fn from(value: &Table<'a, S>) -> Self {
172        OwnedTable::try_from_iter(
173            value
174                .inner_table()
175                .iter()
176                .map(|(name, column)| (name.clone(), OwnedColumn::from(column))),
177        )
178        .expect("Tables should not have columns with differing lengths")
179    }
180}
181
182impl<'a, S: Scalar> From<Table<'a, S>> for OwnedTable<S> {
183    fn from(value: Table<'a, S>) -> Self {
184        OwnedTable::try_from_iter(
185            value
186                .into_inner()
187                .into_iter()
188                .map(|(name, column)| (name, OwnedColumn::from(&column))),
189        )
190        .expect("Tables should not have columns with differing lengths")
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::OwnedTable;
197    use crate::base::{
198        database::{
199            owned_table_utility::*, table_utility::*, ColumnCoercionError, Table,
200            TableCoercionError, TableOptions,
201        },
202        map::indexmap,
203        scalar::test_scalar::TestScalar,
204    };
205    use bumpalo::Bump;
206    use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
207
208    #[test]
209    fn test_conversion_from_table_to_owned_table() {
210        let alloc = Bump::new();
211
212        let borrowed_table = table::<TestScalar>([
213            borrowed_bigint(
214                "bigint",
215                [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX],
216                &alloc,
217            ),
218            borrowed_int128(
219                "decimal",
220                [0_i128, 1, 2, 3, 4, 5, 6, i128::MIN, i128::MAX],
221                &alloc,
222            ),
223            borrowed_varchar(
224                "varchar",
225                ["0", "1", "2", "3", "4", "5", "6", "7", "8"],
226                &alloc,
227            ),
228            borrowed_scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8], &alloc),
229            borrowed_boolean(
230                "boolean",
231                [true, false, true, false, true, false, true, false, true],
232                &alloc,
233            ),
234            borrowed_timestamptz(
235                "time_stamp",
236                PoSQLTimeUnit::Second,
237                PoSQLTimeZone::utc(),
238                [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX],
239                &alloc,
240            ),
241        ]);
242
243        let expected_table = owned_table::<TestScalar>([
244            bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
245            int128("decimal", [0_i128, 1, 2, 3, 4, 5, 6, i128::MIN, i128::MAX]),
246            varchar("varchar", ["0", "1", "2", "3", "4", "5", "6", "7", "8"]),
247            scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
248            boolean(
249                "boolean",
250                [true, false, true, false, true, false, true, false, true],
251            ),
252            timestamptz(
253                "time_stamp",
254                PoSQLTimeUnit::Second,
255                PoSQLTimeZone::utc(),
256                [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX],
257            ),
258        ]);
259
260        assert_eq!(OwnedTable::from(&borrowed_table), expected_table);
261        assert_eq!(OwnedTable::from(borrowed_table), expected_table);
262    }
263
264    #[test]
265    fn test_empty_and_no_columns_tables() {
266        let alloc = Bump::new();
267        // Test with no rows
268        let empty_table = table::<TestScalar>([borrowed_bigint("bigint", [0; 0], &alloc)]);
269        let expected_empty_table = owned_table::<TestScalar>([bigint("bigint", [0; 0])]);
270        assert_eq!(OwnedTable::from(&empty_table), expected_empty_table);
271        assert_eq!(OwnedTable::from(empty_table), expected_empty_table);
272
273        // Test with no columns
274        let no_columns_table_no_rows =
275            Table::try_new_with_options(indexmap! {}, TableOptions::new(Some(0))).unwrap();
276        let no_columns_table_two_rows =
277            Table::try_new_with_options(indexmap! {}, TableOptions::new(Some(2))).unwrap();
278        let expected_no_columns_table = owned_table::<TestScalar>([]);
279        assert_eq!(
280            OwnedTable::from(&no_columns_table_no_rows),
281            expected_no_columns_table
282        );
283        assert_eq!(
284            OwnedTable::from(no_columns_table_no_rows),
285            expected_no_columns_table
286        );
287        assert_eq!(
288            OwnedTable::from(&no_columns_table_two_rows),
289            expected_no_columns_table
290        );
291        assert_eq!(
292            OwnedTable::from(no_columns_table_two_rows),
293            expected_no_columns_table
294        );
295    }
296
297    #[test]
298    fn test_try_coerce_with_fields() {
299        use crate::base::database::{ColumnField, ColumnType};
300
301        let table = owned_table::<TestScalar>([
302            bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
303            scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
304        ]);
305
306        let fields = vec![
307            ColumnField::new("bigint".into(), ColumnType::BigInt),
308            ColumnField::new("scalar".into(), ColumnType::Int),
309        ];
310
311        let coerced_table = table.clone().try_coerce_with_fields(fields).unwrap();
312
313        let expected_table = owned_table::<TestScalar>([
314            bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
315            int("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
316        ]);
317
318        assert_eq!(coerced_table, expected_table);
319    }
320
321    #[test]
322    fn test_try_coerce_with_fields_name_mismatch() {
323        use crate::base::database::{ColumnField, ColumnType};
324
325        let table = owned_table::<TestScalar>([
326            bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
327            scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
328        ]);
329
330        let fields = vec![
331            ColumnField::new("bigint".into(), ColumnType::BigInt),
332            ColumnField::new("mismatch".into(), ColumnType::Int),
333        ];
334
335        let result = table.clone().try_coerce_with_fields(fields);
336
337        assert!(matches!(result, Err(TableCoercionError::NameMismatch)));
338    }
339
340    #[test]
341    fn test_try_coerce_with_fields_column_count_mismatch() {
342        use crate::base::database::{ColumnField, ColumnType};
343
344        let table = owned_table::<TestScalar>([
345            bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
346            scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
347        ]);
348
349        let fields = vec![ColumnField::new("bigint".into(), ColumnType::BigInt)];
350
351        let result = table.clone().try_coerce_with_fields(fields);
352
353        assert!(matches!(
354            result,
355            Err(TableCoercionError::ColumnCountMismatch)
356        ));
357    }
358
359    #[test]
360    fn test_try_coerce_with_fields_overflow() {
361        use crate::base::database::{ColumnField, ColumnType};
362
363        let table = owned_table::<TestScalar>([
364            bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
365            scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, i64::MAX]),
366        ]);
367
368        let fields = vec![
369            ColumnField::new("bigint".into(), ColumnType::BigInt),
370            ColumnField::new("scalar".into(), ColumnType::TinyInt),
371        ];
372
373        let result = table.clone().try_coerce_with_fields(fields);
374
375        assert!(matches!(
376            result,
377            Err(TableCoercionError::ColumnCoercionError {
378                source: ColumnCoercionError::Overflow
379            })
380        ));
381    }
382}