proof_of_sql/sql/proof/
provable_query_result.rs

1use super::{decode_and_convert, decode_multiple_elements, ProvableResultColumn, QueryError};
2use crate::base::{
3    database::{Column, ColumnField, ColumnType, OwnedColumn, OwnedTable, Table},
4    polynomial::compute_evaluation_vector,
5    scalar::{Scalar, ScalarExt},
6};
7use alloc::{vec, vec::Vec};
8use num_traits::Zero;
9use serde::{Deserialize, Serialize};
10
11/// An intermediate form of a query result that can be transformed
12/// to either the finalized query result form or a query error
13#[derive(Debug, Default, Clone, Serialize, Deserialize)]
14pub struct ProvableQueryResult {
15    num_columns: u64,
16    pub(crate) table_length: u64,
17    data: Vec<u8>,
18}
19
20// TODO: Handle truncation properly. The `allow(clippy::cast_possible_truncation)` is a temporary fix and should be replaced with proper logic to manage possible truncation scenarios.
21impl ProvableQueryResult {
22    #[allow(clippy::cast_possible_truncation)]
23    /// The number of columns in the result
24    #[must_use]
25    pub fn num_columns(&self) -> usize {
26        self.num_columns as usize
27    }
28    /// A mutable reference to the number of columns in the result. Because the struct is deserialized from untrusted data, it
29    /// cannot maintain any invariant on its data members; hence, this function is available to allow for easy manipulation for testing.
30    #[cfg(test)]
31    pub fn num_columns_mut(&mut self) -> &mut u64 {
32        &mut self.num_columns
33    }
34
35    #[allow(clippy::cast_possible_truncation)]
36    /// The number of rows in the result
37    #[must_use]
38    pub fn table_length(&self) -> usize {
39        self.table_length as usize
40    }
41    /// A mutable reference to the underlying encoded data of the result. Because the struct is deserialized from untrusted data, it
42    /// cannot maintain any invariant on its data members; hence, this function is available to allow for easy manipulation for testing.
43    #[cfg(test)]
44    pub fn data_mut(&mut self) -> &mut Vec<u8> {
45        &mut self.data
46    }
47    /// This function is available to allow for easy creation for testing.
48    #[cfg(test)]
49    #[must_use]
50    pub fn new_from_raw_data(num_columns: u64, table_length: u64, data: Vec<u8>) -> Self {
51        Self {
52            num_columns,
53            table_length,
54            data,
55        }
56    }
57
58    /// Form intermediate query result from index rows and result columns
59    /// # Panics
60    ///
61    /// Will panic if `table_length` is somehow larger than the length of some column
62    /// which should never happen.
63    #[must_use]
64    pub fn new<'a, S: Scalar>(table_length: u64, columns: &'a [Column<'a, S>]) -> Self {
65        assert!(columns
66            .iter()
67            .all(|column| table_length == column.len() as u64));
68        let mut sz = 0;
69        for col in columns {
70            sz += col.num_bytes(table_length);
71        }
72        let mut data = vec![0u8; sz];
73        let mut sz = 0;
74        for col in columns {
75            sz += col.write(&mut data[sz..], table_length);
76        }
77        ProvableQueryResult {
78            num_columns: columns.len() as u64,
79            table_length,
80            data,
81        }
82    }
83
84    #[allow(clippy::cast_possible_truncation)]
85    #[allow(
86        clippy::missing_panics_doc,
87        reason = "Assertions ensure preconditions are met, eliminating the possibility of panic."
88    )]
89    /// Given an evaluation vector, compute the evaluation of the intermediate result
90    /// columns as spare multilinear extensions
91    ///
92    /// # Panics
93    /// This function will panic if the length of `evaluation_point` does not match `self.num_columns`.
94    /// It will also panic if the `data` array is not properly formatted for the expected column types.
95    pub fn evaluate<S: Scalar>(
96        &self,
97        evaluation_point: &[S],
98        output_length: usize,
99        column_result_fields: &[ColumnField],
100    ) -> Result<Vec<S>, QueryError> {
101        if self.num_columns as usize != column_result_fields.len() {
102            return Err(QueryError::InvalidColumnCount);
103        }
104        let mut evaluation_vec = vec![Zero::zero(); output_length];
105        compute_evaluation_vector(&mut evaluation_vec, evaluation_point);
106        let mut offset: usize = 0;
107        let mut res = Vec::with_capacity(self.num_columns as usize);
108
109        for field in column_result_fields {
110            let mut val = S::zero();
111            for entry in evaluation_vec.iter().take(output_length) {
112                let (x, sz) = match field.data_type() {
113                    ColumnType::Boolean => decode_and_convert::<bool, S>(&self.data[offset..]),
114                    ColumnType::Uint8 => decode_and_convert::<u8, S>(&self.data[offset..]),
115                    ColumnType::TinyInt => decode_and_convert::<i8, S>(&self.data[offset..]),
116                    ColumnType::SmallInt => decode_and_convert::<i16, S>(&self.data[offset..]),
117                    ColumnType::Int => decode_and_convert::<i32, S>(&self.data[offset..]),
118                    ColumnType::BigInt => decode_and_convert::<i64, S>(&self.data[offset..]),
119                    ColumnType::Int128 => decode_and_convert::<i128, S>(&self.data[offset..]),
120                    ColumnType::Decimal75(_, _) | ColumnType::Scalar => {
121                        decode_and_convert::<S, S>(&self.data[offset..])
122                    }
123
124                    ColumnType::VarChar => decode_and_convert::<&str, S>(&self.data[offset..]),
125                    ColumnType::VarBinary => {
126                        let (raw_bytes, used) =
127                            decode_and_convert::<&[u8], &[u8]>(&self.data[offset..])?;
128                        let x = S::from_byte_slice_via_hash(raw_bytes);
129                        Ok((x, used))
130                    }
131                    ColumnType::TimestampTZ(_, _) => {
132                        decode_and_convert::<i64, S>(&self.data[offset..])
133                    }
134                }?;
135                val += *entry * x;
136                offset += sz;
137            }
138            res.push(val);
139        }
140        if offset != self.data.len() {
141            return Err(QueryError::MiscellaneousEvaluationError);
142        }
143
144        Ok(res)
145    }
146
147    #[allow(
148        clippy::missing_panics_doc,
149        reason = "Assertions ensure preconditions are met, eliminating the possibility of panic."
150    )]
151    /// Convert the intermediate query result into a final query result
152    ///
153    /// The result is essentially an `OwnedTable` type.
154    pub fn to_owned_table<S: Scalar>(
155        &self,
156        column_result_fields: &[ColumnField],
157    ) -> Result<OwnedTable<S>, QueryError> {
158        if column_result_fields.len() != self.num_columns() {
159            return Err(QueryError::InvalidColumnCount);
160        }
161
162        let n = self.table_length();
163        let mut offset: usize = 0;
164
165        let owned_table = OwnedTable::try_new(
166            column_result_fields
167                .iter()
168                .map(|field| match field.data_type() {
169                    ColumnType::Boolean => {
170                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
171                        offset += num_read;
172                        Ok((field.name(), OwnedColumn::Boolean(col)))
173                    }
174                    ColumnType::Uint8 => {
175                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
176                        offset += num_read;
177                        Ok((field.name(), OwnedColumn::Uint8(col)))
178                    }
179                    ColumnType::TinyInt => {
180                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
181                        offset += num_read;
182                        Ok((field.name(), OwnedColumn::TinyInt(col)))
183                    }
184                    ColumnType::SmallInt => {
185                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
186                        offset += num_read;
187                        Ok((field.name(), OwnedColumn::SmallInt(col)))
188                    }
189                    ColumnType::Int => {
190                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
191                        offset += num_read;
192                        Ok((field.name(), OwnedColumn::Int(col)))
193                    }
194                    ColumnType::BigInt => {
195                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
196                        offset += num_read;
197                        Ok((field.name(), OwnedColumn::BigInt(col)))
198                    }
199                    ColumnType::Int128 => {
200                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
201                        offset += num_read;
202                        Ok((field.name(), OwnedColumn::Int128(col)))
203                    }
204                    ColumnType::VarChar => {
205                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
206                        offset += num_read;
207                        Ok((field.name(), OwnedColumn::VarChar(col)))
208                    }
209                    ColumnType::VarBinary => {
210                        // Manually specify the item type: `&[u8]`
211                        let (decoded_slices, num_read) =
212                            decode_multiple_elements::<&[u8]>(&self.data[offset..], n)?;
213                        offset += num_read;
214
215                        // Convert those slices to owned `Vec<u8>`
216                        let col_vec = decoded_slices.into_iter().map(<[u8]>::to_vec).collect();
217
218                        Ok((field.name(), OwnedColumn::VarBinary(col_vec)))
219                    }
220                    ColumnType::Scalar => {
221                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
222                        offset += num_read;
223                        Ok((field.name(), OwnedColumn::Scalar(col)))
224                    }
225                    ColumnType::Decimal75(precision, scale) => {
226                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
227                        offset += num_read;
228                        Ok((field.name(), OwnedColumn::Decimal75(precision, scale, col)))
229                    }
230                    ColumnType::TimestampTZ(tu, tz) => {
231                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
232                        offset += num_read;
233                        Ok((field.name(), OwnedColumn::TimestampTZ(tu, tz, col)))
234                    }
235                })
236                .collect::<Result<_, QueryError>>()?,
237        )?;
238
239        assert_eq!(offset, self.data.len());
240        assert_eq!(owned_table.num_columns(), self.num_columns());
241
242        Ok(owned_table)
243    }
244}
245
246impl<S: Scalar> From<Table<'_, S>> for ProvableQueryResult {
247    fn from(table: Table<S>) -> Self {
248        let num_rows = table.num_rows();
249        let columns = table
250            .into_inner()
251            .into_iter()
252            .map(|(_, col)| col)
253            .collect::<Vec<_>>();
254        Self::new(num_rows as u64, &columns)
255    }
256}