proof_of_sql/sql/proof/
provable_query_result.rs

1use super::{decode_and_convert, decode_multiple_elements, ProvableResultColumn, QueryError};
2use crate::base::{
3    database::{Column, ColumnField, ColumnType, OwnedColumn, OwnedTable, Table},
4    polynomial::compute_evaluation_vector,
5    scalar::{Scalar, ScalarExt},
6};
7use alloc::{vec, vec::Vec};
8use num_traits::Zero;
9use serde::{Deserialize, Serialize};
10
11/// An intermediate form of a query result that can be transformed
12/// to either the finalized query result form or a query error
13#[derive(Debug, Default, Clone, Serialize, Deserialize)]
14pub struct ProvableQueryResult {
15    num_columns: u64,
16    pub(crate) table_length: u64,
17    data: Vec<u8>,
18}
19
20// TODO: Handle truncation properly. The `expect(clippy::cast_possible_truncation)` is a temporary fix and should be replaced with proper logic to manage possible truncation scenarios.
21impl ProvableQueryResult {
22    #[expect(clippy::cast_possible_truncation)]
23    /// The number of columns in the result
24    #[must_use]
25    pub fn num_columns(&self) -> usize {
26        self.num_columns as usize
27    }
28    /// A mutable reference to the number of columns in the result. Because the struct is deserialized from untrusted data, it
29    /// cannot maintain any invariant on its data members; hence, this function is available to allow for easy manipulation for testing.
30    #[cfg(test)]
31    pub fn num_columns_mut(&mut self) -> &mut u64 {
32        &mut self.num_columns
33    }
34
35    #[expect(clippy::cast_possible_truncation)]
36    /// The number of rows in the result
37    #[must_use]
38    pub fn table_length(&self) -> usize {
39        self.table_length as usize
40    }
41    /// A mutable reference to the underlying encoded data of the result. Because the struct is deserialized from untrusted data, it
42    /// cannot maintain any invariant on its data members; hence, this function is available to allow for easy manipulation for testing.
43    #[cfg(test)]
44    pub fn data_mut(&mut self) -> &mut Vec<u8> {
45        &mut self.data
46    }
47    /// This function is available to allow for easy creation for testing.
48    #[cfg(test)]
49    #[must_use]
50    pub fn new_from_raw_data(num_columns: u64, table_length: u64, data: Vec<u8>) -> Self {
51        Self {
52            num_columns,
53            table_length,
54            data,
55        }
56    }
57
58    /// Form intermediate query result from index rows and result columns
59    /// # Panics
60    ///
61    /// Will panic if `table_length` is somehow larger than the length of some column
62    /// which should never happen.
63    #[must_use]
64    pub fn new<'a, S: Scalar>(table_length: u64, columns: &'a [Column<'a, S>]) -> Self {
65        assert!(columns
66            .iter()
67            .all(|column| table_length == column.len() as u64));
68        let mut sz = 0;
69        for col in columns {
70            sz += col.num_bytes(table_length);
71        }
72        let mut data = vec![0u8; sz];
73        let mut sz = 0;
74        for col in columns {
75            sz += col.write(&mut data[sz..], table_length);
76        }
77        ProvableQueryResult {
78            num_columns: columns.len() as u64,
79            table_length,
80            data,
81        }
82    }
83
84    #[expect(clippy::cast_possible_truncation)]
85    /// Given an evaluation vector, compute the evaluation of the intermediate result
86    /// columns as spare multilinear extensions
87    ///
88    /// # Panics
89    /// This function will panic if the length of `evaluation_point` does not match `self.num_columns`.
90    /// It will also panic if the `data` array is not properly formatted for the expected column types.
91    pub fn evaluate<S: Scalar>(
92        &self,
93        evaluation_point: &[S],
94        output_length: usize,
95        column_result_fields: &[ColumnField],
96    ) -> Result<Vec<S>, QueryError> {
97        if self.num_columns as usize != column_result_fields.len() {
98            return Err(QueryError::InvalidColumnCount);
99        }
100        let mut evaluation_vec = vec![Zero::zero(); output_length];
101        compute_evaluation_vector(&mut evaluation_vec, evaluation_point);
102        let mut offset: usize = 0;
103        let mut res = Vec::with_capacity(self.num_columns as usize);
104
105        for field in column_result_fields {
106            let mut val = S::zero();
107            for entry in evaluation_vec.iter().take(output_length) {
108                let (x, sz) = match field.data_type() {
109                    ColumnType::Boolean => decode_and_convert::<bool, S>(&self.data[offset..]),
110                    ColumnType::Uint8 => decode_and_convert::<u8, S>(&self.data[offset..]),
111                    ColumnType::TinyInt => decode_and_convert::<i8, S>(&self.data[offset..]),
112                    ColumnType::SmallInt => decode_and_convert::<i16, S>(&self.data[offset..]),
113                    ColumnType::Int => decode_and_convert::<i32, S>(&self.data[offset..]),
114                    ColumnType::BigInt => decode_and_convert::<i64, S>(&self.data[offset..]),
115                    ColumnType::Int128 => decode_and_convert::<i128, S>(&self.data[offset..]),
116                    ColumnType::Decimal75(_, _) | ColumnType::Scalar => {
117                        decode_and_convert::<S, S>(&self.data[offset..])
118                    }
119
120                    ColumnType::VarChar => decode_and_convert::<&str, S>(&self.data[offset..]),
121                    ColumnType::VarBinary => {
122                        let (raw_bytes, used) =
123                            decode_and_convert::<&[u8], &[u8]>(&self.data[offset..])?;
124                        let x = S::from_byte_slice_via_hash(raw_bytes);
125                        Ok((x, used))
126                    }
127                    ColumnType::TimestampTZ(_, _) => {
128                        decode_and_convert::<i64, S>(&self.data[offset..])
129                    }
130                }?;
131                val += *entry * x;
132                offset += sz;
133            }
134            res.push(val);
135        }
136        if offset != self.data.len() {
137            return Err(QueryError::MiscellaneousEvaluationError);
138        }
139
140        Ok(res)
141    }
142
143    #[expect(
144        clippy::missing_panics_doc,
145        reason = "Assertions ensure preconditions are met, eliminating the possibility of panic."
146    )]
147    /// Convert the intermediate query result into a final query result
148    ///
149    /// The result is essentially an `OwnedTable` type.
150    pub fn to_owned_table<S: Scalar>(
151        &self,
152        column_result_fields: &[ColumnField],
153    ) -> Result<OwnedTable<S>, QueryError> {
154        if column_result_fields.len() != self.num_columns() {
155            return Err(QueryError::InvalidColumnCount);
156        }
157
158        let n = self.table_length();
159        let mut offset: usize = 0;
160
161        let owned_table = OwnedTable::try_new(
162            column_result_fields
163                .iter()
164                .map(|field| match field.data_type() {
165                    ColumnType::Boolean => {
166                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
167                        offset += num_read;
168                        Ok((field.name(), OwnedColumn::Boolean(col)))
169                    }
170                    ColumnType::Uint8 => {
171                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
172                        offset += num_read;
173                        Ok((field.name(), OwnedColumn::Uint8(col)))
174                    }
175                    ColumnType::TinyInt => {
176                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
177                        offset += num_read;
178                        Ok((field.name(), OwnedColumn::TinyInt(col)))
179                    }
180                    ColumnType::SmallInt => {
181                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
182                        offset += num_read;
183                        Ok((field.name(), OwnedColumn::SmallInt(col)))
184                    }
185                    ColumnType::Int => {
186                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
187                        offset += num_read;
188                        Ok((field.name(), OwnedColumn::Int(col)))
189                    }
190                    ColumnType::BigInt => {
191                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
192                        offset += num_read;
193                        Ok((field.name(), OwnedColumn::BigInt(col)))
194                    }
195                    ColumnType::Int128 => {
196                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
197                        offset += num_read;
198                        Ok((field.name(), OwnedColumn::Int128(col)))
199                    }
200                    ColumnType::VarChar => {
201                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
202                        offset += num_read;
203                        Ok((field.name(), OwnedColumn::VarChar(col)))
204                    }
205                    ColumnType::VarBinary => {
206                        // Manually specify the item type: `&[u8]`
207                        let (decoded_slices, num_read) =
208                            decode_multiple_elements::<&[u8]>(&self.data[offset..], n)?;
209                        offset += num_read;
210
211                        // Convert those slices to owned `Vec<u8>`
212                        let col_vec = decoded_slices.into_iter().map(<[u8]>::to_vec).collect();
213
214                        Ok((field.name(), OwnedColumn::VarBinary(col_vec)))
215                    }
216                    ColumnType::Scalar => {
217                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
218                        offset += num_read;
219                        Ok((field.name(), OwnedColumn::Scalar(col)))
220                    }
221                    ColumnType::Decimal75(precision, scale) => {
222                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
223                        offset += num_read;
224                        Ok((field.name(), OwnedColumn::Decimal75(precision, scale, col)))
225                    }
226                    ColumnType::TimestampTZ(tu, tz) => {
227                        let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
228                        offset += num_read;
229                        Ok((field.name(), OwnedColumn::TimestampTZ(tu, tz, col)))
230                    }
231                })
232                .collect::<Result<_, QueryError>>()?,
233        )?;
234
235        assert_eq!(offset, self.data.len());
236        assert_eq!(owned_table.num_columns(), self.num_columns());
237
238        Ok(owned_table)
239    }
240}
241
242impl<S: Scalar> From<Table<'_, S>> for ProvableQueryResult {
243    fn from(table: Table<S>) -> Self {
244        let num_rows = table.num_rows();
245        let columns = table
246            .into_inner()
247            .into_iter()
248            .map(|(_, col)| col)
249            .collect::<Vec<_>>();
250        Self::new(num_rows as u64, &columns)
251    }
252}