1use super::{ColumnField, OwnedColumn, Table};
2use crate::base::{
3 database::ColumnCoercionError, map::IndexMap, polynomial::compute_evaluation_vector,
4 scalar::Scalar,
5};
6use alloc::{vec, vec::Vec};
7use itertools::{EitherOrBoth, Itertools};
8use serde::{Deserialize, Serialize};
9use snafu::Snafu;
10use sqlparser::ast::Ident;
11
12#[derive(Snafu, Debug, PartialEq, Eq)]
14pub enum OwnedTableError {
15 #[snafu(display("Columns have different lengths"))]
17 ColumnLengthMismatch,
18}
19
20#[derive(Snafu, Debug, PartialEq, Eq)]
22pub(crate) enum TableCoercionError {
23 #[snafu(transparent)]
24 ColumnCoercionError { source: ColumnCoercionError },
25 #[snafu(display("Name mismatch between column and field"))]
27 NameMismatch,
28 #[snafu(display("Column count mismatch"))]
30 ColumnCountMismatch,
31}
32
33#[derive(Debug, Clone, Eq, Serialize, Deserialize)]
39pub struct OwnedTable<S: Scalar> {
40 table: IndexMap<Ident, OwnedColumn<S>>,
41}
42impl<S: Scalar> OwnedTable<S> {
43 pub fn try_new(table: IndexMap<Ident, OwnedColumn<S>>) -> Result<Self, OwnedTableError> {
45 if table.is_empty() {
46 return Ok(Self { table });
47 }
48 let num_rows = table[0].len();
49 if table.values().any(|column| column.len() != num_rows) {
50 Err(OwnedTableError::ColumnLengthMismatch)
51 } else {
52 Ok(Self { table })
53 }
54 }
55 pub fn try_from_iter<T: IntoIterator<Item = (Ident, OwnedColumn<S>)>>(
57 iter: T,
58 ) -> Result<Self, OwnedTableError> {
59 Self::try_new(IndexMap::from_iter(iter))
60 }
61
62 #[allow(
63 clippy::missing_panics_doc,
64 reason = "Mapping from one table to another should not result in column mismatch"
65 )]
66 pub(crate) fn try_coerce_with_fields<T: IntoIterator<Item = ColumnField>>(
79 self,
80 fields: T,
81 ) -> Result<Self, TableCoercionError> {
82 self.into_inner()
83 .into_iter()
84 .zip_longest(fields)
85 .map(|p| match p {
86 EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => {
87 Err(TableCoercionError::ColumnCountMismatch)
88 }
89 EitherOrBoth::Both((name, column), field) if name == field.name() => Ok((
90 name,
91 column.try_coerce_scalar_to_numeric(field.data_type())?,
92 )),
93 EitherOrBoth::Both(_, _) => Err(TableCoercionError::NameMismatch),
94 })
95 .process_results(|iter| {
96 Self::try_from_iter(iter).expect("Columns should have the same length")
97 })
98 }
99
100 #[must_use]
102 pub fn num_columns(&self) -> usize {
103 self.table.len()
104 }
105 #[must_use]
107 pub fn num_rows(&self) -> usize {
108 if self.table.is_empty() {
109 0
110 } else {
111 self.table[0].len()
112 }
113 }
114 #[must_use]
116 pub fn is_empty(&self) -> bool {
117 self.table.is_empty()
118 }
119 #[must_use]
121 pub fn into_inner(self) -> IndexMap<Ident, OwnedColumn<S>> {
122 self.table
123 }
124 #[must_use]
126 pub fn inner_table(&self) -> &IndexMap<Ident, OwnedColumn<S>> {
127 &self.table
128 }
129 pub fn column_names(&self) -> impl Iterator<Item = &Ident> {
131 self.table.keys()
132 }
133 #[must_use]
135 pub fn column_by_index(&self, index: usize) -> Option<&OwnedColumn<S>> {
136 self.table.get_index(index).map(|(_, v)| v)
137 }
138
139 pub(crate) fn mle_evaluations(&self, evaluation_point: &[S]) -> Vec<S> {
140 let mut evaluation_vector = vec![S::ZERO; self.num_rows()];
141 compute_evaluation_vector(&mut evaluation_vector, evaluation_point);
142 self.table
143 .values()
144 .map(|column| column.inner_product(&evaluation_vector))
145 .collect()
146 }
147}
148
149impl<S: Scalar> PartialEq for OwnedTable<S> {
152 fn eq(&self, other: &Self) -> bool {
153 self.table == other.table
154 && self
155 .table
156 .keys()
157 .zip(other.table.keys())
158 .all(|(a, b)| a == b)
159 }
160}
161
162#[cfg(test)]
163impl<S: Scalar> core::ops::Index<&str> for OwnedTable<S> {
164 type Output = OwnedColumn<S>;
165 fn index(&self, index: &str) -> &Self::Output {
166 self.table.get(&Ident::new(index)).unwrap()
167 }
168}
169
170impl<'a, S: Scalar> From<&Table<'a, S>> for OwnedTable<S> {
171 fn from(value: &Table<'a, S>) -> Self {
172 OwnedTable::try_from_iter(
173 value
174 .inner_table()
175 .iter()
176 .map(|(name, column)| (name.clone(), OwnedColumn::from(column))),
177 )
178 .expect("Tables should not have columns with differing lengths")
179 }
180}
181
182impl<'a, S: Scalar> From<Table<'a, S>> for OwnedTable<S> {
183 fn from(value: Table<'a, S>) -> Self {
184 OwnedTable::try_from_iter(
185 value
186 .into_inner()
187 .into_iter()
188 .map(|(name, column)| (name, OwnedColumn::from(&column))),
189 )
190 .expect("Tables should not have columns with differing lengths")
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::OwnedTable;
197 use crate::base::{
198 database::{
199 owned_table_utility::*, table_utility::*, ColumnCoercionError, Table,
200 TableCoercionError, TableOptions,
201 },
202 map::indexmap,
203 scalar::test_scalar::TestScalar,
204 };
205 use bumpalo::Bump;
206 use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
207
208 #[test]
209 fn test_conversion_from_table_to_owned_table() {
210 let alloc = Bump::new();
211
212 let borrowed_table = table::<TestScalar>([
213 borrowed_bigint(
214 "bigint",
215 [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX],
216 &alloc,
217 ),
218 borrowed_int128(
219 "decimal",
220 [0_i128, 1, 2, 3, 4, 5, 6, i128::MIN, i128::MAX],
221 &alloc,
222 ),
223 borrowed_varchar(
224 "varchar",
225 ["0", "1", "2", "3", "4", "5", "6", "7", "8"],
226 &alloc,
227 ),
228 borrowed_scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8], &alloc),
229 borrowed_boolean(
230 "boolean",
231 [true, false, true, false, true, false, true, false, true],
232 &alloc,
233 ),
234 borrowed_timestamptz(
235 "time_stamp",
236 PoSQLTimeUnit::Second,
237 PoSQLTimeZone::utc(),
238 [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX],
239 &alloc,
240 ),
241 ]);
242
243 let expected_table = owned_table::<TestScalar>([
244 bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
245 int128("decimal", [0_i128, 1, 2, 3, 4, 5, 6, i128::MIN, i128::MAX]),
246 varchar("varchar", ["0", "1", "2", "3", "4", "5", "6", "7", "8"]),
247 scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
248 boolean(
249 "boolean",
250 [true, false, true, false, true, false, true, false, true],
251 ),
252 timestamptz(
253 "time_stamp",
254 PoSQLTimeUnit::Second,
255 PoSQLTimeZone::utc(),
256 [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX],
257 ),
258 ]);
259
260 assert_eq!(OwnedTable::from(&borrowed_table), expected_table);
261 assert_eq!(OwnedTable::from(borrowed_table), expected_table);
262 }
263
264 #[test]
265 fn test_empty_and_no_columns_tables() {
266 let alloc = Bump::new();
267 let empty_table = table::<TestScalar>([borrowed_bigint("bigint", [0; 0], &alloc)]);
269 let expected_empty_table = owned_table::<TestScalar>([bigint("bigint", [0; 0])]);
270 assert_eq!(OwnedTable::from(&empty_table), expected_empty_table);
271 assert_eq!(OwnedTable::from(empty_table), expected_empty_table);
272
273 let no_columns_table_no_rows =
275 Table::try_new_with_options(indexmap! {}, TableOptions::new(Some(0))).unwrap();
276 let no_columns_table_two_rows =
277 Table::try_new_with_options(indexmap! {}, TableOptions::new(Some(2))).unwrap();
278 let expected_no_columns_table = owned_table::<TestScalar>([]);
279 assert_eq!(
280 OwnedTable::from(&no_columns_table_no_rows),
281 expected_no_columns_table
282 );
283 assert_eq!(
284 OwnedTable::from(no_columns_table_no_rows),
285 expected_no_columns_table
286 );
287 assert_eq!(
288 OwnedTable::from(&no_columns_table_two_rows),
289 expected_no_columns_table
290 );
291 assert_eq!(
292 OwnedTable::from(no_columns_table_two_rows),
293 expected_no_columns_table
294 );
295 }
296
297 #[test]
298 fn test_try_coerce_with_fields() {
299 use crate::base::database::{ColumnField, ColumnType};
300
301 let table = owned_table::<TestScalar>([
302 bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
303 scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
304 ]);
305
306 let fields = vec![
307 ColumnField::new("bigint".into(), ColumnType::BigInt),
308 ColumnField::new("scalar".into(), ColumnType::Int),
309 ];
310
311 let coerced_table = table.clone().try_coerce_with_fields(fields).unwrap();
312
313 let expected_table = owned_table::<TestScalar>([
314 bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
315 int("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
316 ]);
317
318 assert_eq!(coerced_table, expected_table);
319 }
320
321 #[test]
322 fn test_try_coerce_with_fields_name_mismatch() {
323 use crate::base::database::{ColumnField, ColumnType};
324
325 let table = owned_table::<TestScalar>([
326 bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
327 scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
328 ]);
329
330 let fields = vec![
331 ColumnField::new("bigint".into(), ColumnType::BigInt),
332 ColumnField::new("mismatch".into(), ColumnType::Int),
333 ];
334
335 let result = table.clone().try_coerce_with_fields(fields);
336
337 assert!(matches!(result, Err(TableCoercionError::NameMismatch)));
338 }
339
340 #[test]
341 fn test_try_coerce_with_fields_column_count_mismatch() {
342 use crate::base::database::{ColumnField, ColumnType};
343
344 let table = owned_table::<TestScalar>([
345 bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
346 scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]),
347 ]);
348
349 let fields = vec![ColumnField::new("bigint".into(), ColumnType::BigInt)];
350
351 let result = table.clone().try_coerce_with_fields(fields);
352
353 assert!(matches!(
354 result,
355 Err(TableCoercionError::ColumnCountMismatch)
356 ));
357 }
358
359 #[test]
360 fn test_try_coerce_with_fields_overflow() {
361 use crate::base::database::{ColumnField, ColumnType};
362
363 let table = owned_table::<TestScalar>([
364 bigint("bigint", [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]),
365 scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, i64::MAX]),
366 ]);
367
368 let fields = vec![
369 ColumnField::new("bigint".into(), ColumnType::BigInt),
370 ColumnField::new("scalar".into(), ColumnType::TinyInt),
371 ];
372
373 let result = table.clone().try_coerce_with_fields(fields);
374
375 assert!(matches!(
376 result,
377 Err(TableCoercionError::ColumnCoercionError {
378 source: ColumnCoercionError::Overflow
379 })
380 ));
381 }
382}