proof_of_sql/base/database/
table.rs1use super::{Column, ColumnField};
2use crate::base::{map::IndexMap, scalar::Scalar};
3use alloc::vec::Vec;
4use bumpalo::Bump;
5use snafu::Snafu;
6use sqlparser::ast::Ident;
7
8#[derive(Debug, Default, Clone, Copy)]
11pub struct TableOptions {
12    pub row_count: Option<usize>,
14}
15
16impl TableOptions {
17    #[must_use]
19    pub fn new(row_count: Option<usize>) -> Self {
20        Self { row_count }
21    }
22}
23
24#[derive(Snafu, Debug, PartialEq, Eq)]
26pub enum TableError {
27    #[snafu(display("Columns have different lengths"))]
29    ColumnLengthMismatch,
30
31    #[snafu(display("Column has length different from the provided row count"))]
33    ColumnLengthMismatchWithSpecifiedRowCount,
34
35    #[snafu(display("Table is empty and no row count is specified"))]
37    EmptyTableWithoutSpecifiedRowCount,
38}
39#[derive(Debug, Clone, Eq)]
45pub struct Table<'a, S: Scalar> {
46    table: IndexMap<Ident, Column<'a, S>>,
47    row_count: usize,
48}
49impl<'a, S: Scalar> Table<'a, S> {
50    pub fn try_new(table: IndexMap<Ident, Column<'a, S>>) -> Result<Self, TableError> {
52        Self::try_new_with_options(table, TableOptions::default())
53    }
54
55    pub fn try_new_with_options(
57        table: IndexMap<Ident, Column<'a, S>>,
58        options: TableOptions,
59    ) -> Result<Self, TableError> {
60        match (table.is_empty(), options.row_count) {
61            (true, None) => Err(TableError::EmptyTableWithoutSpecifiedRowCount),
62            (true, Some(row_count)) => Ok(Self { table, row_count }),
63            (false, None) => {
64                let row_count = table[0].len();
65                if table.values().any(|column| column.len() != row_count) {
66                    Err(TableError::ColumnLengthMismatch)
67                } else {
68                    Ok(Self { table, row_count })
69                }
70            }
71            (false, Some(row_count)) => {
72                if table.values().any(|column| column.len() != row_count) {
73                    Err(TableError::ColumnLengthMismatchWithSpecifiedRowCount)
74                } else {
75                    Ok(Self { table, row_count })
76                }
77            }
78        }
79    }
80
81    pub fn try_from_iter<T: IntoIterator<Item = (Ident, Column<'a, S>)>>(
83        iter: T,
84    ) -> Result<Self, TableError> {
85        Self::try_from_iter_with_options(iter, TableOptions::default())
86    }
87
88    pub fn try_from_iter_with_options<T: IntoIterator<Item = (Ident, Column<'a, S>)>>(
90        iter: T,
91        options: TableOptions,
92    ) -> Result<Self, TableError> {
93        Self::try_new_with_options(IndexMap::from_iter(iter), options)
94    }
95
96    #[must_use]
98    pub fn num_columns(&self) -> usize {
99        self.table.len()
100    }
101    #[must_use]
103    pub fn num_rows(&self) -> usize {
104        self.row_count
105    }
106    #[must_use]
108    pub fn is_empty(&self) -> bool {
109        self.table.is_empty()
110    }
111    #[must_use]
113    pub fn into_inner(self) -> IndexMap<Ident, Column<'a, S>> {
114        self.table
115    }
116    #[must_use]
118    pub fn inner_table(&self) -> &IndexMap<Ident, Column<'a, S>> {
119        &self.table
120    }
121    #[must_use]
123    pub fn schema(&self) -> Vec<ColumnField> {
124        self.table
125            .iter()
126            .map(|(name, column)| ColumnField::new(name.clone(), column.column_type()))
127            .collect()
128    }
129    pub fn column_names(&self) -> impl Iterator<Item = &Ident> {
131        self.table.keys()
132    }
133    pub fn columns(&self) -> impl Iterator<Item = &Column<'a, S>> {
135        self.table.values()
136    }
137    #[must_use]
139    pub fn column(&self, index: usize) -> Option<&Column<'a, S>> {
140        self.table.values().nth(index)
141    }
142    #[must_use]
144    pub fn add_rho_column(mut self, alloc: &'a Bump) -> Self {
145        self.table
146            .insert(Ident::new("rho"), Column::rho(self.row_count, alloc));
147        self
148    }
149}
150
151impl<S: Scalar> PartialEq for Table<'_, S> {
154    fn eq(&self, other: &Self) -> bool {
155        self.table == other.table
156            && self
157                .table
158                .keys()
159                .zip(other.table.keys())
160                .all(|(a, b)| a == b)
161    }
162}
163
164#[cfg(test)]
165impl<'a, S: Scalar> core::ops::Index<&str> for Table<'a, S> {
166    type Output = Column<'a, S>;
167    fn index(&self, index: &str) -> &Self::Output {
168        self.table.get(&Ident::new(index)).unwrap()
169    }
170}